Beispiel #1
0
        def wrapped(*args, **kwargs):
            # If we're passed a cache_context then we'll want to call its invalidate()
            # whenever we are invalidated
            invalidate_callback = kwargs.pop("on_invalidate", None)

            # Add temp cache_context so inspect.getcallargs doesn't explode
            if self.add_cache_context:
                kwargs["cache_context"] = None

            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)

            # Add our own `cache_context` to argument list if the wrapped function
            # has asked for one
            if self.add_cache_context:
                kwargs["cache_context"] = _CacheContext(cache, cache_key)

            try:
                cached_result_d = cache.get(cache_key, callback=invalidate_callback)

                observer = cached_result_d.observe()
                if DEBUG_CACHES:
                    @defer.inlineCallbacks
                    def check_result(cached_result):
                        actual_result = yield self.function_to_call(obj, *args, **kwargs)
                        if actual_result != cached_result:
                            logger.error(
                                "Stale cache entry %s%r: cached: %r, actual %r",
                                self.orig.__name__, cache_key,
                                cached_result, actual_result,
                            )
                            raise ValueError("Stale cache entry")
                        defer.returnValue(cached_result)
                    observer.addCallback(check_result)

                return preserve_context_over_deferred(observer)
            except KeyError:
                # Get the sequence number of the cache before reading from the
                # database so that we can tell if the cache is invalidated
                # while the SELECT is executing (SYN-369)
                sequence = cache.sequence

                ret = defer.maybeDeferred(
                    preserve_context_over_fn,
                    self.function_to_call,
                    obj, *args, **kwargs
                )

                def onErr(f):
                    cache.invalidate(cache_key)
                    return f

                ret.addErrback(onErr)

                ret = ObservableDeferred(ret, consumeErrors=True)
                cache.update(sequence, cache_key, ret, callback=invalidate_callback)

                return preserve_context_over_deferred(ret.observe())
Beispiel #2
0
    def get_keys_from_perspectives(self, server_name_and_key_ids):
        @defer.inlineCallbacks
        def get_key(perspective_name, perspective_keys):
            try:
                result = yield self.get_server_verify_key_v2_indirect(
                    server_name_and_key_ids, perspective_name, perspective_keys
                )
                defer.returnValue(result)
            except Exception as e:
                logger.exception(
                    "Unable to get key from %r: %s %s",
                    perspective_name,
                    type(e).__name__, str(e.message),
                )
                defer.returnValue({})

        results = yield preserve_context_over_deferred(defer.gatherResults(
            [
                preserve_fn(get_key)(p_name, p_keys)
                for p_name, p_keys in self.perspective_servers.items()
            ],
            consumeErrors=True,
        )).addErrback(unwrapFirstError)

        union_of_keys = {}
        for result in results:
            for server_name, keys in result.items():
                union_of_keys.setdefault(server_name, {}).update(keys)

        defer.returnValue(union_of_keys)
Beispiel #3
0
def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
    """Fetch the keys for a remote server."""

    factory = SynapseKeyClientFactory()
    factory.path = path
    endpoint = matrix_federation_endpoint(
        reactor, server_name, ssl_context_factory, timeout=30
    )

    for i in range(5):
        try:
            protocol = yield preserve_context_over_fn(
                endpoint.connect, factory
            )
            server_response, server_certificate = yield preserve_context_over_deferred(
                protocol.remote_key
            )
            defer.returnValue((server_response, server_certificate))
            return
        except SynapseKeyClientError as e:
            logger.exception("Error getting key for %r" % (server_name,))
            if e.status.startswith("4"):
                # Don't retry for 4xx responses.
                raise IOError("Cannot get key for %r" % server_name)
        except Exception as e:
            logger.exception(e)
    raise IOError("Cannot get key for %r" % server_name)
Beispiel #4
0
    def fire(self, *args, **kwargs):
        """Invokes every callable in the observer list, passing in the args and
        kwargs. Exceptions thrown by observers are logged but ignored. It is
        not an error to fire a signal with no observers.

        Returns a Deferred that will complete when all the observers have
        completed."""
        def do(observer):
            def eb(failure):
                logger.warning("%s signal observer %s failed: %r",
                               self.name,
                               observer,
                               failure,
                               exc_info=(failure.type, failure.value,
                                         failure.getTracebackObject()))
                if not self.suppress_failures:
                    return failure

            return defer.maybeDeferred(observer, *args,
                                       **kwargs).addErrback(eb)

        with PreserveLoggingContext():
            deferreds = [do(observer) for observer in self.observers]

            d = defer.gatherResults(deferreds, consumeErrors=True)

        d.addErrback(unwrapFirstError)

        return preserve_context_over_deferred(d)
Beispiel #5
0
    def wait_for_previous_lookups(self, server_names, server_to_deferred):
        """Waits for any previous key lookups for the given servers to finish.

        Args:
            server_names (list): list of server_names we want to lookup
            server_to_deferred (dict): server_name to deferred which gets
                resolved once we've finished looking up keys for that server
        """
        while True:
            wait_on = [
                self.key_downloads[server_name]
                for server_name in server_names
                if server_name in self.key_downloads
            ]
            if wait_on:
                with PreserveLoggingContext():
                    yield defer.DeferredList(wait_on)
            else:
                break

        for server_name, deferred in server_to_deferred.items():
            d = ObservableDeferred(preserve_context_over_deferred(deferred))
            self.key_downloads[server_name] = d

            def rm(r, server_name):
                self.key_downloads.pop(server_name, None)
                return r

            d.addBoth(rm, server_name)
