Example #1
0
    def wait_for_previous_lookups(self, server_names, server_to_deferred):
        """Waits for any previous key lookups for the given servers to finish.

        Args:
            server_names (list): list of server_names we want to lookup
            server_to_deferred (dict): server_name to deferred which gets
                resolved once we've finished looking up keys for that server
        """
        while True:
            wait_on = [
                self.key_downloads[server_name]
                for server_name in server_names
                if server_name in self.key_downloads
            ]
            if wait_on:
                with PreserveLoggingContext():
                    yield defer.DeferredList(wait_on)
            else:
                break

        for server_name, deferred in server_to_deferred.items():
            d = ObservableDeferred(preserve_context_over_deferred(deferred))
            self.key_downloads[server_name] = d

            def rm(r, server_name):
                self.key_downloads.pop(server_name, None)
                return r

            d.addBoth(rm, server_name)
Example #2
0
    def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
        """Fetches the response for this transaction, or executes the given function
        to produce a response for this transaction.

        Args:
            txn_key (str): A key to ensure idempotency should fetch_or_execute be
            called again at a later point in time.
            fn (function): A function which returns a tuple of
            (response_code, response_dict).
            *args: Arguments to pass to fn.
            **kwargs: Keyword arguments to pass to fn.
        Returns:
            Deferred which resolves to a tuple of (response_code, response_dict).
        """
        if txn_key in self.transactions:
            observable = self.transactions[txn_key][0]
        else:
            # execute the function instead.
            deferred = run_in_background(fn, *args, **kwargs)

            observable = ObservableDeferred(deferred)
            self.transactions[txn_key] = (observable, self.clock.time_msec())

            # if the request fails with an exception, remove it
            # from the transaction map. This is done to ensure that we don't
            # cache transient errors like rate-limiting errors, etc.
            def remove_from_map(err):
                self.transactions.pop(txn_key, None)
                # we deliberately do not propagate the error any further, as we
                # expect the observers to have reported it.

            deferred.addErrback(remove_from_map)

        return make_deferred_yieldable(observable.observe())
    def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
        """Fetches the response for this transaction, or executes the given function
        to produce a response for this transaction.

        Args:
            txn_key (str): A key to ensure idempotency should fetch_or_execute be
            called again at a later point in time.
            fn (function): A function which returns a tuple of
            (response_code, response_dict).
            *args: Arguments to pass to fn.
            **kwargs: Keyword arguments to pass to fn.
        Returns:
            Deferred which resolves to a tuple of (response_code, response_dict).
        """
        try:
            return self.transactions[txn_key][0].observe()
        except (KeyError, IndexError):
            pass  # execute the function instead.

        deferred = fn(*args, **kwargs)

        # We don't add an errback to the raw deferred, so we ask ObservableDeferred
        # to swallow the error. This is fine as the error will still be reported
        # to the observers.
        observable = ObservableDeferred(deferred, consumeErrors=True)
        self.transactions[txn_key] = (observable, self.clock.time_msec())
        return observable.observe()
Example #4
0
    def get_server_verify_key(self, server_name, key_ids):
        """Finds a verification key for the server with one of the key ids.
        Trys to fetch the key from a trusted perspective server first.
        Args:
            server_name(str): The name of the server to fetch a key for.
            keys_ids (list of str): The key_ids to check for.
        """
        cached = yield self.store.get_server_verify_keys(server_name, key_ids)

        if cached:
            defer.returnValue(cached[0])
            return

        download = self.key_downloads.get(server_name)

        if download is None:
            download = self._get_server_verify_key_impl(server_name, key_ids)
            download = ObservableDeferred(
                download,
                consumeErrors=True
            )
            self.key_downloads[server_name] = download

            @download.addBoth
            def callback(ret):
                del self.key_downloads[server_name]
                return ret

        r = yield download.observe()
        defer.returnValue(r)
Example #5
0
    def set(self, key, deferred):
        """Set the entry for the given key to the given deferred.

        *deferred* should run its callbacks in the sentinel logcontext (ie,
        you should wrap normal synapse deferreds with
        logcontext.run_in_background).

        Can return either a new Deferred (which also doesn't follow the synapse
        logcontext rules), or, if *deferred* was already complete, the actual
        result. You will probably want to make_deferred_yieldable the result.

        Args:
            key (hashable):
            deferred (twisted.internet.defer.Deferred[T):

        Returns:
            twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual
                result.
        """
        result = ObservableDeferred(deferred, consumeErrors=True)
        self.pending_result_cache[key] = result

        def remove(r):
            if self.timeout_sec:
                self.clock.call_later(
                    self.timeout_sec,
                    self.pending_result_cache.pop, key, None,
                )
            else:
                self.pending_result_cache.pop(key, None)
            return r

        result.addBoth(remove)
        return result.observe()
Example #6
0
    def get_server_verify_key(self, server_name, key_ids):
        """Finds a verification key for the server with one of the key ids.
        Trys to fetch the key from a trusted perspective server first.
        Args:
            server_name(str): The name of the server to fetch a key for.
            keys_ids (list of str): The key_ids to check for.
        """
        cached = yield self.store.get_server_verify_keys(server_name, key_ids)

        if cached:
            defer.returnValue(cached[0])
            return

        download = self.key_downloads.get(server_name)

        if download is None:
            download = self._get_server_verify_key_impl(server_name, key_ids)
            download = ObservableDeferred(download, consumeErrors=True)
            self.key_downloads[server_name] = download

            @download.addBoth
            def callback(ret):
                del self.key_downloads[server_name]
                return ret

        r = yield download.observe()
        defer.returnValue(r)
Example #7
0
    def wait_for_previous_lookups(self, server_names, server_to_deferred):
        """Waits for any previous key lookups for the given servers to finish.

        Args:
            server_names (list): list of server_names we want to lookup
            server_to_deferred (dict): server_name to deferred which gets
                resolved once we've finished looking up keys for that server
        """
        while True:
            wait_on = [
                self.key_downloads[server_name] for server_name in server_names
                if server_name in self.key_downloads
            ]
            if wait_on:
                yield defer.DeferredList(wait_on)
            else:
                break

        for server_name, deferred in server_to_deferred.items():
            d = ObservableDeferred(deferred)
            self.key_downloads[server_name] = d

            def rm(r, server_name):
                self.key_downloads.pop(server_name, None)
                return r

            d.addBoth(rm, server_name)
Example #8
0
    def add_to_queue(self, room_id, events_and_contexts, backfilled,
                     current_state):
        """Add events to the queue, with the given persist_event options.
        """
        queue = self._event_persist_queues.setdefault(room_id, deque())
        if queue:
            end_item = queue[-1]
            if end_item.current_state or current_state:
                # We perist events with current_state set to True one at a time
                pass
            if end_item.backfilled == backfilled:
                end_item.events_and_contexts.extend(events_and_contexts)
                return end_item.deferred.observe()

        deferred = ObservableDeferred(defer.Deferred())

        queue.append(
            self._EventPersistQueueItem(
                events_and_contexts=events_and_contexts,
                backfilled=backfilled,
                current_state=current_state,
                deferred=deferred,
            ))

        return deferred.observe()
Example #9
0
    def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
        """Fetches the response for this transaction, or executes the given function
        to produce a response for this transaction.

        Args:
            txn_key (str): A key to ensure idempotency should fetch_or_execute be
            called again at a later point in time.
            fn (function): A function which returns a tuple of
            (response_code, response_dict).
            *args: Arguments to pass to fn.
            **kwargs: Keyword arguments to pass to fn.
        Returns:
            Deferred which resolves to a tuple of (response_code, response_dict).
        """
        try:
            return self.transactions[txn_key][0].observe()
        except (KeyError, IndexError):
            pass  # execute the function instead.

        deferred = fn(*args, **kwargs)

        # if the request fails with a Twisted failure, remove it
        # from the transaction map. This is done to ensure that we don't
        # cache transient errors like rate-limiting errors, etc.
        def remove_from_map(err):
            self.transactions.pop(txn_key, None)
            return err
        deferred.addErrback(remove_from_map)

        # We don't add any other errbacks to the raw deferred, so we ask
        # ObservableDeferred to swallow the error. This is fine as the error will
        # still be reported to the observers.
        observable = ObservableDeferred(deferred, consumeErrors=True)
        self.transactions[txn_key] = (observable, self.clock.time_msec())
        return observable.observe()
Example #10
0
class _NotifierUserStream(object):
    """This represents a user connected to the event stream.
    It tracks the most recent stream token for that user.
    At a given point a user may have a number of streams listening for
    events.

    This listener will also keep track of which rooms it is listening in
    so that it can remove itself from the indexes in the Notifier class.
    """

    def __init__(self, user_id, rooms, current_token, time_now_ms):
        self.user_id = user_id
        self.rooms = set(rooms)
        self.current_token = current_token
        self.last_notified_ms = time_now_ms

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())

    def notify(self, stream_key, stream_id, time_now_ms):
        """Notify any listeners for this user of a new event from an
        event source.
        Args:
            stream_key(str): The stream the event came from.
            stream_id(str): The new id for the stream the event came from.
            time_now_ms(int): The current time in milliseconds.
        """
        self.current_token = self.current_token.copy_and_advance(
            stream_key, stream_id
        )
        self.last_notified_ms = time_now_ms
        noify_deferred = self.notify_deferred

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
            noify_deferred.callback(self.current_token)

    def remove(self, notifier):
        """ Remove this listener from all the indexes in the Notifier
        it knows about.
        """

        for room in self.rooms:
            lst = notifier.room_to_user_streams.get(room, set())
            lst.discard(self)

        notifier.user_to_user_stream.pop(self.user_id)

    def count_listeners(self):
        return len(self.notify_deferred.observers())

    def new_listener(self, token):
        """Returns a deferred that is resolved when there is a new token
        greater than the given token.
        """
        if self.current_token.is_after(token):
            return _NotificationListener(defer.succeed(self.current_token))
        else:
            return _NotificationListener(self.notify_deferred.observe())
