def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] list_args = arg_dict[self.list_name] # cached is a dict arg -> deferred, where deferred results in a # 2-tuple (`arg`, `result`) cached = {} missing = [] for arg in list_args: key = list(keyargs) key[self.list_pos] = arg try: res = self.cache.get(tuple(key)).observe() res.addCallback(lambda r, arg: (arg, r), arg) cached[arg] = res except KeyError: missing.append(arg) if missing: sequence = self.cache.sequence args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing ret_d = defer.maybeDeferred( preserve_context_over_fn, self.function_to_call, **args_to_call ) ret_d = ObservableDeferred(ret_d) # We need to create deferreds for each arg in the list so that # we can insert the new deferred into the cache. for arg in missing: with PreserveLoggingContext(): observer = ret_d.observe() observer.addCallback(lambda r, arg: r.get(arg, None), arg) observer = ObservableDeferred(observer) key = list(keyargs) key[self.list_pos] = arg self.cache.update(sequence, tuple(key), observer) def invalidate(f, key): self.cache.invalidate(key) return f observer.addErrback(invalidate, tuple(key)) res = observer.observe() res.addCallback(lambda r, arg: (arg, r), arg) cached[arg] = res return preserve_context_over_deferred(defer.gatherResults( cached.values(), consumeErrors=True, ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
def wait_for_previous_lookups(self, server_names, server_to_deferred): """Waits for any previous key lookups for the given servers to finish. Args: server_names (list): list of server_names we want to lookup server_to_deferred (dict): server_name to deferred which gets resolved once we've finished looking up keys for that server """ while True: wait_on = [ self.key_downloads[server_name] for server_name in server_names if server_name in self.key_downloads ] if wait_on: yield defer.DeferredList(wait_on) else: break for server_name, deferred in server_to_deferred.items(): d = ObservableDeferred(deferred) self.key_downloads[server_name] = d def rm(r, server_name): self.key_downloads.pop(server_name, None) return r d.addBoth(rm, server_name)
def add_to_queue(self, room_id, events_and_contexts, backfilled, current_state): """Add events to the queue, with the given persist_event options. """ queue = self._event_persist_queues.setdefault(room_id, deque()) if queue: end_item = queue[-1] if end_item.current_state or current_state: # We perist events with current_state set to True one at a time pass if end_item.backfilled == backfilled: end_item.events_and_contexts.extend(events_and_contexts) return end_item.deferred.observe() deferred = ObservableDeferred(defer.Deferred()) queue.append( self._EventPersistQueueItem( events_and_contexts=events_and_contexts, backfilled=backfilled, current_state=current_state, deferred=deferred, )) return deferred.observe()
def fetch_or_execute(self, txn_key, fn, *args, **kwargs): """Fetches the response for this transaction, or executes the given function to produce a response for this transaction. Args: txn_key (str): A key to ensure idempotency should fetch_or_execute be called again at a later point in time. fn (function): A function which returns a tuple of (response_code, response_dict). *args: Arguments to pass to fn. **kwargs: Keyword arguments to pass to fn. Returns: Deferred which resolves to a tuple of (response_code, response_dict). """ try: return self.transactions[txn_key][0].observe() except (KeyError, IndexError): pass # execute the function instead. deferred = fn(*args, **kwargs) # We don't add an errback to the raw deferred, so we ask ObservableDeferred # to swallow the error. This is fine as the error will still be reported # to the observers. observable = ObservableDeferred(deferred, consumeErrors=True) self.transactions[txn_key] = (observable, self.clock.time_msec()) return observable.observe()
def get_server_verify_key(self, server_name, key_ids): """Finds a verification key for the server with one of the key ids. Trys to fetch the key from a trusted perspective server first. Args: server_name(str): The name of the server to fetch a key for. keys_ids (list of str): The key_ids to check for. """ cached = yield self.store.get_server_verify_keys(server_name, key_ids) if cached: defer.returnValue(cached[0]) return download = self.key_downloads.get(server_name) if download is None: download = self._get_server_verify_key_impl(server_name, key_ids) download = ObservableDeferred(download, consumeErrors=True) self.key_downloads[server_name] = download @download.addBoth def callback(ret): del self.key_downloads[server_name] return ret r = yield download.observe() defer.returnValue(r)
def fetch_or_execute(self, txn_key, fn, *args, **kwargs): """Fetches the response for this transaction, or executes the given function to produce a response for this transaction. Args: txn_key (str): A key to ensure idempotency should fetch_or_execute be called again at a later point in time. fn (function): A function which returns a tuple of (response_code, response_dict). *args: Arguments to pass to fn. **kwargs: Keyword arguments to pass to fn. Returns: Deferred which resolves to a tuple of (response_code, response_dict). """ if txn_key in self.transactions: observable = self.transactions[txn_key][0] else: # execute the function instead. deferred = run_in_background(fn, *args, **kwargs) observable = ObservableDeferred(deferred) self.transactions[txn_key] = (observable, self.clock.time_msec()) # if the request fails with an exception, remove it # from the transaction map. This is done to ensure that we don't # cache transient errors like rate-limiting errors, etc. def remove_from_map(err): self.transactions.pop(txn_key, None) # we deliberately do not propagate the error any further, as we # expect the observers to have reported it. deferred.addErrback(remove_from_map) return make_deferred_yieldable(observable.observe())
def __init__(self, user_id, rooms, current_token, time_now_ms): self.user_id = user_id self.rooms = set(rooms) self.current_token = current_token self.last_notified_ms = time_now_ms with PreserveLoggingContext(): self.notify_deferred = ObservableDeferred(defer.Deferred())
def wrapped(*args, **kwargs): # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) # Add temp cache_context so inspect.getcallargs doesn't explode if self.add_cache_context: kwargs["cache_context"] = None arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) # Add our own `cache_context` to argument list if the wrapped function # has asked for one if self.add_cache_context: kwargs["cache_context"] = _CacheContext(cache, cache_key) try: cached_result_d = cache.get(cache_key, callback=invalidate_callback) observer = cached_result_d.observe() if DEBUG_CACHES: @defer.inlineCallbacks def check_result(cached_result): actual_result = yield self.function_to_call(obj, *args, **kwargs) if actual_result != cached_result: logger.error( "Stale cache entry %s%r: cached: %r, actual %r", self.orig.__name__, cache_key, cached_result, actual_result, ) raise ValueError("Stale cache entry") defer.returnValue(cached_result) observer.addCallback(check_result) return preserve_context_over_deferred(observer) except KeyError: # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) sequence = cache.sequence ret = defer.maybeDeferred( preserve_context_over_fn, self.function_to_call, obj, *args, **kwargs ) def onErr(f): cache.invalidate(cache_key) return f ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) cache.update(sequence, cache_key, ret, callback=invalidate_callback) return preserve_context_over_deferred(ret.observe())
def notify_replication(self): """Notify the any replication listeners that there's a new event""" with PreserveLoggingContext(): deferred = self.replication_deferred self.replication_deferred = ObservableDeferred(defer.Deferred()) deferred.callback(None) for cb in self.replication_callbacks: preserve_fn(cb)()
def set(self, key, deferred): result = ObservableDeferred(deferred, consumeErrors=True) self.pending_result_cache[key] = result def remove(r): self.pending_result_cache.pop(key, None) return r result.addBoth(remove) return result.observe()
def __init__(self, hs): self.user_to_user_stream = {} self.room_to_user_streams = {} self.hs = hs self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() self.pending_new_room_events = [] self.replication_callbacks = [] self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() if hs.should_send_federation(): self.federation_sender = hs.get_federation_sender() else: self.federation_sender = None self.state_handler = hs.get_state_handler() self.clock.looping_call(self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS) self.replication_deferred = ObservableDeferred(defer.Deferred()) # This is not a very cheap test to perform, but it's only executed # when rendering the metrics page, which is likely once per minute at # most when scraping it. def count_listeners(): all_user_streams = set() for x in self.room_to_user_streams.values(): all_user_streams |= x for x in self.user_to_user_stream.values(): all_user_streams.add(x) return sum(stream.count_listeners() for stream in all_user_streams) LaterGauge("synapse_notifier_listeners", "", [], count_listeners) LaterGauge( "synapse_notifier_rooms", "", [], lambda: count(bool, self.room_to_user_streams.values()), ) LaterGauge( "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream), )
def __init__(self, user, rooms, current_token, time_now_ms, appservice=None): self.user = str(user) self.appservice = appservice self.rooms = set(rooms) self.current_token = current_token self.last_notified_ms = time_now_ms self.notify_deferred = ObservableDeferred(defer.Deferred())
def get_remote_media(self, server_name, media_id): key = (server_name, media_id) download = self.downloads.get(key) if download is None: download = self._get_remote_media_impl(server_name, media_id) download = ObservableDeferred(download, consumeErrors=True) self.downloads[key] = download @download.addBoth def callback(media_info): del self.downloads[key] return media_info return download.observe()
def notify(self, stream_key, stream_id, time_now_ms): """Notify any listeners for this user of a new event from an event source. Args: stream_key(str): The stream the event came from. stream_id(str): The new id for the stream the event came from. time_now_ms(int): The current time in milliseconds. """ self.current_token = self.current_token.copy_and_advance( stream_key, stream_id) self.last_notified_ms = time_now_ms noify_deferred = self.notify_deferred self.notify_deferred = ObservableDeferred(defer.Deferred()) noify_deferred.callback(self.current_token)
def __init__(self, hs): self.hs = hs self.user_to_user_stream = {} self.room_to_user_streams = {} self.appservice_to_user_streams = {} self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() self.pending_new_room_events = [] self.clock = hs.get_clock() hs.get_distributor().observe("user_joined_room", self._user_joined_room) self.clock.looping_call(self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS) self.replication_deferred = ObservableDeferred(defer.Deferred()) # This is not a very cheap test to perform, but it's only executed # when rendering the metrics page, which is likely once per minute at # most when scraping it. def count_listeners(): all_user_streams = set() for x in self.room_to_user_streams.values(): all_user_streams |= x for x in self.user_to_user_stream.values(): all_user_streams.add(x) for x in self.appservice_to_user_streams.values(): all_user_streams |= x return sum(stream.count_listeners() for stream in all_user_streams) metrics.register_callback("listeners", count_listeners) metrics.register_callback( "rooms", lambda: count(bool, self.room_to_user_streams.values()), ) metrics.register_callback( "users", lambda: len(self.user_to_user_stream), ) metrics.register_callback( "appservices", lambda: count(bool, self.appservice_to_user_streams.values()), )
def notify_replication(self): """Notify the any replication listeners that there's a new event""" with PreserveLoggingContext(): deferred = self.replication_deferred self.replication_deferred = ObservableDeferred(defer.Deferred()) deferred.callback(None) # the callbacks may well outlast the current request, so we run # them in the sentinel logcontext. # # (ideally it would be up to the callbacks to know if they were # starting off background processes and drop the logcontext # accordingly, but that requires more changes) for cb in self.replication_callbacks: cb()
def __init__(self, user_id, rooms, current_token, time_now_ms): self.user_id = user_id self.rooms = set(rooms) self.current_token = current_token # The last token for which we should wake up any streams that have a # token that comes before it. This gets updated everytime we get poked. # We start it at the current token since if we get any streams # that have a token from before we have no idea whether they should be # woken up or not, so lets just wake them up. self.last_notified_token = current_token self.last_notified_ms = time_now_ms with PreserveLoggingContext(): self.notify_deferred = ObservableDeferred(defer.Deferred())
def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) try: cached_result_d = cache.get(cache_key) observer = cached_result_d.observe() if DEBUG_CACHES: @defer.inlineCallbacks def check_result(cached_result): actual_result = yield self.function_to_call( obj, *args, **kwargs) if actual_result != cached_result: logger.error( "Stale cache entry %s%r: cached: %r, actual %r", self.orig.__name__, cache_key, cached_result, actual_result, ) raise ValueError("Stale cache entry") defer.returnValue(cached_result) observer.addCallback(check_result) return preserve_context_over_deferred(observer) except KeyError: # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) sequence = cache.sequence ret = defer.maybeDeferred(preserve_context_over_fn, self.function_to_call, obj, *args, **kwargs) def onErr(f): cache.invalidate(cache_key) return f ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) cache.update(sequence, cache_key, ret) return preserve_context_over_deferred(ret.observe())
def wrapped(*args, **kwargs): # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) cache_key = get_cache_key(args, kwargs) # Add our own `cache_context` to argument list if the wrapped function # has asked for one if self.add_cache_context: kwargs["cache_context"] = _CacheContext(cache, cache_key) try: cached_result_d = cache.get(cache_key, callback=invalidate_callback) if isinstance(cached_result_d, ObservableDeferred): observer = cached_result_d.observe() else: observer = cached_result_d except KeyError: ret = defer.maybeDeferred( logcontext.preserve_fn(self.function_to_call), obj, *args, **kwargs) def onErr(f): cache.invalidate(cache_key) return f ret.addErrback(onErr) # If our cache_key is a string on py2, try to convert to ascii # to save a bit of space in large caches. Py3 does this # internally automatically. if six.PY2 and isinstance(cache_key, string_types): cache_key = to_ascii(cache_key) result_d = ObservableDeferred(ret, consumeErrors=True) cache.set(cache_key, result_d, callback=invalidate_callback) observer = result_d.observe() if isinstance(observer, defer.Deferred): return logcontext.make_deferred_yieldable(observer) else: return observer
def set(self, time_now_ms, key, deferred): self.rotate(time_now_ms) result = ObservableDeferred(deferred) self.pending_result_cache[key] = result def shuffle_along(r): # When the deferred completes we shuffle it along to the first # generation of the result cache. So that it will eventually # expire from the rotation of that cache. self.next_result_cache[key] = result self.pending_result_cache.pop(key, None) result.observe().addBoth(shuffle_along) return result.observe()
def test_prefill(self): callcount = [0] d = defer.succeed(123) class A(object): @cached() def func(self, key): callcount[0] += 1 return d a = A() a.func.prefill(("foo", ), ObservableDeferred(d)) self.assertEquals(a.func("foo").result, d.result) self.assertEquals(callcount[0], 0)
def set(self, key, deferred): result = ObservableDeferred(deferred, consumeErrors=True) self.pending_result_cache[key] = result def remove(r): if self.timeout_sec: self.clock.call_later( self.timeout_sec, self.pending_result_cache.pop, key, None, ) else: self.pending_result_cache.pop(key, None) return r result.addBoth(remove) return result.observe()
def notify(self, stream_key, stream_id, time_now_ms): """Notify any listeners for this user of a new event from an event source. Args: stream_key(str): The stream the event came from. stream_id(str): The new id for the stream the event came from. time_now_ms(int): The current time in milliseconds. """ self.current_token = self.current_token.copy_and_advance( stream_key, stream_id) self.last_notified_token = self.current_token self.last_notified_ms = time_now_ms noify_deferred = self.notify_deferred users_woken_by_stream_counter.labels(stream_key).inc() with PreserveLoggingContext(): self.notify_deferred = ObservableDeferred(defer.Deferred()) noify_deferred.callback(self.current_token)
def fetch_or_execute(self, txn_key, fn, *args, **kwargs): """Fetches the response for this transaction, or executes the given function to produce a response for this transaction. Args: txn_key (str): A key to ensure idempotency should fetch_or_execute be called again at a later point in time. fn (function): A function which returns a tuple of (response_code, response_dict). *args: Arguments to pass to fn. **kwargs: Keyword arguments to pass to fn. Returns: Deferred which resolves to a tuple of (response_code, response_dict). """ try: return self.transactions[txn_key][0].observe() except (KeyError, IndexError): pass # execute the function instead. deferred = fn(*args, **kwargs) observable = ObservableDeferred(deferred) self.transactions[txn_key] = (observable, self.clock.time_msec()) return observable.observe()
def fetch_or_execute(self, txn_key, fn, *args, **kwargs): """Fetches the response for this transaction, or executes the given function to produce a response for this transaction. Args: txn_key (str): A key to ensure idempotency should fetch_or_execute be called again at a later point in time. fn (function): A function which returns a tuple of (response_code, response_dict). *args: Arguments to pass to fn. **kwargs: Keyword arguments to pass to fn. Returns: Deferred which resolves to a tuple of (response_code, response_dict). """ try: return self.transactions[txn_key][0].observe() except (KeyError, IndexError): pass # execute the function instead. deferred = fn(*args, **kwargs) # if the request fails with a Twisted failure, remove it # from the transaction map. This is done to ensure that we don't # cache transient errors like rate-limiting errors, etc. def remove_from_map(err): self.transactions.pop(txn_key, None) return err deferred.addErrback(remove_from_map) # We don't add any other errbacks to the raw deferred, so we ask # ObservableDeferred to swallow the error. This is fine as the error will # still be reported to the observers. observable = ObservableDeferred(deferred, consumeErrors=True) self.transactions[txn_key] = (observable, self.clock.time_msec()) return observable.observe()
def set(self, key, deferred): """Set the entry for the given key to the given deferred. *deferred* should run its callbacks in the sentinel logcontext (ie, you should wrap normal synapse deferreds with logcontext.run_in_background). Can return either a new Deferred (which also doesn't follow the synapse logcontext rules), or, if *deferred* was already complete, the actual result. You will probably want to make_deferred_yieldable the result. Args: key (hashable): deferred (twisted.internet.defer.Deferred[T): Returns: twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual result. """ result = ObservableDeferred(deferred, consumeErrors=True) self.pending_result_cache[key] = result def remove(r): if self.timeout_sec: self.clock.call_later( self.timeout_sec, self.pending_result_cache.pop, key, None, ) else: self.pending_result_cache.pop(key, None) return r result.addBoth(remove) return result.observe()
def _async_render_GET(self, request): # XXX: if get_user_by_req fails, what should we do in an async render? requester = yield self.auth.get_user_by_req(request) url = request.args.get("url")[0] if "ts" in request.args: ts = int(request.args.get("ts")[0]) else: ts = self.clock.time_msec() url_tuple = urlparse.urlsplit(url) for entry in self.url_preview_url_blacklist: match = True for attrib in entry: pattern = entry[attrib] value = getattr(url_tuple, attrib) logger.debug(("Matching attrib '%s' with value '%s' against" " pattern '%s'") % (attrib, value, pattern)) if value is None: match = False continue if pattern.startswith('^'): if not re.match(pattern, getattr(url_tuple, attrib)): match = False continue else: if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern): match = False continue if match: logger.warn("URL %s blocked by url_blacklist entry %s", url, entry) raise SynapseError( 403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN) # first check the memory cache - good to handle all the clients on this # HS thundering away to preview the same URL at the same time. og = self.cache.get(url) if og: respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) return # then check the URL cache in the DB (which will also provide us with # historical previews, if we have any) cache_result = yield self.store.get_url_cache(url, ts) if (cache_result and cache_result["download_ts"] + cache_result["expires"] > ts and cache_result["response_code"] / 100 == 2): respond_with_json_bytes(request, 200, cache_result["og"].encode('utf-8'), send_cors=True) return # Ensure only one download for a given URL is active at a time download = self.downloads.get(url) if download is None: download = self._download_url(url, requester.user) download = ObservableDeferred(download, consumeErrors=True) self.downloads[url] = download @download.addBoth def callback(media_info): del self.downloads[url] return media_info media_info = yield download.observe() # FIXME: we should probably update our cache now anyway, so that # even if the OG calculation raises, we don't keep hammering on the # remote server. For now, leave it uncached to aid debugging OG # calculation problems logger.debug("got media_info of '%s'" % media_info) if _is_media(media_info['media_type']): dims = yield self.media_repo._generate_local_thumbnails( media_info['filesystem_id'], media_info, url_cache=True, ) og = { "og:description": media_info['download_name'], "og:image": "mxc://%s/%s" % (self.server_name, media_info['filesystem_id']), "og:image:type": media_info['media_type'], "matrix:image:size": media_info['media_length'], } if dims: og["og:image:width"] = dims['width'] og["og:image:height"] = dims['height'] else: logger.warn("Couldn't get dims for %s" % url) # define our OG response for this media elif _is_html(media_info['media_type']): # TODO: somehow stop a big HTML tree from exploding synapse's RAM file = open(media_info['filename']) body = file.read() file.close() # clobber the encoding from the content-type, or default to utf-8 # XXX: this overrides any <meta/> or XML charset headers in the body # which may pose problems, but so far seems to work okay. match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I) encoding = match.group(1) if match else "utf-8" og = decode_and_calc_og(body, media_info['uri'], encoding) # pre-cache the image for posterity # FIXME: it might be cleaner to use the same flow as the main /preview_url # request itself and benefit from the same caching etc. But for now we # just rely on the caching on the master request to speed things up. if 'og:image' in og and og['og:image']: image_info = yield self._download_url( _rebase_url(og['og:image'], media_info['uri']), requester.user) if _is_media(image_info['media_type']): # TODO: make sure we don't choke on white-on-transparent images dims = yield self.media_repo._generate_local_thumbnails( image_info['filesystem_id'], image_info, url_cache=True, ) if dims: og["og:image:width"] = dims['width'] og["og:image:height"] = dims['height'] else: logger.warn("Couldn't get dims for %s" % og["og:image"]) og["og:image"] = "mxc://%s/%s" % ( self.server_name, image_info['filesystem_id']) og["og:image:type"] = image_info['media_type'] og["matrix:image:size"] = image_info['media_length'] else: del og["og:image"] else: logger.warn("Failed to find any OG data in %s", url) og = {} logger.debug("Calculated OG for %s as %s" % (url, og)) # store OG in ephemeral in-memory cache self.cache[url] = og # store OG in history-aware DB cache yield self.store.store_url_cache( url, media_info["response_code"], media_info["etag"], media_info["expires"], json.dumps(og), media_info["filesystem_id"], media_info["created_ts"], ) respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
def wrapped(*args, **kwargs): # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] list_args = arg_dict[self.list_name] # cached is a dict arg -> deferred, where deferred results in a # 2-tuple (`arg`, `result`) results = {} cached_defers = {} missing = [] # If the cache takes a single arg then that is used as the key, # otherwise a tuple is used. if num_args == 1: def cache_get(arg): return cache.get(arg, callback=invalidate_callback) else: key = list(keyargs) def cache_get(arg): key[self.list_pos] = arg return cache.get(tuple(key), callback=invalidate_callback) for arg in list_args: try: res = cache_get(arg) if not isinstance(res, ObservableDeferred): results[arg] = res elif not res.has_succeeded(): res = res.observe() res.addCallback(lambda r, arg: (arg, r), arg) cached_defers[arg] = res else: results[arg] = res.get_result() except KeyError: missing.append(arg) if missing: args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing ret_d = defer.maybeDeferred( logcontext.preserve_fn(self.function_to_call), **args_to_call) ret_d = ObservableDeferred(ret_d) # We need to create deferreds for each arg in the list so that # we can insert the new deferred into the cache. for arg in missing: observer = ret_d.observe() observer.addCallback(lambda r, arg: r.get(arg, None), arg) observer = ObservableDeferred(observer) if num_args == 1: cache.set(arg, observer, callback=invalidate_callback) def invalidate(f, key): cache.invalidate(key) return f observer.addErrback(invalidate, arg) else: key = list(keyargs) key[self.list_pos] = arg cache.set(tuple(key), observer, callback=invalidate_callback) def invalidate(f, key): cache.invalidate(key) return f observer.addErrback(invalidate, tuple(key)) res = observer.observe() res.addCallback(lambda r, arg: (arg, r), arg) cached_defers[arg] = res if cached_defers: def update_results_dict(res): results.update(res) return results return logcontext.make_deferred_yieldable( defer.gatherResults( list(cached_defers.values()), consumeErrors=True, ).addCallback(update_results_dict).addErrback( unwrapFirstError)) else: return results
def _async_render_GET(self, request): # XXX: if get_user_by_req fails, what should we do in an async render? requester = yield self.auth.get_user_by_req(request) url = request.args.get("url")[0] if "ts" in request.args: ts = int(request.args.get("ts")[0]) else: ts = self.clock.time_msec() # XXX: we could move this into _do_preview if we wanted. url_tuple = urlparse.urlsplit(url) for entry in self.url_preview_url_blacklist: match = True for attrib in entry: pattern = entry[attrib] value = getattr(url_tuple, attrib) logger.debug(( "Matching attrib '%s' with value '%s' against" " pattern '%s'" ) % (attrib, value, pattern)) if value is None: match = False continue if pattern.startswith('^'): if not re.match(pattern, getattr(url_tuple, attrib)): match = False continue else: if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern): match = False continue if match: logger.warn( "URL %s blocked by url_blacklist entry %s", url, entry ) raise SynapseError( 403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN ) # the in-memory cache: # * ensures that only one request is active at a time # * takes load off the DB for the thundering herds # * also caches any failures (unlike the DB) so we don't keep # requesting the same endpoint observable = self._cache.get(url) if not observable: download = preserve_fn(self._do_preview)( url, requester.user, ts, ) observable = ObservableDeferred( download, consumeErrors=True ) self._cache[url] = observable else: logger.info("Returning cached response") og = yield make_deferred_yieldable(observable.observe()) respond_with_json_bytes(request, 200, og, send_cors=True)
def wrapped(*args, **kwargs): # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] list_args = arg_dict[self.list_name] # cached is a dict arg -> deferred, where deferred results in a # 2-tuple (`arg`, `result`) results = {} cached_defers = {} missing = [] for arg in list_args: key = list(keyargs) key[self.list_pos] = arg try: res = cache.get(tuple(key), callback=invalidate_callback) if not res.has_succeeded(): res = res.observe() res.addCallback(lambda r, arg: (arg, r), arg) cached_defers[arg] = res else: results[arg] = res.get_result() except KeyError: missing.append(arg) if missing: sequence = cache.sequence args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing ret_d = defer.maybeDeferred( preserve_context_over_fn, self.function_to_call, **args_to_call ) ret_d = ObservableDeferred(ret_d) # We need to create deferreds for each arg in the list so that # we can insert the new deferred into the cache. for arg in missing: with PreserveLoggingContext(): observer = ret_d.observe() observer.addCallback(lambda r, arg: r.get(arg, None), arg) observer = ObservableDeferred(observer) key = list(keyargs) key[self.list_pos] = arg cache.update( sequence, tuple(key), observer, callback=invalidate_callback ) def invalidate(f, key): cache.invalidate(key) return f observer.addErrback(invalidate, tuple(key)) res = observer.observe() res.addCallback(lambda r, arg: (arg, r), arg) cached_defers[arg] = res if cached_defers: def update_results_dict(res): results.update(res) return results return preserve_context_over_deferred(defer.gatherResults( cached_defers.values(), consumeErrors=True, ).addCallback(update_results_dict).addErrback( unwrapFirstError )) else: return results