Beispiel #6
0
        def wrapped(*args, **kwargs):
            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
            list_args = arg_dict[self.list_name]

            # cached is a dict arg -> deferred, where deferred results in a
            # 2-tuple (`arg`, `result`)
            cached = {}
            missing = []
            for arg in list_args:
                key = list(keyargs)
                key[self.list_pos] = arg

                try:
                    res = self.cache.get(tuple(key)).observe()
                    res.addCallback(lambda r, arg: (arg, r), arg)
                    cached[arg] = res
                except KeyError:
                    missing.append(arg)

            if missing:
                sequence = self.cache.sequence
                args_to_call = dict(arg_dict)
                args_to_call[self.list_name] = missing

                ret_d = defer.maybeDeferred(
                    preserve_context_over_fn,
                    self.function_to_call,
                    **args_to_call
                )

                ret_d = ObservableDeferred(ret_d)

                # We need to create deferreds for each arg in the list so that
                # we can insert the new deferred into the cache.
                for arg in missing:
                    with PreserveLoggingContext():
                        observer = ret_d.observe()
                    observer.addCallback(lambda r, arg: r.get(arg, None), arg)

                    observer = ObservableDeferred(observer)

                    key = list(keyargs)
                    key[self.list_pos] = arg
                    self.cache.update(sequence, tuple(key), observer)

                    def invalidate(f, key):
                        self.cache.invalidate(key)
                        return f
                    observer.addErrback(invalidate, tuple(key))

                    res = observer.observe()
                    res.addCallback(lambda r, arg: (arg, r), arg)

                    cached[arg] = res

            return preserve_context_over_deferred(defer.gatherResults(
                cached.values(),
                consumeErrors=True,
            ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
Beispiel #7
0
    def wait_for_previous_lookups(self, server_names, server_to_deferred):
        """Waits for any previous key lookups for the given servers to finish.

        Args:
            server_names (list): list of server_names we want to lookup
            server_to_deferred (dict): server_name to deferred which gets
                resolved once we've finished looking up keys for that server
        """
        while True:
            wait_on = [
                self.key_downloads[server_name] for server_name in server_names
                if server_name in self.key_downloads
            ]
            if wait_on:
                with PreserveLoggingContext():
                    yield defer.DeferredList(wait_on)
            else:
                break

        for server_name, deferred in server_to_deferred.items():
            d = ObservableDeferred(preserve_context_over_deferred(deferred))
            self.key_downloads[server_name] = d

            def rm(r, server_name):
                self.key_downloads.pop(server_name, None)
                return r

            d.addBoth(rm, server_name)
def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
    """Fetch the keys for a remote server."""

    factory = SynapseKeyClientFactory()
    factory.path = path
    factory.host = server_name
    endpoint = matrix_federation_endpoint(reactor,
                                          server_name,
                                          ssl_context_factory,
                                          timeout=30)

    for i in range(5):
        try:
            protocol = yield preserve_context_over_fn(endpoint.connect,
                                                      factory)
            server_response, server_certificate = yield preserve_context_over_deferred(
                protocol.remote_key)
            defer.returnValue((server_response, server_certificate))
            return
        except SynapseKeyClientError as e:
            logger.exception("Error getting key for %r" % (server_name, ))
            if e.status.startswith("4"):
                # Don't retry for 4xx responses.
                raise IOError("Cannot get key for %r" % server_name)
        except Exception as e:
            logger.exception(e)
    raise IOError("Cannot get key for %r" % server_name)
Beispiel #9
0
    def _push_update(self, room_id, user_id, typing):
        users = yield self.state.get_current_user_in_room(room_id)
        domains = set(get_domain_from_id(u) for u in users)

        deferreds = []
        for domain in domains:
            if domain == self.server_name:
                preserve_fn(self._push_update_local)(
                    room_id=room_id,
                    user_id=user_id,
                    typing=typing
                )
            else:
                deferreds.append(preserve_fn(self.federation.send_edu)(
                    destination=domain,
                    edu_type="m.typing",
                    content={
                        "room_id": room_id,
                        "user_id": user_id,
                        "typing": typing,
                    },
                    key=(room_id, user_id),
                ))

        yield preserve_context_over_deferred(
            defer.DeferredList(deferreds, consumeErrors=True)
        )
Beispiel #10
0
    def get_room_events_stream_for_rooms(self,
                                         room_ids,
                                         from_key,
                                         to_key,
                                         limit=0,
                                         order='DESC'):
        from_id = RoomStreamToken.parse_stream_token(from_key).stream

        room_ids = yield self._events_stream_cache.get_entities_changed(
            room_ids, from_id)

        if not room_ids:
            defer.returnValue({})

        results = {}
        room_ids = list(room_ids)
        for rm_ids in (room_ids[i:i + 20]
                       for i in xrange(0, len(room_ids), 20)):
            res = yield preserve_context_over_deferred(
                defer.gatherResults([
                    preserve_fn(self.get_room_events_stream_for_room)(
                        room_id,
                        from_key,
                        to_key,
                        limit,
                        order=order,
                    ) for room_id in rm_ids
                ]))
            results.update(dict(zip(rm_ids, res)))

        defer.returnValue(results)
Beispiel #11
0
    def get_keys_from_perspectives(self, server_name_and_key_ids):
        @defer.inlineCallbacks
        def get_key(perspective_name, perspective_keys):
            try:
                result = yield self.get_server_verify_key_v2_indirect(
                    server_name_and_key_ids, perspective_name, perspective_keys
                )
                defer.returnValue(result)
            except Exception as e:
                logger.exception(
                    "Unable to get key from %r: %s %s",
                    perspective_name,
                    type(e).__name__, str(e.message),
                )
                defer.returnValue({})

        results = yield preserve_context_over_deferred(defer.gatherResults(
            [
                preserve_fn(get_key)(p_name, p_keys)
                for p_name, p_keys in self.perspective_servers.items()
            ],
            consumeErrors=True,
        )).addErrback(unwrapFirstError)

        union_of_keys = {}
        for result in results:
            for server_name, keys in result.items():
                union_of_keys.setdefault(server_name, {}).update(keys)

        defer.returnValue(union_of_keys)
Beispiel #12
0
    def get_events(self, destinations, room_id, event_ids, return_local=True):
        """Fetch events from some remote destinations, checking if we already
        have them.

        Args:
            destinations (list)
            room_id (str)
            event_ids (list)
            return_local (bool): Whether to include events we already have in
                the DB in the returned list of events

        Returns:
            Deferred: A deferred resolving to a 2-tuple where the first is a list of
            events and the second is a list of event ids that we failed to fetch.
        """
        if return_local:
            seen_events = yield self.store.get_events(event_ids,
                                                      allow_rejected=True)
            signed_events = seen_events.values()
        else:
            seen_events = yield self.store.have_events(event_ids)
            signed_events = []

        failed_to_fetch = set()

        missing_events = set(event_ids)
        for k in seen_events:
            missing_events.discard(k)

        if not missing_events:
            defer.returnValue((signed_events, failed_to_fetch))

        def random_server_list():
            srvs = list(destinations)
            random.shuffle(srvs)
            return srvs

        batch_size = 20
        missing_events = list(missing_events)
        for i in xrange(0, len(missing_events), batch_size):
            batch = set(missing_events[i:i + batch_size])

            deferreds = [
                preserve_fn(self.get_pdu)(
                    destinations=random_server_list(),
                    event_id=e_id,
                ) for e_id in batch
            ]

            res = yield preserve_context_over_deferred(
                defer.DeferredList(deferreds, consumeErrors=True))
            for success, result in res:
                if success:
                    signed_events.append(result)
                    batch.discard(result.event_id)

            # We removed all events we successfully fetched from `batch`
            failed_to_fetch.update(batch)

        defer.returnValue((signed_events, failed_to_fetch))
Beispiel #13
0
    def fire(self, *args, **kwargs):
        """Invokes every callable in the observer list, passing in the args and
        kwargs. Exceptions thrown by observers are logged but ignored. It is
        not an error to fire a signal with no observers.

        Returns a Deferred that will complete when all the observers have
        completed."""

        def do(observer):
            def eb(failure):
                logger.warning(
                    "%s signal observer %s failed: %r",
                    self.name, observer, failure,
                    exc_info=(
                        failure.type,
                        failure.value,
                        failure.getTracebackObject()))
                if not self.suppress_failures:
                    return failure
            return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)

        with PreserveLoggingContext():
            deferreds = [
                do(observer)
                for observer in self.observers
            ]

            d = defer.gatherResults(deferreds, consumeErrors=True)

        d.addErrback(unwrapFirstError)

        return preserve_context_over_deferred(d)
Beispiel #14
0
        def wrapped(*args, **kwargs):
            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
            list_args = arg_dict[self.list_name]

            # cached is a dict arg -> deferred, where deferred results in a
            # 2-tuple (`arg`, `result`)
            cached = {}
            missing = []
            for arg in list_args:
                key = list(keyargs)
                key[self.list_pos] = arg

                try:
                    res = self.cache.get(tuple(key)).observe()
                    res.addCallback(lambda r, arg: (arg, r), arg)
                    cached[arg] = res
                except KeyError:
                    missing.append(arg)

            if missing:
                sequence = self.cache.sequence
                args_to_call = dict(arg_dict)
                args_to_call[self.list_name] = missing

                ret_d = defer.maybeDeferred(
                    preserve_context_over_fn,
                    self.function_to_call,
                    **args_to_call
                )

                ret_d = ObservableDeferred(ret_d)

                # We need to create deferreds for each arg in the list so that
                # we can insert the new deferred into the cache.
                for arg in missing:
                    with PreserveLoggingContext():
                        observer = ret_d.observe()
                    observer.addCallback(lambda r, arg: r.get(arg, None), arg)

                    observer = ObservableDeferred(observer)

                    key = list(keyargs)
                    key[self.list_pos] = arg
                    self.cache.update(sequence, tuple(key), observer)

                    def invalidate(f, key):
                        self.cache.invalidate(key)
                        return f
                    observer.addErrback(invalidate, tuple(key))

                    res = observer.observe()
                    res.addCallback(lambda r, arg: (arg, r), arg)

                    cached[arg] = res

            return preserve_context_over_deferred(defer.gatherResults(
                cached.values(),
                consumeErrors=True,
            ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
Beispiel #15
0
def get_badge_count(store, user_id):
    invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
        preserve_fn(store.get_invited_rooms_for_user)(user_id),
        preserve_fn(store.get_rooms_for_user)(user_id),
    ], consumeErrors=True))

    my_receipts_by_room = yield store.get_receipts_for_user(
        user_id, "m.read",
    )

    badge = len(invites)

    for r in joins:
        if r.room_id in my_receipts_by_room:
            last_unread_event_id = my_receipts_by_room[r.room_id]

            notifs = yield (
                store.get_unread_event_push_actions_by_room_for_user(
                    r.room_id, user_id, last_unread_event_id
                )
            )
            # return one badge count per conversation, as count per
            # message is so noisy as to be almost useless
            badge += 1 if notifs["notify_count"] else 0
    defer.returnValue(badge)
Beispiel #16
0
    def _push_update(self, room_id, user_id, typing):
        users = yield self.state.get_current_user_in_room(room_id)
        domains = set(get_domain_from_id(u) for u in users)

        deferreds = []
        for domain in domains:
            if domain == self.server_name:
                preserve_fn(self._push_update_local)(room_id=room_id,
                                                     user_id=user_id,
                                                     typing=typing)
            else:
                deferreds.append(
                    preserve_fn(self.federation.send_edu)(
                        destination=domain,
                        edu_type="m.typing",
                        content={
                            "room_id": room_id,
                            "user_id": user_id,
                            "typing": typing,
                        },
                        key=(room_id, user_id),
                    ))

        yield preserve_context_over_deferred(
            defer.DeferredList(deferreds, consumeErrors=True))
Beispiel #17
0
        def wrapped(*args, **kwargs):
            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
            try:
                cached_result_d = cache.get(cache_key)

                observer = cached_result_d.observe()
                if DEBUG_CACHES:

                    @defer.inlineCallbacks
                    def check_result(cached_result):
                        actual_result = yield self.function_to_call(
                            obj, *args, **kwargs)
                        if actual_result != cached_result:
                            logger.error(
                                "Stale cache entry %s%r: cached: %r, actual %r",
                                self.orig.__name__,
                                cache_key,
                                cached_result,
                                actual_result,
                            )
                            raise ValueError("Stale cache entry")
                        defer.returnValue(cached_result)

                    observer.addCallback(check_result)

                return preserve_context_over_deferred(observer)
            except KeyError:
                # Get the sequence number of the cache before reading from the
                # database so that we can tell if the cache is invalidated
                # while the SELECT is executing (SYN-369)
                sequence = cache.sequence

                ret = defer.maybeDeferred(preserve_context_over_fn,
                                          self.function_to_call, obj, *args,
                                          **kwargs)

                def onErr(f):
                    cache.invalidate(cache_key)
                    return f

                ret.addErrback(onErr)

                ret = ObservableDeferred(ret, consumeErrors=True)
                cache.update(sequence, cache_key, ret)

                return preserve_context_over_deferred(ret.observe())