Example #11
0
class _NotifierUserStream(object):
    """This represents a user connected to the event stream.
    It tracks the most recent stream token for that user.
    At a given point a user may have a number of streams listening for
    events.

    This listener will also keep track of which rooms it is listening in
    so that it can remove itself from the indexes in the Notifier class.
    """

    def __init__(self, user_id, rooms, current_token, time_now_ms):
        self.user_id = user_id
        self.rooms = set(rooms)
        self.current_token = current_token
        self.last_notified_ms = time_now_ms

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())

    def notify(self, stream_key, stream_id, time_now_ms):
        """Notify any listeners for this user of a new event from an
        event source.
        Args:
            stream_key(str): The stream the event came from.
            stream_id(str): The new id for the stream the event came from.
            time_now_ms(int): The current time in milliseconds.
        """
        self.current_token = self.current_token.copy_and_advance(
            stream_key, stream_id
        )
        self.last_notified_ms = time_now_ms
        noify_deferred = self.notify_deferred

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
            noify_deferred.callback(self.current_token)

    def remove(self, notifier):
        """ Remove this listener from all the indexes in the Notifier
        it knows about.
        """

        for room in self.rooms:
            lst = notifier.room_to_user_streams.get(room, set())
            lst.discard(self)

        notifier.user_to_user_stream.pop(self.user_id)

    def count_listeners(self):
        return len(self.notify_deferred.observers())

    def new_listener(self, token):
        """Returns a deferred that is resolved when there is a new token
        greater than the given token.
        """
        if self.current_token.is_after(token):
            return _NotificationListener(defer.succeed(self.current_token))
        else:
            return _NotificationListener(self.notify_deferred.observe())
Example #12
0
    def __init__(self, user_id, rooms, current_token, time_now_ms):
        self.user_id = user_id
        self.rooms = set(rooms)
        self.current_token = current_token
        self.last_notified_ms = time_now_ms

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
Example #13
0
        def wrapped(*args, **kwargs):
            # If we're passed a cache_context then we'll want to call its invalidate()
            # whenever we are invalidated
            invalidate_callback = kwargs.pop("on_invalidate", None)

            # Add temp cache_context so inspect.getcallargs doesn't explode
            if self.add_cache_context:
                kwargs["cache_context"] = None

            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)

            # Add our own `cache_context` to argument list if the wrapped function
            # has asked for one
            if self.add_cache_context:
                kwargs["cache_context"] = _CacheContext(cache, cache_key)

            try:
                cached_result_d = cache.get(cache_key, callback=invalidate_callback)

                observer = cached_result_d.observe()
                if DEBUG_CACHES:
                    @defer.inlineCallbacks
                    def check_result(cached_result):
                        actual_result = yield self.function_to_call(obj, *args, **kwargs)
                        if actual_result != cached_result:
                            logger.error(
                                "Stale cache entry %s%r: cached: %r, actual %r",
                                self.orig.__name__, cache_key,
                                cached_result, actual_result,
                            )
                            raise ValueError("Stale cache entry")
                        defer.returnValue(cached_result)
                    observer.addCallback(check_result)

                return preserve_context_over_deferred(observer)
            except KeyError:
                # Get the sequence number of the cache before reading from the
                # database so that we can tell if the cache is invalidated
                # while the SELECT is executing (SYN-369)
                sequence = cache.sequence

                ret = defer.maybeDeferred(
                    preserve_context_over_fn,
                    self.function_to_call,
                    obj, *args, **kwargs
                )

                def onErr(f):
                    cache.invalidate(cache_key)
                    return f

                ret.addErrback(onErr)

                ret = ObservableDeferred(ret, consumeErrors=True)
                cache.update(sequence, cache_key, ret, callback=invalidate_callback)

                return preserve_context_over_deferred(ret.observe())
    def notify_replication(self):
        """Notify the any replication listeners that there's a new event"""
        with PreserveLoggingContext():
            deferred = self.replication_deferred
            self.replication_deferred = ObservableDeferred(defer.Deferred())
            deferred.callback(None)

        for cb in self.replication_callbacks:
            preserve_fn(cb)()
Example #15
0
    def set(self, key, deferred):
        result = ObservableDeferred(deferred, consumeErrors=True)
        self.pending_result_cache[key] = result

        def remove(r):
            self.pending_result_cache.pop(key, None)
            return r

        result.addBoth(remove)
        return result.observe()
Example #16
0
    def set(self, key, deferred):
        result = ObservableDeferred(deferred, consumeErrors=True)
        self.pending_result_cache[key] = result

        def remove(r):
            self.pending_result_cache.pop(key, None)
            return r

        result.addBoth(remove)
        return result.observe()
Example #17
0
    def __init__(self, hs):
        self.user_to_user_stream = {}
        self.room_to_user_streams = {}

        self.hs = hs
        self.event_sources = hs.get_event_sources()
        self.store = hs.get_datastore()
        self.pending_new_room_events = []

        self.replication_callbacks = []

        self.clock = hs.get_clock()
        self.appservice_handler = hs.get_application_service_handler()

        if hs.should_send_federation():
            self.federation_sender = hs.get_federation_sender()
        else:
            self.federation_sender = None

        self.state_handler = hs.get_state_handler()

        self.clock.looping_call(self.remove_expired_streams,
                                self.UNUSED_STREAM_EXPIRY_MS)

        self.replication_deferred = ObservableDeferred(defer.Deferred())

        # This is not a very cheap test to perform, but it's only executed
        # when rendering the metrics page, which is likely once per minute at
        # most when scraping it.
        def count_listeners():
            all_user_streams = set()

            for x in self.room_to_user_streams.values():
                all_user_streams |= x
            for x in self.user_to_user_stream.values():
                all_user_streams.add(x)

            return sum(stream.count_listeners() for stream in all_user_streams)

        LaterGauge("synapse_notifier_listeners", "", [], count_listeners)

        LaterGauge(
            "synapse_notifier_rooms",
            "",
            [],
            lambda: count(bool, self.room_to_user_streams.values()),
        )
        LaterGauge(
            "synapse_notifier_users",
            "",
            [],
            lambda: len(self.user_to_user_stream),
        )
Example #18
0
    def __init__(self,
                 user,
                 rooms,
                 current_token,
                 time_now_ms,
                 appservice=None):
        self.user = str(user)
        self.appservice = appservice
        self.rooms = set(rooms)
        self.current_token = current_token
        self.last_notified_ms = time_now_ms

        self.notify_deferred = ObservableDeferred(defer.Deferred())
