Esempio n. 1
0
    def _start_pushers(self, pushers):
        if not self.start_pushers:
            logger.info(
                "Not starting pushers because they are disabled in the config")
            return
        logger.info("Starting %d pushers", len(pushers))
        for pusherdict in pushers:
            try:
                p = self.pusher_factory.create_pusher(pusherdict)
            except Exception:
                logger.exception("Couldn't start a pusher: caught Exception")
                continue
            if p:
                appid_pushkey = "%s:%s" % (
                    pusherdict['app_id'],
                    pusherdict['pushkey'],
                )
                byuser = self.pushers.setdefault(pusherdict['user_name'], {})

                if appid_pushkey in byuser:
                    byuser[appid_pushkey].on_stop()
                byuser[appid_pushkey] = p
                preserve_fn(p.on_started)()

        logger.info("Started pushers")
Esempio n. 2
0
    def on_rdata(self, stream_name, token, rows):
        super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)

        if stream_name == "events":
            max_stream_id = self.store.get_room_max_stream_ordering()
            preserve_fn(self.appservice_handler.notify_interested_services)(
                max_stream_id)
Esempio n. 3
0
    def _push_update(self, room_id, user_id, typing):
        users = yield self.state.get_current_user_in_room(room_id)
        domains = set(get_domain_from_id(u) for u in users)

        deferreds = []
        for domain in domains:
            if domain == self.server_name:
                preserve_fn(self._push_update_local)(room_id=room_id,
                                                     user_id=user_id,
                                                     typing=typing)
            else:
                deferreds.append(
                    preserve_fn(self.federation.send_edu)(
                        destination=domain,
                        edu_type="m.typing",
                        content={
                            "room_id": room_id,
                            "user_id": user_id,
                            "typing": typing,
                        },
                        key=(room_id, user_id),
                    ))

        yield preserve_context_over_deferred(
            defer.DeferredList(deferreds, consumeErrors=True))
Esempio n. 4
0
    def _push_update(self, room_id, user_id, typing):
        users = yield self.state.get_current_user_in_room(room_id)
        domains = set(get_domain_from_id(u) for u in users)

        deferreds = []
        for domain in domains:
            if domain == self.server_name:
                preserve_fn(self._push_update_local)(
                    room_id=room_id,
                    user_id=user_id,
                    typing=typing
                )
            else:
                deferreds.append(preserve_fn(self.federation.send_edu)(
                    destination=domain,
                    edu_type="m.typing",
                    content={
                        "room_id": room_id,
                        "user_id": user_id,
                        "typing": typing,
                    },
                    key=(room_id, user_id),
                ))

        yield preserve_context_over_deferred(
            defer.DeferredList(deferreds, consumeErrors=True)
        )
Esempio n. 5
0
    def _renew_attestations(self):
        """Called periodically to check if we need to update any of our attestations
        """

        now = self.clock.time_msec()

        rows = yield self.store.get_attestations_need_renewals(
            now + UPDATE_ATTESTATION_TIME_MS)

        @defer.inlineCallbacks
        def _renew_attestation(group_id, user_id):
            attestation = self.attestations.create_attestation(
                group_id, user_id)

            if self.is_mine_id(group_id):
                destination = get_domain_from_id(user_id)
            else:
                destination = get_domain_from_id(group_id)

            yield self.transport_client.renew_group_attestation(
                destination,
                group_id,
                user_id,
                content={"attestation": attestation},
            )

            yield self.store.update_attestation_renewal(
                group_id, user_id, attestation)

        for row in rows:
            group_id = row["group_id"]
            user_id = row["user_id"]

            preserve_fn(_renew_attestation)(group_id, user_id)
Esempio n. 6
0
def get_badge_count(store, user_id):
    invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
        preserve_fn(store.get_invited_rooms_for_user)(user_id),
        preserve_fn(store.get_rooms_for_user)(user_id),
    ], consumeErrors=True))

    my_receipts_by_room = yield store.get_receipts_for_user(
        user_id, "m.read",
    )

    badge = len(invites)

    for r in joins:
        if r.room_id in my_receipts_by_room:
            last_unread_event_id = my_receipts_by_room[r.room_id]

            notifs = yield (
                store.get_unread_event_push_actions_by_room_for_user(
                    r.room_id, user_id, last_unread_event_id
                )
            )
            # return one badge count per conversation, as count per
            # message is so noisy as to be almost useless
            badge += 1 if notifs["notify_count"] else 0
    defer.returnValue(badge)
Esempio n. 7
0
    def _handle_timeouts(self):
        """Checks the presence of users that have timed out and updates as
        appropriate.
        """
        now = self.clock.time_msec()

        with Measure(self.clock, "presence_handle_timeouts"):
            # Fetch the list of users that *may* have timed out. Things may have
            # changed since the timeout was set, so we won't necessarily have to
            # take any action.
            users_to_check = self.wheel_timer.fetch(now)

            states = [
                self.user_to_current_state.get(
                    user_id, UserPresenceState.default(user_id)
                )
                for user_id in set(users_to_check)
            ]

            timers_fired_counter.inc_by(len(states))

            changes = handle_timeouts(
                states,
                is_mine_fn=self.hs.is_mine_id,
                user_to_num_current_syncs=self.user_to_num_current_syncs,
                now=now,
            )

        preserve_fn(self._update_states)(changes)
Esempio n. 8
0
    def copy_to_backup(self, path):
        """Copy a file from the primary to backup media store, if configured.

        Args:
            path(str): Relative path to write file to
        """
        if self.backup_base_path:
            primary_fname = os.path.join(self.primary_base_path, path)
            backup_fname = os.path.join(self.backup_base_path, path)

            # We can either wait for successful writing to the backup repository
            # or write in the background and immediately return
            if self.synchronous_backup_media_store:
                yield make_deferred_yieldable(
                    threads.deferToThread(
                        shutil.copyfile,
                        primary_fname,
                        backup_fname,
                    ))
            else:
                preserve_fn(threads.deferToThread)(
                    shutil.copyfile,
                    primary_fname,
                    backup_fname,
                )
Esempio n. 9
0
        def poke_pushers(results):
            pushers_rows = set(
                map(tuple, results.get("pushers", {}).get("rows", []))
            )
            deleted_pushers_rows = set(
                map(tuple, results.get("deleted_pushers", {}).get("rows", []))
            )
            for row in sorted(pushers_rows | deleted_pushers_rows):
                if row in deleted_pushers_rows:
                    user_id, app_id, pushkey = row[1:4]
                    stop_pusher(user_id, app_id, pushkey)
                elif row in pushers_rows:
                    user_id = row[1]
                    app_id = row[5]
                    pushkey = row[8]
                    yield start_pusher(user_id, app_id, pushkey)

            stream = results.get("events")
            if stream:
                min_stream_id = stream["rows"][0][0]
                max_stream_id = stream["position"]
                preserve_fn(pusher_pool.on_new_notifications)(
                    min_stream_id, max_stream_id
                )

            stream = results.get("receipts")
            if stream:
                rows = stream["rows"]
                affected_room_ids = set(row[1] for row in rows)
                min_stream_id = rows[0][0]
                max_stream_id = stream["position"]
                preserve_fn(pusher_pool.on_new_receipts)(
                    min_stream_id, max_stream_id, affected_room_ids
                )
Esempio n. 10
0
        def poke_pushers(results):
            pushers_rows = set(
                map(tuple,
                    results.get("pushers", {}).get("rows", [])))
            deleted_pushers_rows = set(
                map(tuple,
                    results.get("deleted_pushers", {}).get("rows", [])))
            for row in sorted(pushers_rows | deleted_pushers_rows):
                if row in deleted_pushers_rows:
                    user_id, app_id, pushkey = row[1:4]
                    stop_pusher(user_id, app_id, pushkey)
                elif row in pushers_rows:
                    user_id = row[1]
                    app_id = row[5]
                    pushkey = row[8]
                    yield start_pusher(user_id, app_id, pushkey)

            stream = results.get("events")
            if stream and stream["rows"]:
                min_stream_id = stream["rows"][0][0]
                max_stream_id = stream["position"]
                preserve_fn(pusher_pool.on_new_notifications)(min_stream_id,
                                                              max_stream_id)

            stream = results.get("receipts")
            if stream and stream["rows"]:
                rows = stream["rows"]
                affected_room_ids = set(row[1] for row in rows)
                min_stream_id = rows[0][0]
                max_stream_id = stream["position"]
                preserve_fn(pusher_pool.on_new_receipts)(min_stream_id,
                                                         max_stream_id,
                                                         affected_room_ids)