Beispiel #18
0
    def claim_one_time_keys(self, query, timeout):
        local_query = []
        remote_queries = {}

        for user_id, device_keys in query.get("one_time_keys", {}).items():
            if self.is_mine_id(user_id):
                for device_id, algorithm in device_keys.items():
                    local_query.append((user_id, device_id, algorithm))
            else:
                domain = get_domain_from_id(user_id)
                remote_queries.setdefault(domain, {})[user_id] = device_keys

        results = yield self.store.claim_e2e_one_time_keys(local_query)

        json_result = {}
        failures = {}
        for user_id, device_keys in results.items():
            for device_id, keys in device_keys.items():
                for key_id, json_bytes in keys.items():
                    json_result.setdefault(user_id, {})[device_id] = {
                        key_id: json.loads(json_bytes)
                    }

        @defer.inlineCallbacks
        def claim_client_keys(destination):
            device_keys = remote_queries[destination]
            try:
                limiter = yield get_retry_limiter(destination, self.clock,
                                                  self.store)
                with limiter:
                    remote_result = yield self.federation.claim_client_keys(
                        destination, {"one_time_keys": device_keys},
                        timeout=timeout)
                    for user_id, keys in remote_result["one_time_keys"].items(
                    ):
                        if user_id in device_keys:
                            json_result[user_id] = keys
            except CodeMessageException as e:
                failures[destination] = {
                    "status": e.code,
                    "message": e.message
                }
            except NotRetryingDestination as e:
                failures[destination] = {
                    "status": 503,
                    "message": "Not ready for retry",
                }
            except Exception as e:
                # include ConnectionRefused and other errors
                failures[destination] = {"status": 503, "message": e.message}

        yield preserve_context_over_deferred(
            defer.gatherResults([
                preserve_fn(claim_client_keys)(destination)
                for destination in remote_queries
            ]))

        defer.returnValue({"one_time_keys": json_result, "failures": failures})
Beispiel #19
0
    def claim_one_time_keys(self, query, timeout):
        local_query = []
        remote_queries = {}

        for user_id, device_keys in query.get("one_time_keys", {}).items():
            if self.is_mine_id(user_id):
                for device_id, algorithm in device_keys.items():
                    local_query.append((user_id, device_id, algorithm))
            else:
                domain = get_domain_from_id(user_id)
                remote_queries.setdefault(domain, {})[user_id] = device_keys

        results = yield self.store.claim_e2e_one_time_keys(local_query)

        json_result = {}
        failures = {}
        for user_id, device_keys in results.items():
            for device_id, keys in device_keys.items():
                for key_id, json_bytes in keys.items():
                    json_result.setdefault(user_id, {})[device_id] = {
                        key_id: json.loads(json_bytes)
                    }

        @defer.inlineCallbacks
        def claim_client_keys(destination):
            device_keys = remote_queries[destination]
            try:
                limiter = yield get_retry_limiter(
                    destination, self.clock, self.store
                )
                with limiter:
                    remote_result = yield self.federation.claim_client_keys(
                        destination,
                        {"one_time_keys": device_keys},
                        timeout=timeout
                    )
                    for user_id, keys in remote_result["one_time_keys"].items():
                        if user_id in device_keys:
                            json_result[user_id] = keys
            except CodeMessageException as e:
                failures[destination] = {
                    "status": e.code, "message": e.message
                }
            except NotRetryingDestination as e:
                failures[destination] = {
                    "status": 503, "message": "Not ready for retry",
                }

        yield preserve_context_over_deferred(defer.gatherResults([
            preserve_fn(claim_client_keys)(destination)
            for destination in remote_queries
        ]))

        defer.returnValue({
            "one_time_keys": json_result,
            "failures": failures
        })
Beispiel #20
0
        def wrapped(*args, **kwargs):
            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
            try:
                cached_result_d = cache.get(cache_key)

                observer = cached_result_d.observe()
                if DEBUG_CACHES:
                    @defer.inlineCallbacks
                    def check_result(cached_result):
                        actual_result = yield self.function_to_call(obj, *args, **kwargs)
                        if actual_result != cached_result:
                            logger.error(
                                "Stale cache entry %s%r: cached: %r, actual %r",
                                self.orig.__name__, cache_key,
                                cached_result, actual_result,
                            )
                            raise ValueError("Stale cache entry")
                        defer.returnValue(cached_result)
                    observer.addCallback(check_result)

                return preserve_context_over_deferred(observer)
            except KeyError:
                # Get the sequence number of the cache before reading from the
                # database so that we can tell if the cache is invalidated
                # while the SELECT is executing (SYN-369)
                sequence = cache.sequence

                ret = defer.maybeDeferred(
                    preserve_context_over_fn,
                    self.function_to_call,
                    obj, *args, **kwargs
                )

                def onErr(f):
                    cache.invalidate(cache_key)
                    return f

                ret.addErrback(onErr)

                ret = ObservableDeferred(ret, consumeErrors=True)
                cache.update(sequence, cache_key, ret)

                return preserve_context_over_deferred(ret.observe())
Beispiel #21
0
    def get_server_verify_key_v2_direct(self, server_name, key_ids):
        keys = {}

        for requested_key_id in key_ids:
            if requested_key_id in keys:
                continue

            (response, tls_certificate) = yield fetch_server_key(
                server_name, self.hs.tls_server_context_factory,
                path=(b"/_matrix/key/v2/server/%s" % (
                    urllib.quote(requested_key_id),
                )).encode("ascii"),
            )

            if (u"signatures" not in response
                    or server_name not in response[u"signatures"]):
                raise KeyLookupError("Key response not signed by remote server")

            if "tls_fingerprints" not in response:
                raise KeyLookupError("Key response missing TLS fingerprints")

            certificate_bytes = crypto.dump_certificate(
                crypto.FILETYPE_ASN1, tls_certificate
            )
            sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
            sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)

            response_sha256_fingerprints = set()
            for fingerprint in response[u"tls_fingerprints"]:
                if u"sha256" in fingerprint:
                    response_sha256_fingerprints.add(fingerprint[u"sha256"])

            if sha256_fingerprint_b64 not in response_sha256_fingerprints:
                raise KeyLookupError("TLS certificate not allowed by fingerprints")

            response_keys = yield self.process_v2_response(
                from_server=server_name,
                requested_ids=[requested_key_id],
                response_json=response,
            )

            keys.update(response_keys)

        yield preserve_context_over_deferred(defer.gatherResults(
            [
                preserve_fn(self.store_keys)(
                    server_name=key_server_name,
                    from_server=server_name,
                    verify_keys=verify_keys,
                )
                for key_server_name, verify_keys in keys.items()
            ],
            consumeErrors=True
        )).addErrback(unwrapFirstError)

        defer.returnValue(keys)
Beispiel #22
0
    def get_server_verify_key_v2_direct(self, server_name, key_ids):
        keys = {}

        for requested_key_id in key_ids:
            if requested_key_id in keys:
                continue

            (response, tls_certificate) = yield fetch_server_key(
                server_name, self.hs.tls_server_context_factory,
                path=(b"/_matrix/key/v2/server/%s" % (
                    urllib.quote(requested_key_id),
                )).encode("ascii"),
            )

            if (u"signatures" not in response
                    or server_name not in response[u"signatures"]):
                raise KeyLookupError("Key response not signed by remote server")

            if "tls_fingerprints" not in response:
                raise KeyLookupError("Key response missing TLS fingerprints")

            certificate_bytes = crypto.dump_certificate(
                crypto.FILETYPE_ASN1, tls_certificate
            )
            sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
            sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)

            response_sha256_fingerprints = set()
            for fingerprint in response[u"tls_fingerprints"]:
                if u"sha256" in fingerprint:
                    response_sha256_fingerprints.add(fingerprint[u"sha256"])

            if sha256_fingerprint_b64 not in response_sha256_fingerprints:
                raise KeyLookupError("TLS certificate not allowed by fingerprints")

            response_keys = yield self.process_v2_response(
                from_server=server_name,
                requested_ids=[requested_key_id],
                response_json=response,
            )

            keys.update(response_keys)

        yield preserve_context_over_deferred(defer.gatherResults(
            [
                preserve_fn(self.store_keys)(
                    server_name=key_server_name,
                    from_server=server_name,
                    verify_keys=verify_keys,
                )
                for key_server_name, verify_keys in keys.items()
            ],
            consumeErrors=True
        )).addErrback(unwrapFirstError)

        defer.returnValue(keys)