Example #19
0
        def wrapped(*args, **kwargs):
            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
            list_args = arg_dict[self.list_name]

            # cached is a dict arg -> deferred, where deferred results in a
            # 2-tuple (`arg`, `result`)
            cached = {}
            missing = []
            for arg in list_args:
                key = list(keyargs)
                key[self.list_pos] = arg

                try:
                    res = self.cache.get(tuple(key)).observe()
                    res.addCallback(lambda r, arg: (arg, r), arg)
                    cached[arg] = res
                except KeyError:
                    missing.append(arg)

            if missing:
                sequence = self.cache.sequence
                args_to_call = dict(arg_dict)
                args_to_call[self.list_name] = missing

                ret_d = defer.maybeDeferred(
                    preserve_context_over_fn,
                    self.function_to_call,
                    **args_to_call
                )

                ret_d = ObservableDeferred(ret_d)

                # We need to create deferreds for each arg in the list so that
                # we can insert the new deferred into the cache.
                for arg in missing:
                    with PreserveLoggingContext():
                        observer = ret_d.observe()
                    observer.addCallback(lambda r, arg: r.get(arg, None), arg)

                    observer = ObservableDeferred(observer)

                    key = list(keyargs)
                    key[self.list_pos] = arg
                    self.cache.update(sequence, tuple(key), observer)

                    def invalidate(f, key):
                        self.cache.invalidate(key)
                        return f
                    observer.addErrback(invalidate, tuple(key))

                    res = observer.observe()
                    res.addCallback(lambda r, arg: (arg, r), arg)

                    cached[arg] = res

            return preserve_context_over_deferred(defer.gatherResults(
                cached.values(),
                consumeErrors=True,
            ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
Example #20
0
 def notify(self, stream_key, stream_id, time_now_ms):
     """Notify any listeners for this user of a new event from an
     event source.
     Args:
         stream_key(str): The stream the event came from.
         stream_id(str): The new id for the stream the event came from.
         time_now_ms(int): The current time in milliseconds.
     """
     self.current_token = self.current_token.copy_and_advance(
         stream_key, stream_id)
     self.last_notified_ms = time_now_ms
     noify_deferred = self.notify_deferred
     self.notify_deferred = ObservableDeferred(defer.Deferred())
     noify_deferred.callback(self.current_token)
Example #21
0
    def __init__(self, hs):
        self.hs = hs

        self.user_to_user_stream = {}
        self.room_to_user_streams = {}
        self.appservice_to_user_streams = {}

        self.event_sources = hs.get_event_sources()
        self.store = hs.get_datastore()
        self.pending_new_room_events = []

        self.clock = hs.get_clock()

        hs.get_distributor().observe("user_joined_room",
                                     self._user_joined_room)

        self.clock.looping_call(self.remove_expired_streams,
                                self.UNUSED_STREAM_EXPIRY_MS)

        self.replication_deferred = ObservableDeferred(defer.Deferred())

        # This is not a very cheap test to perform, but it's only executed
        # when rendering the metrics page, which is likely once per minute at
        # most when scraping it.
        def count_listeners():
            all_user_streams = set()

            for x in self.room_to_user_streams.values():
                all_user_streams |= x
            for x in self.user_to_user_stream.values():
                all_user_streams.add(x)
            for x in self.appservice_to_user_streams.values():
                all_user_streams |= x

            return sum(stream.count_listeners() for stream in all_user_streams)

        metrics.register_callback("listeners", count_listeners)

        metrics.register_callback(
            "rooms",
            lambda: count(bool, self.room_to_user_streams.values()),
        )
        metrics.register_callback(
            "users",
            lambda: len(self.user_to_user_stream),
        )
        metrics.register_callback(
            "appservices",
            lambda: count(bool, self.appservice_to_user_streams.values()),
        )
Example #22
0
    def get_remote_media(self, server_name, media_id):
        key = (server_name, media_id)
        download = self.downloads.get(key)
        if download is None:
            download = self._get_remote_media_impl(server_name, media_id)
            download = ObservableDeferred(download, consumeErrors=True)
            self.downloads[key] = download

            @download.addBoth
            def callback(media_info):
                del self.downloads[key]
                return media_info

        return download.observe()
Example #23
0
    def get_remote_media(self, server_name, media_id):
        key = (server_name, media_id)
        download = self.downloads.get(key)
        if download is None:
            download = self._get_remote_media_impl(server_name, media_id)
            download = ObservableDeferred(download, consumeErrors=True)
            self.downloads[key] = download

            @download.addBoth
            def callback(media_info):
                del self.downloads[key]
                return media_info

        return download.observe()
Example #24
0
    def notify_replication(self):
        """Notify the any replication listeners that there's a new event"""
        with PreserveLoggingContext():
            deferred = self.replication_deferred
            self.replication_deferred = ObservableDeferred(defer.Deferred())
            deferred.callback(None)

            # the callbacks may well outlast the current request, so we run
            # them in the sentinel logcontext.
            #
            # (ideally it would be up to the callbacks to know if they were
            # starting off background processes and drop the logcontext
            # accordingly, but that requires more changes)
            for cb in self.replication_callbacks:
                cb()
    def __init__(self, user_id, rooms, current_token, time_now_ms):
        self.user_id = user_id
        self.rooms = set(rooms)
        self.current_token = current_token

        # The last token for which we should wake up any streams that have a
        # token that comes before it. This gets updated everytime we get poked.
        # We start it at the current token since if we get any streams
        # that have a token from before we have no idea whether they should be
        # woken up or not, so lets just wake them up.
        self.last_notified_token = current_token
        self.last_notified_ms = time_now_ms

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
Example #26
0
    def set(self, key, deferred):
        result = ObservableDeferred(deferred, consumeErrors=True)
        self.pending_result_cache[key] = result

        def remove(r):
            if self.timeout_sec:
                self.clock.call_later(
                    self.timeout_sec,
                    self.pending_result_cache.pop, key, None,
                )
            else:
                self.pending_result_cache.pop(key, None)
            return r

        result.addBoth(remove)
        return result.observe()
Example #27
0
        def wrapped(*args, **kwargs):
            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
            try:
                cached_result_d = cache.get(cache_key)

                observer = cached_result_d.observe()
                if DEBUG_CACHES:

                    @defer.inlineCallbacks
                    def check_result(cached_result):
                        actual_result = yield self.function_to_call(
                            obj, *args, **kwargs)
                        if actual_result != cached_result:
                            logger.error(
                                "Stale cache entry %s%r: cached: %r, actual %r",
                                self.orig.__name__,
                                cache_key,
                                cached_result,
                                actual_result,
                            )
                            raise ValueError("Stale cache entry")
                        defer.returnValue(cached_result)

                    observer.addCallback(check_result)

                return preserve_context_over_deferred(observer)
            except KeyError:
                # Get the sequence number of the cache before reading from the
                # database so that we can tell if the cache is invalidated
                # while the SELECT is executing (SYN-369)
                sequence = cache.sequence

                ret = defer.maybeDeferred(preserve_context_over_fn,
                                          self.function_to_call, obj, *args,
                                          **kwargs)

                def onErr(f):
                    cache.invalidate(cache_key)
                    return f

                ret.addErrback(onErr)

                ret = ObservableDeferred(ret, consumeErrors=True)
                cache.update(sequence, cache_key, ret)

                return preserve_context_over_deferred(ret.observe())
Example #28
0
        def wrapped(*args, **kwargs):
            # If we're passed a cache_context then we'll want to call its invalidate()
            # whenever we are invalidated
            invalidate_callback = kwargs.pop("on_invalidate", None)

            cache_key = get_cache_key(args, kwargs)

            # Add our own `cache_context` to argument list if the wrapped function
            # has asked for one
            if self.add_cache_context:
                kwargs["cache_context"] = _CacheContext(cache, cache_key)

            try:
                cached_result_d = cache.get(cache_key,
                                            callback=invalidate_callback)

                if isinstance(cached_result_d, ObservableDeferred):
                    observer = cached_result_d.observe()
                else:
                    observer = cached_result_d

            except KeyError:
                ret = defer.maybeDeferred(
                    logcontext.preserve_fn(self.function_to_call), obj, *args,
                    **kwargs)

                def onErr(f):
                    cache.invalidate(cache_key)
                    return f

                ret.addErrback(onErr)

                # If our cache_key is a string on py2, try to convert to ascii
                # to save a bit of space in large caches. Py3 does this
                # internally automatically.
                if six.PY2 and isinstance(cache_key, string_types):
                    cache_key = to_ascii(cache_key)

                result_d = ObservableDeferred(ret, consumeErrors=True)
                cache.set(cache_key, result_d, callback=invalidate_callback)
                observer = result_d.observe()

            if isinstance(observer, defer.Deferred):
                return logcontext.make_deferred_yieldable(observer)
            else:
                return observer
Example #29
0
    def __init__(self, user_id, rooms, current_token, time_now_ms):
        self.user_id = user_id
        self.rooms = set(rooms)
        self.current_token = current_token
        self.last_notified_ms = time_now_ms

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
Example #30
0
    def set(self, time_now_ms, key, deferred):
        self.rotate(time_now_ms)

        result = ObservableDeferred(deferred)

        self.pending_result_cache[key] = result

        def shuffle_along(r):
            # When the deferred completes we shuffle it along to the first
            # generation of the result cache. So that it will eventually
            # expire from the rotation of that cache.
            self.next_result_cache[key] = result
            self.pending_result_cache.pop(key, None)

        result.observe().addBoth(shuffle_along)

        return result.observe()
Example #31
0
    def set(self, time_now_ms, key, deferred):
        self.rotate(time_now_ms)

        result = ObservableDeferred(deferred)

        self.pending_result_cache[key] = result

        def shuffle_along(r):
            # When the deferred completes we shuffle it along to the first
            # generation of the result cache. So that it will eventually
            # expire from the rotation of that cache.
            self.next_result_cache[key] = result
            self.pending_result_cache.pop(key, None)

        result.observe().addBoth(shuffle_along)

        return result.observe()
Example #32
0
    def set(self, key, deferred):
        result = ObservableDeferred(deferred, consumeErrors=True)
        self.pending_result_cache[key] = result

        def remove(r):
            if self.timeout_sec:
                self.clock.call_later(
                    self.timeout_sec,
                    self.pending_result_cache.pop,
                    key,
                    None,
                )
            else:
                self.pending_result_cache.pop(key, None)
            return r

        result.addBoth(remove)
        return result.observe()
Example #33
0
    def __init__(self, user, rooms, current_token, time_now_ms,
                 appservice=None):
        self.user = str(user)
        self.appservice = appservice
        self.rooms = set(rooms)
        self.current_token = current_token
        self.last_notified_ms = time_now_ms

        self.notify_deferred = ObservableDeferred(defer.Deferred())
Example #34
0
        def wrapped(*args, **kwargs):
            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
            try:
                cached_result_d = cache.get(cache_key)

                observer = cached_result_d.observe()
                if DEBUG_CACHES:
                    @defer.inlineCallbacks
                    def check_result(cached_result):
                        actual_result = yield self.function_to_call(obj, *args, **kwargs)
                        if actual_result != cached_result:
                            logger.error(
                                "Stale cache entry %s%r: cached: %r, actual %r",
                                self.orig.__name__, cache_key,
                                cached_result, actual_result,
                            )
                            raise ValueError("Stale cache entry")
                        defer.returnValue(cached_result)
                    observer.addCallback(check_result)

                return preserve_context_over_deferred(observer)
            except KeyError:
                # Get the sequence number of the cache before reading from the
                # database so that we can tell if the cache is invalidated
                # while the SELECT is executing (SYN-369)
                sequence = cache.sequence

                ret = defer.maybeDeferred(
                    preserve_context_over_fn,
                    self.function_to_call,
                    obj, *args, **kwargs
                )

                def onErr(f):
                    cache.invalidate(cache_key)
                    return f

                ret.addErrback(onErr)

                ret = ObservableDeferred(ret, consumeErrors=True)
                cache.update(sequence, cache_key, ret)

                return preserve_context_over_deferred(ret.observe())
Example #35
0
    def notify(self, stream_key, stream_id, time_now_ms):
        """Notify any listeners for this user of a new event from an
        event source.
        Args:
            stream_key(str): The stream the event came from.
            stream_id(str): The new id for the stream the event came from.
            time_now_ms(int): The current time in milliseconds.
        """
        self.current_token = self.current_token.copy_and_advance(
            stream_key, stream_id)
        self.last_notified_token = self.current_token
        self.last_notified_ms = time_now_ms
        noify_deferred = self.notify_deferred

        users_woken_by_stream_counter.labels(stream_key).inc()

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
            noify_deferred.callback(self.current_token)
Example #36
0
        def wrapped(*args, **kwargs):
            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
            list_args = arg_dict[self.list_name]

            # cached is a dict arg -> deferred, where deferred results in a
            # 2-tuple (`arg`, `result`)
            cached = {}
            missing = []
            for arg in list_args:
                key = list(keyargs)
                key[self.list_pos] = arg

                try:
                    res = self.cache.get(tuple(key)).observe()
                    res.addCallback(lambda r, arg: (arg, r), arg)
                    cached[arg] = res
                except KeyError:
                    missing.append(arg)

            if missing:
                sequence = self.cache.sequence
                args_to_call = dict(arg_dict)
                args_to_call[self.list_name] = missing

                ret_d = defer.maybeDeferred(
                    preserve_context_over_fn,
                    self.function_to_call,
                    **args_to_call
                )

                ret_d = ObservableDeferred(ret_d)

                # We need to create deferreds for each arg in the list so that
                # we can insert the new deferred into the cache.
                for arg in missing:
                    with PreserveLoggingContext():
                        observer = ret_d.observe()
                    observer.addCallback(lambda r, arg: r.get(arg, None), arg)

                    observer = ObservableDeferred(observer)

                    key = list(keyargs)
                    key[self.list_pos] = arg
                    self.cache.update(sequence, tuple(key), observer)

                    def invalidate(f, key):
                        self.cache.invalidate(key)
                        return f
                    observer.addErrback(invalidate, tuple(key))

                    res = observer.observe()
                    res.addCallback(lambda r, arg: (arg, r), arg)

                    cached[arg] = res

            return preserve_context_over_deferred(defer.gatherResults(
                cached.values(),
                consumeErrors=True,
            ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
Example #37
0
    def __init__(self, hs):
        self.hs = hs

        self.user_to_user_stream = {}
        self.room_to_user_streams = {}
        self.appservice_to_user_streams = {}

        self.event_sources = hs.get_event_sources()
        self.store = hs.get_datastore()
        self.pending_new_room_events = []

        self.clock = hs.get_clock()

        hs.get_distributor().observe(
            "user_joined_room", self._user_joined_room
        )

        self.clock.looping_call(
            self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
        )

        self.replication_deferred = ObservableDeferred(defer.Deferred())

        # This is not a very cheap test to perform, but it's only executed
        # when rendering the metrics page, which is likely once per minute at
        # most when scraping it.
        def count_listeners():
            all_user_streams = set()

            for x in self.room_to_user_streams.values():
                all_user_streams |= x
            for x in self.user_to_user_stream.values():
                all_user_streams.add(x)
            for x in self.appservice_to_user_streams.values():
                all_user_streams |= x

            return sum(stream.count_listeners() for stream in all_user_streams)
        metrics.register_callback("listeners", count_listeners)

        metrics.register_callback(
            "rooms",
            lambda: count(bool, self.room_to_user_streams.values()),
        )
        metrics.register_callback(
            "users",
            lambda: len(self.user_to_user_stream),
        )
        metrics.register_callback(
            "appservices",
            lambda: count(bool, self.appservice_to_user_streams.values()),
        )
Example #38
0
    def add_to_queue(self, room_id, events_and_contexts, backfilled, current_state):
        """Add events to the queue, with the given persist_event options.
        """
        queue = self._event_persist_queues.setdefault(room_id, deque())
        if queue:
            end_item = queue[-1]
            if end_item.current_state or current_state:
                # We perist events with current_state set to True one at a time
                pass
            if end_item.backfilled == backfilled:
                end_item.events_and_contexts.extend(events_and_contexts)
                return end_item.deferred.observe()

        deferred = ObservableDeferred(defer.Deferred())

        queue.append(self._EventPersistQueueItem(
            events_and_contexts=events_and_contexts,
            backfilled=backfilled,
            current_state=current_state,
            deferred=deferred,
        ))

        return deferred.observe()
Example #39
0
    def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
        """Fetches the response for this transaction, or executes the given function
        to produce a response for this transaction.

        Args:
            txn_key (str): A key to ensure idempotency should fetch_or_execute be
            called again at a later point in time.
            fn (function): A function which returns a tuple of
            (response_code, response_dict).
            *args: Arguments to pass to fn.
            **kwargs: Keyword arguments to pass to fn.
        Returns:
            Deferred which resolves to a tuple of (response_code, response_dict).
        """
        try:
            return self.transactions[txn_key][0].observe()
        except (KeyError, IndexError):
            pass  # execute the function instead.

        deferred = fn(*args, **kwargs)
        observable = ObservableDeferred(deferred)
        self.transactions[txn_key] = (observable, self.clock.time_msec())
        return observable.observe()
    def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
        """Fetches the response for this transaction, or executes the given function
        to produce a response for this transaction.

        Args:
            txn_key (str): A key to ensure idempotency should fetch_or_execute be
            called again at a later point in time.
            fn (function): A function which returns a tuple of
            (response_code, response_dict).
            *args: Arguments to pass to fn.
            **kwargs: Keyword arguments to pass to fn.
        Returns:
            Deferred which resolves to a tuple of (response_code, response_dict).
        """
        try:
            return self.transactions[txn_key][0].observe()
        except (KeyError, IndexError):
            pass  # execute the function instead.

        deferred = fn(*args, **kwargs)

        # if the request fails with a Twisted failure, remove it
        # from the transaction map. This is done to ensure that we don't
        # cache transient errors like rate-limiting errors, etc.
        def remove_from_map(err):
            self.transactions.pop(txn_key, None)
            return err

        deferred.addErrback(remove_from_map)

        # We don't add any other errbacks to the raw deferred, so we ask
        # ObservableDeferred to swallow the error. This is fine as the error will
        # still be reported to the observers.
        observable = ObservableDeferred(deferred, consumeErrors=True)
        self.transactions[txn_key] = (observable, self.clock.time_msec())
        return observable.observe()
Example #41
0
    def __init__(self, user_id, rooms, current_token, time_now_ms):
        self.user_id = user_id
        self.rooms = set(rooms)
        self.current_token = current_token

        # The last token for which we should wake up any streams that have a
        # token that comes before it. This gets updated everytime we get poked.
        # We start it at the current token since if we get any streams
        # that have a token from before we have no idea whether they should be
        # woken up or not, so lets just wake them up.
        self.last_notified_token = current_token
        self.last_notified_ms = time_now_ms

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
Example #42
0
    def notify_replication(self):
        """Notify the any replication listeners that there's a new event"""
        with PreserveLoggingContext():
            deferred = self.replication_deferred
            self.replication_deferred = ObservableDeferred(defer.Deferred())
            deferred.callback(None)

            # the callbacks may well outlast the current request, so we run
            # them in the sentinel logcontext.
            #
            # (ideally it would be up to the callbacks to know if they were
            # starting off background processes and drop the logcontext
            # accordingly, but that requires more changes)
            for cb in self.replication_callbacks:
                cb()
Example #43
0
    def set(self, key, deferred):
        """Set the entry for the given key to the given deferred.

        *deferred* should run its callbacks in the sentinel logcontext (ie,
        you should wrap normal synapse deferreds with
        logcontext.run_in_background).

        Can return either a new Deferred (which also doesn't follow the synapse
        logcontext rules), or, if *deferred* was already complete, the actual
        result. You will probably want to make_deferred_yieldable the result.

        Args:
            key (hashable):
            deferred (twisted.internet.defer.Deferred[T):

        Returns:
            twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual
                result.
        """
        result = ObservableDeferred(deferred, consumeErrors=True)
        self.pending_result_cache[key] = result

        def remove(r):
            if self.timeout_sec:
                self.clock.call_later(
                    self.timeout_sec,
                    self.pending_result_cache.pop,
                    key,
                    None,
                )
            else:
                self.pending_result_cache.pop(key, None)
            return r

        result.addBoth(remove)
        return result.observe()
Example #44
0
 def notify(self, stream_key, stream_id, time_now_ms):
     """Notify any listeners for this user of a new event from an
     event source.
     Args:
         stream_key(str): The stream the event came from.
         stream_id(str): The new id for the stream the event came from.
         time_now_ms(int): The current time in milliseconds.
     """
     self.current_token = self.current_token.copy_and_advance(
         stream_key, stream_id
     )
     self.last_notified_ms = time_now_ms
     noify_deferred = self.notify_deferred
     self.notify_deferred = ObservableDeferred(defer.Deferred())
     noify_deferred.callback(self.current_token)
Example #45
0
    def test_prefill(self):
        callcount = [0]

        d = defer.succeed(123)

        class A(object):
            @cached()
            def func(self, key):
                callcount[0] += 1
                return d

        a = A()

        a.func.prefill(("foo", ), ObservableDeferred(d))

        self.assertEquals(a.func("foo").result, d.result)
        self.assertEquals(callcount[0], 0)
    def _async_render_GET(self, request):

        # XXX: if get_user_by_req fails, what should we do in an async render?
        requester = yield self.auth.get_user_by_req(request)
        url = request.args.get("url")[0]
        if "ts" in request.args:
            ts = int(request.args.get("ts")[0])
        else:
            ts = self.clock.time_msec()

        url_tuple = urlparse.urlsplit(url)
        for entry in self.url_preview_url_blacklist:
            match = True
            for attrib in entry:
                pattern = entry[attrib]
                value = getattr(url_tuple, attrib)
                logger.debug(("Matching attrib '%s' with value '%s' against"
                              " pattern '%s'") % (attrib, value, pattern))

                if value is None:
                    match = False
                    continue

                if pattern.startswith('^'):
                    if not re.match(pattern, getattr(url_tuple, attrib)):
                        match = False
                        continue
                else:
                    if not fnmatch.fnmatch(getattr(url_tuple, attrib),
                                           pattern):
                        match = False
                        continue
            if match:
                logger.warn("URL %s blocked by url_blacklist entry %s", url,
                            entry)
                raise SynapseError(
                    403, "URL blocked by url pattern blacklist entry",
                    Codes.UNKNOWN)

        # first check the memory cache - good to handle all the clients on this
        # HS thundering away to preview the same URL at the same time.
        og = self.cache.get(url)
        if og:
            respond_with_json_bytes(request,
                                    200,
                                    json.dumps(og),
                                    send_cors=True)
            return

        # then check the URL cache in the DB (which will also provide us with
        # historical previews, if we have any)
        cache_result = yield self.store.get_url_cache(url, ts)
        if (cache_result
                and cache_result["download_ts"] + cache_result["expires"] > ts
                and cache_result["response_code"] / 100 == 2):
            respond_with_json_bytes(request,
                                    200,
                                    cache_result["og"].encode('utf-8'),
                                    send_cors=True)
            return

        # Ensure only one download for a given URL is active at a time
        download = self.downloads.get(url)
        if download is None:
            download = self._download_url(url, requester.user)
            download = ObservableDeferred(download, consumeErrors=True)
            self.downloads[url] = download

            @download.addBoth
            def callback(media_info):
                del self.downloads[url]
                return media_info

        media_info = yield download.observe()

        # FIXME: we should probably update our cache now anyway, so that
        # even if the OG calculation raises, we don't keep hammering on the
        # remote server.  For now, leave it uncached to aid debugging OG
        # calculation problems

        logger.debug("got media_info of '%s'" % media_info)

        if _is_media(media_info['media_type']):
            dims = yield self.media_repo._generate_local_thumbnails(
                media_info['filesystem_id'],
                media_info,
                url_cache=True,
            )

            og = {
                "og:description":
                media_info['download_name'],
                "og:image":
                "mxc://%s/%s" %
                (self.server_name, media_info['filesystem_id']),
                "og:image:type":
                media_info['media_type'],
                "matrix:image:size":
                media_info['media_length'],
            }

            if dims:
                og["og:image:width"] = dims['width']
                og["og:image:height"] = dims['height']
            else:
                logger.warn("Couldn't get dims for %s" % url)

            # define our OG response for this media
        elif _is_html(media_info['media_type']):
            # TODO: somehow stop a big HTML tree from exploding synapse's RAM

            file = open(media_info['filename'])
            body = file.read()
            file.close()

            # clobber the encoding from the content-type, or default to utf-8
            # XXX: this overrides any <meta/> or XML charset headers in the body
            # which may pose problems, but so far seems to work okay.
            match = re.match(r'.*; *charset=(.*?)(;|$)',
                             media_info['media_type'], re.I)
            encoding = match.group(1) if match else "utf-8"

            og = decode_and_calc_og(body, media_info['uri'], encoding)

            # pre-cache the image for posterity
            # FIXME: it might be cleaner to use the same flow as the main /preview_url
            # request itself and benefit from the same caching etc.  But for now we
            # just rely on the caching on the master request to speed things up.
            if 'og:image' in og and og['og:image']:
                image_info = yield self._download_url(
                    _rebase_url(og['og:image'], media_info['uri']),
                    requester.user)

                if _is_media(image_info['media_type']):
                    # TODO: make sure we don't choke on white-on-transparent images
                    dims = yield self.media_repo._generate_local_thumbnails(
                        image_info['filesystem_id'],
                        image_info,
                        url_cache=True,
                    )
                    if dims:
                        og["og:image:width"] = dims['width']
                        og["og:image:height"] = dims['height']
                    else:
                        logger.warn("Couldn't get dims for %s" %
                                    og["og:image"])

                    og["og:image"] = "mxc://%s/%s" % (
                        self.server_name, image_info['filesystem_id'])
                    og["og:image:type"] = image_info['media_type']
                    og["matrix:image:size"] = image_info['media_length']
                else:
                    del og["og:image"]
        else:
            logger.warn("Failed to find any OG data in %s", url)
            og = {}

        logger.debug("Calculated OG for %s as %s" % (url, og))

        # store OG in ephemeral in-memory cache
        self.cache[url] = og

        # store OG in history-aware DB cache
        yield self.store.store_url_cache(
            url,
            media_info["response_code"],
            media_info["etag"],
            media_info["expires"],
            json.dumps(og),
            media_info["filesystem_id"],
            media_info["created_ts"],
        )

        respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
Example #47
0
class _NotifierUserStream(object):
    """This represents a user connected to the event stream.
    It tracks the most recent stream token for that user.
    At a given point a user may have a number of streams listening for
    events.

    This listener will also keep track of which rooms it is listening in
    so that it can remove itself from the indexes in the Notifier class.
    """

    def __init__(self, user_id, rooms, current_token, time_now_ms):
        self.user_id = user_id
        self.rooms = set(rooms)
        self.current_token = current_token

        # The last token for which we should wake up any streams that have a
        # token that comes before it. This gets updated everytime we get poked.
        # We start it at the current token since if we get any streams
        # that have a token from before we have no idea whether they should be
        # woken up or not, so lets just wake them up.
        self.last_notified_token = current_token
        self.last_notified_ms = time_now_ms

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())

    def notify(self, stream_key, stream_id, time_now_ms):
        """Notify any listeners for this user of a new event from an
        event source.
        Args:
            stream_key(str): The stream the event came from.
            stream_id(str): The new id for the stream the event came from.
            time_now_ms(int): The current time in milliseconds.
        """
        self.current_token = self.current_token.copy_and_advance(
            stream_key, stream_id
        )
        self.last_notified_token = self.current_token
        self.last_notified_ms = time_now_ms
        noify_deferred = self.notify_deferred

        users_woken_by_stream_counter.inc(stream_key)

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
            noify_deferred.callback(self.current_token)

    def remove(self, notifier):
        """ Remove this listener from all the indexes in the Notifier
        it knows about.
        """

        for room in self.rooms:
            lst = notifier.room_to_user_streams.get(room, set())
            lst.discard(self)

        notifier.user_to_user_stream.pop(self.user_id)

    def count_listeners(self):
        return len(self.notify_deferred.observers())

    def new_listener(self, token):
        """Returns a deferred that is resolved when there is a new token
        greater than the given token.

        Args:
            token: The token from which we are streaming from, i.e. we shouldn't
                notify for things that happened before this.
        """
        # Immediately wake up stream if something has already since happened
        # since their last token.
        if self.last_notified_token.is_after(token):
            return _NotificationListener(defer.succeed(self.current_token))
        else:
            return _NotificationListener(self.notify_deferred.observe())
Example #48
0
 def notify_replication(self):
     """Notify the any replication listeners that there's a new event"""
     with PreserveLoggingContext():
         deferred = self.replication_deferred
         self.replication_deferred = ObservableDeferred(defer.Deferred())
         deferred.callback(None)
Example #49
0
        def wrapped(*args, **kwargs):
            # If we're passed a cache_context then we'll want to call its invalidate()
            # whenever we are invalidated
            invalidate_callback = kwargs.pop("on_invalidate", None)

            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
            list_args = arg_dict[self.list_name]

            # cached is a dict arg -> deferred, where deferred results in a
            # 2-tuple (`arg`, `result`)
            results = {}
            cached_defers = {}
            missing = []

            # If the cache takes a single arg then that is used as the key,
            # otherwise a tuple is used.
            if num_args == 1:

                def cache_get(arg):
                    return cache.get(arg, callback=invalidate_callback)
            else:
                key = list(keyargs)

                def cache_get(arg):
                    key[self.list_pos] = arg
                    return cache.get(tuple(key), callback=invalidate_callback)

            for arg in list_args:
                try:
                    res = cache_get(arg)

                    if not isinstance(res, ObservableDeferred):
                        results[arg] = res
                    elif not res.has_succeeded():
                        res = res.observe()
                        res.addCallback(lambda r, arg: (arg, r), arg)
                        cached_defers[arg] = res
                    else:
                        results[arg] = res.get_result()
                except KeyError:
                    missing.append(arg)

            if missing:
                args_to_call = dict(arg_dict)
                args_to_call[self.list_name] = missing

                ret_d = defer.maybeDeferred(
                    logcontext.preserve_fn(self.function_to_call),
                    **args_to_call)

                ret_d = ObservableDeferred(ret_d)

                # We need to create deferreds for each arg in the list so that
                # we can insert the new deferred into the cache.
                for arg in missing:
                    observer = ret_d.observe()
                    observer.addCallback(lambda r, arg: r.get(arg, None), arg)

                    observer = ObservableDeferred(observer)

                    if num_args == 1:
                        cache.set(arg, observer, callback=invalidate_callback)

                        def invalidate(f, key):
                            cache.invalidate(key)
                            return f

                        observer.addErrback(invalidate, arg)
                    else:
                        key = list(keyargs)
                        key[self.list_pos] = arg
                        cache.set(tuple(key),
                                  observer,
                                  callback=invalidate_callback)

                        def invalidate(f, key):
                            cache.invalidate(key)
                            return f

                        observer.addErrback(invalidate, tuple(key))

                    res = observer.observe()
                    res.addCallback(lambda r, arg: (arg, r), arg)

                    cached_defers[arg] = res

            if cached_defers:

                def update_results_dict(res):
                    results.update(res)
                    return results

                return logcontext.make_deferred_yieldable(
                    defer.gatherResults(
                        list(cached_defers.values()),
                        consumeErrors=True,
                    ).addCallback(update_results_dict).addErrback(
                        unwrapFirstError))
            else:
                return results
class _NotifierUserStream(object):
    """This represents a user connected to the event stream.
    It tracks the most recent stream token for that user.
    At a given point a user may have a number of streams listening for
    events.

    This listener will also keep track of which rooms it is listening in
    so that it can remove itself from the indexes in the Notifier class.
    """
    def __init__(self, user_id, rooms, current_token, time_now_ms):
        self.user_id = user_id
        self.rooms = set(rooms)
        self.current_token = current_token

        # The last token for which we should wake up any streams that have a
        # token that comes before it. This gets updated everytime we get poked.
        # We start it at the current token since if we get any streams
        # that have a token from before we have no idea whether they should be
        # woken up or not, so lets just wake them up.
        self.last_notified_token = current_token
        self.last_notified_ms = time_now_ms

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())

    def notify(self, stream_key, stream_id, time_now_ms):
        """Notify any listeners for this user of a new event from an
        event source.
        Args:
            stream_key(str): The stream the event came from.
            stream_id(str): The new id for the stream the event came from.
            time_now_ms(int): The current time in milliseconds.
        """
        self.current_token = self.current_token.copy_and_advance(
            stream_key, stream_id)
        self.last_notified_token = self.current_token
        self.last_notified_ms = time_now_ms
        noify_deferred = self.notify_deferred

        users_woken_by_stream_counter.inc(stream_key)

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
            noify_deferred.callback(self.current_token)

    def remove(self, notifier):
        """ Remove this listener from all the indexes in the Notifier
        it knows about.
        """

        for room in self.rooms:
            lst = notifier.room_to_user_streams.get(room, set())
            lst.discard(self)

        notifier.user_to_user_stream.pop(self.user_id)

    def count_listeners(self):
        return len(self.notify_deferred.observers())

    def new_listener(self, token):
        """Returns a deferred that is resolved when there is a new token
        greater than the given token.

        Args:
            token: The token from which we are streaming from, i.e. we shouldn't
                notify for things that happened before this.
        """
        # Immediately wake up stream if something has already since happened
        # since their last token.
        if self.last_notified_token.is_after(token):
            return _NotificationListener(defer.succeed(self.current_token))
        else:
            return _NotificationListener(self.notify_deferred.observe())
Example #51
0
class Notifier(object):
    """ This class is responsible for notifying any listeners when there are
    new events available for it.

    Primarily used from the /events stream.
    """

    UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000

    def __init__(self, hs):
        self.user_to_user_stream = {}
        self.room_to_user_streams = {}
        self.appservice_to_user_streams = {}

        self.event_sources = hs.get_event_sources()
        self.store = hs.get_datastore()
        self.pending_new_room_events = []

        self.clock = hs.get_clock()
        self.appservice_handler = hs.get_application_service_handler()
        self.state_handler = hs.get_state_handler()

        self.clock.looping_call(
            self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
        )

        self.replication_deferred = ObservableDeferred(defer.Deferred())

        # This is not a very cheap test to perform, but it's only executed
        # when rendering the metrics page, which is likely once per minute at
        # most when scraping it.
        def count_listeners():
            all_user_streams = set()

            for x in self.room_to_user_streams.values():
                all_user_streams |= x
            for x in self.user_to_user_stream.values():
                all_user_streams.add(x)
            for x in self.appservice_to_user_streams.values():
                all_user_streams |= x

            return sum(stream.count_listeners() for stream in all_user_streams)
        metrics.register_callback("listeners", count_listeners)

        metrics.register_callback(
            "rooms",
            lambda: count(bool, self.room_to_user_streams.values()),
        )
        metrics.register_callback(
            "users",
            lambda: len(self.user_to_user_stream),
        )
        metrics.register_callback(
            "appservices",
            lambda: count(bool, self.appservice_to_user_streams.values()),
        )

    def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
                          extra_users=[]):
        """ Used by handlers to inform the notifier something has happened
        in the room, room event wise.

        This triggers the notifier to wake up any listeners that are
        listening to the room, and any listeners for the users in the
        `extra_users` param.

        The events can be peristed out of order. The notifier will wait
        until all previous events have been persisted before notifying
        the client streams.
        """
        with PreserveLoggingContext():
            self.pending_new_room_events.append((
                room_stream_id, event, extra_users
            ))
            self._notify_pending_new_room_events(max_room_stream_id)

            self.notify_replication()

    def _notify_pending_new_room_events(self, max_room_stream_id):
        """Notify for the room events that were queued waiting for a previous
        event to be persisted.
        Args:
            max_room_stream_id(int): The highest stream_id below which all
                events have been persisted.
        """
        pending = self.pending_new_room_events
        self.pending_new_room_events = []
        for room_stream_id, event, extra_users in pending:
            if room_stream_id > max_room_stream_id:
                self.pending_new_room_events.append((
                    room_stream_id, event, extra_users
                ))
            else:
                self._on_new_room_event(event, room_stream_id, extra_users)

    def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
        """Notify any user streams that are interested in this room event"""
        # poke any interested application service.
        self.appservice_handler.notify_interested_services(event)

        app_streams = set()

        for appservice in self.appservice_to_user_streams:
            # TODO (kegan): Redundant appservice listener checks?
            # App services will already be in the room_to_user_streams set, but
            # that isn't enough. They need to be checked here in order to
            # receive *invites* for users they are interested in. Does this
            # make the room_to_user_streams check somewhat obselete?
            if appservice.is_interested(event):
                app_user_streams = self.appservice_to_user_streams.get(
                    appservice, set()
                )
                app_streams |= app_user_streams

        if event.type == EventTypes.Member and event.membership == Membership.JOIN:
            self._user_joined_room(event.state_key, event.room_id)

        self.on_new_event(
            "room_key", room_stream_id,
            users=extra_users,
            rooms=[event.room_id],
            extra_streams=app_streams,
        )

    def on_new_event(self, stream_key, new_token, users=[], rooms=[],
                     extra_streams=set()):
        """ Used to inform listeners that something has happend event wise.

        Will wake up all listeners for the given users and rooms.
        """
        with PreserveLoggingContext():
            user_streams = set()

            for user in users:
                user_stream = self.user_to_user_stream.get(str(user))
                if user_stream is not None:
                    user_streams.add(user_stream)

            for room in rooms:
                user_streams |= self.room_to_user_streams.get(room, set())

            time_now_ms = self.clock.time_msec()
            for user_stream in user_streams:
                try:
                    user_stream.notify(stream_key, new_token, time_now_ms)
                except:
                    logger.exception("Failed to notify listener")

            self.notify_replication()

    def on_new_replication_data(self):
        """Used to inform replication listeners that something has happend
        without waking up any of the normal user event streams"""
        with PreserveLoggingContext():
            self.notify_replication()

    @defer.inlineCallbacks
    def wait_for_events(self, user_id, timeout, callback, room_ids=None,
                        from_token=StreamToken.START):
        """Wait until the callback returns a non empty response or the
        timeout fires.
        """
        user_stream = self.user_to_user_stream.get(user_id)
        if user_stream is None:
            appservice = yield self.store.get_app_service_by_user_id(user_id)
            current_token = yield self.event_sources.get_current_token()
            if room_ids is None:
                rooms = yield self.store.get_rooms_for_user(user_id)
                room_ids = [room.room_id for room in rooms]
            user_stream = _NotifierUserStream(
                user_id=user_id,
                rooms=room_ids,
                appservice=appservice,
                current_token=current_token,
                time_now_ms=self.clock.time_msec(),
            )
            self._register_with_keys(user_stream)

        result = None
        if timeout:
            # Will be set to a _NotificationListener that we'll be waiting on.
            # Allows us to cancel it.
            listener = None

            def timed_out():
                if listener:
                    listener.deferred.cancel()
            timer = self.clock.call_later(timeout / 1000., timed_out)

            prev_token = from_token
            while not result:
                try:
                    current_token = user_stream.current_token

                    result = yield callback(prev_token, current_token)
                    if result:
                        break

                    # Now we wait for the _NotifierUserStream to be told there
                    # is a new token.
                    # We need to supply the token we supplied to callback so
                    # that we don't miss any current_token updates.
                    prev_token = current_token
                    listener = user_stream.new_listener(prev_token)
                    with PreserveLoggingContext():
                        yield listener.deferred
                except defer.CancelledError:
                    break

            self.clock.cancel_call_later(timer, ignore_errs=True)
        else:
            current_token = user_stream.current_token
            result = yield callback(from_token, current_token)

        defer.returnValue(result)

    @defer.inlineCallbacks
    def get_events_for(self, user, pagination_config, timeout,
                       only_keys=None,
                       is_guest=False, explicit_room_id=None):
        """ For the given user and rooms, return any new events for them. If
        there are no new events wait for up to `timeout` milliseconds for any
        new events to happen before returning.

        If `only_keys` is not None, events from keys will be sent down.

        If explicit_room_id is not set, the user's joined rooms will be polled
        for events.
        If explicit_room_id is set, that room will be polled for events only if
        it is world readable or the user has joined the room.
        """
        from_token = pagination_config.from_token
        if not from_token:
            from_token = yield self.event_sources.get_current_token()

        limit = pagination_config.limit

        room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id)
        is_peeking = not is_joined

        @defer.inlineCallbacks
        def check_for_updates(before_token, after_token):
            if not after_token.is_after(before_token):
                defer.returnValue(EventStreamResult([], (from_token, from_token)))

            events = []
            end_token = from_token

            for name, source in self.event_sources.sources.items():
                keyname = "%s_key" % name
                before_id = getattr(before_token, keyname)
                after_id = getattr(after_token, keyname)
                if before_id == after_id:
                    continue
                if only_keys and name not in only_keys:
                    continue

                new_events, new_key = yield source.get_new_events(
                    user=user,
                    from_key=getattr(from_token, keyname),
                    limit=limit,
                    is_guest=is_peeking,
                    room_ids=room_ids,
                )

                if name == "room":
                    new_events = yield filter_events_for_client(
                        self.store,
                        user.to_string(),
                        new_events,
                        is_peeking=is_peeking,
                    )

                events.extend(new_events)
                end_token = end_token.copy_and_replace(keyname, new_key)

            defer.returnValue(EventStreamResult(events, (from_token, end_token)))

        user_id_for_stream = user.to_string()
        if is_peeking:
            # Internally, the notifier keeps an event stream per user_id.
            # This is used by both /sync and /events.
            # We want /events to be used for peeking independently of /sync,
            # without polluting its contents. So we invent an illegal user ID
            # (which thus cannot clash with any real users) for keying peeking
            # over /events.
            #
            # I am sorry for what I have done.
            user_id_for_stream = "_PEEKING_%s_%s" % (
                explicit_room_id, user_id_for_stream
            )

        result = yield self.wait_for_events(
            user_id_for_stream,
            timeout,
            check_for_updates,
            room_ids=room_ids,
            from_token=from_token,
        )

        defer.returnValue(result)

    @defer.inlineCallbacks
    def _get_room_ids(self, user, explicit_room_id):
        joined_rooms = yield self.store.get_rooms_for_user(user.to_string())
        joined_room_ids = map(lambda r: r.room_id, joined_rooms)
        if explicit_room_id:
            if explicit_room_id in joined_room_ids:
                defer.returnValue(([explicit_room_id], True))
            if (yield self._is_world_readable(explicit_room_id)):
                defer.returnValue(([explicit_room_id], False))
            raise AuthError(403, "Non-joined access not allowed")
        defer.returnValue((joined_room_ids, True))

    @defer.inlineCallbacks
    def _is_world_readable(self, room_id):
        state = yield self.state_handler.get_current_state(
            room_id,
            EventTypes.RoomHistoryVisibility
        )
        if state and "history_visibility" in state.content:
            defer.returnValue(state.content["history_visibility"] == "world_readable")
        else:
            defer.returnValue(False)

    @log_function
    def remove_expired_streams(self):
        time_now_ms = self.clock.time_msec()
        expired_streams = []
        expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
        for stream in self.user_to_user_stream.values():
            if stream.count_listeners():
                continue
            if stream.last_notified_ms < expire_before_ts:
                expired_streams.append(stream)

        for expired_stream in expired_streams:
            expired_stream.remove(self)

    @log_function
    def _register_with_keys(self, user_stream):
        self.user_to_user_stream[user_stream.user_id] = user_stream

        for room in user_stream.rooms:
            s = self.room_to_user_streams.setdefault(room, set())
            s.add(user_stream)

        if user_stream.appservice:
            self.appservice_to_user_stream.setdefault(
                user_stream.appservice, set()
            ).add(user_stream)

    def _user_joined_room(self, user_id, room_id):
        new_user_stream = self.user_to_user_stream.get(user_id)
        if new_user_stream is not None:
            room_streams = self.room_to_user_streams.setdefault(room_id, set())
            room_streams.add(new_user_stream)
            new_user_stream.rooms.add(room_id)

    def notify_replication(self):
        """Notify the any replication listeners that there's a new event"""
        with PreserveLoggingContext():
            deferred = self.replication_deferred
            self.replication_deferred = ObservableDeferred(defer.Deferred())
            deferred.callback(None)

    @defer.inlineCallbacks
    def wait_for_replication(self, callback, timeout):
        """Wait for an event to happen.

        Args:
            callback: Gets called whenever an event happens. If this returns a
                truthy value then ``wait_for_replication`` returns, otherwise
                it waits for another event.
            timeout: How many milliseconds to wait for callback return a truthy
                value.

        Returns:
            A deferred that resolves with the value returned by the callback.
        """
        listener = _NotificationListener(None)

        def timed_out():
            listener.deferred.cancel()

        timer = self.clock.call_later(timeout / 1000., timed_out)
        while True:
            listener.deferred = self.replication_deferred.observe()
            result = yield callback()
            if result:
                break

            try:
                with PreserveLoggingContext():
                    yield listener.deferred
            except defer.CancelledError:
                break

        self.clock.cancel_call_later(timer, ignore_errs=True)

        defer.returnValue(result)