Esempio n. 11
0
    def _handle_timeouts(self):
        logger.info("Checking for typing timeouts")

        now = self.clock.time_msec()

        members = set(self.wheel_timer.fetch(now))

        for member in members:
            if not self.is_typing(member):
                # Nothing to do if they're no longer typing
                continue

            until = self._member_typing_until.get(member, None)
            if not until or until < now:
                logger.info("Timing out typing for: %s", member.user_id)
                preserve_fn(self._stopped_typing)(member)
                continue

            # Check if we need to resend a keep alive over federation for this
            # user.
            if self.hs.is_mine_id(member.user_id):
                last_fed_poke = self._member_last_federation_poke.get(member, None)
                if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL < now:
                    preserve_fn(self._push_remote)(
                        member=member,
                        typing=True
                    )
Esempio n. 12
0
    def notify_interested_services(self, current_id):
        """Notifies (pushes) all application services interested in this event.

        Pushing is done asynchronously, so this method won't block for any
        prolonged length of time.

        Args:
            current_id(int): The current maximum ID.
        """
        services = yield self.store.get_app_services()
        if not services or not self.notify_appservices:
            return

        self.current_max = max(self.current_max, current_id)
        if self.is_processing:
            return

        with Measure(self.clock, "notify_interested_services"):
            self.is_processing = True
            try:
                upper_bound = self.current_max
                limit = 100
                while True:
                    upper_bound, events = yield self.store.get_new_events_for_appservice(
                        upper_bound, limit)

                    if not events:
                        break

                    for event in events:
                        # Gather interested services
                        services = yield self._get_services_for_event(event)
                        if len(services) == 0:
                            continue  # no services need notifying

                        # Do we know this user exists? If not, poke the user
                        # query API for all services which match that user regex.
                        # This needs to block as these user queries need to be
                        # made BEFORE pushing the event.
                        yield self._check_user_exists(event.sender)
                        if event.type == EventTypes.Member:
                            yield self._check_user_exists(event.state_key)

                        if not self.started_scheduler:
                            self.scheduler.start().addErrback(log_failure)
                            self.started_scheduler = True

                        # Fork off pushes to these services
                        for service in services:
                            preserve_fn(self.scheduler.submit_event_for_as)(
                                service, event)

                    yield self.store.set_appservice_last_pos(upper_bound)

                    if len(events) < limit:
                        break
            finally:
                self.is_processing = False
    def notify_replication(self):
        """Notify the any replication listeners that there's a new event"""
        with PreserveLoggingContext():
            deferred = self.replication_deferred
            self.replication_deferred = ObservableDeferred(defer.Deferred())
            deferred.callback(None)

        for cb in self.replication_callbacks:
            preserve_fn(cb)()
Esempio n. 14
0
    def process_replication_rows(self, stream_name, token, rows):
        # The federation stream contains things that we want to send out, e.g.
        # presence, typing, etc.
        if stream_name == "federation":
            send_queue.process_rows_for_federation(self.federation_sender, rows)
            preserve_fn(self.update_token)(token)

        # We also need to poke the federation sender when new events happen
        elif stream_name == "events":
            self.federation_sender.notify_new_events(token)
Esempio n. 15
0
    def store_file(self, path, file_info):
        if not file_info.server_name and not self.store_local:
            return defer.succeed(None)

        if file_info.server_name and not self.store_remote:
            return defer.succeed(None)

        if self.store_synchronous:
            return self.backend.store_file(path, file_info)
        else:
            # TODO: Handle errors.
            preserve_fn(self.backend.store_file)(path, file_info)
            return defer.succeed(None)
Esempio n. 16
0
 def send(self, service, events):
     try:
         txn = yield self.store.create_appservice_txn(service=service,
                                                      events=events)
         service_is_up = yield self._is_service_up(service)
         if service_is_up:
             sent = yield txn.send(self.as_api)
             if sent:
                 yield txn.complete(self.store)
             else:
                 preserve_fn(self._start_recoverer)(service)
     except Exception as e:
         logger.exception(e)
         preserve_fn(self._start_recoverer)(service)
Esempio n. 17
0
    def send_nonmember_event(self, requester, event, context, ratelimit=True):
        """
        Persists and notifies local clients and federation of an event.

        Args:
            event (FrozenEvent) the event to send.
            context (Context) the context of the event.
            ratelimit (bool): Whether to rate limit this send.
            is_guest (bool): Whether the sender is a guest.
        """
        if event.type == EventTypes.Member:
            raise SynapseError(
                500, "Tried to send member event through non-member codepath")

        # We check here if we are currently being rate limited, so that we
        # don't do unnecessary work. We check again just before we actually
        # send the event.
        time_now = self.clock.time()
        allowed, time_allowed = self.ratelimiter.send_message(
            event.sender,
            time_now,
            msg_rate_hz=self.hs.config.rc_messages_per_second,
            burst_count=self.hs.config.rc_message_burst_count,
            update=False,
        )
        if not allowed:
            raise LimitExceededError(retry_after_ms=int(
                1000 * (time_allowed - time_now)), )

        user = UserID.from_string(event.sender)

        assert self.hs.is_mine(user), "User must be our own: %s" % (user, )

        if event.is_state():
            prev_state = yield self.deduplicate_state_event(event, context)
            if prev_state is not None:
                defer.returnValue(prev_state)

        yield self.handle_new_client_event(
            requester=requester,
            event=event,
            context=context,
            ratelimit=ratelimit,
        )

        if event.type == EventTypes.Message:
            presence = self.hs.get_presence_handler()
            # We don't want to block sending messages on any presence code. This
            # matters as sometimes presence code can take a while.
            preserve_fn(presence.bump_presence_active_time)(user)
Esempio n. 18
0
    def get_room_events_stream_for_rooms(self,
                                         room_ids,
                                         from_key,
                                         to_key,
                                         limit=0,
                                         order='DESC'):
        from_id = RoomStreamToken.parse_stream_token(from_key).stream

        room_ids = yield self._events_stream_cache.get_entities_changed(
            room_ids, from_id)

        if not room_ids:
            defer.returnValue({})

        results = {}
        room_ids = list(room_ids)
        for rm_ids in (room_ids[i:i + 20]
                       for i in xrange(0, len(room_ids), 20)):
            res = yield preserve_context_over_deferred(
                defer.gatherResults([
                    preserve_fn(self.get_room_events_stream_for_room)(
                        room_id,
                        from_key,
                        to_key,
                        limit,
                        order=order,
                    ) for room_id in rm_ids
                ]))
            results.update(dict(zip(rm_ids, res)))

        defer.returnValue(results)
Esempio n. 19
0
 def send(self, service, events):
     try:
         txn = yield self.store.create_appservice_txn(
             service=service,
             events=events
         )
         service_is_up = yield self._is_service_up(service)
         if service_is_up:
             sent = yield txn.send(self.as_api)
             if sent:
                 yield txn.complete(self.store)
             else:
                 preserve_fn(self._start_recoverer)(service)
     except Exception as e:
         logger.exception(e)
         preserve_fn(self._start_recoverer)(service)