Beispiel #23
0
    def get_keys_from_store(self, server_name_and_key_ids):
        res = yield preserve_context_over_deferred(defer.gatherResults(
            [
                preserve_fn(self.store.get_server_verify_keys)(
                    server_name, key_ids
                ).addCallback(lambda ks, server: (server, ks), server_name)
                for server_name, key_ids in server_name_and_key_ids
            ],
            consumeErrors=True,
        )).addErrback(unwrapFirstError)

        defer.returnValue(dict(res))
Beispiel #24
0
    def _enqueue_events(self, events, check_redacted=True,
                        get_prev_content=False, allow_rejected=False):
        """Fetches events from the database using the _event_fetch_list. This
        allows batch and bulk fetching of events - it allows us to fetch events
        without having to create a new transaction for each request for events.
        """
        if not events:
            defer.returnValue({})

        events_d = defer.Deferred()
        with self._event_fetch_lock:
            self._event_fetch_list.append(
                (events, events_d)
            )

            self._event_fetch_lock.notify()

            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
                self._event_fetch_ongoing += 1
                should_start = True
            else:
                should_start = False

        if should_start:
            self.runWithConnection(
                self._do_fetch
            )

        rows = yield preserve_context_over_deferred(events_d)

        if not allow_rejected:
            rows[:] = [r for r in rows if not r["rejects"]]

        res = yield defer.gatherResults(
            [
                self._get_event_from_row(
                    row["internal_metadata"], row["json"], row["redacts"],
                    check_redacted=check_redacted,
                    get_prev_content=get_prev_content,
                    rejected_reason=row["rejects"],
                )
                for row in rows
            ],
            consumeErrors=True
        )

        defer.returnValue({
            e.event_id: e
            for e in res if e
        })
Beispiel #25
0
    def query_3pe(self, kind, protocol, fields):
        services = yield self._get_services_for_3pn(protocol)

        results = yield preserve_context_over_deferred(
            defer.DeferredList([
                preserve_fn(self.appservice_api.query_3pe)(
                    service, kind, protocol, fields) for service in services
            ],
                               consumeErrors=True))

        ret = []
        for (success, result) in results:
            if success:
                ret.extend(result)

        defer.returnValue(ret)
Beispiel #26
0
    def get_keys_from_server(self, server_name_and_key_ids):
        @defer.inlineCallbacks
        def get_key(server_name, key_ids):
            limiter = yield get_retry_limiter(
                server_name,
                self.clock,
                self.store,
            )
            with limiter:
                keys = None
                try:
                    keys = yield self.get_server_verify_key_v2_direct(
                        server_name, key_ids
                    )
                except Exception as e:
                    logger.info(
                        "Unable to get key %r for %r directly: %s %s",
                        key_ids, server_name,
                        type(e).__name__, str(e.message),
                    )

                if not keys:
                    keys = yield self.get_server_verify_key_v1_direct(
                        server_name, key_ids
                    )

                    keys = {server_name: keys}

            defer.returnValue(keys)

        results = yield preserve_context_over_deferred(defer.gatherResults(
            [
                preserve_fn(get_key)(server_name, key_ids)
                for server_name, key_ids in server_name_and_key_ids
            ],
            consumeErrors=True,
        )).addErrback(unwrapFirstError)

        merged = {}
        for result in results:
            merged.update(result)

        defer.returnValue({
            server_name: keys
            for server_name, keys in merged.items()
            if keys
        })
Beispiel #27
0
    def _enqueue_events(self,
                        events,
                        check_redacted=True,
                        get_prev_content=False,
                        allow_rejected=False):
        """Fetches events from the database using the _event_fetch_list. This
        allows batch and bulk fetching of events - it allows us to fetch events
        without having to create a new transaction for each request for events.
        """
        if not events:
            defer.returnValue({})

        events_d = defer.Deferred()
        with self._event_fetch_lock:
            self._event_fetch_list.append((events, events_d))

            self._event_fetch_lock.notify()

            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
                self._event_fetch_ongoing += 1
                should_start = True
            else:
                should_start = False

        if should_start:
            self.runWithConnection(self._do_fetch)

        rows = yield preserve_context_over_deferred(events_d)

        if not allow_rejected:
            rows[:] = [r for r in rows if not r["rejects"]]

        res = yield defer.gatherResults([
            self._get_event_from_row(
                row["internal_metadata"],
                row["json"],
                row["redacts"],
                check_redacted=check_redacted,
                get_prev_content=get_prev_content,
                rejected_reason=row["rejects"],
            ) for row in rows
        ],
                                        consumeErrors=True)

        defer.returnValue({e.event_id: e for e in res if e})
Beispiel #28
0
 def store_keys(self, server_name, from_server, verify_keys):
     """Store a collection of verify keys for a given server
     Args:
         server_name(str): The name of the server the keys are for.
         from_server(str): The server the keys were downloaded from.
         verify_keys(dict): A mapping of key_id to VerifyKey.
     Returns:
         A deferred that completes when the keys are stored.
     """
     # TODO(markjh): Store whether the keys have expired.
     yield preserve_context_over_deferred(defer.gatherResults(
         [
             preserve_fn(self.store.store_server_verify_key)(
                 server_name, server_name, key.time_added, key
             )
             for key_id, key in verify_keys.items()
         ],
         consumeErrors=True,
     )).addErrback(unwrapFirstError)
Beispiel #29
0
 def store_keys(self, server_name, from_server, verify_keys):
     """Store a collection of verify keys for a given server
     Args:
         server_name(str): The name of the server the keys are for.
         from_server(str): The server the keys were downloaded from.
         verify_keys(dict): A mapping of key_id to VerifyKey.
     Returns:
         A deferred that completes when the keys are stored.
     """
     # TODO(markjh): Store whether the keys have expired.
     yield preserve_context_over_deferred(defer.gatherResults(
         [
             preserve_fn(self.store.store_server_verify_key)(
                 server_name, server_name, key.time_added, key
             )
             for key_id, key in verify_keys.items()
         ],
         consumeErrors=True,
     )).addErrback(unwrapFirstError)
Beispiel #30
0
    def on_new_notifications(self, min_stream_id, max_stream_id):
        yield run_on_reactor()
        try:
            users_affected = yield self.store.get_push_action_users_in_range(
                min_stream_id, max_stream_id
            )

            deferreds = []

            for u in users_affected:
                if u in self.pushers:
                    for p in self.pushers[u].values():
                        deferreds.append(
                            preserve_fn(p.on_new_notifications)(
                                min_stream_id, max_stream_id
                            )
                        )

            yield preserve_context_over_deferred(defer.gatherResults(deferreds))
        except:
            logger.exception("Exception in pusher on_new_notifications")
Beispiel #31
0
    def on_new_notifications(self, min_stream_id, max_stream_id):
        yield run_on_reactor()
        try:
            users_affected = yield self.store.get_push_action_users_in_range(
                min_stream_id, max_stream_id
            )

            deferreds = []

            for u in users_affected:
                if u in self.pushers:
                    for p in self.pushers[u].values():
                        deferreds.append(
                            preserve_fn(p.on_new_notifications)(
                                min_stream_id, max_stream_id
                            )
                        )

            yield preserve_context_over_deferred(defer.gatherResults(deferreds))
        except:
            logger.exception("Exception in pusher on_new_notifications")
Beispiel #32
0
    def get_keys_from_store(self, server_name_and_key_ids):
        """

        Args:
            server_name_and_key_ids (list[(str, iterable[str])]):
                list of (server_name, iterable[key_id]) tuples to fetch keys for

        Returns:
            Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
                server_name -> key_id -> VerifyKey
        """
        res = yield preserve_context_over_deferred(defer.gatherResults(
            [
                preserve_fn(self.store.get_server_verify_keys)(
                    server_name, key_ids
                ).addCallback(lambda ks, server: (server, ks), server_name)
                for server_name, key_ids in server_name_and_key_ids
            ],
            consumeErrors=True,
        )).addErrback(unwrapFirstError)

        defer.returnValue(dict(res))
Beispiel #33
0
    def backfill(self, dest, context, limit, extremities):
        """Requests some more historic PDUs for the given context from the
        given destination server.

        Args:
            dest (str): The remote home server to ask.
            context (str): The context to backfill.
            limit (int): The maximum number of PDUs to return.
            extremities (list): List of PDU id and origins of the first pdus
                we have seen from the context

        Returns:
            Deferred: Results in the received PDUs.
        """
        logger.debug("backfill extrem=%s", extremities)

        # If there are no extremeties then we've (probably) reached the start.
        if not extremities:
            return

        transaction_data = yield self.transport_layer.backfill(
            dest, context, extremities, limit)

        logger.debug("backfill transaction_data=%s", repr(transaction_data))

        pdus = [
            self.event_from_pdu_json(p, outlier=False)
            for p in transaction_data["pdus"]
        ]

        # FIXME: We should handle signature failures more gracefully.
        pdus[:] = yield preserve_context_over_deferred(
            defer.gatherResults(
                self._check_sigs_and_hashes(pdus),
                consumeErrors=True,
            )).addErrback(unwrapFirstError)

        defer.returnValue(pdus)
