Esempio n. 1
0
        def wrapped(*args, **kwargs):
            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
            list_args = arg_dict[self.list_name]

            # cached is a dict arg -> deferred, where deferred results in a
            # 2-tuple (`arg`, `result`)
            cached = {}
            missing = []
            for arg in list_args:
                key = list(keyargs)
                key[self.list_pos] = arg

                try:
                    res = self.cache.get(tuple(key)).observe()
                    res.addCallback(lambda r, arg: (arg, r), arg)
                    cached[arg] = res
                except KeyError:
                    missing.append(arg)

            if missing:
                sequence = self.cache.sequence
                args_to_call = dict(arg_dict)
                args_to_call[self.list_name] = missing

                ret_d = defer.maybeDeferred(
                    preserve_context_over_fn,
                    self.function_to_call,
                    **args_to_call
                )

                ret_d = ObservableDeferred(ret_d)

                # We need to create deferreds for each arg in the list so that
                # we can insert the new deferred into the cache.
                for arg in missing:
                    with PreserveLoggingContext():
                        observer = ret_d.observe()
                    observer.addCallback(lambda r, arg: r.get(arg, None), arg)

                    observer = ObservableDeferred(observer)

                    key = list(keyargs)
                    key[self.list_pos] = arg
                    self.cache.update(sequence, tuple(key), observer)

                    def invalidate(f, key):
                        self.cache.invalidate(key)
                        return f
                    observer.addErrback(invalidate, tuple(key))

                    res = observer.observe()
                    res.addCallback(lambda r, arg: (arg, r), arg)

                    cached[arg] = res

            return preserve_context_over_deferred(defer.gatherResults(
                cached.values(),
                consumeErrors=True,
            ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
Esempio n. 2
0
    def wait_for_previous_lookups(self, server_names, server_to_deferred):
        """Waits for any previous key lookups for the given servers to finish.

        Args:
            server_names (list): list of server_names we want to lookup
            server_to_deferred (dict): server_name to deferred which gets
                resolved once we've finished looking up keys for that server
        """
        while True:
            wait_on = [
                self.key_downloads[server_name] for server_name in server_names
                if server_name in self.key_downloads
            ]
            if wait_on:
                yield defer.DeferredList(wait_on)
            else:
                break

        for server_name, deferred in server_to_deferred.items():
            d = ObservableDeferred(deferred)
            self.key_downloads[server_name] = d

            def rm(r, server_name):
                self.key_downloads.pop(server_name, None)
                return r

            d.addBoth(rm, server_name)
Esempio n. 3
0
    def add_to_queue(self, room_id, events_and_contexts, backfilled,
                     current_state):
        """Add events to the queue, with the given persist_event options.
        """
        queue = self._event_persist_queues.setdefault(room_id, deque())
        if queue:
            end_item = queue[-1]
            if end_item.current_state or current_state:
                # We perist events with current_state set to True one at a time
                pass
            if end_item.backfilled == backfilled:
                end_item.events_and_contexts.extend(events_and_contexts)
                return end_item.deferred.observe()

        deferred = ObservableDeferred(defer.Deferred())

        queue.append(
            self._EventPersistQueueItem(
                events_and_contexts=events_and_contexts,
                backfilled=backfilled,
                current_state=current_state,
                deferred=deferred,
            ))

        return deferred.observe()
Esempio n. 4
0
    def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
        """Fetches the response for this transaction, or executes the given function
        to produce a response for this transaction.

        Args:
            txn_key (str): A key to ensure idempotency should fetch_or_execute be
            called again at a later point in time.
            fn (function): A function which returns a tuple of
            (response_code, response_dict).
            *args: Arguments to pass to fn.
            **kwargs: Keyword arguments to pass to fn.
        Returns:
            Deferred which resolves to a tuple of (response_code, response_dict).
        """
        try:
            return self.transactions[txn_key][0].observe()
        except (KeyError, IndexError):
            pass  # execute the function instead.

        deferred = fn(*args, **kwargs)

        # We don't add an errback to the raw deferred, so we ask ObservableDeferred
        # to swallow the error. This is fine as the error will still be reported
        # to the observers.
        observable = ObservableDeferred(deferred, consumeErrors=True)
        self.transactions[txn_key] = (observable, self.clock.time_msec())
        return observable.observe()
Esempio n. 5
0
    def get_server_verify_key(self, server_name, key_ids):
        """Finds a verification key for the server with one of the key ids.
        Trys to fetch the key from a trusted perspective server first.
        Args:
            server_name(str): The name of the server to fetch a key for.
            keys_ids (list of str): The key_ids to check for.
        """
        cached = yield self.store.get_server_verify_keys(server_name, key_ids)

        if cached:
            defer.returnValue(cached[0])
            return

        download = self.key_downloads.get(server_name)

        if download is None:
            download = self._get_server_verify_key_impl(server_name, key_ids)
            download = ObservableDeferred(download, consumeErrors=True)
            self.key_downloads[server_name] = download

            @download.addBoth
            def callback(ret):
                del self.key_downloads[server_name]
                return ret

        r = yield download.observe()
        defer.returnValue(r)
Esempio n. 6
0
    def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
        """Fetches the response for this transaction, or executes the given function
        to produce a response for this transaction.

        Args:
            txn_key (str): A key to ensure idempotency should fetch_or_execute be
            called again at a later point in time.
            fn (function): A function which returns a tuple of
            (response_code, response_dict).
            *args: Arguments to pass to fn.
            **kwargs: Keyword arguments to pass to fn.
        Returns:
            Deferred which resolves to a tuple of (response_code, response_dict).
        """
        if txn_key in self.transactions:
            observable = self.transactions[txn_key][0]
        else:
            # execute the function instead.
            deferred = run_in_background(fn, *args, **kwargs)

            observable = ObservableDeferred(deferred)
            self.transactions[txn_key] = (observable, self.clock.time_msec())

            # if the request fails with an exception, remove it
            # from the transaction map. This is done to ensure that we don't
            # cache transient errors like rate-limiting errors, etc.
            def remove_from_map(err):
                self.transactions.pop(txn_key, None)
                # we deliberately do not propagate the error any further, as we
                # expect the observers to have reported it.

            deferred.addErrback(remove_from_map)

        return make_deferred_yieldable(observable.observe())
Esempio n. 7
0
    def __init__(self, user_id, rooms, current_token, time_now_ms):
        self.user_id = user_id
        self.rooms = set(rooms)
        self.current_token = current_token
        self.last_notified_ms = time_now_ms

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
Esempio n. 8
0
        def wrapped(*args, **kwargs):
            # If we're passed a cache_context then we'll want to call its invalidate()
            # whenever we are invalidated
            invalidate_callback = kwargs.pop("on_invalidate", None)

            # Add temp cache_context so inspect.getcallargs doesn't explode
            if self.add_cache_context:
                kwargs["cache_context"] = None

            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)

            # Add our own `cache_context` to argument list if the wrapped function
            # has asked for one
            if self.add_cache_context:
                kwargs["cache_context"] = _CacheContext(cache, cache_key)

            try:
                cached_result_d = cache.get(cache_key, callback=invalidate_callback)

                observer = cached_result_d.observe()
                if DEBUG_CACHES:
                    @defer.inlineCallbacks
                    def check_result(cached_result):
                        actual_result = yield self.function_to_call(obj, *args, **kwargs)
                        if actual_result != cached_result:
                            logger.error(
                                "Stale cache entry %s%r: cached: %r, actual %r",
                                self.orig.__name__, cache_key,
                                cached_result, actual_result,
                            )
                            raise ValueError("Stale cache entry")
                        defer.returnValue(cached_result)
                    observer.addCallback(check_result)

                return preserve_context_over_deferred(observer)
            except KeyError:
                # Get the sequence number of the cache before reading from the
                # database so that we can tell if the cache is invalidated
                # while the SELECT is executing (SYN-369)
                sequence = cache.sequence

                ret = defer.maybeDeferred(
                    preserve_context_over_fn,
                    self.function_to_call,
                    obj, *args, **kwargs
                )

                def onErr(f):
                    cache.invalidate(cache_key)
                    return f

                ret.addErrback(onErr)

                ret = ObservableDeferred(ret, consumeErrors=True)
                cache.update(sequence, cache_key, ret, callback=invalidate_callback)

                return preserve_context_over_deferred(ret.observe())
    def notify_replication(self):
        """Notify the any replication listeners that there's a new event"""
        with PreserveLoggingContext():
            deferred = self.replication_deferred
            self.replication_deferred = ObservableDeferred(defer.Deferred())
            deferred.callback(None)

        for cb in self.replication_callbacks:
            preserve_fn(cb)()
Esempio n. 10
0
    def set(self, key, deferred):
        result = ObservableDeferred(deferred, consumeErrors=True)
        self.pending_result_cache[key] = result

        def remove(r):
            self.pending_result_cache.pop(key, None)
            return r

        result.addBoth(remove)
        return result.observe()
Esempio n. 11
0
    def __init__(self, hs):
        self.user_to_user_stream = {}
        self.room_to_user_streams = {}

        self.hs = hs
        self.event_sources = hs.get_event_sources()
        self.store = hs.get_datastore()
        self.pending_new_room_events = []

        self.replication_callbacks = []

        self.clock = hs.get_clock()
        self.appservice_handler = hs.get_application_service_handler()

        if hs.should_send_federation():
            self.federation_sender = hs.get_federation_sender()
        else:
            self.federation_sender = None

        self.state_handler = hs.get_state_handler()

        self.clock.looping_call(self.remove_expired_streams,
                                self.UNUSED_STREAM_EXPIRY_MS)

        self.replication_deferred = ObservableDeferred(defer.Deferred())

        # This is not a very cheap test to perform, but it's only executed
        # when rendering the metrics page, which is likely once per minute at
        # most when scraping it.
        def count_listeners():
            all_user_streams = set()

            for x in self.room_to_user_streams.values():
                all_user_streams |= x
            for x in self.user_to_user_stream.values():
                all_user_streams.add(x)

            return sum(stream.count_listeners() for stream in all_user_streams)

        LaterGauge("synapse_notifier_listeners", "", [], count_listeners)

        LaterGauge(
            "synapse_notifier_rooms",
            "",
            [],
            lambda: count(bool, self.room_to_user_streams.values()),
        )
        LaterGauge(
            "synapse_notifier_users",
            "",
            [],
            lambda: len(self.user_to_user_stream),
        )
Esempio n. 12
0
    def __init__(self,
                 user,
                 rooms,
                 current_token,
                 time_now_ms,
                 appservice=None):
        self.user = str(user)
        self.appservice = appservice
        self.rooms = set(rooms)
        self.current_token = current_token
        self.last_notified_ms = time_now_ms

        self.notify_deferred = ObservableDeferred(defer.Deferred())
Esempio n. 13
0
    def get_remote_media(self, server_name, media_id):
        key = (server_name, media_id)
        download = self.downloads.get(key)
        if download is None:
            download = self._get_remote_media_impl(server_name, media_id)
            download = ObservableDeferred(download, consumeErrors=True)
            self.downloads[key] = download

            @download.addBoth
            def callback(media_info):
                del self.downloads[key]
                return media_info

        return download.observe()
Esempio n. 14
0
 def notify(self, stream_key, stream_id, time_now_ms):
     """Notify any listeners for this user of a new event from an
     event source.
     Args:
         stream_key(str): The stream the event came from.
         stream_id(str): The new id for the stream the event came from.
         time_now_ms(int): The current time in milliseconds.
     """
     self.current_token = self.current_token.copy_and_advance(
         stream_key, stream_id)
     self.last_notified_ms = time_now_ms
     noify_deferred = self.notify_deferred
     self.notify_deferred = ObservableDeferred(defer.Deferred())
     noify_deferred.callback(self.current_token)
Esempio n. 15
0
    def __init__(self, hs):
        self.hs = hs

        self.user_to_user_stream = {}
        self.room_to_user_streams = {}
        self.appservice_to_user_streams = {}

        self.event_sources = hs.get_event_sources()
        self.store = hs.get_datastore()
        self.pending_new_room_events = []

        self.clock = hs.get_clock()

        hs.get_distributor().observe("user_joined_room",
                                     self._user_joined_room)

        self.clock.looping_call(self.remove_expired_streams,
                                self.UNUSED_STREAM_EXPIRY_MS)

        self.replication_deferred = ObservableDeferred(defer.Deferred())

        # This is not a very cheap test to perform, but it's only executed
        # when rendering the metrics page, which is likely once per minute at
        # most when scraping it.
        def count_listeners():
            all_user_streams = set()

            for x in self.room_to_user_streams.values():
                all_user_streams |= x
            for x in self.user_to_user_stream.values():
                all_user_streams.add(x)
            for x in self.appservice_to_user_streams.values():
                all_user_streams |= x

            return sum(stream.count_listeners() for stream in all_user_streams)

        metrics.register_callback("listeners", count_listeners)

        metrics.register_callback(
            "rooms",
            lambda: count(bool, self.room_to_user_streams.values()),
        )
        metrics.register_callback(
            "users",
            lambda: len(self.user_to_user_stream),
        )
        metrics.register_callback(
            "appservices",
            lambda: count(bool, self.appservice_to_user_streams.values()),
        )
Esempio n. 16
0
    def notify_replication(self):
        """Notify the any replication listeners that there's a new event"""
        with PreserveLoggingContext():
            deferred = self.replication_deferred
            self.replication_deferred = ObservableDeferred(defer.Deferred())
            deferred.callback(None)

            # the callbacks may well outlast the current request, so we run
            # them in the sentinel logcontext.
            #
            # (ideally it would be up to the callbacks to know if they were
            # starting off background processes and drop the logcontext
            # accordingly, but that requires more changes)
            for cb in self.replication_callbacks:
                cb()
    def __init__(self, user_id, rooms, current_token, time_now_ms):
        self.user_id = user_id
        self.rooms = set(rooms)
        self.current_token = current_token

        # The last token for which we should wake up any streams that have a
        # token that comes before it. This gets updated everytime we get poked.
        # We start it at the current token since if we get any streams
        # that have a token from before we have no idea whether they should be
        # woken up or not, so lets just wake them up.
        self.last_notified_token = current_token
        self.last_notified_ms = time_now_ms

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
Esempio n. 18
0
        def wrapped(*args, **kwargs):
            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
            try:
                cached_result_d = cache.get(cache_key)

                observer = cached_result_d.observe()
                if DEBUG_CACHES:

                    @defer.inlineCallbacks
                    def check_result(cached_result):
                        actual_result = yield self.function_to_call(
                            obj, *args, **kwargs)
                        if actual_result != cached_result:
                            logger.error(
                                "Stale cache entry %s%r: cached: %r, actual %r",
                                self.orig.__name__,
                                cache_key,
                                cached_result,
                                actual_result,
                            )
                            raise ValueError("Stale cache entry")
                        defer.returnValue(cached_result)

                    observer.addCallback(check_result)

                return preserve_context_over_deferred(observer)
            except KeyError:
                # Get the sequence number of the cache before reading from the
                # database so that we can tell if the cache is invalidated
                # while the SELECT is executing (SYN-369)
                sequence = cache.sequence

                ret = defer.maybeDeferred(preserve_context_over_fn,
                                          self.function_to_call, obj, *args,
                                          **kwargs)

                def onErr(f):
                    cache.invalidate(cache_key)
                    return f

                ret.addErrback(onErr)

                ret = ObservableDeferred(ret, consumeErrors=True)
                cache.update(sequence, cache_key, ret)

                return preserve_context_over_deferred(ret.observe())
Esempio n. 19
0
        def wrapped(*args, **kwargs):
            # If we're passed a cache_context then we'll want to call its invalidate()
            # whenever we are invalidated
            invalidate_callback = kwargs.pop("on_invalidate", None)

            cache_key = get_cache_key(args, kwargs)

            # Add our own `cache_context` to argument list if the wrapped function
            # has asked for one
            if self.add_cache_context:
                kwargs["cache_context"] = _CacheContext(cache, cache_key)

            try:
                cached_result_d = cache.get(cache_key,
                                            callback=invalidate_callback)

                if isinstance(cached_result_d, ObservableDeferred):
                    observer = cached_result_d.observe()
                else:
                    observer = cached_result_d

            except KeyError:
                ret = defer.maybeDeferred(
                    logcontext.preserve_fn(self.function_to_call), obj, *args,
                    **kwargs)

                def onErr(f):
                    cache.invalidate(cache_key)
                    return f

                ret.addErrback(onErr)

                # If our cache_key is a string on py2, try to convert to ascii
                # to save a bit of space in large caches. Py3 does this
                # internally automatically.
                if six.PY2 and isinstance(cache_key, string_types):
                    cache_key = to_ascii(cache_key)

                result_d = ObservableDeferred(ret, consumeErrors=True)
                cache.set(cache_key, result_d, callback=invalidate_callback)
                observer = result_d.observe()

            if isinstance(observer, defer.Deferred):
                return logcontext.make_deferred_yieldable(observer)
            else:
                return observer
Esempio n. 20
0
    def set(self, time_now_ms, key, deferred):
        self.rotate(time_now_ms)

        result = ObservableDeferred(deferred)

        self.pending_result_cache[key] = result

        def shuffle_along(r):
            # When the deferred completes we shuffle it along to the first
            # generation of the result cache. So that it will eventually
            # expire from the rotation of that cache.
            self.next_result_cache[key] = result
            self.pending_result_cache.pop(key, None)

        result.observe().addBoth(shuffle_along)

        return result.observe()
Esempio n. 21
0
    def test_prefill(self):
        callcount = [0]

        d = defer.succeed(123)

        class A(object):
            @cached()
            def func(self, key):
                callcount[0] += 1
                return d

        a = A()

        a.func.prefill(("foo", ), ObservableDeferred(d))

        self.assertEquals(a.func("foo").result, d.result)
        self.assertEquals(callcount[0], 0)
Esempio n. 22
0
    def set(self, key, deferred):
        result = ObservableDeferred(deferred, consumeErrors=True)
        self.pending_result_cache[key] = result

        def remove(r):
            if self.timeout_sec:
                self.clock.call_later(
                    self.timeout_sec,
                    self.pending_result_cache.pop,
                    key,
                    None,
                )
            else:
                self.pending_result_cache.pop(key, None)
            return r

        result.addBoth(remove)
        return result.observe()
Esempio n. 23
0
    def notify(self, stream_key, stream_id, time_now_ms):
        """Notify any listeners for this user of a new event from an
        event source.
        Args:
            stream_key(str): The stream the event came from.
            stream_id(str): The new id for the stream the event came from.
            time_now_ms(int): The current time in milliseconds.
        """
        self.current_token = self.current_token.copy_and_advance(
            stream_key, stream_id)
        self.last_notified_token = self.current_token
        self.last_notified_ms = time_now_ms
        noify_deferred = self.notify_deferred

        users_woken_by_stream_counter.labels(stream_key).inc()

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
            noify_deferred.callback(self.current_token)
Esempio n. 24
0
    def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
        """Fetches the response for this transaction, or executes the given function
        to produce a response for this transaction.

        Args:
            txn_key (str): A key to ensure idempotency should fetch_or_execute be
            called again at a later point in time.
            fn (function): A function which returns a tuple of
            (response_code, response_dict).
            *args: Arguments to pass to fn.
            **kwargs: Keyword arguments to pass to fn.
        Returns:
            Deferred which resolves to a tuple of (response_code, response_dict).
        """
        try:
            return self.transactions[txn_key][0].observe()
        except (KeyError, IndexError):
            pass  # execute the function instead.

        deferred = fn(*args, **kwargs)
        observable = ObservableDeferred(deferred)
        self.transactions[txn_key] = (observable, self.clock.time_msec())
        return observable.observe()
Esempio n. 25
0
    def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
        """Fetches the response for this transaction, or executes the given function
        to produce a response for this transaction.

        Args:
            txn_key (str): A key to ensure idempotency should fetch_or_execute be
            called again at a later point in time.
            fn (function): A function which returns a tuple of
            (response_code, response_dict).
            *args: Arguments to pass to fn.
            **kwargs: Keyword arguments to pass to fn.
        Returns:
            Deferred which resolves to a tuple of (response_code, response_dict).
        """
        try:
            return self.transactions[txn_key][0].observe()
        except (KeyError, IndexError):
            pass  # execute the function instead.

        deferred = fn(*args, **kwargs)

        # if the request fails with a Twisted failure, remove it
        # from the transaction map. This is done to ensure that we don't
        # cache transient errors like rate-limiting errors, etc.
        def remove_from_map(err):
            self.transactions.pop(txn_key, None)
            return err

        deferred.addErrback(remove_from_map)

        # We don't add any other errbacks to the raw deferred, so we ask
        # ObservableDeferred to swallow the error. This is fine as the error will
        # still be reported to the observers.
        observable = ObservableDeferred(deferred, consumeErrors=True)
        self.transactions[txn_key] = (observable, self.clock.time_msec())
        return observable.observe()
Esempio n. 26
0
    def set(self, key, deferred):
        """Set the entry for the given key to the given deferred.

        *deferred* should run its callbacks in the sentinel logcontext (ie,
        you should wrap normal synapse deferreds with
        logcontext.run_in_background).

        Can return either a new Deferred (which also doesn't follow the synapse
        logcontext rules), or, if *deferred* was already complete, the actual
        result. You will probably want to make_deferred_yieldable the result.

        Args:
            key (hashable):
            deferred (twisted.internet.defer.Deferred[T):

        Returns:
            twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual
                result.
        """
        result = ObservableDeferred(deferred, consumeErrors=True)
        self.pending_result_cache[key] = result

        def remove(r):
            if self.timeout_sec:
                self.clock.call_later(
                    self.timeout_sec,
                    self.pending_result_cache.pop,
                    key,
                    None,
                )
            else:
                self.pending_result_cache.pop(key, None)
            return r

        result.addBoth(remove)
        return result.observe()
    def _async_render_GET(self, request):

        # XXX: if get_user_by_req fails, what should we do in an async render?
        requester = yield self.auth.get_user_by_req(request)
        url = request.args.get("url")[0]
        if "ts" in request.args:
            ts = int(request.args.get("ts")[0])
        else:
            ts = self.clock.time_msec()

        url_tuple = urlparse.urlsplit(url)
        for entry in self.url_preview_url_blacklist:
            match = True
            for attrib in entry:
                pattern = entry[attrib]
                value = getattr(url_tuple, attrib)
                logger.debug(("Matching attrib '%s' with value '%s' against"
                              " pattern '%s'") % (attrib, value, pattern))

                if value is None:
                    match = False
                    continue

                if pattern.startswith('^'):
                    if not re.match(pattern, getattr(url_tuple, attrib)):
                        match = False
                        continue
                else:
                    if not fnmatch.fnmatch(getattr(url_tuple, attrib),
                                           pattern):
                        match = False
                        continue
            if match:
                logger.warn("URL %s blocked by url_blacklist entry %s", url,
                            entry)
                raise SynapseError(
                    403, "URL blocked by url pattern blacklist entry",
                    Codes.UNKNOWN)

        # first check the memory cache - good to handle all the clients on this
        # HS thundering away to preview the same URL at the same time.
        og = self.cache.get(url)
        if og:
            respond_with_json_bytes(request,
                                    200,
                                    json.dumps(og),
                                    send_cors=True)
            return

        # then check the URL cache in the DB (which will also provide us with
        # historical previews, if we have any)
        cache_result = yield self.store.get_url_cache(url, ts)
        if (cache_result
                and cache_result["download_ts"] + cache_result["expires"] > ts
                and cache_result["response_code"] / 100 == 2):
            respond_with_json_bytes(request,
                                    200,
                                    cache_result["og"].encode('utf-8'),
                                    send_cors=True)
            return

        # Ensure only one download for a given URL is active at a time
        download = self.downloads.get(url)
        if download is None:
            download = self._download_url(url, requester.user)
            download = ObservableDeferred(download, consumeErrors=True)
            self.downloads[url] = download

            @download.addBoth
            def callback(media_info):
                del self.downloads[url]
                return media_info

        media_info = yield download.observe()

        # FIXME: we should probably update our cache now anyway, so that
        # even if the OG calculation raises, we don't keep hammering on the
        # remote server.  For now, leave it uncached to aid debugging OG
        # calculation problems

        logger.debug("got media_info of '%s'" % media_info)

        if _is_media(media_info['media_type']):
            dims = yield self.media_repo._generate_local_thumbnails(
                media_info['filesystem_id'],
                media_info,
                url_cache=True,
            )

            og = {
                "og:description":
                media_info['download_name'],
                "og:image":
                "mxc://%s/%s" %
                (self.server_name, media_info['filesystem_id']),
                "og:image:type":
                media_info['media_type'],
                "matrix:image:size":
                media_info['media_length'],
            }

            if dims:
                og["og:image:width"] = dims['width']
                og["og:image:height"] = dims['height']
            else:
                logger.warn("Couldn't get dims for %s" % url)

            # define our OG response for this media
        elif _is_html(media_info['media_type']):
            # TODO: somehow stop a big HTML tree from exploding synapse's RAM

            file = open(media_info['filename'])
            body = file.read()
            file.close()

            # clobber the encoding from the content-type, or default to utf-8
            # XXX: this overrides any <meta/> or XML charset headers in the body
            # which may pose problems, but so far seems to work okay.
            match = re.match(r'.*; *charset=(.*?)(;|$)',
                             media_info['media_type'], re.I)
            encoding = match.group(1) if match else "utf-8"

            og = decode_and_calc_og(body, media_info['uri'], encoding)

            # pre-cache the image for posterity
            # FIXME: it might be cleaner to use the same flow as the main /preview_url
            # request itself and benefit from the same caching etc.  But for now we
            # just rely on the caching on the master request to speed things up.
            if 'og:image' in og and og['og:image']:
                image_info = yield self._download_url(
                    _rebase_url(og['og:image'], media_info['uri']),
                    requester.user)

                if _is_media(image_info['media_type']):
                    # TODO: make sure we don't choke on white-on-transparent images
                    dims = yield self.media_repo._generate_local_thumbnails(
                        image_info['filesystem_id'],
                        image_info,
                        url_cache=True,
                    )
                    if dims:
                        og["og:image:width"] = dims['width']
                        og["og:image:height"] = dims['height']
                    else:
                        logger.warn("Couldn't get dims for %s" %
                                    og["og:image"])

                    og["og:image"] = "mxc://%s/%s" % (
                        self.server_name, image_info['filesystem_id'])
                    og["og:image:type"] = image_info['media_type']
                    og["matrix:image:size"] = image_info['media_length']
                else:
                    del og["og:image"]
        else:
            logger.warn("Failed to find any OG data in %s", url)
            og = {}

        logger.debug("Calculated OG for %s as %s" % (url, og))

        # store OG in ephemeral in-memory cache
        self.cache[url] = og

        # store OG in history-aware DB cache
        yield self.store.store_url_cache(
            url,
            media_info["response_code"],
            media_info["etag"],
            media_info["expires"],
            json.dumps(og),
            media_info["filesystem_id"],
            media_info["created_ts"],
        )

        respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
Esempio n. 28
0
        def wrapped(*args, **kwargs):
            # If we're passed a cache_context then we'll want to call its invalidate()
            # whenever we are invalidated
            invalidate_callback = kwargs.pop("on_invalidate", None)

            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
            list_args = arg_dict[self.list_name]

            # cached is a dict arg -> deferred, where deferred results in a
            # 2-tuple (`arg`, `result`)
            results = {}
            cached_defers = {}
            missing = []

            # If the cache takes a single arg then that is used as the key,
            # otherwise a tuple is used.
            if num_args == 1:

                def cache_get(arg):
                    return cache.get(arg, callback=invalidate_callback)
            else:
                key = list(keyargs)

                def cache_get(arg):
                    key[self.list_pos] = arg
                    return cache.get(tuple(key), callback=invalidate_callback)

            for arg in list_args:
                try:
                    res = cache_get(arg)

                    if not isinstance(res, ObservableDeferred):
                        results[arg] = res
                    elif not res.has_succeeded():
                        res = res.observe()
                        res.addCallback(lambda r, arg: (arg, r), arg)
                        cached_defers[arg] = res
                    else:
                        results[arg] = res.get_result()
                except KeyError:
                    missing.append(arg)

            if missing:
                args_to_call = dict(arg_dict)
                args_to_call[self.list_name] = missing

                ret_d = defer.maybeDeferred(
                    logcontext.preserve_fn(self.function_to_call),
                    **args_to_call)

                ret_d = ObservableDeferred(ret_d)

                # We need to create deferreds for each arg in the list so that
                # we can insert the new deferred into the cache.
                for arg in missing:
                    observer = ret_d.observe()
                    observer.addCallback(lambda r, arg: r.get(arg, None), arg)

                    observer = ObservableDeferred(observer)

                    if num_args == 1:
                        cache.set(arg, observer, callback=invalidate_callback)

                        def invalidate(f, key):
                            cache.invalidate(key)
                            return f

                        observer.addErrback(invalidate, arg)
                    else:
                        key = list(keyargs)
                        key[self.list_pos] = arg
                        cache.set(tuple(key),
                                  observer,
                                  callback=invalidate_callback)

                        def invalidate(f, key):
                            cache.invalidate(key)
                            return f

                        observer.addErrback(invalidate, tuple(key))

                    res = observer.observe()
                    res.addCallback(lambda r, arg: (arg, r), arg)

                    cached_defers[arg] = res

            if cached_defers:

                def update_results_dict(res):
                    results.update(res)
                    return results

                return logcontext.make_deferred_yieldable(
                    defer.gatherResults(
                        list(cached_defers.values()),
                        consumeErrors=True,
                    ).addCallback(update_results_dict).addErrback(
                        unwrapFirstError))
            else:
                return results
Esempio n. 29
0
    def _async_render_GET(self, request):

        # XXX: if get_user_by_req fails, what should we do in an async render?
        requester = yield self.auth.get_user_by_req(request)
        url = request.args.get("url")[0]
        if "ts" in request.args:
            ts = int(request.args.get("ts")[0])
        else:
            ts = self.clock.time_msec()

        # XXX: we could move this into _do_preview if we wanted.
        url_tuple = urlparse.urlsplit(url)
        for entry in self.url_preview_url_blacklist:
            match = True
            for attrib in entry:
                pattern = entry[attrib]
                value = getattr(url_tuple, attrib)
                logger.debug((
                    "Matching attrib '%s' with value '%s' against"
                    " pattern '%s'"
                ) % (attrib, value, pattern))

                if value is None:
                    match = False
                    continue

                if pattern.startswith('^'):
                    if not re.match(pattern, getattr(url_tuple, attrib)):
                        match = False
                        continue
                else:
                    if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern):
                        match = False
                        continue
            if match:
                logger.warn(
                    "URL %s blocked by url_blacklist entry %s", url, entry
                )
                raise SynapseError(
                    403, "URL blocked by url pattern blacklist entry",
                    Codes.UNKNOWN
                )

        # the in-memory cache:
        # * ensures that only one request is active at a time
        # * takes load off the DB for the thundering herds
        # * also caches any failures (unlike the DB) so we don't keep
        #    requesting the same endpoint

        observable = self._cache.get(url)

        if not observable:
            download = preserve_fn(self._do_preview)(
                url, requester.user, ts,
            )
            observable = ObservableDeferred(
                download,
                consumeErrors=True
            )
            self._cache[url] = observable
        else:
            logger.info("Returning cached response")

        og = yield make_deferred_yieldable(observable.observe())
        respond_with_json_bytes(request, 200, og, send_cors=True)