Esempio n. 20
0
    def verify_json_objects_for_server(self, server_and_json):
        """Bulk verifies signatures of json objects, bulk fetching keys as
        necessary.

        Args:
            server_and_json (list): List of pairs of (server_name, json_object)

        Returns:
            List<Deferred>: for each input pair, a deferred indicating success
                or failure to verify each json object's signature for the given
                server_name. The deferreds run their callbacks in the sentinel
                logcontext.
        """
        # a list of VerifyKeyRequests
        verify_requests = []
        handle = preserve_fn(_handle_key_deferred)

        def process(server_name, json_object):
            """Process an entry in the request list

            Given a (server_name, json_object) pair from the request list,
            adds a key request to verify_requests, and returns a deferred which will
            complete or fail (in the sentinel context) when verification completes.
            """
            key_ids = signature_ids(json_object, server_name)

            if not key_ids:
                return defer.fail(
                    SynapseError(
                        400,
                        "Not signed by %s" % (server_name,),
                        Codes.UNAUTHORIZED,
                    )
                )

            logger.debug("Verifying for %s with key_ids %s",
                         server_name, key_ids)

            # add the key request to the queue, but don't start it off yet.
            verify_request = VerifyKeyRequest(
                server_name, key_ids, json_object, defer.Deferred(),
            )
            verify_requests.append(verify_request)

            # now run _handle_key_deferred, which will wait for the key request
            # to complete and then do the verification.
            #
            # We want _handle_key_request to log to the right context, so we
            # wrap it with preserve_fn (aka run_in_background)
            return handle(verify_request)

        results = [
            process(server_name, json_object)
            for server_name, json_object in server_and_json
        ]

        if verify_requests:
            run_in_background(self._start_key_lookups, verify_requests)

        return results
Esempio n. 21
0
    def get_keys_from_perspectives(self, server_name_and_key_ids):
        @defer.inlineCallbacks
        def get_key(perspective_name, perspective_keys):
            try:
                result = yield self.get_server_verify_key_v2_indirect(
                    server_name_and_key_ids, perspective_name, perspective_keys
                )
                defer.returnValue(result)
            except Exception as e:
                logger.exception(
                    "Unable to get key from %r: %s %s",
                    perspective_name,
                    type(e).__name__, str(e.message),
                )
                defer.returnValue({})

        results = yield preserve_context_over_deferred(defer.gatherResults(
            [
                preserve_fn(get_key)(p_name, p_keys)
                for p_name, p_keys in self.perspective_servers.items()
            ],
            consumeErrors=True,
        )).addErrback(unwrapFirstError)

        union_of_keys = {}
        for result in results:
            for server_name, keys in result.items():
                union_of_keys.setdefault(server_name, {}).update(keys)

        defer.returnValue(union_of_keys)
Esempio n. 22
0
    def get_events(self, destinations, room_id, event_ids, return_local=True):
        """Fetch events from some remote destinations, checking if we already
        have them.

        Args:
            destinations (list)
            room_id (str)
            event_ids (list)
            return_local (bool): Whether to include events we already have in
                the DB in the returned list of events

        Returns:
            Deferred: A deferred resolving to a 2-tuple where the first is a list of
            events and the second is a list of event ids that we failed to fetch.
        """
        if return_local:
            seen_events = yield self.store.get_events(event_ids,
                                                      allow_rejected=True)
            signed_events = seen_events.values()
        else:
            seen_events = yield self.store.have_events(event_ids)
            signed_events = []

        failed_to_fetch = set()

        missing_events = set(event_ids)
        for k in seen_events:
            missing_events.discard(k)

        if not missing_events:
            defer.returnValue((signed_events, failed_to_fetch))

        def random_server_list():
            srvs = list(destinations)
            random.shuffle(srvs)
            return srvs

        batch_size = 20
        missing_events = list(missing_events)
        for i in xrange(0, len(missing_events), batch_size):
            batch = set(missing_events[i:i + batch_size])

            deferreds = [
                preserve_fn(self.get_pdu)(
                    destinations=random_server_list(),
                    event_id=e_id,
                ) for e_id in batch
            ]

            res = yield preserve_context_over_deferred(
                defer.DeferredList(deferreds, consumeErrors=True))
            for success, result in res:
                if success:
                    signed_events.append(result)
                    batch.discard(result.event_id)

            # We removed all events we successfully fetched from `batch`
            failed_to_fetch.update(batch)

        defer.returnValue((signed_events, failed_to_fetch))
Esempio n. 23
0
    def get_keys_from_perspectives(self, server_name_and_key_ids):
        @defer.inlineCallbacks
        def get_key(perspective_name, perspective_keys):
            try:
                result = yield self.get_server_verify_key_v2_indirect(
                    server_name_and_key_ids, perspective_name,
                    perspective_keys)
                defer.returnValue(result)
            except Exception as e:
                logger.exception(
                    "Unable to get key from %r: %s %s",
                    perspective_name,
                    type(e).__name__,
                    str(e.message),
                )
                defer.returnValue({})

        results = yield logcontext.make_deferred_yieldable(
            defer.gatherResults(
                [
                    preserve_fn(get_key)(p_name, p_keys)
                    for p_name, p_keys in self.perspective_servers.items()
                ],
                consumeErrors=True,
            ).addErrback(unwrapFirstError))

        union_of_keys = {}
        for result in results:
            for server_name, keys in result.items():
                union_of_keys.setdefault(server_name, {}).update(keys)

        defer.returnValue(union_of_keys)
Esempio n. 24
0
    def _handle_timeouts(self):
        """Checks the presence of users that have timed out and updates as
        appropriate.
        """
        logger.info("Handling presence timeouts")
        now = self.clock.time_msec()

        try:
            with Measure(self.clock, "presence_handle_timeouts"):
                # Fetch the list of users that *may* have timed out. Things may have
                # changed since the timeout was set, so we won't necessarily have to
                # take any action.
                users_to_check = set(self.wheel_timer.fetch(now))

                # Check whether the lists of syncing processes from an external
                # process have expired.
                expired_process_ids = [
                    process_id for process_id, last_update
                    in self.external_process_last_updated_ms.items()
                    if now - last_update > EXTERNAL_PROCESS_EXPIRY
                ]
                for process_id in expired_process_ids:
                    users_to_check.update(
                        self.external_process_last_updated_ms.pop(process_id, ())
                    )
                    self.external_process_last_update.pop(process_id)

                states = [
                    self.user_to_current_state.get(
                        user_id, UserPresenceState.default(user_id)
                    )
                    for user_id in users_to_check
                ]

                timers_fired_counter.inc_by(len(states))

                changes = handle_timeouts(
                    states,
                    is_mine_fn=self.is_mine_id,
                    syncing_user_ids=self.get_currently_syncing_users(),
                    now=now,
                )

            preserve_fn(self._update_states)(changes)
        except:
            logger.exception("Exception in _handle_timeouts loop")
Esempio n. 25
0
    def _handle_timeouts(self):
        """Checks the presence of users that have timed out and updates as
        appropriate.
        """
        logger.info("Handling presence timeouts")
        now = self.clock.time_msec()

        try:
            with Measure(self.clock, "presence_handle_timeouts"):
                # Fetch the list of users that *may* have timed out. Things may have
                # changed since the timeout was set, so we won't necessarily have to
                # take any action.
                users_to_check = set(self.wheel_timer.fetch(now))

                # Check whether the lists of syncing processes from an external
                # process have expired.
                expired_process_ids = [
                    process_id for process_id, last_update
                    in self.external_process_last_updated_ms.items()
                    if now - last_update > EXTERNAL_PROCESS_EXPIRY
                ]
                for process_id in expired_process_ids:
                    users_to_check.update(
                        self.external_process_last_updated_ms.pop(process_id, ())
                    )
                    self.external_process_last_update.pop(process_id)

                states = [
                    self.user_to_current_state.get(
                        user_id, UserPresenceState.default(user_id)
                    )
                    for user_id in users_to_check
                ]

                timers_fired_counter.inc_by(len(states))

                changes = handle_timeouts(
                    states,
                    is_mine_fn=self.is_mine_id,
                    syncing_user_ids=self.get_currently_syncing_users(),
                    now=now,
                )

            preserve_fn(self._update_states)(changes)
        except:
            logger.exception("Exception in _handle_timeouts loop")