Beispiel #34
0
    def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
        yield run_on_reactor()
        try:
            # Need to subtract 1 from the minimum because the lower bound here
            # is not inclusive
            updated_receipts = yield self.store.get_all_updated_receipts(
                min_stream_id - 1, max_stream_id
            )
            # This returns a tuple, user_id is at index 3
            users_affected = set([r[3] for r in updated_receipts])

            deferreds = []

            for u in users_affected:
                if u in self.pushers:
                    for p in self.pushers[u].values():
                        deferreds.append(
                            preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
                        )

            yield preserve_context_over_deferred(defer.gatherResults(deferreds))
        except:
            logger.exception("Exception in pusher on_new_receipts")
Beispiel #35
0
    def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
                                         order='DESC'):
        from_id = RoomStreamToken.parse_stream_token(from_key).stream

        room_ids = yield self._events_stream_cache.get_entities_changed(
            room_ids, from_id
        )

        if not room_ids:
            defer.returnValue({})

        results = {}
        room_ids = list(room_ids)
        for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
            res = yield preserve_context_over_deferred(defer.gatherResults([
                preserve_fn(self.get_room_events_stream_for_room)(
                    room_id, from_key, to_key, limit, order=order,
                )
                for room_id in rm_ids
            ]))
            results.update(dict(zip(rm_ids, res)))

        defer.returnValue(results)
Beispiel #36
0
    def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
        yield run_on_reactor()
        try:
            # Need to subtract 1 from the minimum because the lower bound here
            # is not inclusive
            updated_receipts = yield self.store.get_all_updated_receipts(
                min_stream_id - 1, max_stream_id
            )
            # This returns a tuple, user_id is at index 3
            users_affected = set([r[3] for r in updated_receipts])

            deferreds = []

            for u in users_affected:
                if u in self.pushers:
                    for p in self.pushers[u].values():
                        deferreds.append(
                            preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
                        )

            yield preserve_context_over_deferred(defer.gatherResults(deferreds))
        except:
            logger.exception("Exception in pusher on_new_receipts")
Beispiel #37
0
    def backfill(self, dest, context, limit, extremities):
        """Requests some more historic PDUs for the given context from the
        given destination server.

        Args:
            dest (str): The remote home server to ask.
            context (str): The context to backfill.
            limit (int): The maximum number of PDUs to return.
            extremities (list): List of PDU id and origins of the first pdus
                we have seen from the context

        Returns:
            Deferred: Results in the received PDUs.
        """
        logger.debug("backfill extrem=%s", extremities)

        # If there are no extremeties then we've (probably) reached the start.
        if not extremities:
            return

        transaction_data = yield self.transport_layer.backfill(
            dest, context, extremities, limit)

        logger.debug("backfill transaction_data=%s", repr(transaction_data))

        pdus = [
            self.event_from_pdu_json(p, outlier=False)
            for p in transaction_data["pdus"]
        ]

        # FIXME: We should handle signature failures more gracefully.
        pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
            self._check_sigs_and_hashes(pdus),
            consumeErrors=True,
        )).addErrback(unwrapFirstError)

        defer.returnValue(pdus)
Beispiel #38
0
    def process_v2_response(self, from_server, response_json,
                            requested_ids=[], only_from_server=True):
        time_now_ms = self.clock.time_msec()
        response_keys = {}
        verify_keys = {}
        for key_id, key_data in response_json["verify_keys"].items():
            if is_signing_algorithm_supported(key_id):
                key_base64 = key_data["key"]
                key_bytes = decode_base64(key_base64)
                verify_key = decode_verify_key_bytes(key_id, key_bytes)
                verify_key.time_added = time_now_ms
                verify_keys[key_id] = verify_key

        old_verify_keys = {}
        for key_id, key_data in response_json["old_verify_keys"].items():
            if is_signing_algorithm_supported(key_id):
                key_base64 = key_data["key"]
                key_bytes = decode_base64(key_base64)
                verify_key = decode_verify_key_bytes(key_id, key_bytes)
                verify_key.expired = key_data["expired_ts"]
                verify_key.time_added = time_now_ms
                old_verify_keys[key_id] = verify_key

        results = {}
        server_name = response_json["server_name"]
        if only_from_server:
            if server_name != from_server:
                raise KeyLookupError(
                    "Expected a response for server %r not %r" % (
                        from_server, server_name
                    )
                )
        for key_id in response_json["signatures"].get(server_name, {}):
            if key_id not in response_json["verify_keys"]:
                raise KeyLookupError(
                    "Key response must include verification keys for all"
                    " signatures"
                )
            if key_id in verify_keys:
                verify_signed_json(
                    response_json,
                    server_name,
                    verify_keys[key_id]
                )

        signed_key_json = sign_json(
            response_json,
            self.config.server_name,
            self.config.signing_key[0],
        )

        signed_key_json_bytes = encode_canonical_json(signed_key_json)
        ts_valid_until_ms = signed_key_json[u"valid_until_ts"]

        updated_key_ids = set(requested_ids)
        updated_key_ids.update(verify_keys)
        updated_key_ids.update(old_verify_keys)

        response_keys.update(verify_keys)
        response_keys.update(old_verify_keys)

        yield preserve_context_over_deferred(defer.gatherResults(
            [
                preserve_fn(self.store.store_server_keys_json)(
                    server_name=server_name,
                    key_id=key_id,
                    from_server=server_name,
                    ts_now_ms=time_now_ms,
                    ts_expires_ms=ts_valid_until_ms,
                    key_json_bytes=signed_key_json_bytes,
                )
                for key_id in updated_key_ids
            ],
            consumeErrors=True,
        )).addErrback(unwrapFirstError)

        results[server_name] = response_keys

        defer.returnValue(results)
Beispiel #39
0
    def get_missing_events(self, destination, room_id, earliest_events_ids,
                           latest_events, limit, min_depth):
        """Tries to fetch events we are missing. This is called when we receive
        an event without having received all of its ancestors.

        Args:
            destination (str)
            room_id (str)
            earliest_events_ids (list): List of event ids. Effectively the
                events we expected to receive, but haven't. `get_missing_events`
                should only return events that didn't happen before these.
            latest_events (list): List of events we have received that we don't
                have all previous events for.
            limit (int): Maximum number of events to return.
            min_depth (int): Minimum depth of events tor return.
        """
        try:
            content = yield self.transport_layer.get_missing_events(
                destination=destination,
                room_id=room_id,
                earliest_events=earliest_events_ids,
                latest_events=[e.event_id for e in latest_events],
                limit=limit,
                min_depth=min_depth,
            )

            events = [
                self.event_from_pdu_json(e) for e in content.get("events", [])
            ]

            signed_events = yield self._check_sigs_and_hash_and_fetch(
                destination, events, outlier=False)

            have_gotten_all_from_destination = True
        except HttpResponseException as e:
            if not e.code == 400:
                raise

            # We are probably hitting an old server that doesn't support
            # get_missing_events
            signed_events = []
            have_gotten_all_from_destination = False

        if len(signed_events) >= limit:
            defer.returnValue(signed_events)

        users = yield self.state.get_current_user_in_room(room_id)
        servers = set(get_domain_from_id(u) for u in users)

        servers = set(servers)
        servers.discard(self.server_name)

        failed_to_fetch = set()

        while len(signed_events) < limit:
            # Are we missing any?

            seen_events = set(earliest_events_ids)
            seen_events.update(e.event_id for e in signed_events if e)

            missing_events = {}
            for e in itertools.chain(latest_events, signed_events):
                if e.depth > min_depth:
                    missing_events.update({
                        e_id: e.depth
                        for e_id, _ in e.prev_events if e_id not in seen_events
                        and e_id not in failed_to_fetch
                    })

            if not missing_events:
                break

            have_seen = yield self.store.have_events(missing_events)

            for k in have_seen:
                missing_events.pop(k, None)

            if not missing_events:
                break

            # Okay, we haven't gotten everything yet. Lets get them.
            ordered_missing = sorted(missing_events.items(),
                                     key=lambda x: x[0])

            if have_gotten_all_from_destination:
                servers.discard(destination)

            def random_server_list():
                srvs = list(servers)
                random.shuffle(srvs)
                return srvs

            deferreds = [
                preserve_fn(self.get_pdu)(
                    destinations=random_server_list(),
                    event_id=e_id,
                )
                for e_id, depth in ordered_missing[:limit - len(signed_events)]
            ]

            res = yield preserve_context_over_deferred(
                defer.DeferredList(deferreds, consumeErrors=True))
            for (result, val), (e_id, _) in zip(res, ordered_missing):
                if result and val:
                    signed_events.append(val)
                else:
                    failed_to_fetch.add(e_id)

        defer.returnValue(signed_events)