Esempio n. 30
0
        def wrapped(*args, **kwargs):
            # If we're passed a cache_context then we'll want to call its invalidate()
            # whenever we are invalidated
            invalidate_callback = kwargs.pop("on_invalidate", None)

            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
            list_args = arg_dict[self.list_name]

            # cached is a dict arg -> deferred, where deferred results in a
            # 2-tuple (`arg`, `result`)
            results = {}
            cached_defers = {}
            missing = []
            for arg in list_args:
                key = list(keyargs)
                key[self.list_pos] = arg

                try:
                    res = cache.get(tuple(key), callback=invalidate_callback)
                    if not res.has_succeeded():
                        res = res.observe()
                        res.addCallback(lambda r, arg: (arg, r), arg)
                        cached_defers[arg] = res
                    else:
                        results[arg] = res.get_result()
                except KeyError:
                    missing.append(arg)

            if missing:
                sequence = cache.sequence
                args_to_call = dict(arg_dict)
                args_to_call[self.list_name] = missing

                ret_d = defer.maybeDeferred(
                    preserve_context_over_fn,
                    self.function_to_call,
                    **args_to_call
                )

                ret_d = ObservableDeferred(ret_d)

                # We need to create deferreds for each arg in the list so that
                # we can insert the new deferred into the cache.
                for arg in missing:
                    with PreserveLoggingContext():
                        observer = ret_d.observe()
                    observer.addCallback(lambda r, arg: r.get(arg, None), arg)

                    observer = ObservableDeferred(observer)

                    key = list(keyargs)
                    key[self.list_pos] = arg
                    cache.update(
                        sequence, tuple(key), observer,
                        callback=invalidate_callback
                    )

                    def invalidate(f, key):
                        cache.invalidate(key)
                        return f
                    observer.addErrback(invalidate, tuple(key))

                    res = observer.observe()
                    res.addCallback(lambda r, arg: (arg, r), arg)

                    cached_defers[arg] = res

            if cached_defers:
                def update_results_dict(res):
                    results.update(res)
                    return results

                return preserve_context_over_deferred(defer.gatherResults(
                    cached_defers.values(),
                    consumeErrors=True,
                ).addCallback(update_results_dict).addErrback(
                    unwrapFirstError
                ))
            else:
                return results