Esempio n. 26
0
    def verify_json_objects_for_server(self, server_and_json):
        """Bulk verifies signatures of json objects, bulk fetching keys as
        necessary.

        Args:
            server_and_json (list): List of pairs of (server_name, json_object)

        Returns:
            List<Deferred>: for each input pair, a deferred indicating success
                or failure to verify each json object's signature for the given
                server_name. The deferreds run their callbacks in the sentinel
                logcontext.
        """
        verify_requests = []

        for server_name, json_object in server_and_json:

            key_ids = signature_ids(json_object, server_name)
            if not key_ids:
                logger.warn("Request from %s: no supported signature keys",
                            server_name)
                deferred = defer.fail(SynapseError(
                    400,
                    "Not signed with a supported algorithm",
                    Codes.UNAUTHORIZED,
                ))
            else:
                deferred = defer.Deferred()

            logger.debug("Verifying for %s with key_ids %s",
                         server_name, key_ids)

            verify_request = VerifyKeyRequest(
                server_name, key_ids, json_object, deferred
            )

            verify_requests.append(verify_request)

        preserve_fn(self._start_key_lookups)(verify_requests)

        # Pass those keys to handle_key_deferred so that the json object
        # signatures can be verified
        handle = preserve_fn(_handle_key_deferred)
        return [
            handle(rq) for rq in verify_requests
        ]
Esempio n. 27
0
    def claim_one_time_keys(self, query, timeout):
        local_query = []
        remote_queries = {}

        for user_id, device_keys in query.get("one_time_keys", {}).items():
            if self.is_mine_id(user_id):
                for device_id, algorithm in device_keys.items():
                    local_query.append((user_id, device_id, algorithm))
            else:
                domain = get_domain_from_id(user_id)
                remote_queries.setdefault(domain, {})[user_id] = device_keys

        results = yield self.store.claim_e2e_one_time_keys(local_query)

        json_result = {}
        failures = {}
        for user_id, device_keys in results.items():
            for device_id, keys in device_keys.items():
                for key_id, json_bytes in keys.items():
                    json_result.setdefault(user_id, {})[device_id] = {
                        key_id: json.loads(json_bytes)
                    }

        @defer.inlineCallbacks
        def claim_client_keys(destination):
            device_keys = remote_queries[destination]
            try:
                limiter = yield get_retry_limiter(destination, self.clock,
                                                  self.store)
                with limiter:
                    remote_result = yield self.federation.claim_client_keys(
                        destination, {"one_time_keys": device_keys},
                        timeout=timeout)
                    for user_id, keys in remote_result["one_time_keys"].items(
                    ):
                        if user_id in device_keys:
                            json_result[user_id] = keys
            except CodeMessageException as e:
                failures[destination] = {
                    "status": e.code,
                    "message": e.message
                }
            except NotRetryingDestination as e:
                failures[destination] = {
                    "status": 503,
                    "message": "Not ready for retry",
                }
            except Exception as e:
                # include ConnectionRefused and other errors
                failures[destination] = {"status": 503, "message": e.message}

        yield preserve_context_over_deferred(
            defer.gatherResults([
                preserve_fn(claim_client_keys)(destination)
                for destination in remote_queries
            ]))

        defer.returnValue({"one_time_keys": json_result, "failures": failures})
Esempio n. 28
0
    def claim_one_time_keys(self, query, timeout):
        local_query = []
        remote_queries = {}

        for user_id, device_keys in query.get("one_time_keys", {}).items():
            if self.is_mine_id(user_id):
                for device_id, algorithm in device_keys.items():
                    local_query.append((user_id, device_id, algorithm))
            else:
                domain = get_domain_from_id(user_id)
                remote_queries.setdefault(domain, {})[user_id] = device_keys

        results = yield self.store.claim_e2e_one_time_keys(local_query)

        json_result = {}
        failures = {}
        for user_id, device_keys in results.items():
            for device_id, keys in device_keys.items():
                for key_id, json_bytes in keys.items():
                    json_result.setdefault(user_id, {})[device_id] = {
                        key_id: json.loads(json_bytes)
                    }

        @defer.inlineCallbacks
        def claim_client_keys(destination):
            device_keys = remote_queries[destination]
            try:
                limiter = yield get_retry_limiter(
                    destination, self.clock, self.store
                )
                with limiter:
                    remote_result = yield self.federation.claim_client_keys(
                        destination,
                        {"one_time_keys": device_keys},
                        timeout=timeout
                    )
                    for user_id, keys in remote_result["one_time_keys"].items():
                        if user_id in device_keys:
                            json_result[user_id] = keys
            except CodeMessageException as e:
                failures[destination] = {
                    "status": e.code, "message": e.message
                }
            except NotRetryingDestination as e:
                failures[destination] = {
                    "status": 503, "message": "Not ready for retry",
                }

        yield preserve_context_over_deferred(defer.gatherResults([
            preserve_fn(claim_client_keys)(destination)
            for destination in remote_queries
        ]))

        defer.returnValue({
            "one_time_keys": json_result,
            "failures": failures
        })
Esempio n. 29
0
    def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
        """Notify any user streams that are interested in this room event"""
        # poke any interested application service.
        preserve_fn(
            self.appservice_handler.notify_interested_services)(room_stream_id)

        if self.federation_sender:
            self.federation_sender.notify_new_events(room_stream_id)

        if event.type == EventTypes.Member and event.membership == Membership.JOIN:
            self._user_joined_room(event.state_key, event.room_id)

        self.on_new_event(
            "room_key",
            room_stream_id,
            users=extra_users,
            rooms=[event.room_id],
        )
Esempio n. 30
0
    def _verify_objects(self, verify_requests):
        """Does the work of verify_json_[objects_]for_server


        Args:
            verify_requests (iterable[VerifyJsonRequest]):
                Iterable of verification requests.

        Returns:
            List<Deferred[None]>: for each input item, a deferred indicating success
                or failure to verify each json object's signature for the given
                server_name. The deferreds run their callbacks in the sentinel
                logcontext.
        """
        # a list of VerifyJsonRequests which are awaiting a key lookup
        key_lookups = []
        handle = preserve_fn(_handle_key_deferred)

        def process(verify_request):
            """Process an entry in the request list

            Adds a key request to key_lookups, and returns a deferred which
            will complete or fail (in the sentinel context) when verification completes.
            """
            if not verify_request.key_ids:
                return defer.fail(
                    SynapseError(
                        400,
                        "Not signed by %s" % (verify_request.server_name,),
                        Codes.UNAUTHORIZED,
                    )
                )

            logger.debug(
                "Verifying %s for %s with key_ids %s, min_validity %i",
                verify_request.request_name,
                verify_request.server_name,
                verify_request.key_ids,
                verify_request.minimum_valid_until_ts,
            )

            # add the key request to the queue, but don't start it off yet.
            key_lookups.append(verify_request)

            # now run _handle_key_deferred, which will wait for the key request
            # to complete and then do the verification.
            #
            # We want _handle_key_request to log to the right context, so we
            # wrap it with preserve_fn (aka run_in_background)
            return handle(verify_request)

        results = [process(r) for r in verify_requests]

        if key_lookups:
            run_in_background(self._start_key_lookups, key_lookups)

        return results
Esempio n. 31
0
    def get_server_verify_key_v2_direct(self, server_name, key_ids):
        keys = {}

        for requested_key_id in key_ids:
            if requested_key_id in keys:
                continue

            (response, tls_certificate) = yield fetch_server_key(
                server_name,
                self.hs.tls_server_context_factory,
                path=(b"/_matrix/key/v2/server/%s" %
                      (urllib.quote(requested_key_id), )).encode("ascii"),
            )

            if (u"signatures" not in response
                    or server_name not in response[u"signatures"]):
                raise KeyLookupError(
                    "Key response not signed by remote server")

            if "tls_fingerprints" not in response:
                raise KeyLookupError("Key response missing TLS fingerprints")

            certificate_bytes = crypto.dump_certificate(
                crypto.FILETYPE_ASN1, tls_certificate)
            sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
            sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)

            response_sha256_fingerprints = set()
            for fingerprint in response[u"tls_fingerprints"]:
                if u"sha256" in fingerprint:
                    response_sha256_fingerprints.add(fingerprint[u"sha256"])

            if sha256_fingerprint_b64 not in response_sha256_fingerprints:
                raise KeyLookupError(
                    "TLS certificate not allowed by fingerprints")

            response_keys = yield self.process_v2_response(
                from_server=server_name,
                requested_ids=[requested_key_id],
                response_json=response,
            )

            keys.update(response_keys)

        yield logcontext.make_deferred_yieldable(
            defer.gatherResults(
                [
                    preserve_fn(self.store_keys)(
                        server_name=key_server_name,
                        from_server=server_name,
                        verify_keys=verify_keys,
                    ) for key_server_name, verify_keys in keys.items()
                ],
                consumeErrors=True).addErrback(unwrapFirstError))

        defer.returnValue(keys)