Beispiel #40
0
    def get_events(self, destinations, room_id, event_ids, return_local=True):
        """Fetch events from some remote destinations, checking if we already
        have them.

        Args:
            destinations (list)
            room_id (str)
            event_ids (list)
            return_local (bool): Whether to include events we already have in
                the DB in the returned list of events

        Returns:
            Deferred: A deferred resolving to a 2-tuple where the first is a list of
            events and the second is a list of event ids that we failed to fetch.
        """
        if return_local:
            seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
            signed_events = seen_events.values()
        else:
            seen_events = yield self.store.have_events(event_ids)
            signed_events = []

        failed_to_fetch = set()

        missing_events = set(event_ids)
        for k in seen_events:
            missing_events.discard(k)

        if not missing_events:
            defer.returnValue((signed_events, failed_to_fetch))

        def random_server_list():
            srvs = list(destinations)
            random.shuffle(srvs)
            return srvs

        batch_size = 20
        missing_events = list(missing_events)
        for i in xrange(0, len(missing_events), batch_size):
            batch = set(missing_events[i:i + batch_size])

            deferreds = [
                preserve_fn(self.get_pdu)(
                    destinations=random_server_list(),
                    event_id=e_id,
                )
                for e_id in batch
            ]

            res = yield preserve_context_over_deferred(
                defer.DeferredList(deferreds, consumeErrors=True)
            )
            for success, result in res:
                if success and result:
                    signed_events.append(result)
                    batch.discard(result.event_id)

            # We removed all events we successfully fetched from `batch`
            failed_to_fetch.update(batch)

        defer.returnValue((signed_events, failed_to_fetch))
Beispiel #41
0
    def get_missing_events(self, destination, room_id, earliest_events_ids,
                           latest_events, limit, min_depth):
        """Tries to fetch events we are missing. This is called when we receive
        an event without having received all of its ancestors.

        Args:
            destination (str)
            room_id (str)
            earliest_events_ids (list): List of event ids. Effectively the
                events we expected to receive, but haven't. `get_missing_events`
                should only return events that didn't happen before these.
            latest_events (list): List of events we have received that we don't
                have all previous events for.
            limit (int): Maximum number of events to return.
            min_depth (int): Minimum depth of events tor return.
        """
        try:
            content = yield self.transport_layer.get_missing_events(
                destination=destination,
                room_id=room_id,
                earliest_events=earliest_events_ids,
                latest_events=[e.event_id for e in latest_events],
                limit=limit,
                min_depth=min_depth,
            )

            events = [
                self.event_from_pdu_json(e)
                for e in content.get("events", [])
            ]

            signed_events = yield self._check_sigs_and_hash_and_fetch(
                destination, events, outlier=False
            )

            have_gotten_all_from_destination = True
        except HttpResponseException as e:
            if not e.code == 400:
                raise

            # We are probably hitting an old server that doesn't support
            # get_missing_events
            signed_events = []
            have_gotten_all_from_destination = False

        if len(signed_events) >= limit:
            defer.returnValue(signed_events)

        users = yield self.state.get_current_user_in_room(room_id)
        servers = set(get_domain_from_id(u) for u in users)

        servers = set(servers)
        servers.discard(self.server_name)

        failed_to_fetch = set()

        while len(signed_events) < limit:
            # Are we missing any?

            seen_events = set(earliest_events_ids)
            seen_events.update(e.event_id for e in signed_events if e)

            missing_events = {}
            for e in itertools.chain(latest_events, signed_events):
                if e.depth > min_depth:
                    missing_events.update({
                        e_id: e.depth for e_id, _ in e.prev_events
                        if e_id not in seen_events
                        and e_id not in failed_to_fetch
                    })

            if not missing_events:
                break

            have_seen = yield self.store.have_events(missing_events)

            for k in have_seen:
                missing_events.pop(k, None)

            if not missing_events:
                break

            # Okay, we haven't gotten everything yet. Lets get them.
            ordered_missing = sorted(missing_events.items(), key=lambda x: x[0])

            if have_gotten_all_from_destination:
                servers.discard(destination)

            def random_server_list():
                srvs = list(servers)
                random.shuffle(srvs)
                return srvs

            deferreds = [
                preserve_fn(self.get_pdu)(
                    destinations=random_server_list(),
                    event_id=e_id,
                )
                for e_id, depth in ordered_missing[:limit - len(signed_events)]
            ]

            res = yield preserve_context_over_deferred(
                defer.DeferredList(deferreds, consumeErrors=True)
            )
            for (result, val), (e_id, _) in zip(res, ordered_missing):
                if result and val:
                    signed_events.append(val)
                else:
                    failed_to_fetch.add(e_id)

        defer.returnValue(signed_events)
Beispiel #42
0
        def wrapped(*args, **kwargs):
            # If we're passed a cache_context then we'll want to call its invalidate()
            # whenever we are invalidated
            invalidate_callback = kwargs.pop("on_invalidate", None)

            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
            list_args = arg_dict[self.list_name]

            # cached is a dict arg -> deferred, where deferred results in a
            # 2-tuple (`arg`, `result`)
            results = {}
            cached_defers = {}
            missing = []
            for arg in list_args:
                key = list(keyargs)
                key[self.list_pos] = arg

                try:
                    res = cache.get(tuple(key), callback=invalidate_callback)
                    if not res.has_succeeded():
                        res = res.observe()
                        res.addCallback(lambda r, arg: (arg, r), arg)
                        cached_defers[arg] = res
                    else:
                        results[arg] = res.get_result()
                except KeyError:
                    missing.append(arg)

            if missing:
                sequence = cache.sequence
                args_to_call = dict(arg_dict)
                args_to_call[self.list_name] = missing

                ret_d = defer.maybeDeferred(
                    preserve_context_over_fn,
                    self.function_to_call,
                    **args_to_call
                )

                ret_d = ObservableDeferred(ret_d)

                # We need to create deferreds for each arg in the list so that
                # we can insert the new deferred into the cache.
                for arg in missing:
                    with PreserveLoggingContext():
                        observer = ret_d.observe()
                    observer.addCallback(lambda r, arg: r.get(arg, None), arg)

                    observer = ObservableDeferred(observer)

                    key = list(keyargs)
                    key[self.list_pos] = arg
                    cache.update(
                        sequence, tuple(key), observer,
                        callback=invalidate_callback
                    )

                    def invalidate(f, key):
                        cache.invalidate(key)
                        return f
                    observer.addErrback(invalidate, tuple(key))

                    res = observer.observe()
                    res.addCallback(lambda r, arg: (arg, r), arg)

                    cached_defers[arg] = res

            if cached_defers:
                def update_results_dict(res):
                    results.update(res)
                    return results

                return preserve_context_over_deferred(defer.gatherResults(
                    cached_defers.values(),
                    consumeErrors=True,
                ).addCallback(update_results_dict).addErrback(
                    unwrapFirstError
                ))
            else:
                return results
    def _check_sigs_and_hash_and_fetch(self,
                                       origin,
                                       pdus,
                                       outlier=False,
                                       include_none=False):
        """Takes a list of PDUs and checks the signatures and hashs of each
        one. If a PDU fails its signature check then we check if we have it in
        the database and if not then request if from the originating server of
        that PDU.

        If a PDU fails its content hash check then it is redacted.

        The given list of PDUs are not modified, instead the function returns
        a new list.

        Args:
            pdu (list)
            outlier (bool)

        Returns:
            Deferred : A list of PDUs that have valid signatures and hashes.
        """
        deferreds = self._check_sigs_and_hashes(pdus)

        def callback(pdu):
            return pdu

        def errback(failure, pdu):
            failure.trap(SynapseError)
            return None

        def try_local_db(res, pdu):
            if not res:
                # Check local db.
                return self.store.get_event(
                    pdu.event_id,
                    allow_rejected=True,
                    allow_none=True,
                )
            return res

        def try_remote(res, pdu):
            if not res and pdu.origin != origin:
                return self.get_pdu(
                    destinations=[pdu.origin],
                    event_id=pdu.event_id,
                    outlier=outlier,
                    timeout=10000,
                ).addErrback(lambda e: None)
            return res

        def warn(res, pdu):
            if not res:
                logger.warn(
                    "Failed to find copy of %s with valid signature",
                    pdu.event_id,
                )
            return res

        for pdu, deferred in zip(pdus, deferreds):
            deferred.addCallbacks(callback, errback, errbackArgs=[
                pdu
            ]).addCallback(try_local_db,
                           pdu).addCallback(try_remote,
                                            pdu).addCallback(warn, pdu)

        valid_pdus = yield preserve_context_over_deferred(
            defer.gatherResults(
                deferreds, consumeErrors=True)).addErrback(unwrapFirstError)

        if include_none:
            defer.returnValue(valid_pdus)
        else:
            defer.returnValue([p for p in valid_pdus if p])