Example #52
0
class Notifier(object):
    """ This class is responsible for notifying any listeners when there are
    new events available for it.

    Primarily used from the /events stream.
    """

    UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000

    def __init__(self, hs):
        self.user_to_user_stream = {}
        self.room_to_user_streams = {}

        self.event_sources = hs.get_event_sources()
        self.store = hs.get_datastore()
        self.pending_new_room_events = []

        self.replication_callbacks = []

        self.clock = hs.get_clock()
        self.appservice_handler = hs.get_application_service_handler()

        if hs.should_send_federation():
            self.federation_sender = hs.get_federation_sender()
        else:
            self.federation_sender = None

        self.state_handler = hs.get_state_handler()

        self.clock.looping_call(
            self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
        )

        self.replication_deferred = ObservableDeferred(defer.Deferred())

        # This is not a very cheap test to perform, but it's only executed
        # when rendering the metrics page, which is likely once per minute at
        # most when scraping it.
        def count_listeners():
            all_user_streams = set()

            for x in self.room_to_user_streams.values():
                all_user_streams |= x
            for x in self.user_to_user_stream.values():
                all_user_streams.add(x)

            return sum(stream.count_listeners() for stream in all_user_streams)
        metrics.register_callback("listeners", count_listeners)

        metrics.register_callback(
            "rooms",
            lambda: count(bool, self.room_to_user_streams.values()),
        )
        metrics.register_callback(
            "users",
            lambda: len(self.user_to_user_stream),
        )

    def add_replication_callback(self, cb):
        """Add a callback that will be called when some new data is available.
        Callback is not given any arguments.
        """
        self.replication_callbacks.append(cb)

    def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
                          extra_users=[]):
        """ Used by handlers to inform the notifier something has happened
        in the room, room event wise.

        This triggers the notifier to wake up any listeners that are
        listening to the room, and any listeners for the users in the
        `extra_users` param.

        The events can be peristed out of order. The notifier will wait
        until all previous events have been persisted before notifying
        the client streams.
        """
        self.pending_new_room_events.append((
            room_stream_id, event, extra_users
        ))
        self._notify_pending_new_room_events(max_room_stream_id)

        self.notify_replication()

    def _notify_pending_new_room_events(self, max_room_stream_id):
        """Notify for the room events that were queued waiting for a previous
        event to be persisted.
        Args:
            max_room_stream_id(int): The highest stream_id below which all
                events have been persisted.
        """
        pending = self.pending_new_room_events
        self.pending_new_room_events = []
        for room_stream_id, event, extra_users in pending:
            if room_stream_id > max_room_stream_id:
                self.pending_new_room_events.append((
                    room_stream_id, event, extra_users
                ))
            else:
                self._on_new_room_event(event, room_stream_id, extra_users)

    def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
        """Notify any user streams that are interested in this room event"""
        # poke any interested application service.
        run_in_background(self._notify_app_services, room_stream_id)

        if self.federation_sender:
            self.federation_sender.notify_new_events(room_stream_id)

        if event.type == EventTypes.Member and event.membership == Membership.JOIN:
            self._user_joined_room(event.state_key, event.room_id)

        self.on_new_event(
            "room_key", room_stream_id,
            users=extra_users,
            rooms=[event.room_id],
        )

    @defer.inlineCallbacks
    def _notify_app_services(self, room_stream_id):
        try:
            yield self.appservice_handler.notify_interested_services(room_stream_id)
        except Exception:
            logger.exception("Error notifying application services of event")

    def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
        """ Used to inform listeners that something has happend event wise.

        Will wake up all listeners for the given users and rooms.
        """
        with PreserveLoggingContext():
            with Measure(self.clock, "on_new_event"):
                user_streams = set()

                for user in users:
                    user_stream = self.user_to_user_stream.get(str(user))
                    if user_stream is not None:
                        user_streams.add(user_stream)

                for room in rooms:
                    user_streams |= self.room_to_user_streams.get(room, set())

                time_now_ms = self.clock.time_msec()
                for user_stream in user_streams:
                    try:
                        user_stream.notify(stream_key, new_token, time_now_ms)
                    except Exception:
                        logger.exception("Failed to notify listener")

                self.notify_replication()

    def on_new_replication_data(self):
        """Used to inform replication listeners that something has happend
        without waking up any of the normal user event streams"""
        self.notify_replication()

    @defer.inlineCallbacks
    def wait_for_events(self, user_id, timeout, callback, room_ids=None,
                        from_token=StreamToken.START):
        """Wait until the callback returns a non empty response or the
        timeout fires.
        """
        user_stream = self.user_to_user_stream.get(user_id)
        if user_stream is None:
            current_token = yield self.event_sources.get_current_token()
            if room_ids is None:
                room_ids = yield self.store.get_rooms_for_user(user_id)
            user_stream = _NotifierUserStream(
                user_id=user_id,
                rooms=room_ids,
                current_token=current_token,
                time_now_ms=self.clock.time_msec(),
            )
            self._register_with_keys(user_stream)

        result = None
        prev_token = from_token
        if timeout:
            end_time = self.clock.time_msec() + timeout

            while not result:
                try:
                    now = self.clock.time_msec()
                    if end_time <= now:
                        break

                    # Now we wait for the _NotifierUserStream to be told there
                    # is a new token.
                    listener = user_stream.new_listener(prev_token)
                    add_timeout_to_deferred(
                        listener.deferred,
                        (end_time - now) / 1000.,
                    )
                    with PreserveLoggingContext():
                        yield listener.deferred

                    current_token = user_stream.current_token

                    result = yield callback(prev_token, current_token)
                    if result:
                        break

                    # Update the prev_token to the current_token since nothing
                    # has happened between the old prev_token and the current_token
                    prev_token = current_token
                except DeferredTimeoutError:
                    break
                except defer.CancelledError:
                    break

        if result is None:
            # This happened if there was no timeout or if the timeout had
            # already expired.
            current_token = user_stream.current_token
            result = yield callback(prev_token, current_token)

        defer.returnValue(result)

    @defer.inlineCallbacks
    def get_events_for(self, user, pagination_config, timeout,
                       only_keys=None,
                       is_guest=False, explicit_room_id=None):
        """ For the given user and rooms, return any new events for them. If
        there are no new events wait for up to `timeout` milliseconds for any
        new events to happen before returning.

        If `only_keys` is not None, events from keys will be sent down.

        If explicit_room_id is not set, the user's joined rooms will be polled
        for events.
        If explicit_room_id is set, that room will be polled for events only if
        it is world readable or the user has joined the room.
        """
        from_token = pagination_config.from_token
        if not from_token:
            from_token = yield self.event_sources.get_current_token()

        limit = pagination_config.limit

        room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id)
        is_peeking = not is_joined

        @defer.inlineCallbacks
        def check_for_updates(before_token, after_token):
            if not after_token.is_after(before_token):
                defer.returnValue(EventStreamResult([], (from_token, from_token)))

            events = []
            end_token = from_token

            for name, source in self.event_sources.sources.items():
                keyname = "%s_key" % name
                before_id = getattr(before_token, keyname)
                after_id = getattr(after_token, keyname)
                if before_id == after_id:
                    continue
                if only_keys and name not in only_keys:
                    continue

                new_events, new_key = yield source.get_new_events(
                    user=user,
                    from_key=getattr(from_token, keyname),
                    limit=limit,
                    is_guest=is_peeking,
                    room_ids=room_ids,
                    explicit_room_id=explicit_room_id,
                )

                if name == "room":
                    new_events = yield filter_events_for_client(
                        self.store,
                        user.to_string(),
                        new_events,
                        is_peeking=is_peeking,
                    )
                elif name == "presence":
                    now = self.clock.time_msec()
                    new_events[:] = [
                        {
                            "type": "m.presence",
                            "content": format_user_presence_state(event, now),
                        }
                        for event in new_events
                    ]

                events.extend(new_events)
                end_token = end_token.copy_and_replace(keyname, new_key)

            defer.returnValue(EventStreamResult(events, (from_token, end_token)))

        user_id_for_stream = user.to_string()
        if is_peeking:
            # Internally, the notifier keeps an event stream per user_id.
            # This is used by both /sync and /events.
            # We want /events to be used for peeking independently of /sync,
            # without polluting its contents. So we invent an illegal user ID
            # (which thus cannot clash with any real users) for keying peeking
            # over /events.
            #
            # I am sorry for what I have done.
            user_id_for_stream = "_PEEKING_%s_%s" % (
                explicit_room_id, user_id_for_stream
            )

        result = yield self.wait_for_events(
            user_id_for_stream,
            timeout,
            check_for_updates,
            room_ids=room_ids,
            from_token=from_token,
        )

        defer.returnValue(result)

    @defer.inlineCallbacks
    def _get_room_ids(self, user, explicit_room_id):
        joined_room_ids = yield self.store.get_rooms_for_user(user.to_string())
        if explicit_room_id:
            if explicit_room_id in joined_room_ids:
                defer.returnValue(([explicit_room_id], True))
            if (yield self._is_world_readable(explicit_room_id)):
                defer.returnValue(([explicit_room_id], False))
            raise AuthError(403, "Non-joined access not allowed")
        defer.returnValue((joined_room_ids, True))

    @defer.inlineCallbacks
    def _is_world_readable(self, room_id):
        state = yield self.state_handler.get_current_state(
            room_id,
            EventTypes.RoomHistoryVisibility,
            "",
        )
        if state and "history_visibility" in state.content:
            defer.returnValue(state.content["history_visibility"] == "world_readable")
        else:
            defer.returnValue(False)

    @log_function
    def remove_expired_streams(self):
        time_now_ms = self.clock.time_msec()
        expired_streams = []
        expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
        for stream in self.user_to_user_stream.values():
            if stream.count_listeners():
                continue
            if stream.last_notified_ms < expire_before_ts:
                expired_streams.append(stream)

        for expired_stream in expired_streams:
            expired_stream.remove(self)

    @log_function
    def _register_with_keys(self, user_stream):
        self.user_to_user_stream[user_stream.user_id] = user_stream

        for room in user_stream.rooms:
            s = self.room_to_user_streams.setdefault(room, set())
            s.add(user_stream)

    def _user_joined_room(self, user_id, room_id):
        new_user_stream = self.user_to_user_stream.get(user_id)
        if new_user_stream is not None:
            room_streams = self.room_to_user_streams.setdefault(room_id, set())
            room_streams.add(new_user_stream)
            new_user_stream.rooms.add(room_id)

    def notify_replication(self):
        """Notify the any replication listeners that there's a new event"""
        with PreserveLoggingContext():
            deferred = self.replication_deferred
            self.replication_deferred = ObservableDeferred(defer.Deferred())
            deferred.callback(None)

            # the callbacks may well outlast the current request, so we run
            # them in the sentinel logcontext.
            #
            # (ideally it would be up to the callbacks to know if they were
            # starting off background processes and drop the logcontext
            # accordingly, but that requires more changes)
            for cb in self.replication_callbacks:
                cb()

    @defer.inlineCallbacks
    def wait_for_replication(self, callback, timeout):
        """Wait for an event to happen.

        Args:
            callback: Gets called whenever an event happens. If this returns a
                truthy value then ``wait_for_replication`` returns, otherwise
                it waits for another event.
            timeout: How many milliseconds to wait for callback return a truthy
                value.

        Returns:
            A deferred that resolves with the value returned by the callback.
        """
        listener = _NotificationListener(None)

        end_time = self.clock.time_msec() + timeout

        while True:
            listener.deferred = self.replication_deferred.observe()
            result = yield callback()
            if result:
                break

            now = self.clock.time_msec()
            if end_time <= now:
                break

            add_timeout_to_deferred(
                listener.deferred.addTimeout,
                (end_time - now) / 1000.,
            )
            try:
                with PreserveLoggingContext():
                    yield listener.deferred
            except DeferredTimeoutError:
                break
            except defer.CancelledError:
                break

        defer.returnValue(result)