Esempio n. 32
0
    def get_server_verify_key_v2_direct(self, server_name, key_ids):
        keys = {}

        for requested_key_id in key_ids:
            if requested_key_id in keys:
                continue

            (response, tls_certificate) = yield fetch_server_key(
                server_name, self.hs.tls_server_context_factory,
                path=(b"/_matrix/key/v2/server/%s" % (
                    urllib.quote(requested_key_id),
                )).encode("ascii"),
            )

            if (u"signatures" not in response
                    or server_name not in response[u"signatures"]):
                raise ValueError("Key response not signed by remote server")

            if "tls_fingerprints" not in response:
                raise ValueError("Key response missing TLS fingerprints")

            certificate_bytes = crypto.dump_certificate(
                crypto.FILETYPE_ASN1, tls_certificate
            )
            sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
            sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)

            response_sha256_fingerprints = set()
            for fingerprint in response[u"tls_fingerprints"]:
                if u"sha256" in fingerprint:
                    response_sha256_fingerprints.add(fingerprint[u"sha256"])

            if sha256_fingerprint_b64 not in response_sha256_fingerprints:
                raise ValueError("TLS certificate not allowed by fingerprints")

            response_keys = yield self.process_v2_response(
                from_server=server_name,
                requested_ids=[requested_key_id],
                response_json=response,
            )

            keys.update(response_keys)

        yield defer.gatherResults(
            [
                preserve_fn(self.store_keys)(
                    server_name=key_server_name,
                    from_server=server_name,
                    verify_keys=verify_keys,
                )
                for key_server_name, verify_keys in keys.items()
            ],
            consumeErrors=True
        ).addErrback(unwrapFirstError)

        defer.returnValue(keys)
    def _test_preserve_fn(self, function):
        sentinel_context = LoggingContext.current_context()

        callback_completed = [False]

        @defer.inlineCallbacks
        def cb():
            context_one.test_key = "one"
            yield function()
            self._check_test_key("one")

            callback_completed[0] = True

        with LoggingContext() as context_one:
            context_one.test_key = "one"

            # fire off function, but don't wait on it.
            logcontext.preserve_fn(cb)()

            self._check_test_key("one")

        # now wait for the function under test to have run, and check that
        # the logcontext is left in a sane state.
        d2 = defer.Deferred()

        def check_logcontext():
            if not callback_completed[0]:
                reactor.callLater(0.01, check_logcontext)
                return

            # make sure that the context was reset before it got thrown back
            # into the reactor
            try:
                self.assertIs(LoggingContext.current_context(),
                              sentinel_context)
                d2.callback(None)
            except BaseException:
                d2.errback(twisted.python.failure.Failure())

        reactor.callLater(0.01, check_logcontext)

        # test is done once d2 finishes
        return d2
Esempio n. 34
0
    def _start_pushers(self, pushers):
        logger.info("Starting %d pushers", len(pushers))
        for pusherdict in pushers:
            try:
                p = self._create_pusher(pusherdict)
            except PusherConfigException:
                logger.exception(
                    "Couldn't start a pusher: caught PusherConfigException")
                continue
            if p:
                fullid = "%s:%s:%s" % (pusherdict['app_id'],
                                       pusherdict['pushkey'],
                                       pusherdict['user_name'])
                if fullid in self.pushers:
                    self.pushers[fullid].stop()
                self.pushers[fullid] = p
                preserve_fn(p.start)()

        logger.info("Started pushers")
Esempio n. 35
0
    def _start_pushers(self, pushers):
        logger.info("Starting %d pushers", len(pushers))
        for pusherdict in pushers:
            try:
                p = self._create_pusher(pusherdict)
            except PusherConfigException:
                logger.exception("Couldn't start a pusher: caught PusherConfigException")
                continue
            if p:
                fullid = "%s:%s:%s" % (
                    pusherdict['app_id'],
                    pusherdict['pushkey'],
                    pusherdict['user_name']
                )
                if fullid in self.pushers:
                    self.pushers[fullid].stop()
                self.pushers[fullid] = p
                preserve_fn(p.start)()

        logger.info("Started pushers")
Esempio n. 36
0
    def get_keys_from_store(self, server_name_and_key_ids):
        res = yield preserve_context_over_deferred(defer.gatherResults(
            [
                preserve_fn(self.store.get_server_verify_keys)(
                    server_name, key_ids
                ).addCallback(lambda ks, server: (server, ks), server_name)
                for server_name, key_ids in server_name_and_key_ids
            ],
            consumeErrors=True,
        )).addErrback(unwrapFirstError)

        defer.returnValue(dict(res))
    def _process_presence_inner(self, states):
        """Given a list of states populate self.pending_presence_by_dest and
        poke to send a new transaction to each destination

        Args:
            states (list(UserPresenceState))
        """
        hosts_and_states = yield get_interested_remotes(
            self.store, states, self.state)

        for destinations, states in hosts_and_states:
            for destination in destinations:
                if not self.can_send_to(destination):
                    continue

                self.pending_presence_by_dest.setdefault(
                    destination,
                    {}).update({state.user_id: state
                                for state in states})

                preserve_fn(self._attempt_new_transaction)(destination)
Esempio n. 38
0
    def _enqueue_events(self, events, check_redacted=True,
                        get_prev_content=False, allow_rejected=False):
        """Fetches events from the database using the _event_fetch_list. This
        allows batch and bulk fetching of events - it allows us to fetch events
        without having to create a new transaction for each request for events.
        """
        if not events:
            defer.returnValue({})

        events_d = defer.Deferred()
        with self._event_fetch_lock:
            self._event_fetch_list.append(
                (events, events_d)
            )

            self._event_fetch_lock.notify()

            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
                self._event_fetch_ongoing += 1
                should_start = True
            else:
                should_start = False

        if should_start:
            with PreserveLoggingContext():
                self.runWithConnection(
                    self._do_fetch
                )

        with PreserveLoggingContext():
            rows = yield events_d

        if not allow_rejected:
            rows[:] = [r for r in rows if not r["rejects"]]

        res = yield defer.gatherResults(
            [
                preserve_fn(self._get_event_from_row)(
                    row["internal_metadata"], row["json"], row["redacts"],
                    check_redacted=check_redacted,
                    get_prev_content=get_prev_content,
                    rejected_reason=row["rejects"],
                )
                for row in rows
            ],
            consumeErrors=True
        )

        defer.returnValue({
            e.event_id: e
            for e in res if e
        })