Beispiel #44
0
    def process_v2_response(self, from_server, response_json,
                            requested_ids=[], only_from_server=True):
        time_now_ms = self.clock.time_msec()
        response_keys = {}
        verify_keys = {}
        for key_id, key_data in response_json["verify_keys"].items():
            if is_signing_algorithm_supported(key_id):
                key_base64 = key_data["key"]
                key_bytes = decode_base64(key_base64)
                verify_key = decode_verify_key_bytes(key_id, key_bytes)
                verify_key.time_added = time_now_ms
                verify_keys[key_id] = verify_key

        old_verify_keys = {}
        for key_id, key_data in response_json["old_verify_keys"].items():
            if is_signing_algorithm_supported(key_id):
                key_base64 = key_data["key"]
                key_bytes = decode_base64(key_base64)
                verify_key = decode_verify_key_bytes(key_id, key_bytes)
                verify_key.expired = key_data["expired_ts"]
                verify_key.time_added = time_now_ms
                old_verify_keys[key_id] = verify_key

        results = {}
        server_name = response_json["server_name"]
        if only_from_server:
            if server_name != from_server:
                raise KeyLookupError(
                    "Expected a response for server %r not %r" % (
                        from_server, server_name
                    )
                )
        for key_id in response_json["signatures"].get(server_name, {}):
            if key_id not in response_json["verify_keys"]:
                raise KeyLookupError(
                    "Key response must include verification keys for all"
                    " signatures"
                )
            if key_id in verify_keys:
                verify_signed_json(
                    response_json,
                    server_name,
                    verify_keys[key_id]
                )

        signed_key_json = sign_json(
            response_json,
            self.config.server_name,
            self.config.signing_key[0],
        )

        signed_key_json_bytes = encode_canonical_json(signed_key_json)
        ts_valid_until_ms = signed_key_json[u"valid_until_ts"]

        updated_key_ids = set(requested_ids)
        updated_key_ids.update(verify_keys)
        updated_key_ids.update(old_verify_keys)

        response_keys.update(verify_keys)
        response_keys.update(old_verify_keys)

        yield preserve_context_over_deferred(defer.gatherResults(
            [
                preserve_fn(self.store.store_server_keys_json)(
                    server_name=server_name,
                    key_id=key_id,
                    from_server=server_name,
                    ts_now_ms=time_now_ms,
                    ts_expires_ms=ts_valid_until_ms,
                    key_json_bytes=signed_key_json_bytes,
                )
                for key_id in updated_key_ids
            ],
            consumeErrors=True,
        )).addErrback(unwrapFirstError)

        results[server_name] = response_keys

        defer.returnValue(results)
Beispiel #45
0
    def query_devices(self, query_body, timeout):
        """ Handle a device key query from a client

        {
            "device_keys": {
                "<user_id>": ["<device_id>"]
            }
        }
        ->
        {
            "device_keys": {
                "<user_id>": {
                    "<device_id>": {
                        ...
                    }
                }
            }
        }
        """
        device_keys_query = query_body.get("device_keys", {})

        # separate users by domain.
        # make a map from domain to user_id to device_ids
        local_query = {}
        remote_queries = {}

        for user_id, device_ids in device_keys_query.items():
            if self.is_mine_id(user_id):
                local_query[user_id] = device_ids
            else:
                remote_queries[user_id] = device_ids

        # Firt get local devices.
        failures = {}
        results = {}
        if local_query:
            local_result = yield self.query_local_devices(local_query)
            for user_id, keys in local_result.items():
                if user_id in local_query:
                    results[user_id] = keys

        # Now attempt to get any remote devices from our local cache.
        remote_queries_not_in_cache = {}
        if remote_queries:
            query_list = []
            for user_id, device_ids in remote_queries.iteritems():
                if device_ids:
                    query_list.extend((user_id, device_id) for device_id in device_ids)
                else:
                    query_list.append((user_id, None))

            user_ids_not_in_cache, remote_results = (
                yield self.store.get_user_devices_from_cache(
                    query_list
                )
            )
            for user_id, devices in remote_results.iteritems():
                user_devices = results.setdefault(user_id, {})
                for device_id, device in devices.iteritems():
                    keys = device.get("keys", None)
                    device_display_name = device.get("device_display_name", None)
                    if keys:
                        result = dict(keys)
                        unsigned = result.setdefault("unsigned", {})
                        if device_display_name:
                            unsigned["device_display_name"] = device_display_name
                        user_devices[device_id] = result

            for user_id in user_ids_not_in_cache:
                domain = get_domain_from_id(user_id)
                r = remote_queries_not_in_cache.setdefault(domain, {})
                r[user_id] = remote_queries[user_id]

        # Now fetch any devices that we don't have in our cache
        @defer.inlineCallbacks
        def do_remote_query(destination):
            destination_query = remote_queries_not_in_cache[destination]
            try:
                limiter = yield get_retry_limiter(
                    destination, self.clock, self.store
                )
                with limiter:
                    remote_result = yield self.federation.query_client_keys(
                        destination,
                        {"device_keys": destination_query},
                        timeout=timeout
                    )

                for user_id, keys in remote_result["device_keys"].items():
                    if user_id in destination_query:
                        results[user_id] = keys

            except CodeMessageException as e:
                failures[destination] = {
                    "status": e.code, "message": e.message
                }
            except NotRetryingDestination as e:
                failures[destination] = {
                    "status": 503, "message": "Not ready for retry",
                }
            except Exception as e:
                # include ConnectionRefused and other errors
                failures[destination] = {
                    "status": 503, "message": e.message
                }

        yield preserve_context_over_deferred(defer.gatherResults([
            preserve_fn(do_remote_query)(destination)
            for destination in remote_queries_not_in_cache
        ]))

        defer.returnValue({
            "device_keys": results, "failures": failures,
        })
Beispiel #46
0
    def query_devices(self, query_body, timeout):
        """ Handle a device key query from a client

        {
            "device_keys": {
                "<user_id>": ["<device_id>"]
            }
        }
        ->
        {
            "device_keys": {
                "<user_id>": {
                    "<device_id>": {
                        ...
                    }
                }
            }
        }
        """
        device_keys_query = query_body.get("device_keys", {})

        # separate users by domain.
        # make a map from domain to user_id to device_ids
        local_query = {}
        remote_queries = {}

        for user_id, device_ids in device_keys_query.items():
            if self.is_mine_id(user_id):
                local_query[user_id] = device_ids
            else:
                domain = get_domain_from_id(user_id)
                remote_queries.setdefault(domain, {})[user_id] = device_ids

        # do the queries
        failures = {}
        results = {}
        if local_query:
            local_result = yield self.query_local_devices(local_query)
            for user_id, keys in local_result.items():
                if user_id in local_query:
                    results[user_id] = keys

        @defer.inlineCallbacks
        def do_remote_query(destination):
            destination_query = remote_queries[destination]
            try:
                limiter = yield get_retry_limiter(
                    destination, self.clock, self.store
                )
                with limiter:
                    remote_result = yield self.federation.query_client_keys(
                        destination,
                        {"device_keys": destination_query},
                        timeout=timeout
                    )

                for user_id, keys in remote_result["device_keys"].items():
                    if user_id in destination_query:
                        results[user_id] = keys

            except CodeMessageException as e:
                failures[destination] = {
                    "status": e.code, "message": e.message
                }
            except NotRetryingDestination as e:
                failures[destination] = {
                    "status": 503, "message": "Not ready for retry",
                }

        yield preserve_context_over_deferred(defer.gatherResults([
            preserve_fn(do_remote_query)(destination)
            for destination in remote_queries
        ]))

        defer.returnValue({
            "device_keys": results, "failures": failures,
        })