Example #53
0
    def _async_render_GET(self, request):

        # XXX: if get_user_by_req fails, what should we do in an async render?
        requester = yield self.auth.get_user_by_req(request)
        url = request.args.get("url")[0]
        if "ts" in request.args:
            ts = int(request.args.get("ts")[0])
        else:
            ts = self.clock.time_msec()

        url_tuple = urlparse.urlsplit(url)
        for entry in self.url_preview_url_blacklist:
            match = True
            for attrib in entry:
                pattern = entry[attrib]
                value = getattr(url_tuple, attrib)
                logger.debug((
                    "Matching attrib '%s' with value '%s' against"
                    " pattern '%s'"
                ) % (attrib, value, pattern))

                if value is None:
                    match = False
                    continue

                if pattern.startswith('^'):
                    if not re.match(pattern, getattr(url_tuple, attrib)):
                        match = False
                        continue
                else:
                    if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern):
                        match = False
                        continue
            if match:
                logger.warn(
                    "URL %s blocked by url_blacklist entry %s", url, entry
                )
                raise SynapseError(
                    403, "URL blocked by url pattern blacklist entry",
                    Codes.UNKNOWN
                )

        # first check the memory cache - good to handle all the clients on this
        # HS thundering away to preview the same URL at the same time.
        og = self.cache.get(url)
        if og:
            respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
            return

        # then check the URL cache in the DB (which will also provide us with
        # historical previews, if we have any)
        cache_result = yield self.store.get_url_cache(url, ts)
        if (
            cache_result and
            cache_result["download_ts"] + cache_result["expires"] > ts and
            cache_result["response_code"] / 100 == 2
        ):
            respond_with_json_bytes(
                request, 200, cache_result["og"].encode('utf-8'),
                send_cors=True
            )
            return

        # Ensure only one download for a given URL is active at a time
        download = self.downloads.get(url)
        if download is None:
            download = self._download_url(url, requester.user)
            download = ObservableDeferred(
                download,
                consumeErrors=True
            )
            self.downloads[url] = download

            @download.addBoth
            def callback(media_info):
                del self.downloads[url]
                return media_info
        media_info = yield download.observe()

        # FIXME: we should probably update our cache now anyway, so that
        # even if the OG calculation raises, we don't keep hammering on the
        # remote server.  For now, leave it uncached to aid debugging OG
        # calculation problems

        logger.debug("got media_info of '%s'" % media_info)

        if self._is_media(media_info['media_type']):
            dims = yield self.media_repo._generate_local_thumbnails(
                media_info['filesystem_id'], media_info
            )

            og = {
                "og:description": media_info['download_name'],
                "og:image": "mxc://%s/%s" % (
                    self.server_name, media_info['filesystem_id']
                ),
                "og:image:type": media_info['media_type'],
                "matrix:image:size": media_info['media_length'],
            }

            if dims:
                og["og:image:width"] = dims['width']
                og["og:image:height"] = dims['height']
            else:
                logger.warn("Couldn't get dims for %s" % url)

            # define our OG response for this media
        elif self._is_html(media_info['media_type']):
            # TODO: somehow stop a big HTML tree from exploding synapse's RAM

            from lxml import etree

            file = open(media_info['filename'])
            body = file.read()
            file.close()

            # clobber the encoding from the content-type, or default to utf-8
            # XXX: this overrides any <meta/> or XML charset headers in the body
            # which may pose problems, but so far seems to work okay.
            match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
            encoding = match.group(1) if match else "utf-8"

            try:
                parser = etree.HTMLParser(recover=True, encoding=encoding)
                tree = etree.fromstring(body, parser)
                og = yield self._calc_og(tree, media_info, requester)
            except UnicodeDecodeError:
                # blindly try decoding the body as utf-8, which seems to fix
                # the charset mismatches on https://google.com
                parser = etree.HTMLParser(recover=True, encoding=encoding)
                tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
                og = yield self._calc_og(tree, media_info, requester)

        else:
            logger.warn("Failed to find any OG data in %s", url)
            og = {}

        logger.debug("Calculated OG for %s as %s" % (url, og))

        # store OG in ephemeral in-memory cache
        self.cache[url] = og

        # store OG in history-aware DB cache
        yield self.store.store_url_cache(
            url,
            media_info["response_code"],
            media_info["etag"],
            media_info["expires"],
            json.dumps(og),
            media_info["filesystem_id"],
            media_info["created_ts"],
        )

        respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