Esempio n. 39
0
    def claim_one_time_keys(self, query, timeout):
        local_query = []
        remote_queries = {}

        for user_id, device_keys in query.get("one_time_keys", {}).items():
            # we use UserID.from_string to catch invalid user ids
            if self.is_mine(UserID.from_string(user_id)):
                for device_id, algorithm in device_keys.items():
                    local_query.append((user_id, device_id, algorithm))
            else:
                domain = get_domain_from_id(user_id)
                remote_queries.setdefault(domain, {})[user_id] = device_keys

        results = yield self.store.claim_e2e_one_time_keys(local_query)

        json_result = {}
        failures = {}
        for user_id, device_keys in results.items():
            for device_id, keys in device_keys.items():
                for key_id, json_bytes in keys.items():
                    json_result.setdefault(user_id, {})[device_id] = {
                        key_id: json.loads(json_bytes)
                    }

        @defer.inlineCallbacks
        def claim_client_keys(destination):
            device_keys = remote_queries[destination]
            try:
                remote_result = yield self.federation.claim_client_keys(
                    destination, {"one_time_keys": device_keys},
                    timeout=timeout)
                for user_id, keys in remote_result["one_time_keys"].items():
                    if user_id in device_keys:
                        json_result[user_id] = keys
            except Exception as e:
                failures[destination] = _exception_to_failure(e)

        yield make_deferred_yieldable(
            defer.gatherResults([
                preserve_fn(claim_client_keys)(destination)
                for destination in remote_queries
            ]))

        logger.info(
            "Claimed one-time-keys: %s",
            ",".join(("%s for %s:%s" % (key_id, user_id, device_id)
                      for user_id, user_keys in json_result.iteritems()
                      for device_id, device_keys in user_keys.iteritems()
                      for key_id, _ in device_keys.iteritems())),
        )

        defer.returnValue({"one_time_keys": json_result, "failures": failures})
Esempio n. 40
0
                    def handle_event(event):
                        # Gather interested services
                        services = yield self._get_services_for_event(event)
                        if len(services) == 0:
                            return  # no services need notifying

                        # Do we know this user exists? If not, poke the user
                        # query API for all services which match that user regex.
                        # This needs to block as these user queries need to be
                        # made BEFORE pushing the event.
                        yield self._check_user_exists(event.sender)
                        if event.type == EventTypes.Member:
                            yield self._check_user_exists(event.state_key)

                        if not self.started_scheduler:
                            self.scheduler.start().addErrback(log_failure)
                            self.started_scheduler = True

                        # Fork off pushes to these services
                        for service in services:
                            preserve_fn(self.scheduler.submit_event_for_as)(
                                service, event)
Esempio n. 41
0
    def handle_queue(self, room_id, per_item_callback):
        """Attempts to handle the queue for a room if not already being handled.

        The given callback will be invoked with for each item in the queue,1
        of type _EventPersistQueueItem. The per_item_callback will continuously
        be called with new items, unless the queue becomnes empty. The return
        value of the function will be given to the deferreds waiting on the item,
        exceptions will be passed to the deferres as well.

        This function should therefore be called whenever anything is added
        to the queue.

        If another callback is currently handling the queue then it will not be
        invoked.
        """

        if room_id in self._currently_persisting_rooms:
            return

        self._currently_persisting_rooms.add(room_id)

        @defer.inlineCallbacks
        def handle_queue_loop():
            try:
                queue = self._get_drainining_queue(room_id)
                for item in queue:
                    try:
                        ret = yield per_item_callback(item)
                        item.deferred.callback(ret)
                    except Exception as e:
                        item.deferred.errback(e)
            finally:
                queue = self._event_persist_queues.pop(room_id, None)
                if queue:
                    self._event_persist_queues[room_id] = queue
                self._currently_persisting_rooms.discard(room_id)

        preserve_fn(handle_queue_loop)()
Esempio n. 42
0
    def _start_pushers(self, pushers):
        if not self.start_pushers:
            logger.info("Not starting pushers because they are disabled in the config")
            return
        logger.info("Starting %d pushers", len(pushers))
        for pusherdict in pushers:
            try:
                p = pusher.create_pusher(self.hs, pusherdict)
            except:
                logger.exception("Couldn't start a pusher: caught Exception")
                continue
            if p:
                appid_pushkey = "%s:%s" % (
                    pusherdict['app_id'],
                    pusherdict['pushkey'],
                )
                byuser = self.pushers.setdefault(pusherdict['user_name'], {})

                if appid_pushkey in byuser:
                    byuser[appid_pushkey].on_stop()
                byuser[appid_pushkey] = p
                preserve_fn(p.on_started)()

        logger.info("Started pushers")
Esempio n. 43
0
    def get_keys_from_server(self, server_name_and_key_ids):
        @defer.inlineCallbacks
        def get_key(server_name, key_ids):
            limiter = yield get_retry_limiter(
                server_name,
                self.clock,
                self.store,
            )
            with limiter:
                keys = None
                try:
                    keys = yield self.get_server_verify_key_v2_direct(
                        server_name, key_ids
                    )
                except Exception as e:
                    logger.info(
                        "Unable to get key %r for %r directly: %s %s",
                        key_ids, server_name,
                        type(e).__name__, str(e.message),
                    )

                if not keys:
                    keys = yield self.get_server_verify_key_v1_direct(
                        server_name, key_ids
                    )

                    keys = {server_name: keys}

            defer.returnValue(keys)

        results = yield preserve_context_over_deferred(defer.gatherResults(
            [
                preserve_fn(get_key)(server_name, key_ids)
                for server_name, key_ids in server_name_and_key_ids
            ],
            consumeErrors=True,
        )).addErrback(unwrapFirstError)

        merged = {}
        for result in results:
            merged.update(result)

        defer.returnValue({
            server_name: keys
            for server_name, keys in merged.items()
            if keys
        })
Esempio n. 44
0
    def verify_json_objects_for_server(self, server_and_json):
        """Bulk verifies signatures of json objects, bulk fetching keys as
        necessary.

        Args:
            server_and_json (list): List of pairs of (server_name, json_object)

        Returns:
            List<Deferred>: for each input pair, a deferred indicating success
                or failure to verify each json object's signature for the given
                server_name. The deferreds run their callbacks in the sentinel
                logcontext.
        """
        verify_requests = []

        for server_name, json_object in server_and_json:

            key_ids = signature_ids(json_object, server_name)
            if not key_ids:
                logger.warn("Request from %s: no supported signature keys",
                            server_name)
                deferred = defer.fail(SynapseError(
                    400,
                    "Not signed with a supported algorithm",
                    Codes.UNAUTHORIZED,
                ))
            else:
                deferred = defer.Deferred()

            logger.debug("Verifying for %s with key_ids %s",
                         server_name, key_ids)

            verify_request = VerifyKeyRequest(
                server_name, key_ids, json_object, deferred
            )

            verify_requests.append(verify_request)

        run_in_background(self._start_key_lookups, verify_requests)

        # Pass those keys to handle_key_deferred so that the json object
        # signatures can be verified
        handle = preserve_fn(_handle_key_deferred)
        return [
            handle(rq) for rq in verify_requests
        ]
Esempio n. 45
0
        def wrapped(*args, **kwargs):
            # If we're passed a cache_context then we'll want to call its invalidate()
            # whenever we are invalidated
            invalidate_callback = kwargs.pop("on_invalidate", None)

            cache_key = get_cache_key(args, kwargs)

            # Add our own `cache_context` to argument list if the wrapped function
            # has asked for one
            if self.add_cache_context:
                kwargs["cache_context"] = _CacheContext(cache, cache_key)

            try:
                cached_result_d = cache.get(cache_key, callback=invalidate_callback)

                if isinstance(cached_result_d, ObservableDeferred):
                    observer = cached_result_d.observe()
                else:
                    observer = cached_result_d

            except KeyError:
                ret = defer.maybeDeferred(
                    logcontext.preserve_fn(self.function_to_call),
                    obj, *args, **kwargs
                )

                def onErr(f):
                    cache.invalidate(cache_key)
                    return f

                ret.addErrback(onErr)

                # If our cache_key is a string on py2, try to convert to ascii
                # to save a bit of space in large caches. Py3 does this
                # internally automatically.
                if six.PY2 and isinstance(cache_key, string_types):
                    cache_key = to_ascii(cache_key)

                result_d = ObservableDeferred(ret, consumeErrors=True)
                cache.set(cache_key, result_d, callback=invalidate_callback)
                observer = result_d.observe()

            if isinstance(observer, defer.Deferred):
                return logcontext.make_deferred_yieldable(observer)
            else:
                return observer