def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
    """ Returns dict of user_id -> list of events that user is allowed to
    see.

    Args:
        user_tuples (str, bool): (user id, is_peeking) for each user to be
            checked. is_peeking should be true if:
            * the user is not currently a member of the room, and:
            * the user has not been a member of the room since the
            given events
        events ([synapse.events.EventBase]): list of events to filter
    """
    forgotten = yield preserve_context_over_deferred(defer.gatherResults([
        defer.maybeDeferred(
            preserve_fn(store.who_forgot_in_room),
            room_id,
        )
        for room_id in frozenset(e.room_id for e in events)
    ], consumeErrors=True))

    # Set of membership event_ids that have been forgotten
    event_id_forgotten = frozenset(
        row["event_id"] for rows in forgotten for row in rows
    )

    ignore_dict_content = yield store.get_global_account_data_by_type_for_users(
        "m.ignored_user_list", user_ids=[user_id for user_id, _ in user_tuples]
    )

    # FIXME: This will explode if people upload something incorrect.
    ignore_dict = {
        user_id: frozenset(
            content.get("ignored_users", {}).keys() if content else []
        )
        for user_id, content in ignore_dict_content.items()
    }

    def allowed(event, user_id, is_peeking, ignore_list):
        """
        Args:
            event (synapse.events.EventBase): event to check
            user_id (str)
            is_peeking (bool)
            ignore_list (list): list of users to ignore
        """
        if not event.is_state() and event.sender in ignore_list:
            return False

        state = event_id_to_state[event.event_id]

        # get the room_visibility at the time of the event.
        visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
        if visibility_event:
            visibility = visibility_event.content.get("history_visibility", "shared")
        else:
            visibility = "shared"

        if visibility not in VISIBILITY_PRIORITY:
            visibility = "shared"

        # if it was world_readable, it's easy: everyone can read it
        if visibility == "world_readable":
            return True

        # Always allow history visibility events on boundaries. This is done
        # by setting the effective visibility to the least restrictive
        # of the old vs new.
        if event.type == EventTypes.RoomHistoryVisibility:
            prev_content = event.unsigned.get("prev_content", {})
            prev_visibility = prev_content.get("history_visibility", None)

            if prev_visibility not in VISIBILITY_PRIORITY:
                prev_visibility = "shared"

            new_priority = VISIBILITY_PRIORITY.index(visibility)
            old_priority = VISIBILITY_PRIORITY.index(prev_visibility)
            if old_priority < new_priority:
                visibility = prev_visibility

        # likewise, if the event is the user's own membership event, use
        # the 'most joined' membership
        membership = None
        if event.type == EventTypes.Member and event.state_key == user_id:
            membership = event.content.get("membership", None)
            if membership not in MEMBERSHIP_PRIORITY:
                membership = "leave"

            prev_content = event.unsigned.get("prev_content", {})
            prev_membership = prev_content.get("membership", None)
            if prev_membership not in MEMBERSHIP_PRIORITY:
                prev_membership = "leave"

            # Always allow the user to see their own leave events, otherwise
            # they won't see the room disappear if they reject the invite
            if membership == "leave" and (
                prev_membership == "join" or prev_membership == "invite"
            ):
                return True

            new_priority = MEMBERSHIP_PRIORITY.index(membership)
            old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
            if old_priority < new_priority:
                membership = prev_membership

        # otherwise, get the user's membership at the time of the event.
        if membership is None:
            membership_event = state.get((EventTypes.Member, user_id), None)
            if membership_event:
                if membership_event.event_id not in event_id_forgotten:
                    membership = membership_event.membership

        # if the user was a member of the room at the time of the event,
        # they can see it.
        if membership == Membership.JOIN:
            return True

        if visibility == "joined":
            # we weren't a member at the time of the event, so we can't
            # see this event.
            return False

        elif visibility == "invited":
            # user can also see the event if they were *invited* at the time
            # of the event.
            return membership == Membership.INVITE

        else:
            # visibility is shared: user can also see the event if they have
            # become a member since the event
            #
            # XXX: if the user has subsequently joined and then left again,
            # ideally we would share history up to the point they left. But
            # we don't know when they left.
            return not is_peeking

    defer.returnValue({
        user_id: [
            event
            for event in events
            if allowed(event, user_id, is_peeking, ignore_dict.get(user_id, []))
        ]
        for user_id, is_peeking in user_tuples
    })
Beispiel #48
0
        def wrapped(*args, **kwargs):
            # If we're passed a cache_context then we'll want to call its invalidate()
            # whenever we are invalidated
            invalidate_callback = kwargs.pop("on_invalidate", None)

            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
            list_args = arg_dict[self.list_name]

            # cached is a dict arg -> deferred, where deferred results in a
            # 2-tuple (`arg`, `result`)
            results = {}
            cached_defers = {}
            missing = []
            for arg in list_args:
                key = list(keyargs)
                key[self.list_pos] = arg

                try:
                    res = cache.get(tuple(key), callback=invalidate_callback)
                    if not res.has_succeeded():
                        res = res.observe()
                        res.addCallback(lambda r, arg: (arg, r), arg)
                        cached_defers[arg] = res
                    else:
                        results[arg] = res.get_result()
                except KeyError:
                    missing.append(arg)

            if missing:
                args_to_call = dict(arg_dict)
                args_to_call[self.list_name] = missing

                ret_d = defer.maybeDeferred(preserve_context_over_fn,
                                            self.function_to_call,
                                            **args_to_call)

                ret_d = ObservableDeferred(ret_d)

                # We need to create deferreds for each arg in the list so that
                # we can insert the new deferred into the cache.
                for arg in missing:
                    with PreserveLoggingContext():
                        observer = ret_d.observe()
                    observer.addCallback(lambda r, arg: r.get(arg, None), arg)

                    observer = ObservableDeferred(observer)

                    key = list(keyargs)
                    key[self.list_pos] = arg
                    cache.set(tuple(key),
                              observer,
                              callback=invalidate_callback)

                    def invalidate(f, key):
                        cache.invalidate(key)
                        return f

                    observer.addErrback(invalidate, tuple(key))

                    res = observer.observe()
                    res.addCallback(lambda r, arg: (arg, r), arg)

                    cached_defers[arg] = res

            if cached_defers:

                def update_results_dict(res):
                    results.update(res)
                    return results

                return preserve_context_over_deferred(
                    defer.gatherResults(
                        cached_defers.values(),
                        consumeErrors=True,
                    ).addCallback(update_results_dict).addErrback(
                        unwrapFirstError))
            else:
                return results
Beispiel #49
0
        def handle_room(event):
            d = {
                "room_id":
                event.room_id,
                "membership":
                event.membership,
                "visibility":
                ("public" if event.room_id in public_room_ids else "private"),
            }

            if event.membership == Membership.INVITE:
                time_now = self.clock.time_msec()
                d["inviter"] = event.sender

                invite_event = yield self.store.get_event(event.event_id)
                d["invite"] = serialize_event(invite_event, time_now,
                                              as_client_event)

            rooms_ret.append(d)

            if event.membership not in (Membership.JOIN, Membership.LEAVE):
                return

            try:
                if event.membership == Membership.JOIN:
                    room_end_token = now_token.room_key
                    deferred_room_state = self.state_handler.get_current_state(
                        event.room_id)
                elif event.membership == Membership.LEAVE:
                    room_end_token = "s%d" % (event.stream_ordering, )
                    deferred_room_state = self.store.get_state_for_events(
                        [event.event_id], None)
                    deferred_room_state.addCallback(
                        lambda states: states[event.event_id])

                (messages,
                 token), current_state = yield preserve_context_over_deferred(
                     defer.gatherResults([
                         preserve_fn(self.store.get_recent_events_for_room)(
                             event.room_id,
                             limit=limit,
                             end_token=room_end_token,
                         ),
                         deferred_room_state,
                     ])).addErrback(unwrapFirstError)

                messages = yield filter_events_for_client(
                    self.store, user_id, messages)

                start_token = now_token.copy_and_replace("room_key", token[0])
                end_token = now_token.copy_and_replace("room_key", token[1])
                time_now = self.clock.time_msec()

                d["messages"] = {
                    "chunk": [
                        serialize_event(m, time_now, as_client_event)
                        for m in messages
                    ],
                    "start":
                    start_token.to_string(),
                    "end":
                    end_token.to_string(),
                }

                d["state"] = [
                    serialize_event(c, time_now, as_client_event)
                    for c in current_state.values()
                ]

                account_data_events = []
                tags = tags_by_room.get(event.room_id)
                if tags:
                    account_data_events.append({
                        "type": "m.tag",
                        "content": {
                            "tags": tags
                        },
                    })

                account_data = account_data_by_room.get(event.room_id, {})
                for account_data_type, content in account_data.items():
                    account_data_events.append({
                        "type": account_data_type,
                        "content": content,
                    })

                d["account_data"] = account_data_events
            except Exception:
                logger.exception("Failed to get snapshot")
Beispiel #50
0
    def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
                                          perspective_name,
                                          perspective_keys):
        # TODO(mark): Set the minimum_valid_until_ts to that needed by
        # the events being validated or the current time if validating
        # an incoming request.
        query_response = yield self.client.post_json(
            destination=perspective_name,
            path=b"/_matrix/key/v2/query",
            data={
                u"server_keys": {
                    server_name: {
                        key_id: {
                            u"minimum_valid_until_ts": 0
                        } for key_id in key_ids
                    }
                    for server_name, key_ids in server_names_and_key_ids
                }
            },
            long_retries=True,
        )

        keys = {}

        responses = query_response["server_keys"]

        for response in responses:
            if (u"signatures" not in response
                    or perspective_name not in response[u"signatures"]):
                raise KeyLookupError(
                    "Key response not signed by perspective server"
                    " %r" % (perspective_name,)
                )

            verified = False
            for key_id in response[u"signatures"][perspective_name]:
                if key_id in perspective_keys:
                    verify_signed_json(
                        response,
                        perspective_name,
                        perspective_keys[key_id]
                    )
                    verified = True

            if not verified:
                logging.info(
                    "Response from perspective server %r not signed with a"
                    " known key, signed with: %r, known keys: %r",
                    perspective_name,
                    list(response[u"signatures"][perspective_name]),
                    list(perspective_keys)
                )
                raise KeyLookupError(
                    "Response not signed with a known key for perspective"
                    " server %r" % (perspective_name,)
                )

            processed_response = yield self.process_v2_response(
                perspective_name, response, only_from_server=False
            )

            for server_name, response_keys in processed_response.items():
                keys.setdefault(server_name, {}).update(response_keys)

        yield preserve_context_over_deferred(defer.gatherResults(
            [
                preserve_fn(self.store_keys)(
                    server_name=server_name,
                    from_server=perspective_name,
                    verify_keys=response_keys,
                )
                for server_name, response_keys in keys.items()
            ],
            consumeErrors=True
        )).addErrback(unwrapFirstError)

        defer.returnValue(keys)