def _do_get_well_known(self, server_name): """Actually fetch and parse a .well-known, without checking the cache Args: server_name (bytes): name of the server, from the requested url Returns: Deferred[Tuple[bytes|None|object],int]: result, cache period, where result is one of: - the new server name from the .well-known (as a `bytes`) - None if there was no .well-known file. - INVALID_WELL_KNOWN if the .well-known was invalid """ uri = b"https://%s/.well-known/matrix/server" % (server_name, ) uri_str = uri.decode("ascii") logger.info("Fetching %s", uri_str) try: response = yield make_deferred_yieldable( self._well_known_agent.request(b"GET", uri), ) body = yield make_deferred_yieldable(readBody(response)) if response.code != 200: raise Exception("Non-200 response %s" % (response.code, )) parsed_body = json.loads(body.decode('utf-8')) logger.info("Response from .well-known: %s", parsed_body) if not isinstance(parsed_body, dict): raise Exception("not a dict") if "m.server" not in parsed_body: raise Exception("Missing key 'm.server'") except Exception as e: logger.info("Error fetching %s: %s", uri_str, e) # add some randomness to the TTL to avoid a stampeding herd every hour # after startup cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD cache_period += random.uniform( 0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) defer.returnValue((None, cache_period)) result = parsed_body["m.server"].encode("ascii") cache_period = _cache_period_from_headers( response.headers, time_now=self._reactor.seconds, ) if cache_period is None: cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD # add some randomness to the TTL to avoid a stampeding herd every 24 hours # after startup cache_period += random.uniform( 0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) else: cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD) defer.returnValue((result, cache_period))
def _on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): try: # Need to subtract 1 from the minimum because the lower bound here # is not inclusive updated_receipts = yield self.store.get_all_updated_receipts( min_stream_id - 1, max_stream_id) # This returns a tuple, user_id is at index 3 users_affected = set([r[3] for r in updated_receipts]) deferreds = [] for u in users_affected: if u in self.pushers: for p in self.pushers[u].values(): deferreds.append( run_in_background( p.on_new_receipts, min_stream_id, max_stream_id, )) yield make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True), ) except Exception: logger.exception("Exception in pusher on_new_receipts")
def get_keys_from_perspectives(self, server_name_and_key_ids): @defer.inlineCallbacks def get_key(perspective_name, perspective_keys): try: result = yield self.get_server_verify_key_v2_indirect( server_name_and_key_ids, perspective_name, perspective_keys ) defer.returnValue(result) except KeyLookupError as e: logger.warning( "Key lookup failed from %r: %s", perspective_name, e, ) except Exception as e: logger.exception( "Unable to get key from %r: %s %s", perspective_name, type(e).__name__, str(e), ) defer.returnValue({}) results = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ run_in_background(get_key, p_name, p_keys) for p_name, p_keys in self.perspective_servers.items() ], consumeErrors=True, ).addErrback(unwrapFirstError)) union_of_keys = {} for result in results: for server_name, keys in result.items(): union_of_keys.setdefault(server_name, {}).update(keys) defer.returnValue(union_of_keys)
def get_raw(self, uri, args={}, headers=None): """ Gets raw text from the given URI. Args: uri (str): The URI to request, not including query parameters args (dict): A dictionary used to create query strings, defaults to None. **Note**: The value of each key is assumed to be an iterable and *not* a string. headers (dict[str, List[str]]|None): If not None, a map from header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body at text. Raises: HttpResponseException on a non-2xx HTTP response. """ if len(args): query_bytes = urllib.parse.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) actual_headers = {b"User-Agent": [self.user_agent]} if headers: actual_headers.update(headers) response = yield self.request("GET", uri, headers=Headers(actual_headers)) body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: defer.returnValue(body) else: raise HttpResponseException(response.code, response.phrase, body)
def request(self, method, uri, data=b'', headers=None): # A small wrapper around self.agent.request() so we can easily attach # counters to it outgoing_requests_counter.labels(method).inc() # log request but strip `access_token` (AS requests for example include this) logger.info("Sending request %s %s", method, redact_uri(uri)) try: request_deferred = treq.request(method, uri, agent=self.agent, data=data, headers=headers) request_deferred = timeout_deferred( request_deferred, 60, self.hs.get_reactor(), cancelled_to_request_timed_out_error, ) response = yield make_deferred_yieldable(request_deferred) incoming_responses_counter.labels(method, response.code).inc() logger.info("Received response to %s %s: %s", method, redact_uri(uri), response.code) defer.returnValue(response) except Exception as e: incoming_responses_counter.labels(method, "ERR").inc() logger.info("Error sending request to %s %s: %s %s", method, redact_uri(uri), type(e).__name__, e.args[0]) raise
def concurrently_execute(func, args, limit): """Executes the function with each argument conncurrently while limiting the number of concurrent executions. Args: func (func): Function to execute, should return a deferred. args (list): List of arguments to pass to func, each invocation of func gets a signle argument. limit (int): Maximum number of conccurent executions. Returns: deferred: Resolved when all function invocations have finished. """ it = iter(args) @defer.inlineCallbacks def _concurrently_execute_inner(): try: while True: yield func(next(it)) except StopIteration: pass return logcontext.make_deferred_yieldable(defer.gatherResults([ run_in_background(_concurrently_execute_inner) for _ in range(limit) ], consumeErrors=True)).addErrback(unwrapFirstError)
def concurrently_execute(func, args, limit): """Executes the function with each argument conncurrently while limiting the number of concurrent executions. Args: func (func): Function to execute, should return a deferred. args (list): List of arguments to pass to func, each invocation of func gets a signle argument. limit (int): Maximum number of conccurent executions. Returns: deferred: Resolved when all function invocations have finished. """ it = iter(args) @defer.inlineCallbacks def _concurrently_execute_inner(): try: while True: yield func(next(it)) except StopIteration: pass return logcontext.make_deferred_yieldable( defer.gatherResults([ run_in_background(_concurrently_execute_inner) for _ in range(limit) ], consumeErrors=True)).addErrback(unwrapFirstError)
def first_lookup(): with LoggingContext("11") as context_11: context_11.request = "11" res_deferreds = kr.verify_json_objects_for_server([ ("server10", json1, 0, "test10"), ("server11", {}, 0, "test11") ]) # the unsigned json should be rejected pretty quickly self.assertTrue(res_deferreds[1].called) try: yield res_deferreds[1] self.assertFalse("unsigned json didn't cause a failure") except SynapseError: pass self.assertFalse(res_deferreds[0].called) res_deferreds[0].addBoth(self.check_context, None) yield logcontext.make_deferred_yieldable(res_deferreds[0]) # let verify_json_objects_for_server finish its work before we kill the # logcontext yield self.clock.sleep(0)
def post_urlencoded_get_json(self, uri, args={}, headers=None): """ Args: uri (str): args (dict[str, str|List[str]]): query params headers (dict[str, List[str]]|None): If not None, a map from header name to a list of values for that header Returns: Deferred[object]: parsed json """ # TODO: Do we ever want to log message contents? logger.debug("post_urlencoded_get_json args: %s", args) query_bytes = urllib.urlencode(encode_urlencode_args(args), True) actual_headers = { b"Content-Type": [b"application/x-www-form-urlencoded"], b"User-Agent": [self.user_agent], } if headers: actual_headers.update(headers) response = yield self.request("POST", uri.encode("ascii"), headers=Headers(actual_headers), bodyProducer=FileBodyProducer( StringIO(query_bytes))) body = yield make_deferred_yieldable(readBody(response)) defer.returnValue(json.loads(body))
def handle_check_result(pdu, deferred): try: res = yield logcontext.make_deferred_yieldable(deferred) except SynapseError: res = None if not res: # Check local db. res = yield self.store.get_event( pdu.event_id, allow_rejected=True, allow_none=True, ) if not res and pdu.origin != origin: try: res = yield self.get_pdu( destinations=[pdu.origin], event_id=pdu.event_id, outlier=outlier, timeout=10000, ) except SynapseError: pass if not res: logger.warn( "Failed to find copy of %s with valid signature", pdu.event_id, ) defer.returnValue(res)
def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): yield run_on_reactor() try: # Need to subtract 1 from the minimum because the lower bound here # is not inclusive updated_receipts = yield self.store.get_all_updated_receipts( min_stream_id - 1, max_stream_id ) # This returns a tuple, user_id is at index 3 users_affected = set([r[3] for r in updated_receipts]) deferreds = [] for u in users_affected: if u in self.pushers: for p in self.pushers[u].values(): deferreds.append( run_in_background( p.on_new_receipts, min_stream_id, max_stream_id, ) ) yield make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True), ) except Exception: logger.exception("Exception in pusher on_new_receipts")
def copy_to_backup(self, path): """Copy a file from the primary to backup media store, if configured. Args: path(str): Relative path to write file to """ if self.backup_base_path: primary_fname = os.path.join(self.primary_base_path, path) backup_fname = os.path.join(self.backup_base_path, path) # We can either wait for successful writing to the backup repository # or write in the background and immediately return if self.synchronous_backup_media_store: yield make_deferred_yieldable( threads.deferToThread( shutil.copyfile, primary_fname, backup_fname, )) else: preserve_fn(threads.deferToThread)( shutil.copyfile, primary_fname, backup_fname, )
def write_to_file_and_backup(self, source, path): """Write `source` to the on disk media store, and also the backup store if configured. Args: source: A file like object that should be written path (str): Relative path to write file to Returns: Deferred[str]: the file path written to in the primary media store """ fname = os.path.join(self.primary_base_path, path) # Write to the main repository yield make_deferred_yieldable( threads.deferToThread( self._write_file_synchronously, source, fname, )) # Write to backup repository yield self.copy_to_backup(path) defer.returnValue(fname)
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.remote_media_filepath(server_name, file_id) thumbnailer = Thumbnailer(input_path) t_byte_source = yield make_deferred_yieldable( threads.deferToThread(self._generate_thumbnail, thumbnailer, t_width, t_height, t_method, t_type)) if t_byte_source: try: output_path = yield self.write_to_file_and_backup( t_byte_source, self.filepaths.remote_media_thumbnail_rel( server_name, file_id, t_width, t_height, t_type, t_method)) finally: t_byte_source.close() logger.info("Stored thumbnail in file %r", output_path) t_len = os.path.getsize(output_path) yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len) defer.returnValue(output_path)
def fire(self, *args, **kwargs): """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is not an error to fire a signal with no observers. Returns a Deferred that will complete when all the observers have completed.""" def do(observer): def eb(failure): logger.warning( "%s signal observer %s failed: %r", self.name, observer, failure, exc_info=( failure.type, failure.value, failure.getTracebackObject())) return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) deferreds = [ run_in_background(do, o) for o in self.observers ] return make_deferred_yieldable(defer.gatherResults( deferreds, consumeErrors=True, ))
def get_keys_from_store(self, server_name_and_key_ids): """ Args: server_name_and_key_ids (list[(str, iterable[str])]): list of (server_name, iterable[key_id]) tuples to fetch keys for Returns: Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from server_name -> key_id -> VerifyKey """ res = yield logcontext.make_deferred_yieldable( defer.gatherResults( [ run_in_background( self.store.get_server_verify_keys, server_name, key_ids, ).addCallback(lambda ks, server: (server, ks), server_name) for server_name, key_ids in server_name_and_key_ids ], consumeErrors=True, ).addErrback(unwrapFirstError)) defer.returnValue(dict(res))
def validate_hash(self, password, stored_hash): """Validates that self.hash(password) == stored_hash. Args: password (unicode): Password to hash. stored_hash (unicode): Expected hash value. Returns: Deferred(bool): Whether self.hash(password) == stored_hash. """ def _do_validate_hash(): # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) return bcrypt.checkpw( pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), stored_hash.encode('utf8')) if stored_hash: return make_deferred_yieldable( threads.deferToThreadPool( self.hs.get_reactor(), self.hs.get_reactor().getThreadPool(), _do_validate_hash, ), ) else: return defer.succeed(False)
def get_events(self, destinations, room_id, event_ids, return_local=True): """Fetch events from some remote destinations, checking if we already have them. Args: destinations (list) room_id (str) event_ids (list) return_local (bool): Whether to include events we already have in the DB in the returned list of events Returns: Deferred: A deferred resolving to a 2-tuple where the first is a list of events and the second is a list of event ids that we failed to fetch. """ if return_local: seen_events = yield self.store.get_events(event_ids, allow_rejected=True) signed_events = seen_events.values() else: seen_events = yield self.store.have_events(event_ids) signed_events = [] failed_to_fetch = set() missing_events = set(event_ids) for k in seen_events: missing_events.discard(k) if not missing_events: defer.returnValue((signed_events, failed_to_fetch)) def random_server_list(): srvs = list(destinations) random.shuffle(srvs) return srvs batch_size = 20 missing_events = list(missing_events) for i in xrange(0, len(missing_events), batch_size): batch = set(missing_events[i:i + batch_size]) deferreds = [ preserve_fn(self.get_pdu)( destinations=random_server_list(), event_id=e_id, ) for e_id in batch ] res = yield make_deferred_yieldable( defer.DeferredList(deferreds, consumeErrors=True)) for success, result in res: if success and result: signed_events.append(result) batch.discard(result.event_id) # We removed all events we successfully fetched from `batch` failed_to_fetch.update(batch) defer.returnValue((signed_events, failed_to_fetch))
def fetch_or_execute(self, txn_key, fn, *args, **kwargs): """Fetches the response for this transaction, or executes the given function to produce a response for this transaction. Args: txn_key (str): A key to ensure idempotency should fetch_or_execute be called again at a later point in time. fn (function): A function which returns a tuple of (response_code, response_dict). *args: Arguments to pass to fn. **kwargs: Keyword arguments to pass to fn. Returns: Deferred which resolves to a tuple of (response_code, response_dict). """ if txn_key in self.transactions: observable = self.transactions[txn_key][0] else: # execute the function instead. deferred = run_in_background(fn, *args, **kwargs) observable = ObservableDeferred(deferred) self.transactions[txn_key] = (observable, self.clock.time_msec()) # if the request fails with an exception, remove it # from the transaction map. This is done to ensure that we don't # cache transient errors like rate-limiting errors, etc. def remove_from_map(err): self.transactions.pop(txn_key, None) # we deliberately do not propagate the error any further, as we # expect the observers to have reported it. deferred.addErrback(remove_from_map) return make_deferred_yieldable(observable.observe())
def get_file(self, destination, path, output_stream, args={}, retry_on_dns_fail=True, max_size=None, ignore_backoff=False): """GETs a file from a given homeserver Args: destination (str): The remote server to send the HTTP request to. path (str): The HTTP path to GET. output_stream (file): File to write the response body to. args (dict): Optional dictionary used to create the query string. ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. Returns: Deferred: resolves with an (int,dict) tuple of the file length and a dict of the response headers. Fails with ``HttpResponseException`` if we get an HTTP response code >= 300 Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. Fails with ``FederationDeniedError`` if this destination is not on our federation whitelist """ request = MatrixFederationRequest( method="GET", destination=destination, path=path, query=args, ) response = yield self._send_request( request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff, ) headers = dict(response.headers.getAllRawHeaders()) try: d = _readBodyToFile(response, output_stream, max_size) d.addTimeout(self.default_timeout, self.hs.get_reactor()) length = yield make_deferred_yieldable(d) except Exception as e: logger.warn( "{%s} [%s] Error reading response: %s", request.txn_id, request.destination, e, ) raise logger.info( "{%s} [%s] Completed: %d %s [%d bytes]", request.txn_id, request.destination, response.code, response.phrase.decode('ascii', errors='replace'), length, ) defer.returnValue((length, headers))
def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0, order='DESC'): from_id = RoomStreamToken.parse_stream_token(from_key).stream room_ids = yield self._events_stream_cache.get_entities_changed( room_ids, from_id) if not room_ids: defer.returnValue({}) results = {} room_ids = list(room_ids) for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)): res = yield make_deferred_yieldable( defer.gatherResults([ run_in_background( self.get_room_events_stream_for_room, room_id, from_key, to_key, limit, order=order, ) for room_id in rm_ids ], consumeErrors=True)) results.update(dict(zip(rm_ids, res))) defer.returnValue(results)
def validate_hash(self, password, stored_hash): """Validates that self.hash(password) == stored_hash. Args: password (str): Password to hash. stored_hash (str): Expected hash value. Returns: Deferred(bool): Whether self.hash(password) == stored_hash. """ def _do_validate_hash(): return bcrypt.checkpw( password.encode('utf8') + self.hs.config.password_pepper, stored_hash.encode('utf8') ) if stored_hash: return make_deferred_yieldable( threads.deferToThreadPool( self.hs.get_reactor(), self.hs.get_reactor().getThreadPool(), _do_validate_hash, ), ) else: return defer.succeed(False)
def fire(self, *args, **kwargs): """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is not an error to fire a signal with no observers. Returns a Deferred that will complete when all the observers have completed.""" def do(observer): def eb(failure): logger.warning("%s signal observer %s failed: %r", self.name, observer, failure, exc_info=(failure.type, failure.value, failure.getTracebackObject())) return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) deferreds = [run_in_background(do, o) for o in self.observers] return make_deferred_yieldable( defer.gatherResults( deferreds, consumeErrors=True, ))
def request(self, method, uri, *args, **kwargs): # A small wrapper around self.agent.request() so we can easily attach # counters to it outgoing_requests_counter.inc(method) logger.info("Sending request %s %s", method, uri) try: request_deferred = self.agent.request(method, uri, *args, **kwargs) add_timeout_to_deferred( request_deferred, 60, cancelled_to_request_timed_out_error, ) response = yield make_deferred_yieldable(request_deferred) incoming_responses_counter.inc(method, response.code) logger.info("Received response to %s %s: %s", method, uri, response.code) defer.returnValue(response) except Exception as e: incoming_responses_counter.inc(method, "ERR") logger.info("Error sending request to %s %s: %s %s", method, uri, type(e).__name__, e.message) raise e
def get_events_from_store_or_dest(self, destination, room_id, event_ids): """Fetch events from a remote destination, checking if we already have them. Args: destination (str) room_id (str) event_ids (list) Returns: Deferred: A deferred resolving to a 2-tuple where the first is a list of events and the second is a list of event ids that we failed to fetch. """ seen_events = yield self.store.get_events(event_ids, allow_rejected=True) signed_events = list(seen_events.values()) failed_to_fetch = set() missing_events = set(event_ids) for k in seen_events: missing_events.discard(k) if not missing_events: defer.returnValue((signed_events, failed_to_fetch)) logger.debug( "Fetching unknown state/auth events %s for room %s", missing_events, event_ids, ) room_version = yield self.store.get_room_version(room_id) batch_size = 20 missing_events = list(missing_events) for i in range(0, len(missing_events), batch_size): batch = set(missing_events[i:i + batch_size]) deferreds = [ run_in_background( self.get_pdu, destinations=[destination], event_id=e_id, room_version=room_version, ) for e_id in batch ] res = yield make_deferred_yieldable( defer.DeferredList(deferreds, consumeErrors=True) ) for success, result in res: if success and result: signed_events.append(result) batch.discard(result.event_id) # We removed all events we successfully fetched from `batch` failed_to_fetch.update(batch) defer.returnValue((signed_events, failed_to_fetch))
def on_PUT(self, request, event_id): result = self.response_cache.get(event_id) if not result: result = self.response_cache.set(event_id, self._handle_request(request)) else: logger.warn("Returning cached response") return make_deferred_yieldable(result)
def get_server_verify_key_v2_direct(self, server_name, key_ids): keys = {} for requested_key_id in key_ids: if requested_key_id in keys: continue (response, tls_certificate) = yield fetch_server_key( server_name, self.hs.tls_client_options_factory, path=( "/_matrix/key/v2/server/%s" % (urllib.parse.quote(requested_key_id), )).encode("ascii"), ) if (u"signatures" not in response or server_name not in response[u"signatures"]): raise KeyLookupError( "Key response not signed by remote server") if "tls_fingerprints" not in response: raise KeyLookupError("Key response missing TLS fingerprints") certificate_bytes = crypto.dump_certificate( crypto.FILETYPE_ASN1, tls_certificate) sha256_fingerprint = hashlib.sha256(certificate_bytes).digest() sha256_fingerprint_b64 = encode_base64(sha256_fingerprint) response_sha256_fingerprints = set() for fingerprint in response[u"tls_fingerprints"]: if u"sha256" in fingerprint: response_sha256_fingerprints.add(fingerprint[u"sha256"]) if sha256_fingerprint_b64 not in response_sha256_fingerprints: raise KeyLookupError( "TLS certificate not allowed by fingerprints") response_keys = yield self.process_v2_response( from_server=server_name, requested_ids=[requested_key_id], response_json=response, ) keys.update(response_keys) yield logcontext.make_deferred_yieldable( defer.gatherResults( [ run_in_background( self.store_keys, server_name=key_server_name, from_server=server_name, verify_keys=verify_keys, ) for key_server_name, verify_keys in keys.items() ], consumeErrors=True).addErrback(unwrapFirstError)) defer.returnValue(keys)
def claim_one_time_keys(self, query, timeout): local_query = [] remote_queries = {} for user_id, device_keys in query.get("one_time_keys", {}).items(): # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): for device_id, algorithm in device_keys.items(): local_query.append((user_id, device_id, algorithm)) else: domain = get_domain_from_id(user_id) remote_queries.setdefault(domain, {})[user_id] = device_keys results = yield self.store.claim_e2e_one_time_keys(local_query) json_result = {} failures = {} for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): for key_id, json_bytes in keys.items(): json_result.setdefault(user_id, {})[device_id] = { key_id: json.loads(json_bytes) } @defer.inlineCallbacks def claim_client_keys(destination): device_keys = remote_queries[destination] try: remote_result = yield self.federation.claim_client_keys( destination, {"one_time_keys": device_keys}, timeout=timeout ) for user_id, keys in remote_result["one_time_keys"].items(): if user_id in device_keys: json_result[user_id] = keys except Exception as e: failures[destination] = _exception_to_failure(e) yield make_deferred_yieldable(defer.gatherResults([ run_in_background(claim_client_keys, destination) for destination in remote_queries ], consumeErrors=True)) logger.info( "Claimed one-time-keys: %s", ",".join(( "%s for %s:%s" % (key_id, user_id, device_id) for user_id, user_keys in iteritems(json_result) for device_id, device_keys in iteritems(user_keys) for key_id, _ in iteritems(device_keys) )), ) defer.returnValue({ "one_time_keys": json_result, "failures": failures })
def get_file(self, url, output_stream, max_size=None, headers=None): """GETs a file from a given URL Args: url (str): The URL to GET output_stream (file): File to write the response body to. headers (dict[str, List[str]]|None): If not None, a map from header name to a list of values for that header Returns: A (int,dict,string,int) tuple of the file length, dict of the response headers, absolute URI of the response and HTTP response code. """ actual_headers = {b"User-Agent": [self.user_agent]} if headers: actual_headers.update(headers) response = yield self.request("GET", url, headers=Headers(actual_headers)) resp_headers = dict(response.headers.getAllRawHeaders()) if (b"Content-Length" in resp_headers and int(resp_headers[b"Content-Length"][0]) > max_size): logger.warn("Requested URL is too large > %r bytes" % (self.max_size, )) raise SynapseError( 502, "Requested file is too large > %r bytes" % (self.max_size, ), Codes.TOO_LARGE, ) if response.code > 299: logger.warn("Got %d when downloading %s" % (response.code, url)) raise SynapseError(502, "Got error %d" % (response.code, ), Codes.UNKNOWN) # TODO: if our Content-Type is HTML or something, just read the first # N bytes into RAM rather than saving it all to disk only to read it # straight back in again try: length = yield make_deferred_yieldable( _readBodyToFile(response, output_stream, max_size)) except SynapseError: # This can happen e.g. because the body is too large. raise except Exception as e: raise_from( SynapseError(502, ("Failed to download remote body: %s" % e)), e) defer.returnValue(( length, resp_headers, response.request.absoluteURI.decode("ascii"), response.code, ))
def get_file(self, url, output_stream, max_size=None, headers=None): """GETs a file from a given URL Args: url (str): The URL to GET output_stream (file): File to write the response body to. headers (dict[str, List[str]]|None): If not None, a map from header name to a list of values for that header Returns: A (int,dict,string,int) tuple of the file length, dict of the response headers, absolute URI of the response and HTTP response code. """ actual_headers = {b"User-Agent": [self.user_agent]} if headers: actual_headers.update(headers) response = yield self.request("GET", url, headers=Headers(actual_headers)) resp_headers = dict(response.headers.getAllRawHeaders()) if ( b'Content-Length' in resp_headers and int(resp_headers[b'Content-Length'][0]) > max_size ): logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( 502, "Requested file is too large > %r bytes" % (self.max_size,), Codes.TOO_LARGE, ) if response.code > 299: logger.warn("Got %d when downloading %s" % (response.code, url)) raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN) # TODO: if our Content-Type is HTML or something, just read the first # N bytes into RAM rather than saving it all to disk only to read it # straight back in again try: length = yield make_deferred_yieldable( _readBodyToFile(response, output_stream, max_size) ) except Exception as e: logger.exception("Failed to download body") raise SynapseError( 502, ("Failed to download remote body: %s" % e), Codes.UNKNOWN ) defer.returnValue( ( length, resp_headers, response.request.absoluteURI.decode('ascii'), response.code, ) )
def get_server_verify_key_v2_direct(self, server_name, key_ids): keys = {} for requested_key_id in key_ids: if requested_key_id in keys: continue (response, tls_certificate) = yield fetch_server_key( server_name, self.hs.tls_server_context_factory, path=(b"/_matrix/key/v2/server/%s" % ( urllib.quote(requested_key_id), )).encode("ascii"), ) if (u"signatures" not in response or server_name not in response[u"signatures"]): raise KeyLookupError("Key response not signed by remote server") if "tls_fingerprints" not in response: raise KeyLookupError("Key response missing TLS fingerprints") certificate_bytes = crypto.dump_certificate( crypto.FILETYPE_ASN1, tls_certificate ) sha256_fingerprint = hashlib.sha256(certificate_bytes).digest() sha256_fingerprint_b64 = encode_base64(sha256_fingerprint) response_sha256_fingerprints = set() for fingerprint in response[u"tls_fingerprints"]: if u"sha256" in fingerprint: response_sha256_fingerprints.add(fingerprint[u"sha256"]) if sha256_fingerprint_b64 not in response_sha256_fingerprints: raise KeyLookupError("TLS certificate not allowed by fingerprints") response_keys = yield self.process_v2_response( from_server=server_name, requested_ids=[requested_key_id], response_json=response, ) keys.update(response_keys) yield logcontext.make_deferred_yieldable(defer.gatherResults( [ run_in_background( self.store_keys, server_name=key_server_name, from_server=server_name, verify_keys=verify_keys, ) for key_server_name, verify_keys in keys.items() ], consumeErrors=True ).addErrback(unwrapFirstError)) defer.returnValue(keys)
def get_json(self, destination, path, args=None, retry_on_dns_fail=True, timeout=None, ignore_backoff=False): """ GETs some json from the given host homeserver and path Args: destination (str): The remote server to send the HTTP request to. path (str): The HTTP path. args (dict|None): A dictionary used to create query strings, defaults to None. timeout (int): How long to try (in ms) the destination for before giving up. None indicates no timeout and that the request will be retried. ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Fails with ``HTTPRequestException`` if we get an HTTP response code >= 300. Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. Fails with ``FederationDeniedError`` if this destination is not on our federation whitelist """ logger.debug("get_json args: %s", args) logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) response = yield self._request( destination, "GET", path, query=args, retry_on_dns_fail=retry_on_dns_fail, timeout=timeout, ignore_backoff=ignore_backoff, ) if 200 <= response.code < 300: # We need to update the transactions table to say it was sent? check_content_type_is_json(response.headers) with logcontext.PreserveLoggingContext(): d = treq.json_content(response) d.addTimeout(self.default_timeout, self.hs.get_reactor()) body = yield make_deferred_yieldable(d) defer.returnValue(body)
def store_file(self, path, file_info): """See StorageProvider.store_file""" def _store_file(): boto3.resource('s3').Bucket(self.bucket).upload_file( Filename=os.path.join(self.cache_directory, path), Key=path, ExtraArgs={"StorageClass": self.storage_class}, ) return make_deferred_yieldable(reactor.callInThread(_store_file))
def claim_one_time_keys(self, query, timeout): local_query = [] remote_queries = {} for user_id, device_keys in query.get("one_time_keys", {}).items(): # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): for device_id, algorithm in device_keys.items(): local_query.append((user_id, device_id, algorithm)) else: domain = get_domain_from_id(user_id) remote_queries.setdefault(domain, {})[user_id] = device_keys results = yield self.store.claim_e2e_one_time_keys(local_query) json_result = {} failures = {} for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): for key_id, json_bytes in keys.items(): json_result.setdefault(user_id, {})[device_id] = { key_id: json.loads(json_bytes) } @defer.inlineCallbacks def claim_client_keys(destination): device_keys = remote_queries[destination] try: remote_result = yield self.federation.claim_client_keys( destination, {"one_time_keys": device_keys}, timeout=timeout) for user_id, keys in remote_result["one_time_keys"].items(): if user_id in device_keys: json_result[user_id] = keys except Exception as e: failures[destination] = _exception_to_failure(e) yield make_deferred_yieldable( defer.gatherResults( [ run_in_background(claim_client_keys, destination) for destination in remote_queries ], consumeErrors=True, )) logger.info( "Claimed one-time-keys: %s", ",".join(("%s for %s:%s" % (key_id, user_id, device_id) for user_id, user_keys in iteritems(json_result) for device_id, device_keys in iteritems(user_keys) for key_id, _ in iteritems(device_keys))), ) defer.returnValue({"one_time_keys": json_result, "failures": failures})
def on_context_state_request(self, origin, room_id, event_id): if not event_id: raise NotImplementedError("Specify an event") in_room = yield self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") result = self._state_resp_cache.get((room_id, event_id)) if not result: with (yield self._server_linearizer.queue((origin, room_id))): d = self._state_resp_cache.set( (room_id, event_id), preserve_fn(self._on_context_state_request_compute)( room_id, event_id)) resp = yield make_deferred_yieldable(d) else: resp = yield make_deferred_yieldable(result) defer.returnValue((200, resp))
def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0, order='DESC'): """Get new room events in stream ordering since `from_key`. Args: room_id (str) from_key (str): Token from which no events are returned before to_key (str): Token from which no events are returned after. (This is typically the current stream token) limit (int): Maximum number of events to return order (str): Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: Deferred[dict[str,tuple[list[FrozenEvent], str]]] A map from room id to a tuple containing: - list of recent events in the room - stream ordering key for the start of the chunk of events returned. """ from_id = RoomStreamToken.parse_stream_token(from_key).stream room_ids = yield self._events_stream_cache.get_entities_changed( room_ids, from_id) if not room_ids: defer.returnValue({}) results = {} room_ids = list(room_ids) for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)): res = yield make_deferred_yieldable( defer.gatherResults( [ run_in_background( self.get_room_events_stream_for_room, room_id, from_key, to_key, limit, order=order, ) for room_id in rm_ids ], consumeErrors=True, )) results.update(dict(zip(rm_ids, res))) defer.returnValue(results)
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. """ if not events: defer.returnValue({}) events_d = defer.Deferred() with self._event_fetch_lock: self._event_fetch_list.append( (events, events_d) ) self._event_fetch_lock.notify() if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: self._event_fetch_ongoing += 1 should_start = True else: should_start = False if should_start: run_as_background_process( "fetch_events", self.runWithConnection, self._do_fetch, ) logger.debug("Loading %d events", len(events)) with PreserveLoggingContext(): rows = yield events_d logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) if not allow_rejected: rows[:] = [r for r in rows if not r["rejects"]] res = yield make_deferred_yieldable(defer.gatherResults( [ run_in_background( self._get_event_from_row, row["internal_metadata"], row["json"], row["redacts"], rejected_reason=row["rejects"], ) for row in rows ], consumeErrors=True )) defer.returnValue({ e.event.event_id: e for e in res if e })
def test_make_deferred_yieldable_on_non_deferred(self): """Check that make_deferred_yieldable does the right thing when its argument isn't actually a deferred""" with LoggingContext() as context_one: context_one.request = "one" d1 = logcontext.make_deferred_yieldable("bum") self._check_test_key("one") r = yield d1 self.assertEqual(r, "bum") self._check_test_key("one")
def test_make_deferred_yieldable_with_chained_deferreds(self): sentinel_context = LoggingContext.current_context() with LoggingContext() as context_one: context_one.request = "one" d1 = logcontext.make_deferred_yieldable(_chained_deferred_function()) # make sure that the context was reset by make_deferred_yieldable self.assertIs(LoggingContext.current_context(), sentinel_context) yield d1 # now it should be restored self._check_test_key("one")
def request(self, method, uri, data=b'', headers=None): """ Args: method (str): HTTP method to use. uri (str): URI to query. data (bytes): Data to send in the request body, if applicable. headers (t.w.http_headers.Headers): Request headers. Raises: SynapseError: If the IP is blacklisted. """ # A small wrapper around self.agent.request() so we can easily attach # counters to it outgoing_requests_counter.labels(method).inc() # log request but strip `access_token` (AS requests for example include this) logger.info("Sending request %s %s", method, redact_uri(uri)) try: request_deferred = treq.request( method, uri, agent=self.agent, data=data, headers=headers, **self._extra_treq_args ) request_deferred = timeout_deferred( request_deferred, 60, self.hs.get_reactor(), cancelled_to_request_timed_out_error, ) response = yield make_deferred_yieldable(request_deferred) incoming_responses_counter.labels(method, response.code).inc() logger.info( "Received response to %s %s: %s", method, redact_uri(uri), response.code ) defer.returnValue(response) except Exception as e: incoming_responses_counter.labels(method, "ERR").inc() logger.info( "Error sending request to %s %s: %s %s", method, redact_uri(uri), type(e).__name__, e.args[0], ) raise
def second_lookup(): with LoggingContext("12") as context_12: context_12.request = "12" self.http_client.post_json.reset_mock() self.http_client.post_json.return_value = defer.Deferred() res_deferreds_2 = kr.verify_json_objects_for_server( [("server10", json1)] ) res_deferreds_2[0].addBoth(self.check_context, None) yield logcontext.make_deferred_yieldable(res_deferreds_2[0]) # let verify_json_objects_for_server finish its work before we kill the # logcontext yield self.clock.sleep(0)
def get_file(destination, path, output_stream, args=None, max_size=None): """ Returns tuple[int,dict,str,int] of file length, response headers, absolute URI, and response code. """ def write_to(r): data, response = r output_stream.write(data) return response d = Deferred() d.addCallback(write_to) self.fetches.append((d, destination, path, args)) return make_deferred_yieldable(d)
def wrap(self, key, callback, *args, **kwargs): """Wrap together a *get* and *set* call, taking care of logcontexts First looks up the key in the cache, and if it is present makes it follow the synapse logcontext rules and returns it. Otherwise, makes a call to *callback(*args, **kwargs)*, which should follow the synapse logcontext rules, and adds the result to the cache. Example usage: @defer.inlineCallbacks def handle_request(request): # etc defer.returnValue(result) result = yield response_cache.wrap( key, handle_request, request, ) Args: key (hashable): key to get/set in the cache callback (callable): function to call if the key is not found in the cache *args: positional parameters to pass to the callback, if it is used **kwargs: named paramters to pass to the callback, if it is used Returns: twisted.internet.defer.Deferred: yieldable result """ result = self.get(key) if not result: logger.info("[%s]: no cached result for [%s], calculating new one", self._name, key) d = run_in_background(callback, *args, **kwargs) result = self.set(key, d) elif not isinstance(result, defer.Deferred) or result.called: logger.info("[%s]: using completed cached result for [%s]", self._name, key) else: logger.info("[%s]: using incomplete cached result for [%s]", self._name, key) return make_deferred_yieldable(result)
def get_prev_state_ids(self, store): """Gets the prev state IDs Returns: Deferred[dict[(str, str), str]|None]: Returns None if state_group is None, which happens when the associated event is an outlier. """ if not self._fetching_state_deferred: self._fetching_state_deferred = run_in_background( self._fill_out_state, store, ) yield make_deferred_yieldable(self._fetching_state_deferred) defer.returnValue(self._prev_state_ids)
def yieldable_gather_results(func, iter, *args, **kwargs): """Executes the function with each argument concurrently. Args: func (func): Function to execute that returns a Deferred iter (iter): An iterable that yields items that get passed as the first argument to the function *args: Arguments to be passed to each call to func Returns Deferred[list]: Resolved when all functions have been invoked, or errors if one of the function calls fails. """ return logcontext.make_deferred_yieldable(defer.gatherResults([ run_in_background(func, item, *args, **kwargs) for item in iter ], consumeErrors=True)).addErrback(unwrapFirstError)
def query_3pe(self, kind, protocol, fields): services = yield self._get_services_for_3pn(protocol) results = yield make_deferred_yieldable(defer.DeferredList([ run_in_background( self.appservice_api.query_3pe, service, kind, protocol, fields, ) for service in services ], consumeErrors=True)) ret = [] for (success, result) in results: if success: ret.extend(result) defer.returnValue(ret)
def wrapped(*args, **kwargs): # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) cache_key = get_cache_key(args, kwargs) # Add our own `cache_context` to argument list if the wrapped function # has asked for one if self.add_cache_context: kwargs["cache_context"] = _CacheContext(cache, cache_key) try: cached_result_d = cache.get(cache_key, callback=invalidate_callback) if isinstance(cached_result_d, ObservableDeferred): observer = cached_result_d.observe() else: observer = cached_result_d except KeyError: ret = defer.maybeDeferred( logcontext.preserve_fn(self.function_to_call), obj, *args, **kwargs ) def onErr(f): cache.invalidate(cache_key) return f ret.addErrback(onErr) # If our cache_key is a string on py2, try to convert to ascii # to save a bit of space in large caches. Py3 does this # internally automatically. if six.PY2 and isinstance(cache_key, string_types): cache_key = to_ascii(cache_key) result_d = ObservableDeferred(ret, consumeErrors=True) cache.set(cache_key, result_d, callback=invalidate_callback) observer = result_d.observe() if isinstance(observer, defer.Deferred): return logcontext.make_deferred_yieldable(observer) else: return observer
def update_state(self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids): """Replace the state in the context """ # We need to make sure we wait for any ongoing fetching of state # to complete so that the updated state doesn't get clobbered if self._fetching_state_deferred: yield make_deferred_yieldable(self._fetching_state_deferred) self.state_group = state_group self._prev_state_ids = prev_state_ids self.prev_group = prev_group self._current_state_ids = current_state_ids self.delta_ids = delta_ids # We need to ensure that that we've marked as having fetched the state self._fetching_state_deferred = defer.succeed(None)
def get_current_state_ids(self, store): """Gets the current state IDs Returns: Deferred[dict[(str, str), str]|None]: Returns None if state_group is None, which happens when the associated event is an outlier. Maps a (type, state_key) to the event ID of the state event matching this tuple. """ if not self._fetching_state_deferred: self._fetching_state_deferred = run_in_background( self._fill_out_state, store, ) yield make_deferred_yieldable(self._fetching_state_deferred) defer.returnValue(self._current_state_ids)
def respond_with_file(request, media_type, file_path, file_size=None, upload_name=None): logger.debug("Responding with %r", file_path) if os.path.isfile(file_path): if file_size is None: stat = os.stat(file_path) file_size = stat.st_size add_file_headers(request, media_type, file_size, upload_name) with open(file_path, "rb") as f: yield logcontext.make_deferred_yieldable( FileSender().beginFileTransfer(f, request) ) finish_request(request) else: respond_404(request)
def test_send_single_event_with_queue(self): d = defer.Deferred() self.txn_ctrl.send = Mock(side_effect=lambda x, y: make_deferred_yieldable(d)) service = Mock(id=4) event = Mock(event_id="first") event2 = Mock(event_id="second") event3 = Mock(event_id="third") # Send an event and don't resolve it just yet. self.queuer.enqueue(service, event) # Send more events: expect send() to NOT be called multiple times. self.queuer.enqueue(service, event2) self.queuer.enqueue(service, event3) self.txn_ctrl.send.assert_called_with(service, [event]) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) self.txn_ctrl.send.assert_called_with(service, [event2, event3]) self.assertEquals(2, self.txn_ctrl.send.call_count)
def put_json(self, uri, json_body, args={}, headers=None): """ Puts some json to the given URI. Args: uri (str): The URI to request, not including query parameters json_body (dict): The JSON to put in the HTTP body, args (dict): A dictionary used to create query strings, defaults to None. **Note**: The value of each key is assumed to be an iterable and *not* a string. headers (dict[str, List[str]]|None): If not None, a map from header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body as JSON. Raises: HttpResponseException On a non-2xx HTTP response. ValueError: if the response was not JSON """ if len(args): query_bytes = urllib.parse.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) json_str = encode_canonical_json(json_body) actual_headers = { b"Content-Type": [b"application/json"], b"User-Agent": [self.user_agent], } if headers: actual_headers.update(headers) response = yield self.request( "PUT", uri, headers=Headers(actual_headers), data=json_str ) body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: defer.returnValue(json.loads(body)) else: raise HttpResponseException(response.code, response.phrase, body)
def on_REPLICATE(self, cmd): stream_name = cmd.stream_name token = cmd.token if stream_name == "ALL": # Subscribe to all streams we're publishing to. deferreds = [ run_in_background( self.subscribe_to_stream, stream, token, ) for stream in iterkeys(self.streamer.streams_by_name) ] return make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) else: return self.subscribe_to_stream(stream_name, token)
def store_keys(self, server_name, from_server, verify_keys): """Store a collection of verify keys for a given server Args: server_name(str): The name of the server the keys are for. from_server(str): The server the keys were downloaded from. verify_keys(dict): A mapping of key_id to VerifyKey. Returns: A deferred that completes when the keys are stored. """ # TODO(markjh): Store whether the keys have expired. return logcontext.make_deferred_yieldable(defer.gatherResults( [ run_in_background( self.store.store_server_verify_key, server_name, server_name, key.time_added, key ) for key_id, key in verify_keys.items() ], consumeErrors=True, ).addErrback(unwrapFirstError))
def store_file(self, source, file_info): """Write `source` to the on disk media store, and also any other configured storage providers Args: source: A file like object that should be written file_info (FileInfo): Info about the file to store Returns: Deferred[str]: the file path written to in the primary media store """ with self.store_into_file(file_info) as (f, fname, finish_cb): # Write to the main repository yield make_deferred_yieldable(threads.deferToThread( _write_file_synchronously, source, f, )) yield finish_cb() defer.returnValue(fname)