Esempio n. 46
0
 def store_keys(self, server_name, from_server, verify_keys):
     """Store a collection of verify keys for a given server
     Args:
         server_name(str): The name of the server the keys are for.
         from_server(str): The server the keys were downloaded from.
         verify_keys(dict): A mapping of key_id to VerifyKey.
     Returns:
         A deferred that completes when the keys are stored.
     """
     # TODO(markjh): Store whether the keys have expired.
     yield defer.gatherResults(
         [
             preserve_fn(self.store.store_server_verify_key)(
                 server_name, server_name, key.time_added, key
             )
             for key_id, key in verify_keys.items()
         ],
         consumeErrors=True,
     ).addErrback(unwrapFirstError)
Esempio n. 47
0
def wrap_async_request_handler(h):
    """Wraps an async request handler so that it calls request.processing.

    This helps ensure that work done by the request handler after the request is completed
    is correctly recorded against the request metrics/logs.

    The handler method must have a signature of "handle_foo(self, request)",
    where "request" must be a SynapseRequest.

    The handler may return a deferred, in which case the completion of the request isn't
    logged until the deferred completes.
    """
    @defer.inlineCallbacks
    def wrapped_async_request_handler(self, request):
        with request.processing():
            yield h(self, request)

    # we need to preserve_fn here, because the synchronous render method won't yield for
    # us (obviously)
    return preserve_fn(wrapped_async_request_handler)
Esempio n. 48
0
    def on_new_notifications(self, min_stream_id, max_stream_id):
        yield run_on_reactor()
        try:
            users_affected = yield self.store.get_push_action_users_in_range(
                min_stream_id, max_stream_id
            )

            deferreds = []

            for u in users_affected:
                if u in self.pushers:
                    for p in self.pushers[u].values():
                        deferreds.append(
                            preserve_fn(p.on_new_notifications)(
                                min_stream_id, max_stream_id
                            )
                        )

            yield preserve_context_over_deferred(defer.gatherResults(deferreds))
        except:
            logger.exception("Exception in pusher on_new_notifications")
Esempio n. 49
0
    def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
        yield run_on_reactor()
        try:
            # Need to subtract 1 from the minimum because the lower bound here
            # is not inclusive
            updated_receipts = yield self.store.get_all_updated_receipts(
                min_stream_id - 1, max_stream_id
            )
            # This returns a tuple, user_id is at index 3
            users_affected = set([r[3] for r in updated_receipts])

            deferreds = []

            for u in users_affected:
                if u in self.pushers:
                    for p in self.pushers[u].values():
                        deferreds.append(
                            preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
                        )

            yield preserve_context_over_deferred(defer.gatherResults(deferreds))
        except:
            logger.exception("Exception in pusher on_new_receipts")
Esempio n. 50
0
    def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
                                         order='DESC'):
        from_id = RoomStreamToken.parse_stream_token(from_key).stream

        room_ids = yield self._events_stream_cache.get_entities_changed(
            room_ids, from_id
        )

        if not room_ids:
            defer.returnValue({})

        results = {}
        room_ids = list(room_ids)
        for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
            res = yield defer.gatherResults([
                preserve_fn(self.get_room_events_stream_for_room)(
                    room_id, from_key, to_key, limit, order=order,
                )
                for room_id in rm_ids
            ])
            results.update(dict(zip(rm_ids, res)))

        defer.returnValue(results)
Esempio n. 51
0
    def query_devices(self, query_body, timeout):
        """ Handle a device key query from a client

        {
            "device_keys": {
                "<user_id>": ["<device_id>"]
            }
        }
        ->
        {
            "device_keys": {
                "<user_id>": {
                    "<device_id>": {
                        ...
                    }
                }
            }
        }
        """
        device_keys_query = query_body.get("device_keys", {})

        # separate users by domain.
        # make a map from domain to user_id to device_ids
        local_query = {}
        remote_queries = {}

        for user_id, device_ids in device_keys_query.items():
            if self.is_mine_id(user_id):
                local_query[user_id] = device_ids
            else:
                domain = get_domain_from_id(user_id)
                remote_queries.setdefault(domain, {})[user_id] = device_ids

        # do the queries
        failures = {}
        results = {}
        if local_query:
            local_result = yield self.query_local_devices(local_query)
            for user_id, keys in local_result.items():
                if user_id in local_query:
                    results[user_id] = keys

        @defer.inlineCallbacks
        def do_remote_query(destination):
            destination_query = remote_queries[destination]
            try:
                limiter = yield get_retry_limiter(
                    destination, self.clock, self.store
                )
                with limiter:
                    remote_result = yield self.federation.query_client_keys(
                        destination,
                        {"device_keys": destination_query},
                        timeout=timeout
                    )

                for user_id, keys in remote_result["device_keys"].items():
                    if user_id in destination_query:
                        results[user_id] = keys

            except CodeMessageException as e:
                failures[destination] = {
                    "status": e.code, "message": e.message
                }
            except NotRetryingDestination as e:
                failures[destination] = {
                    "status": 503, "message": "Not ready for retry",
                }

        yield preserve_context_over_deferred(defer.gatherResults([
            preserve_fn(do_remote_query)(destination)
            for destination in remote_queries
        ]))

        defer.returnValue({
            "device_keys": results, "failures": failures,
        })
Esempio n. 52
0
 def _send_syncing_users_regularly(self):
     # Only send an update if we aren't in the middle of sending one.
     if not self._sending_sync:
         preserve_fn(self._send_syncing_users_now)()
Esempio n. 53
0
 def enqueue(self, service, event):
     # if this service isn't being sent something
     self.queued_events.setdefault(service.id, []).append(event)
     preserve_fn(self._send_request)(service)
Esempio n. 54
0
    def _on_enter(self, request_id):
        time_now = self.clock.time_msec()
        self.request_times[:] = [
            r for r in self.request_times
            if time_now - r < self.window_size
        ]

        queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
        if queue_size > self.reject_limit:
            raise LimitExceededError(
                retry_after_ms=int(
                    self.window_size / self.sleep_limit
                ),
            )

        self.request_times.append(time_now)

        def queue_request():
            if len(self.current_processing) > self.concurrent_requests:
                logger.debug("Ratelimit [%s]: Queue req", id(request_id))
                queue_defer = defer.Deferred()
                self.ready_request_queue[request_id] = queue_defer
                return queue_defer
            else:
                return defer.succeed(None)

        logger.debug(
            "Ratelimit [%s]: len(self.request_times)=%d",
            id(request_id), len(self.request_times),
        )

        if len(self.request_times) > self.sleep_limit:
            logger.debug(
                "Ratelimit [%s]: sleeping req",
                id(request_id),
            )
            ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0)

            self.sleeping_requests.add(request_id)

            def on_wait_finished(_):
                logger.debug(
                    "Ratelimit [%s]: Finished sleeping",
                    id(request_id),
                )
                self.sleeping_requests.discard(request_id)
                queue_defer = queue_request()
                return queue_defer

            ret_defer.addBoth(on_wait_finished)
        else:
            ret_defer = queue_request()

        def on_start(r):
            logger.debug(
                "Ratelimit [%s]: Processing req",
                id(request_id),
            )
            self.current_processing.add(request_id)
            return r

        def on_err(r):
            self.current_processing.discard(request_id)
            return r

        def on_both(r):
            # Ensure that we've properly cleaned up.
            self.sleeping_requests.discard(request_id)
            self.ready_request_queue.pop(request_id, None)
            return r

        ret_defer.addCallbacks(on_start, on_err)
        ret_defer.addBoth(on_both)
        return ret_defer