Example #54
0
        def wrapped(*args, **kwargs):
            # If we're passed a cache_context then we'll want to call its invalidate()
            # whenever we are invalidated
            invalidate_callback = kwargs.pop("on_invalidate", None)

            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
            list_args = arg_dict[self.list_name]

            # cached is a dict arg -> deferred, where deferred results in a
            # 2-tuple (`arg`, `result`)
            results = {}
            cached_defers = {}
            missing = []
            for arg in list_args:
                key = list(keyargs)
                key[self.list_pos] = arg

                try:
                    res = cache.get(tuple(key), callback=invalidate_callback)
                    if not res.has_succeeded():
                        res = res.observe()
                        res.addCallback(lambda r, arg: (arg, r), arg)
                        cached_defers[arg] = res
                    else:
                        results[arg] = res.get_result()
                except KeyError:
                    missing.append(arg)

            if missing:
                sequence = cache.sequence
                args_to_call = dict(arg_dict)
                args_to_call[self.list_name] = missing

                ret_d = defer.maybeDeferred(
                    preserve_context_over_fn,
                    self.function_to_call,
                    **args_to_call
                )

                ret_d = ObservableDeferred(ret_d)

                # We need to create deferreds for each arg in the list so that
                # we can insert the new deferred into the cache.
                for arg in missing:
                    with PreserveLoggingContext():
                        observer = ret_d.observe()
                    observer.addCallback(lambda r, arg: r.get(arg, None), arg)

                    observer = ObservableDeferred(observer)

                    key = list(keyargs)
                    key[self.list_pos] = arg
                    cache.update(
                        sequence, tuple(key), observer,
                        callback=invalidate_callback
                    )

                    def invalidate(f, key):
                        cache.invalidate(key)
                        return f
                    observer.addErrback(invalidate, tuple(key))

                    res = observer.observe()
                    res.addCallback(lambda r, arg: (arg, r), arg)

                    cached_defers[arg] = res

            if cached_defers:
                def update_results_dict(res):
                    results.update(res)
                    return results

                return preserve_context_over_deferred(defer.gatherResults(
                    cached_defers.values(),
                    consumeErrors=True,
                ).addCallback(update_results_dict).addErrback(
                    unwrapFirstError
                ))
            else:
                return results