Esempio n. 55
0
        def wrapped(*args, **kwargs):
            # If we're passed a cache_context then we'll want to call its
            # invalidate() whenever we are invalidated
            invalidate_callback = kwargs.pop("on_invalidate", None)

            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
            list_args = arg_dict[self.list_name]

            results = {}

            def update_results_dict(res, arg):
                results[arg] = res

            # list of deferreds to wait for
            cached_defers = []

            missing = set()

            # If the cache takes a single arg then that is used as the key,
            # otherwise a tuple is used.
            if num_args == 1:
                def arg_to_cache_key(arg):
                    return arg
            else:
                keylist = list(keyargs)

                def arg_to_cache_key(arg):
                    keylist[self.list_pos] = arg
                    return tuple(keylist)

            for arg in list_args:
                try:
                    res = cache.get(arg_to_cache_key(arg),
                                    callback=invalidate_callback)
                    if not isinstance(res, ObservableDeferred):
                        results[arg] = res
                    elif not res.has_succeeded():
                        res = res.observe()
                        res.addCallback(update_results_dict, arg)
                        cached_defers.append(res)
                    else:
                        results[arg] = res.get_result()
                except KeyError:
                    missing.add(arg)

            if missing:
                # we need an observable deferred for each entry in the list,
                # which we put in the cache. Each deferred resolves with the
                # relevant result for that key.
                deferreds_map = {}
                for arg in missing:
                    deferred = defer.Deferred()
                    deferreds_map[arg] = deferred
                    key = arg_to_cache_key(arg)
                    observable = ObservableDeferred(deferred)
                    cache.set(key, observable, callback=invalidate_callback)

                def complete_all(res):
                    # the wrapped function has completed. It returns a
                    # a dict. We can now resolve the observable deferreds in
                    # the cache and update our own result map.
                    for e in missing:
                        val = res.get(e, None)
                        deferreds_map[e].callback(val)
                        results[e] = val

                def errback(f):
                    # the wrapped function has failed. Invalidate any cache
                    # entries we're supposed to be populating, and fail
                    # their deferreds.
                    for e in missing:
                        key = arg_to_cache_key(e)
                        cache.invalidate(key)
                        deferreds_map[e].errback(f)

                    # return the failure, to propagate to our caller.
                    return f

                args_to_call = dict(arg_dict)
                args_to_call[self.list_name] = list(missing)

                cached_defers.append(defer.maybeDeferred(
                    logcontext.preserve_fn(self.function_to_call),
                    **args_to_call
                ).addCallbacks(complete_all, errback))

            if cached_defers:
                d = defer.gatherResults(
                    cached_defers,
                    consumeErrors=True,
                ).addCallbacks(
                    lambda _: results,
                    unwrapFirstError
                )
                return logcontext.make_deferred_yieldable(d)
            else:
                return results
Esempio n. 56
0
def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
    """ Returns dict of user_id -> list of events that user is allowed to
    see.

    Args:
        user_tuples (str, bool): (user id, is_peeking) for each user to be
            checked. is_peeking should be true if:
            * the user is not currently a member of the room, and:
            * the user has not been a member of the room since the
            given events
        events ([synapse.events.EventBase]): list of events to filter
    """
    forgotten = yield defer.gatherResults([
        preserve_fn(store.who_forgot_in_room)(
            room_id,
        )
        for room_id in frozenset(e.room_id for e in events)
    ], consumeErrors=True)

    # Set of membership event_ids that have been forgotten
    event_id_forgotten = frozenset(
        row["event_id"] for rows in forgotten for row in rows
    )

    ignore_dict_content = yield store.get_global_account_data_by_type_for_users(
        "m.ignored_user_list", user_ids=[user_id for user_id, _ in user_tuples]
    )

    # FIXME: This will explode if people upload something incorrect.
    ignore_dict = {
        user_id: frozenset(
            content.get("ignored_users", {}).keys() if content else []
        )
        for user_id, content in ignore_dict_content.items()
    }

    def allowed(event, user_id, is_peeking, ignore_list):
        """
        Args:
            event (synapse.events.EventBase): event to check
            user_id (str)
            is_peeking (bool)
            ignore_list (list): list of users to ignore
        """
        if not event.is_state() and event.sender in ignore_list:
            return False

        state = event_id_to_state[event.event_id]

        # get the room_visibility at the time of the event.
        visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
        if visibility_event:
            visibility = visibility_event.content.get("history_visibility", "shared")
        else:
            visibility = "shared"

        if visibility not in VISIBILITY_PRIORITY:
            visibility = "shared"

        # if it was world_readable, it's easy: everyone can read it
        if visibility == "world_readable":
            return True

        # Always allow history visibility events on boundaries. This is done
        # by setting the effective visibility to the least restrictive
        # of the old vs new.
        if event.type == EventTypes.RoomHistoryVisibility:
            prev_content = event.unsigned.get("prev_content", {})
            prev_visibility = prev_content.get("history_visibility", None)

            if prev_visibility not in VISIBILITY_PRIORITY:
                prev_visibility = "shared"

            new_priority = VISIBILITY_PRIORITY.index(visibility)
            old_priority = VISIBILITY_PRIORITY.index(prev_visibility)
            if old_priority < new_priority:
                visibility = prev_visibility

        # likewise, if the event is the user's own membership event, use
        # the 'most joined' membership
        membership = None
        if event.type == EventTypes.Member and event.state_key == user_id:
            membership = event.content.get("membership", None)
            if membership not in MEMBERSHIP_PRIORITY:
                membership = "leave"

            prev_content = event.unsigned.get("prev_content", {})
            prev_membership = prev_content.get("membership", None)
            if prev_membership not in MEMBERSHIP_PRIORITY:
                prev_membership = "leave"

            new_priority = MEMBERSHIP_PRIORITY.index(membership)
            old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
            if old_priority < new_priority:
                membership = prev_membership

        # otherwise, get the user's membership at the time of the event.
        if membership is None:
            membership_event = state.get((EventTypes.Member, user_id), None)
            if membership_event:
                if membership_event.event_id not in event_id_forgotten:
                    membership = membership_event.membership

        # if the user was a member of the room at the time of the event,
        # they can see it.
        if membership == Membership.JOIN:
            return True

        if visibility == "joined":
            # we weren't a member at the time of the event, so we can't
            # see this event.
            return False

        elif visibility == "invited":
            # user can also see the event if they were *invited* at the time
            # of the event.
            return membership == Membership.INVITE

        else:
            # visibility is shared: user can also see the event if they have
            # become a member since the event
            #
            # XXX: if the user has subsequently joined and then left again,
            # ideally we would share history up to the point they left. But
            # we don't know when they left.
            return not is_peeking

    defer.returnValue({
        user_id: [
            event
            for event in events
            if allowed(event, user_id, is_peeking, ignore_dict.get(user_id, []))
        ]
        for user_id, is_peeking in user_tuples
    })
Esempio n. 57
0
    def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
                                       include_none=False):
        """Takes a list of PDUs and checks the signatures and hashs of each
        one. If a PDU fails its signature check then we check if we have it in
        the database and if not then request if from the originating server of
        that PDU.

        If a PDU fails its content hash check then it is redacted.

        The given list of PDUs are not modified, instead the function returns
        a new list.

        Args:
            pdu (list)
            outlier (bool)

        Returns:
            Deferred : A list of PDUs that have valid signatures and hashes.
        """
        deferreds = self._check_sigs_and_hashes(pdus)

        @defer.inlineCallbacks
        def handle_check_result(pdu, deferred):
            try:
                res = yield logcontext.make_deferred_yieldable(deferred)
            except SynapseError:
                res = None

            if not res:
                # Check local db.
                res = yield self.store.get_event(
                    pdu.event_id,
                    allow_rejected=True,
                    allow_none=True,
                )

            if not res and pdu.origin != origin:
                try:
                    res = yield self.get_pdu(
                        destinations=[pdu.origin],
                        event_id=pdu.event_id,
                        outlier=outlier,
                        timeout=10000,
                    )
                except SynapseError:
                    pass

            if not res:
                logger.warn(
                    "Failed to find copy of %s with valid signature",
                    pdu.event_id,
                )

            defer.returnValue(res)

        handle = logcontext.preserve_fn(handle_check_result)
        deferreds2 = [
            handle(pdu, deferred)
            for pdu, deferred in zip(pdus, deferreds)
        ]

        valid_pdus = yield logcontext.make_deferred_yieldable(
            defer.gatherResults(
                deferreds2,
                consumeErrors=True,
            )
        ).addErrback(unwrapFirstError)

        if include_none:
            defer.returnValue(valid_pdus)
        else:
            defer.returnValue([p for p in valid_pdus if p])