Пример #1
0
    def test_cancellation(self):
        linearizer = Linearizer()

        key = object()

        d1 = linearizer.queue(key)
        cm1 = yield d1

        d2 = linearizer.queue(key)
        self.assertFalse(d2.called)

        d3 = linearizer.queue(key)
        self.assertFalse(d3.called)

        d2.cancel()

        with cm1:
            pass

        self.assertTrue(d2.called)
        try:
            yield d2
            self.fail("Expected d2 to raise CancelledError")
        except CancelledError:
            pass

        with (yield d3):
            pass
Пример #2
0
    def test_cancellation(self):
        linearizer = Linearizer()

        key = object()

        d1 = linearizer.queue(key)
        cm1 = yield d1

        d2 = linearizer.queue(key)
        self.assertFalse(d2.called)

        d3 = linearizer.queue(key)
        self.assertFalse(d3.called)

        d2.cancel()

        with cm1:
            pass

        self.assertTrue(d2.called)
        try:
            yield d2
            self.fail("Expected d2 to raise CancelledError")
        except CancelledError:
            pass

        with (yield d3):
            pass
Пример #3
0
    def test_linearizer_is_queued(self):
        linearizer = Linearizer()

        key = object()

        d1 = linearizer.queue(key)
        cm1 = yield d1

        # Since d1 gets called immediately, "is_queued" should return false.
        self.assertFalse(linearizer.is_queued(key))

        d2 = linearizer.queue(key)
        self.assertFalse(d2.called)

        # Now d2 is queued up behind successful completion of cm1
        self.assertTrue(linearizer.is_queued(key))

        with cm1:
            self.assertFalse(d2.called)

            # cm1 still not done, so d2 still queued.
            self.assertTrue(linearizer.is_queued(key))

        # And now d2 is called and nothing is in the queue again
        self.assertFalse(linearizer.is_queued(key))

        with (yield d2):
            self.assertFalse(linearizer.is_queued(key))

        self.assertFalse(linearizer.is_queued(key))
Пример #4
0
class ReadMarkerHandler:
    def __init__(self, hs: "HomeServer"):
        self.server_name = hs.config.server.server_name
        self.store = hs.get_datastores().main
        self.account_data_handler = hs.get_account_data_handler()
        self.read_marker_linearizer = Linearizer(name="read_marker")

    async def received_client_read_marker(self, room_id: str, user_id: str,
                                          event_id: str) -> None:
        """Updates the read marker for a given user in a given room if the event ID given
        is ahead in the stream relative to the current read marker.

        This uses a notifier to indicate that account data should be sent down /sync if
        the read marker has changed.
        """

        async with self.read_marker_linearizer.queue((room_id, user_id)):
            existing_read_marker = await self.store.get_account_data_for_room_and_type(
                user_id, room_id, "m.fully_read")

            should_update = True

            if existing_read_marker:
                # Only update if the new marker is ahead in the stream
                should_update = await self.store.is_event_after(
                    event_id, existing_read_marker["event_id"])

            if should_update:
                content = {"event_id": event_id}
                await self.account_data_handler.add_account_data_to_room(
                    user_id, room_id, "m.fully_read", content)
Пример #5
0
class ReadMarkerHandler(BaseHandler):
    def __init__(self, hs):
        super(ReadMarkerHandler, self).__init__(hs)
        self.server_name = hs.config.server_name
        self.store = hs.get_datastore()
        self.read_marker_linearizer = Linearizer(name="read_marker")
        self.notifier = hs.get_notifier()

    @defer.inlineCallbacks
    def received_client_read_marker(self, room_id, user_id, event_id):
        """Updates the read marker for a given user in a given room if the event ID given
        is ahead in the stream relative to the current read marker.

        This uses a notifier to indicate that account data should be sent down /sync if
        the read marker has changed.
        """

        with (yield self.read_marker_linearizer.queue((room_id, user_id))):
            existing_read_marker = yield self.store.get_account_data_for_room_and_type(
                user_id, room_id, "m.fully_read")

            should_update = True

            if existing_read_marker:
                # Only update if the new marker is ahead in the stream
                should_update = yield self.store.is_event_after(
                    event_id, existing_read_marker["event_id"])

            if should_update:
                content = {"event_id": event_id}
                max_id = yield self.store.add_account_data_to_room(
                    user_id, room_id, "m.fully_read", content)
                self.notifier.on_new_event("account_data_key",
                                           max_id,
                                           users=[user_id])
Пример #6
0
    def test_linearizer(self):
        linearizer = Linearizer()

        key = object()

        d1 = linearizer.queue(key)
        cm1 = yield d1

        d2 = linearizer.queue(key)
        self.assertFalse(d2.called)

        with cm1:
            self.assertFalse(d2.called)

        with (yield d2):
            pass
Пример #7
0
    def test_linearizer(self):
        linearizer = Linearizer()

        key = object()

        d1 = linearizer.queue(key)
        cm1 = yield d1

        d2 = linearizer.queue(key)
        self.assertFalse(d2.called)

        with cm1:
            self.assertFalse(d2.called)

        with (yield d2):
            pass
Пример #8
0
class _LimitedHostnameResolver(object):
    """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups.
    """
    def __init__(self, resolver, max_dns_requests_in_flight):
        self._resolver = resolver
        self._limiter = Linearizer(
            name="dns_client_limiter",
            max_count=max_dns_requests_in_flight,
        )

    def resolveHostName(self,
                        resolutionReceiver,
                        hostName,
                        portNumber=0,
                        addressTypes=None,
                        transportSemantics='TCP'):
        # We need this function to return `resolutionReceiver` so we do all the
        # actual logic involving deferreds in a separate function.

        # even though this is happening within the depths of twisted, we need to drop
        # our logcontext before starting _resolve, otherwise: (a) _resolve will drop
        # the logcontext if it returns an incomplete deferred; (b) _resolve will
        # call the resolutionReceiver *with* a logcontext, which it won't be expecting.
        with PreserveLoggingContext():
            self._resolve(
                resolutionReceiver,
                hostName,
                portNumber,
                addressTypes,
                transportSemantics,
            )

        return resolutionReceiver

    @defer.inlineCallbacks
    def _resolve(self,
                 resolutionReceiver,
                 hostName,
                 portNumber=0,
                 addressTypes=None,
                 transportSemantics='TCP'):

        with (yield self._limiter.queue(())):
            # resolveHostName doesn't return a Deferred, so we need to hook into
            # the receiver interface to get told when resolution has finished.

            deferred = defer.Deferred()
            receiver = _DeferredResolutionReceiver(resolutionReceiver,
                                                   deferred)

            self._resolver.resolveHostName(
                receiver,
                hostName,
                portNumber,
                addressTypes,
                transportSemantics,
            )

            yield deferred
Пример #9
0
class FederationSenderHandler(object):
    """Processes the replication stream and forwards the appropriate entries
    to the federation sender.
    """
    def __init__(self, hs, replication_client):
        self.store = hs.get_datastore()
        self.federation_sender = hs.get_federation_sender()
        self.replication_client = replication_client

        self.federation_position = self.store.federation_out_pos_startup
        self._fed_position_linearizer = Linearizer(
            name="_fed_position_linearizer")

        self._last_ack = self.federation_position

        self._room_serials = {}
        self._room_typing = {}

    def on_start(self):
        # There may be some events that are persisted but haven't been sent,
        # so send them now.
        self.federation_sender.notify_new_events(
            self.store.get_room_max_stream_ordering())

    def stream_positions(self):
        return {"federation": self.federation_position}

    def process_replication_rows(self, stream_name, token, rows):
        # The federation stream contains things that we want to send out, e.g.
        # presence, typing, etc.
        if stream_name == "federation":
            send_queue.process_rows_for_federation(self.federation_sender,
                                                   rows)
            run_in_background(self.update_token, token)

        # We also need to poke the federation sender when new events happen
        elif stream_name == "events":
            self.federation_sender.notify_new_events(token)

    @defer.inlineCallbacks
    def update_token(self, token):
        try:
            self.federation_position = token

            # We linearize here to ensure we don't have races updating the token
            with (yield self._fed_position_linearizer.queue(None)):
                if self._last_ack < self.federation_position:
                    yield self.store.update_federation_out_pos(
                        "federation", self.federation_position)

                    # We ACK this token over replication so that the master can drop
                    # its in memory queues
                    self.replication_client.send_federation_ack(
                        self.federation_position)
                    self._last_ack = self.federation_position
        except Exception:
            logger.exception("Error updating federation stream position")
Пример #10
0
class _LimitedHostnameResolver(object):
    """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups.
    """
    def __init__(self, resolver, max_dns_requests_in_flight):
        self._resolver = resolver
        self._limiter = Linearizer(
            name="dns_client_limiter",
            max_count=max_dns_requests_in_flight,
        )

    def resolveHostName(self,
                        resolutionReceiver,
                        hostName,
                        portNumber=0,
                        addressTypes=None,
                        transportSemantics='TCP'):
        # Note this is happening deep within the reactor, so we don't need to
        # worry about log contexts.

        # We need this function to return `resolutionReceiver` so we do all the
        # actual logic involving deferreds in a separate function.
        self._resolve(
            resolutionReceiver,
            hostName,
            portNumber,
            addressTypes,
            transportSemantics,
        )

        return resolutionReceiver

    @defer.inlineCallbacks
    def _resolve(self,
                 resolutionReceiver,
                 hostName,
                 portNumber=0,
                 addressTypes=None,
                 transportSemantics='TCP'):

        with (yield self._limiter.queue(())):
            # resolveHostName doesn't return a Deferred, so we need to hook into
            # the receiver interface to get told when resolution has finished.

            deferred = defer.Deferred()
            receiver = _DeferredResolutionReceiver(resolutionReceiver,
                                                   deferred)

            self._resolver.resolveHostName(
                receiver,
                hostName,
                portNumber,
                addressTypes,
                transportSemantics,
            )

            yield deferred
Пример #11
0
class _LimitedHostnameResolver(object):
    """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups.
    """

    def __init__(self, resolver, max_dns_requests_in_flight):
        self._resolver = resolver
        self._limiter = Linearizer(
            name="dns_client_limiter", max_count=max_dns_requests_in_flight,
        )

    def resolveHostName(self, resolutionReceiver, hostName, portNumber=0,
                        addressTypes=None, transportSemantics='TCP'):
        # Note this is happening deep within the reactor, so we don't need to
        # worry about log contexts.

        # We need this function to return `resolutionReceiver` so we do all the
        # actual logic involving deferreds in a separate function.
        self._resolve(
            resolutionReceiver, hostName, portNumber,
            addressTypes, transportSemantics,
        )

        return resolutionReceiver

    @defer.inlineCallbacks
    def _resolve(self, resolutionReceiver, hostName, portNumber=0,
                 addressTypes=None, transportSemantics='TCP'):

        with (yield self._limiter.queue(())):
            # resolveHostName doesn't return a Deferred, so we need to hook into
            # the receiver interface to get told when resolution has finished.

            deferred = defer.Deferred()
            receiver = _DeferredResolutionReceiver(resolutionReceiver, deferred)

            self._resolver.resolveHostName(
                receiver, hostName, portNumber,
                addressTypes, transportSemantics,
            )

            yield deferred
Пример #12
0
    def test_multiple_entries(self):
        limiter = Linearizer(max_count=3)

        key = object()

        d1 = limiter.queue(key)
        cm1 = yield d1

        d2 = limiter.queue(key)
        cm2 = yield d2

        d3 = limiter.queue(key)
        cm3 = yield d3

        d4 = limiter.queue(key)
        self.assertFalse(d4.called)

        d5 = limiter.queue(key)
        self.assertFalse(d5.called)

        with cm1:
            self.assertFalse(d4.called)
            self.assertFalse(d5.called)

        cm4 = yield d4
        self.assertFalse(d5.called)

        with cm3:
            self.assertFalse(d5.called)

        cm5 = yield d5

        with cm2:
            pass

        with cm4:
            pass

        with cm5:
            pass

        d6 = limiter.queue(key)
        with (yield d6):
            pass
Пример #13
0
    def test_multiple_entries(self):
        limiter = Linearizer(max_count=3)

        key = object()

        d1 = limiter.queue(key)
        cm1 = yield d1

        d2 = limiter.queue(key)
        cm2 = yield d2

        d3 = limiter.queue(key)
        cm3 = yield d3

        d4 = limiter.queue(key)
        self.assertFalse(d4.called)

        d5 = limiter.queue(key)
        self.assertFalse(d5.called)

        with cm1:
            self.assertFalse(d4.called)
            self.assertFalse(d5.called)

        cm4 = yield d4
        self.assertFalse(d5.called)

        with cm3:
            self.assertFalse(d5.called)

        cm5 = yield d5

        with cm2:
            pass

        with cm4:
            pass

        with cm5:
            pass

        d6 = limiter.queue(key)
        with (yield d6):
            pass
Пример #14
0
class RegistrationHandler(BaseHandler):
    def __init__(self, hs):
        """

        Args:
            hs (synapse.server.HomeServer):
        """
        super(RegistrationHandler, self).__init__(hs)
        self.hs = hs
        self.auth = hs.get_auth()
        self._auth_handler = hs.get_auth_handler()
        self.profile_handler = hs.get_profile_handler()
        self.user_directory_handler = hs.get_user_directory_handler()
        self.captcha_client = CaptchaServerHttpClient(hs)

        self._next_generated_user_id = None

        self.macaroon_gen = hs.get_macaroon_generator()

        self._generate_user_id_linearizer = Linearizer(
            name="_generate_user_id_linearizer", )
        self._server_notices_mxid = hs.config.server_notices_mxid

    @defer.inlineCallbacks
    def check_username(self,
                       localpart,
                       guest_access_token=None,
                       assigned_user_id=None):
        if types.contains_invalid_mxid_characters(localpart):
            raise SynapseError(
                400,
                "User ID can only contain characters a-z, 0-9, or '=_-./'",
                Codes.INVALID_USERNAME)

        if not localpart:
            raise SynapseError(400, "User ID cannot be empty",
                               Codes.INVALID_USERNAME)

        if localpart[0] == '_':
            raise SynapseError(400, "User ID may not begin with _",
                               Codes.INVALID_USERNAME)

        user = UserID(localpart, self.hs.hostname)
        user_id = user.to_string()

        if assigned_user_id:
            if user_id == assigned_user_id:
                return
            else:
                raise SynapseError(
                    400,
                    "A different user ID has already been registered for this session",
                )

        self.check_user_id_not_appservice_exclusive(user_id)

        users = yield self.store.get_users_by_id_case_insensitive(user_id)
        if users:
            if not guest_access_token:
                raise SynapseError(
                    400,
                    "User ID already taken.",
                    errcode=Codes.USER_IN_USE,
                )
            user_data = yield self.auth.get_user_by_access_token(
                guest_access_token)
            if not user_data[
                    "is_guest"] or user_data["user"].localpart != localpart:
                raise AuthError(
                    403,
                    "Cannot register taken user ID without valid guest "
                    "credentials for that user.",
                    errcode=Codes.FORBIDDEN,
                )

    @defer.inlineCallbacks
    def register(
        self,
        localpart=None,
        password=None,
        generate_token=True,
        guest_access_token=None,
        make_guest=False,
        admin=False,
        threepid=None,
    ):
        """Registers a new client on the server.

        Args:
            localpart : The local part of the user ID to register. If None,
              one will be generated.
            password (unicode) : The password to assign to this user so they can
              login again. This can be None which means they cannot login again
              via a password (e.g. the user is an application service user).
            generate_token (bool): Whether a new access token should be
              generated. Having this be True should be considered deprecated,
              since it offers no means of associating a device_id with the
              access_token. Instead you should call auth_handler.issue_access_token
              after registration.
        Returns:
            A tuple of (user_id, access_token).
        Raises:
            RegistrationError if there was a problem registering.
        """

        yield self.auth.check_auth_blocking(threepid=threepid)
        password_hash = None
        if password:
            password_hash = yield self.auth_handler().hash(password)

        if localpart:
            yield self.check_username(localpart,
                                      guest_access_token=guest_access_token)

            was_guest = guest_access_token is not None

            if not was_guest:
                try:
                    int(localpart)
                    raise RegistrationError(
                        400, "Numeric user IDs are reserved for guest users.")
                except ValueError:
                    pass

            user = UserID(localpart, self.hs.hostname)
            user_id = user.to_string()

            token = None
            if generate_token:
                token = self.macaroon_gen.generate_access_token(user_id)
            yield self.store.register(
                user_id=user_id,
                token=token,
                password_hash=password_hash,
                was_guest=was_guest,
                make_guest=make_guest,
                create_profile_with_localpart=(
                    # If the user was a guest then they already have a profile
                    None if was_guest else user.localpart),
                admin=admin,
            )

            if self.hs.config.user_directory_search_all_users:
                profile = yield self.store.get_profileinfo(localpart)
                yield self.user_directory_handler.handle_local_profile_change(
                    user_id, profile)

        else:
            # autogen a sequential user ID
            attempts = 0
            token = None
            user = None
            while not user:
                localpart = yield self._generate_user_id(attempts > 0)
                user = UserID(localpart, self.hs.hostname)
                user_id = user.to_string()
                yield self.check_user_id_not_appservice_exclusive(user_id)
                if generate_token:
                    token = self.macaroon_gen.generate_access_token(user_id)
                try:
                    yield self.store.register(
                        user_id=user_id,
                        token=token,
                        password_hash=password_hash,
                        make_guest=make_guest,
                        create_profile_with_localpart=user.localpart,
                    )
                except SynapseError:
                    # if user id is taken, just generate another
                    user = None
                    user_id = None
                    token = None
                    attempts += 1
        if not self.hs.config.user_consent_at_registration:
            yield self._auto_join_rooms(user_id)

        defer.returnValue((user_id, token))

    @defer.inlineCallbacks
    def _auto_join_rooms(self, user_id):
        """Automatically joins users to auto join rooms - creating the room in the first place
        if the user is the first to be created.

        Args:
            user_id(str): The user to join
        """
        # auto-join the user to any rooms we're supposed to dump them into
        fake_requester = create_requester(user_id)

        # try to create the room if we're the first user on the server
        should_auto_create_rooms = False
        if self.hs.config.autocreate_auto_join_rooms:
            count = yield self.store.count_all_users()
            should_auto_create_rooms = count == 1
        for r in self.hs.config.auto_join_rooms:
            try:
                if should_auto_create_rooms:
                    room_alias = RoomAlias.from_string(r)
                    if self.hs.hostname != room_alias.domain:
                        logger.warning(
                            'Cannot create room alias %s, '
                            'it does not match server domain',
                            r,
                        )
                    else:
                        # create room expects the localpart of the room alias
                        room_alias_localpart = room_alias.localpart

                        # getting the RoomCreationHandler during init gives a dependency
                        # loop
                        yield self.hs.get_room_creation_handler().create_room(
                            fake_requester,
                            config={
                                "preset": "public_chat",
                                "room_alias_name": room_alias_localpart
                            },
                            ratelimit=False,
                        )
                else:
                    yield self._join_user_to_room(fake_requester, r)
            except Exception as e:
                logger.error("Failed to join new user to %r: %r", r, e)

    @defer.inlineCallbacks
    def post_consent_actions(self, user_id):
        """A series of registration actions that can only be carried out once consent
        has been granted

        Args:
            user_id (str): The user to join
        """
        yield self._auto_join_rooms(user_id)

    @defer.inlineCallbacks
    def appservice_register(self, user_localpart, as_token):
        user = UserID(user_localpart, self.hs.hostname)
        user_id = user.to_string()
        service = self.store.get_app_service_by_token(as_token)
        if not service:
            raise AuthError(403, "Invalid application service token.")
        if not service.is_interested_in_user(user_id):
            raise SynapseError(
                400,
                "Invalid user localpart for this application service.",
                errcode=Codes.EXCLUSIVE)

        service_id = service.id if service.is_exclusive_user(user_id) else None

        yield self.check_user_id_not_appservice_exclusive(
            user_id, allowed_appservice=service)

        yield self.store.register(
            user_id=user_id,
            password_hash="",
            appservice_id=service_id,
            create_profile_with_localpart=user.localpart,
        )
        defer.returnValue(user_id)

    @defer.inlineCallbacks
    def check_recaptcha(self, ip, private_key, challenge, response):
        """
        Checks a recaptcha is correct.

        Used only by c/s api v1
        """

        captcha_response = yield self._validate_captcha(
            ip, private_key, challenge, response)
        if not captcha_response["valid"]:
            logger.info("Invalid captcha entered from %s. Error: %s", ip,
                        captcha_response["error_url"])
            raise InvalidCaptchaError(error_url=captcha_response["error_url"])
        else:
            logger.info("Valid captcha entered from %s", ip)

    @defer.inlineCallbacks
    def register_saml2(self, localpart):
        """
        Registers email_id as SAML2 Based Auth.
        """
        if types.contains_invalid_mxid_characters(localpart):
            raise SynapseError(
                400,
                "User ID can only contain characters a-z, 0-9, or '=_-./'",
            )
        yield self.auth.check_auth_blocking()
        user = UserID(localpart, self.hs.hostname)
        user_id = user.to_string()

        yield self.check_user_id_not_appservice_exclusive(user_id)
        token = self.macaroon_gen.generate_access_token(user_id)
        try:
            yield self.store.register(
                user_id=user_id,
                token=token,
                password_hash=None,
                create_profile_with_localpart=user.localpart,
            )
        except Exception as e:
            yield self.store.add_access_token_to_user(user_id, token)
            # Ignore Registration errors
            logger.exception(e)
        defer.returnValue((user_id, token))

    @defer.inlineCallbacks
    def register_email(self, threepidCreds):
        """
        Registers emails with an identity server.

        Used only by c/s api v1
        """

        for c in threepidCreds:
            logger.info("validating threepidcred sid %s on id server %s",
                        c['sid'], c['idServer'])
            try:
                identity_handler = self.hs.get_handlers().identity_handler
                threepid = yield identity_handler.threepid_from_creds(c)
            except Exception:
                logger.exception("Couldn't validate 3pid")
                raise RegistrationError(400, "Couldn't validate 3pid")

            if not threepid:
                raise RegistrationError(400, "Couldn't validate 3pid")
            logger.info("got threepid with medium '%s' and address '%s'",
                        threepid['medium'], threepid['address'])

            if not check_3pid_allowed(self.hs, threepid['medium'],
                                      threepid['address']):
                raise RegistrationError(
                    403, "Third party identifier is not allowed")

    @defer.inlineCallbacks
    def bind_emails(self, user_id, threepidCreds):
        """Links emails with a user ID and informs an identity server.

        Used only by c/s api v1
        """

        # Now we have a matrix ID, bind it to the threepids we were given
        for c in threepidCreds:
            identity_handler = self.hs.get_handlers().identity_handler
            # XXX: This should be a deferred list, shouldn't it?
            yield identity_handler.bind_threepid(c, user_id)

    def check_user_id_not_appservice_exclusive(self,
                                               user_id,
                                               allowed_appservice=None):
        # don't allow people to register the server notices mxid
        if self._server_notices_mxid is not None:
            if user_id == self._server_notices_mxid:
                raise SynapseError(400,
                                   "This user ID is reserved.",
                                   errcode=Codes.EXCLUSIVE)

        # valid user IDs must not clash with any user ID namespaces claimed by
        # application services.
        services = self.store.get_app_services()
        interested_services = [
            s for s in services
            if s.is_interested_in_user(user_id) and s != allowed_appservice
        ]
        for service in interested_services:
            if service.is_exclusive_user(user_id):
                raise SynapseError(
                    400,
                    "This user ID is reserved by an application service.",
                    errcode=Codes.EXCLUSIVE)

    @defer.inlineCallbacks
    def _generate_user_id(self, reseed=False):
        if reseed or self._next_generated_user_id is None:
            with (yield self._generate_user_id_linearizer.queue(())):
                if reseed or self._next_generated_user_id is None:
                    self._next_generated_user_id = (
                        yield
                        self.store.find_next_generated_user_id_localpart())

        id = self._next_generated_user_id
        self._next_generated_user_id += 1
        defer.returnValue(str(id))

    @defer.inlineCallbacks
    def _validate_captcha(self, ip_addr, private_key, challenge, response):
        """Validates the captcha provided.

        Used only by c/s api v1

        Returns:
            dict: Containing 'valid'(bool) and 'error_url'(str) if invalid.

        """
        response = yield self._submit_captcha(ip_addr, private_key, challenge,
                                              response)
        # parse Google's response. Lovely format..
        lines = response.split('\n')
        json = {
            "valid":
            lines[0] == 'true',
            "error_url":
            "http://www.google.com/recaptcha/api/challenge?" +
            "error=%s" % lines[1]
        }
        defer.returnValue(json)

    @defer.inlineCallbacks
    def _submit_captcha(self, ip_addr, private_key, challenge, response):
        """
        Used only by c/s api v1
        """
        data = yield self.captcha_client.post_urlencoded_get_raw(
            "http://www.google.com:80/recaptcha/api/verify",
            args={
                'privatekey': private_key,
                'remoteip': ip_addr,
                'challenge': challenge,
                'response': response
            })
        defer.returnValue(data)

    @defer.inlineCallbacks
    def get_or_create_user(self,
                           requester,
                           localpart,
                           displayname,
                           password_hash=None):
        """Creates a new user if the user does not exist,
        else revokes all previous access tokens and generates a new one.

        Args:
            localpart : The local part of the user ID to register. If None,
              one will be randomly generated.
        Returns:
            A tuple of (user_id, access_token).
        Raises:
            RegistrationError if there was a problem registering.
        """
        if localpart is None:
            raise SynapseError(400, "Request must include user id")
        yield self.auth.check_auth_blocking()
        need_register = True

        try:
            yield self.check_username(localpart)
        except SynapseError as e:
            if e.errcode == Codes.USER_IN_USE:
                need_register = False
            else:
                raise

        user = UserID(localpart, self.hs.hostname)
        user_id = user.to_string()
        token = self.macaroon_gen.generate_access_token(user_id)

        if need_register:
            yield self.store.register(
                user_id=user_id,
                token=token,
                password_hash=password_hash,
                create_profile_with_localpart=user.localpart,
            )
        else:
            yield self._auth_handler.delete_access_tokens_for_user(user_id)
            yield self.store.add_access_token_to_user(user_id=user_id,
                                                      token=token)

        if displayname is not None:
            logger.info("setting user display name: %s -> %s", user_id,
                        displayname)
            yield self.profile_handler.set_displayname(
                user,
                requester,
                displayname,
                by_admin=True,
            )

        defer.returnValue((user_id, token))

    def auth_handler(self):
        return self.hs.get_auth_handler()

    @defer.inlineCallbacks
    def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
        """Get a guest access token for a 3PID, creating a guest account if
        one doesn't already exist.

        Args:
            medium (str)
            address (str)
            inviter_user_id (str): The user ID who is trying to invite the
                3PID

        Returns:
            Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
            3PID guest account.
        """
        access_token = yield self.store.get_3pid_guest_access_token(
            medium, address)
        if access_token:
            user_info = yield self.auth.get_user_by_access_token(access_token)

            defer.returnValue((user_info["user"].to_string(), access_token))

        user_id, access_token = yield self.register(generate_token=True,
                                                    make_guest=True)
        access_token = yield self.store.save_or_get_3pid_guest_access_token(
            medium, address, access_token, inviter_user_id)

        defer.returnValue((user_id, access_token))

    @defer.inlineCallbacks
    def _join_user_to_room(self, requester, room_identifier):
        room_id = None
        room_member_handler = self.hs.get_room_member_handler()
        if RoomID.is_valid(room_identifier):
            room_id = room_identifier
        elif RoomAlias.is_valid(room_identifier):
            room_alias = RoomAlias.from_string(room_identifier)
            room_id, remote_room_hosts = (
                yield room_member_handler.lookup_room_alias(room_alias))
            room_id = room_id.to_string()
        else:
            raise SynapseError(
                400,
                "%s was not legal room ID or room alias" % (room_identifier, ))

        yield room_member_handler.update_membership(
            requester=requester,
            target=requester.user,
            room_id=room_id,
            remote_room_hosts=remote_room_hosts,
            action="join",
            ratelimit=False,
        )
Пример #15
0
class RoomMemberHandler(object):
    # TODO(paul): This handler currently contains a messy conflation of
    #   low-level API that works on UserID objects and so on, and REST-level
    #   API that takes ID strings and returns pagination chunks. These concerns
    #   ought to be separated out a lot better.

    __metaclass__ = abc.ABCMeta

    def __init__(self, hs):
        """

        Args:
            hs (synapse.server.HomeServer):
        """
        self.hs = hs
        self.store = hs.get_datastore()
        self.auth = hs.get_auth()
        self.state_handler = hs.get_state_handler()
        self.config = hs.config

        self.federation_handler = hs.get_handlers().federation_handler
        self.directory_handler = hs.get_handlers().directory_handler
        self.identity_handler = hs.get_handlers().identity_handler
        self.registration_handler = hs.get_registration_handler()
        self.profile_handler = hs.get_profile_handler()
        self.event_creation_handler = hs.get_event_creation_handler()

        self.member_linearizer = Linearizer(name="member")

        self.clock = hs.get_clock()
        self.spam_checker = hs.get_spam_checker()
        self.third_party_event_rules = hs.get_third_party_event_rules()
        self._server_notices_mxid = self.config.server_notices_mxid
        self._enable_lookup = hs.config.enable_3pid_lookup
        self.allow_per_room_profiles = self.config.allow_per_room_profiles

        # This is only used to get at ratelimit function, and
        # maybe_kick_guest_users. It's fine there are multiple of these as
        # it doesn't store state.
        self.base_handler = BaseHandler(hs)

    @abc.abstractmethod
    def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
        """Try and join a room that this server is not in

        Args:
            requester (Requester)
            remote_room_hosts (list[str]): List of servers that can be used
                to join via.
            room_id (str): Room that we are trying to join
            user (UserID): User who is trying to join
            content (dict): A dict that should be used as the content of the
                join event.

        Returns:
            Deferred
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def _remote_reject_invite(
        self, requester, remote_room_hosts, room_id, target, content
    ):
        """Attempt to reject an invite for a room this server is not in. If we
        fail to do so we locally mark the invite as rejected.

        Args:
            requester (Requester)
            remote_room_hosts (list[str]): List of servers to use to try and
                reject invite
            room_id (str)
            target (UserID): The user rejecting the invite
            content (dict): The content for the rejection event

        Returns:
            Deferred[dict]: A dictionary to be returned to the client, may
            include event_id etc, or nothing if we locally rejected
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def _user_joined_room(self, target, room_id):
        """Notifies distributor on master process that the user has joined the
        room.

        Args:
            target (UserID)
            room_id (str)

        Returns:
            Deferred|None
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def _user_left_room(self, target, room_id):
        """Notifies distributor on master process that the user has left the
        room.

        Args:
            target (UserID)
            room_id (str)

        Returns:
            Deferred|None
        """
        raise NotImplementedError()

    @defer.inlineCallbacks
    def _local_membership_update(
        self,
        requester,
        target,
        room_id,
        membership,
        prev_event_ids: Collection[str],
        txn_id=None,
        ratelimit=True,
        content=None,
        require_consent=True,
    ):
        user_id = target.to_string()

        if content is None:
            content = {}

        content["membership"] = membership
        if requester.is_guest:
            content["kind"] = "guest"

        event, context = yield self.event_creation_handler.create_event(
            requester,
            {
                "type": EventTypes.Member,
                "content": content,
                "room_id": room_id,
                "sender": requester.user.to_string(),
                "state_key": user_id,
                # For backwards compatibility:
                "membership": membership,
            },
            token_id=requester.access_token_id,
            txn_id=txn_id,
            prev_event_ids=prev_event_ids,
            require_consent=require_consent,
        )

        # Check if this event matches the previous membership event for the user.
        duplicate = yield self.event_creation_handler.deduplicate_state_event(
            event, context
        )
        if duplicate is not None:
            # Discard the new event since this membership change is a no-op.
            return duplicate

        yield self.event_creation_handler.handle_new_client_event(
            requester, event, context, extra_users=[target], ratelimit=ratelimit
        )

        prev_state_ids = yield context.get_prev_state_ids()

        prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)

        if event.membership == Membership.JOIN:
            # Only fire user_joined_room if the user has actually joined the
            # room. Don't bother if the user is just changing their profile
            # info.
            newly_joined = True
            if prev_member_event_id:
                prev_member_event = yield self.store.get_event(prev_member_event_id)
                newly_joined = prev_member_event.membership != Membership.JOIN
            if newly_joined:
                yield self._user_joined_room(target, room_id)
        elif event.membership == Membership.LEAVE:
            if prev_member_event_id:
                prev_member_event = yield self.store.get_event(prev_member_event_id)
                if prev_member_event.membership == Membership.JOIN:
                    yield self._user_left_room(target, room_id)

        return event

    @defer.inlineCallbacks
    def copy_room_tags_and_direct_to_room(self, old_room_id, new_room_id, user_id):
        """Copies the tags and direct room state from one room to another.

        Args:
            old_room_id (str)
            new_room_id (str)
            user_id (str)

        Returns:
            Deferred[None]
        """
        # Retrieve user account data for predecessor room
        user_account_data, _ = yield self.store.get_account_data_for_user(user_id)

        # Copy direct message state if applicable
        direct_rooms = user_account_data.get("m.direct", {})

        # Check which key this room is under
        if isinstance(direct_rooms, dict):
            for key, room_id_list in direct_rooms.items():
                if old_room_id in room_id_list and new_room_id not in room_id_list:
                    # Add new room_id to this key
                    direct_rooms[key].append(new_room_id)

                    # Save back to user's m.direct account data
                    yield self.store.add_account_data_for_user(
                        user_id, "m.direct", direct_rooms
                    )
                    break

        # Copy room tags if applicable
        room_tags = yield self.store.get_tags_for_room(user_id, old_room_id)

        # Copy each room tag to the new room
        for tag, tag_content in room_tags.items():
            yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)

    @defer.inlineCallbacks
    def update_membership(
        self,
        requester,
        target,
        room_id,
        action,
        txn_id=None,
        remote_room_hosts=None,
        third_party_signed=None,
        ratelimit=True,
        content=None,
        require_consent=True,
    ):
        key = (room_id,)

        with (yield self.member_linearizer.queue(key)):
            result = yield self._update_membership(
                requester,
                target,
                room_id,
                action,
                txn_id=txn_id,
                remote_room_hosts=remote_room_hosts,
                third_party_signed=third_party_signed,
                ratelimit=ratelimit,
                content=content,
                require_consent=require_consent,
            )

        return result

    @defer.inlineCallbacks
    def _update_membership(
        self,
        requester,
        target,
        room_id,
        action,
        txn_id=None,
        remote_room_hosts=None,
        third_party_signed=None,
        ratelimit=True,
        content=None,
        require_consent=True,
    ):
        content_specified = bool(content)
        if content is None:
            content = {}
        else:
            # We do a copy here as we potentially change some keys
            # later on.
            content = dict(content)

        if not self.allow_per_room_profiles:
            # Strip profile data, knowing that new profile data will be added to the
            # event's content in event_creation_handler.create_event() using the target's
            # global profile.
            content.pop("displayname", None)
            content.pop("avatar_url", None)

        effective_membership_state = action
        if action in ["kick", "unban"]:
            effective_membership_state = "leave"

        # if this is a join with a 3pid signature, we may need to turn a 3pid
        # invite into a normal invite before we can handle the join.
        if third_party_signed is not None:
            yield self.federation_handler.exchange_third_party_invite(
                third_party_signed["sender"],
                target.to_string(),
                room_id,
                third_party_signed,
            )

        if not remote_room_hosts:
            remote_room_hosts = []

        if effective_membership_state not in ("leave", "ban"):
            is_blocked = yield self.store.is_room_blocked(room_id)
            if is_blocked:
                raise SynapseError(403, "This room has been blocked on this server")

        if effective_membership_state == Membership.INVITE:
            # block any attempts to invite the server notices mxid
            if target.to_string() == self._server_notices_mxid:
                raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user")

            block_invite = False

            if (
                self._server_notices_mxid is not None
                and requester.user.to_string() == self._server_notices_mxid
            ):
                # allow the server notices mxid to send invites
                is_requester_admin = True

            else:
                is_requester_admin = yield self.auth.is_server_admin(requester.user)

            if not is_requester_admin:
                if self.config.block_non_admin_invites:
                    logger.info(
                        "Blocking invite: user is not admin and non-admin "
                        "invites disabled"
                    )
                    block_invite = True

                if not self.spam_checker.user_may_invite(
                    requester.user.to_string(), target.to_string(), room_id
                ):
                    logger.info("Blocking invite due to spam checker")
                    block_invite = True

            if block_invite:
                raise SynapseError(403, "Invites have been disabled on this server")

        latest_event_ids = yield self.store.get_prev_events_for_room(room_id)

        current_state_ids = yield self.state_handler.get_current_state_ids(
            room_id, latest_event_ids=latest_event_ids
        )

        # TODO: Refactor into dictionary of explicitly allowed transitions
        # between old and new state, with specific error messages for some
        # transitions and generic otherwise
        old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
        if old_state_id:
            old_state = yield self.store.get_event(old_state_id, allow_none=True)
            old_membership = old_state.content.get("membership") if old_state else None
            if action == "unban" and old_membership != "ban":
                raise SynapseError(
                    403,
                    "Cannot unban user who was not banned"
                    " (membership=%s)" % old_membership,
                    errcode=Codes.BAD_STATE,
                )
            if old_membership == "ban" and action != "unban":
                raise SynapseError(
                    403,
                    "Cannot %s user who was banned" % (action,),
                    errcode=Codes.BAD_STATE,
                )

            if old_state:
                same_content = content == old_state.content
                same_membership = old_membership == effective_membership_state
                same_sender = requester.user.to_string() == old_state.sender
                if same_sender and same_membership and same_content:
                    return old_state

            if old_membership in ["ban", "leave"] and action == "kick":
                raise AuthError(403, "The target user is not in the room")

            # we don't allow people to reject invites to the server notice
            # room, but they can leave it once they are joined.
            if (
                old_membership == Membership.INVITE
                and effective_membership_state == Membership.LEAVE
            ):
                is_blocked = yield self._is_server_notice_room(room_id)
                if is_blocked:
                    raise SynapseError(
                        http_client.FORBIDDEN,
                        "You cannot reject this invite",
                        errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM,
                    )
        else:
            if action == "kick":
                raise AuthError(403, "The target user is not in the room")

        is_host_in_room = yield self._is_host_in_room(current_state_ids)

        if effective_membership_state == Membership.JOIN:
            if requester.is_guest:
                guest_can_join = yield self._can_guest_join(current_state_ids)
                if not guest_can_join:
                    # This should be an auth check, but guests are a local concept,
                    # so don't really fit into the general auth process.
                    raise AuthError(403, "Guest access not allowed")

            if not is_host_in_room:
                inviter = yield self._get_inviter(target.to_string(), room_id)
                if inviter and not self.hs.is_mine(inviter):
                    remote_room_hosts.append(inviter.domain)

                content["membership"] = Membership.JOIN

                profile = self.profile_handler
                if not content_specified:
                    content["displayname"] = yield profile.get_displayname(target)
                    content["avatar_url"] = yield profile.get_avatar_url(target)

                if requester.is_guest:
                    content["kind"] = "guest"

                remote_join_response = yield self._remote_join(
                    requester, remote_room_hosts, room_id, target, content
                )

                return remote_join_response

        elif effective_membership_state == Membership.LEAVE:
            if not is_host_in_room:
                # perhaps we've been invited
                inviter = yield self._get_inviter(target.to_string(), room_id)
                if not inviter:
                    raise SynapseError(404, "Not a known room")

                if self.hs.is_mine(inviter):
                    # the inviter was on our server, but has now left. Carry on
                    # with the normal rejection codepath.
                    #
                    # This is a bit of a hack, because the room might still be
                    # active on other servers.
                    pass
                else:
                    # send the rejection to the inviter's HS.
                    remote_room_hosts = remote_room_hosts + [inviter.domain]
                    res = yield self._remote_reject_invite(
                        requester, remote_room_hosts, room_id, target, content,
                    )
                    return res

        res = yield self._local_membership_update(
            requester=requester,
            target=target,
            room_id=room_id,
            membership=effective_membership_state,
            txn_id=txn_id,
            ratelimit=ratelimit,
            prev_event_ids=latest_event_ids,
            content=content,
            require_consent=require_consent,
        )
        return res

    @defer.inlineCallbacks
    def transfer_room_state_on_room_upgrade(self, old_room_id, room_id):
        """Upon our server becoming aware of an upgraded room, either by upgrading a room
        ourselves or joining one, we can transfer over information from the previous room.

        Copies user state (tags/push rules) for every local user that was in the old room, as
        well as migrating the room directory state.

        Args:
            old_room_id (str): The ID of the old room

            room_id (str): The ID of the new room

        Returns:
            Deferred
        """
        logger.info("Transferring room state from %s to %s", old_room_id, room_id)

        # Find all local users that were in the old room and copy over each user's state
        users = yield self.store.get_users_in_room(old_room_id)
        yield self.copy_user_state_on_room_upgrade(old_room_id, room_id, users)

        # Add new room to the room directory if the old room was there
        # Remove old room from the room directory
        old_room = yield self.store.get_room(old_room_id)
        if old_room and old_room["is_public"]:
            yield self.store.set_room_is_public(old_room_id, False)
            yield self.store.set_room_is_public(room_id, True)

        # Check if any groups we own contain the predecessor room
        local_group_ids = yield self.store.get_local_groups_for_room(old_room_id)
        for group_id in local_group_ids:
            # Add new the new room to those groups
            yield self.store.add_room_to_group(group_id, room_id, old_room["is_public"])

            # Remove the old room from those groups
            yield self.store.remove_room_from_group(group_id, old_room_id)

    @defer.inlineCallbacks
    def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids):
        """Copy user-specific information when they join a new room when that new room is the
        result of a room upgrade

        Args:
            old_room_id (str): The ID of upgraded room
            new_room_id (str): The ID of the new room
            user_ids (Iterable[str]): User IDs to copy state for

        Returns:
            Deferred
        """

        logger.debug(
            "Copying over room tags and push rules from %s to %s for users %s",
            old_room_id,
            new_room_id,
            user_ids,
        )

        for user_id in user_ids:
            try:
                # It is an upgraded room. Copy over old tags
                yield self.copy_room_tags_and_direct_to_room(
                    old_room_id, new_room_id, user_id
                )
                # Copy over push rules
                yield self.store.copy_push_rules_from_room_to_room_for_user(
                    old_room_id, new_room_id, user_id
                )
            except Exception:
                logger.exception(
                    "Error copying tags and/or push rules from rooms %s to %s for user %s. "
                    "Skipping...",
                    old_room_id,
                    new_room_id,
                    user_id,
                )
                continue

    @defer.inlineCallbacks
    def send_membership_event(self, requester, event, context, ratelimit=True):
        """
        Change the membership status of a user in a room.

        Args:
            requester (Requester): The local user who requested the membership
                event. If None, certain checks, like whether this homeserver can
                act as the sender, will be skipped.
            event (SynapseEvent): The membership event.
            context: The context of the event.
            ratelimit (bool): Whether to rate limit this request.
        Raises:
            SynapseError if there was a problem changing the membership.
        """
        target_user = UserID.from_string(event.state_key)
        room_id = event.room_id

        if requester is not None:
            sender = UserID.from_string(event.sender)
            assert (
                sender == requester.user
            ), "Sender (%s) must be same as requester (%s)" % (sender, requester.user)
            assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
        else:
            requester = types.create_requester(target_user)

        prev_event = yield self.event_creation_handler.deduplicate_state_event(
            event, context
        )
        if prev_event is not None:
            return

        prev_state_ids = yield context.get_prev_state_ids()
        if event.membership == Membership.JOIN:
            if requester.is_guest:
                guest_can_join = yield self._can_guest_join(prev_state_ids)
                if not guest_can_join:
                    # This should be an auth check, but guests are a local concept,
                    # so don't really fit into the general auth process.
                    raise AuthError(403, "Guest access not allowed")

        if event.membership not in (Membership.LEAVE, Membership.BAN):
            is_blocked = yield self.store.is_room_blocked(room_id)
            if is_blocked:
                raise SynapseError(403, "This room has been blocked on this server")

        yield self.event_creation_handler.handle_new_client_event(
            requester, event, context, extra_users=[target_user], ratelimit=ratelimit
        )

        prev_member_event_id = prev_state_ids.get(
            (EventTypes.Member, event.state_key), None
        )

        if event.membership == Membership.JOIN:
            # Only fire user_joined_room if the user has actually joined the
            # room. Don't bother if the user is just changing their profile
            # info.
            newly_joined = True
            if prev_member_event_id:
                prev_member_event = yield self.store.get_event(prev_member_event_id)
                newly_joined = prev_member_event.membership != Membership.JOIN
            if newly_joined:
                yield self._user_joined_room(target_user, room_id)
        elif event.membership == Membership.LEAVE:
            if prev_member_event_id:
                prev_member_event = yield self.store.get_event(prev_member_event_id)
                if prev_member_event.membership == Membership.JOIN:
                    yield self._user_left_room(target_user, room_id)

    @defer.inlineCallbacks
    def _can_guest_join(self, current_state_ids):
        """
        Returns whether a guest can join a room based on its current state.
        """
        guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
        if not guest_access_id:
            return False

        guest_access = yield self.store.get_event(guest_access_id)

        return (
            guest_access
            and guest_access.content
            and "guest_access" in guest_access.content
            and guest_access.content["guest_access"] == "can_join"
        )

    @defer.inlineCallbacks
    def lookup_room_alias(self, room_alias):
        """
        Get the room ID associated with a room alias.

        Args:
            room_alias (RoomAlias): The alias to look up.
        Returns:
            A tuple of:
                The room ID as a RoomID object.
                Hosts likely to be participating in the room ([str]).
        Raises:
            SynapseError if room alias could not be found.
        """
        directory_handler = self.directory_handler
        mapping = yield directory_handler.get_association(room_alias)

        if not mapping:
            raise SynapseError(404, "No such room alias")

        room_id = mapping["room_id"]
        servers = mapping["servers"]

        # put the server which owns the alias at the front of the server list.
        if room_alias.domain in servers:
            servers.remove(room_alias.domain)
        servers.insert(0, room_alias.domain)

        return RoomID.from_string(room_id), servers

    @defer.inlineCallbacks
    def _get_inviter(self, user_id, room_id):
        invite = yield self.store.get_invite_for_local_user_in_room(
            user_id=user_id, room_id=room_id
        )
        if invite:
            return UserID.from_string(invite.sender)

    @defer.inlineCallbacks
    def do_3pid_invite(
        self,
        room_id,
        inviter,
        medium,
        address,
        id_server,
        requester,
        txn_id,
        id_access_token=None,
    ):
        if self.config.block_non_admin_invites:
            is_requester_admin = yield self.auth.is_server_admin(requester.user)
            if not is_requester_admin:
                raise SynapseError(
                    403, "Invites have been disabled on this server", Codes.FORBIDDEN
                )

        # We need to rate limit *before* we send out any 3PID invites, so we
        # can't just rely on the standard ratelimiting of events.
        yield self.base_handler.ratelimit(requester)

        can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited(
            medium, address, room_id
        )
        if not can_invite:
            raise SynapseError(
                403,
                "This third-party identifier can not be invited in this room",
                Codes.FORBIDDEN,
            )

        if not self._enable_lookup:
            raise SynapseError(
                403, "Looking up third-party identifiers is denied from this server"
            )

        invitee = yield self.identity_handler.lookup_3pid(
            id_server, medium, address, id_access_token
        )

        if invitee:
            yield self.update_membership(
                requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
            )
        else:
            yield self._make_and_store_3pid_invite(
                requester,
                id_server,
                medium,
                address,
                room_id,
                inviter,
                txn_id=txn_id,
                id_access_token=id_access_token,
            )

    @defer.inlineCallbacks
    def _make_and_store_3pid_invite(
        self,
        requester,
        id_server,
        medium,
        address,
        room_id,
        user,
        txn_id,
        id_access_token=None,
    ):
        room_state = yield self.state_handler.get_current_state(room_id)

        inviter_display_name = ""
        inviter_avatar_url = ""
        member_event = room_state.get((EventTypes.Member, user.to_string()))
        if member_event:
            inviter_display_name = member_event.content.get("displayname", "")
            inviter_avatar_url = member_event.content.get("avatar_url", "")

        # if user has no display name, default to their MXID
        if not inviter_display_name:
            inviter_display_name = user.to_string()

        canonical_room_alias = ""
        canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, ""))
        if canonical_alias_event:
            canonical_room_alias = canonical_alias_event.content.get("alias", "")

        room_name = ""
        room_name_event = room_state.get((EventTypes.Name, ""))
        if room_name_event:
            room_name = room_name_event.content.get("name", "")

        room_join_rules = ""
        join_rules_event = room_state.get((EventTypes.JoinRules, ""))
        if join_rules_event:
            room_join_rules = join_rules_event.content.get("join_rule", "")

        room_avatar_url = ""
        room_avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
        if room_avatar_event:
            room_avatar_url = room_avatar_event.content.get("url", "")

        (
            token,
            public_keys,
            fallback_public_key,
            display_name,
        ) = yield self.identity_handler.ask_id_server_for_third_party_invite(
            requester=requester,
            id_server=id_server,
            medium=medium,
            address=address,
            room_id=room_id,
            inviter_user_id=user.to_string(),
            room_alias=canonical_room_alias,
            room_avatar_url=room_avatar_url,
            room_join_rules=room_join_rules,
            room_name=room_name,
            inviter_display_name=inviter_display_name,
            inviter_avatar_url=inviter_avatar_url,
            id_access_token=id_access_token,
        )

        yield self.event_creation_handler.create_and_send_nonmember_event(
            requester,
            {
                "type": EventTypes.ThirdPartyInvite,
                "content": {
                    "display_name": display_name,
                    "public_keys": public_keys,
                    # For backwards compatibility:
                    "key_validity_url": fallback_public_key["key_validity_url"],
                    "public_key": fallback_public_key["public_key"],
                },
                "room_id": room_id,
                "sender": user.to_string(),
                "state_key": token,
            },
            ratelimit=False,
            txn_id=txn_id,
        )

    @defer.inlineCallbacks
    def _is_host_in_room(self, current_state_ids):
        # Have we just created the room, and is this about to be the very
        # first member event?
        create_event_id = current_state_ids.get(("m.room.create", ""))
        if len(current_state_ids) == 1 and create_event_id:
            # We can only get here if we're in the process of creating the room
            return True

        for etype, state_key in current_state_ids:
            if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
                continue

            event_id = current_state_ids[(etype, state_key)]
            event = yield self.store.get_event(event_id, allow_none=True)
            if not event:
                continue

            if event.membership == Membership.JOIN:
                return True

        return False

    @defer.inlineCallbacks
    def _is_server_notice_room(self, room_id):
        if self._server_notices_mxid is None:
            return False
        user_ids = yield self.store.get_users_in_room(room_id)
        return self._server_notices_mxid in user_ids
Пример #16
0
class PresenceHandler(object):
    def __init__(self, hs):
        """

        Args:
            hs (synapse.server.HomeServer):
        """
        self.hs = hs
        self.is_mine = hs.is_mine
        self.is_mine_id = hs.is_mine_id
        self.server_name = hs.hostname
        self.clock = hs.get_clock()
        self.store = hs.get_datastore()
        self.wheel_timer = WheelTimer()
        self.notifier = hs.get_notifier()
        self.federation = hs.get_federation_sender()
        self.state = hs.get_state_handler()

        federation_registry = hs.get_federation_registry()

        federation_registry.register_edu_handler("m.presence",
                                                 self.incoming_presence)

        active_presence = self.store.take_presence_startup_info()

        # A dictionary of the current state of users. This is prefilled with
        # non-offline presence from the DB. We should fetch from the DB if
        # we can't find a users presence in here.
        self.user_to_current_state = {
            state.user_id: state
            for state in active_presence
        }

        LaterGauge("synapse_handlers_presence_user_to_current_state_size", "",
                   [], lambda: len(self.user_to_current_state))

        now = self.clock.time_msec()
        for state in active_presence:
            self.wheel_timer.insert(
                now=now,
                obj=state.user_id,
                then=state.last_active_ts + IDLE_TIMER,
            )
            self.wheel_timer.insert(
                now=now,
                obj=state.user_id,
                then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
            )
            if self.is_mine_id(state.user_id):
                self.wheel_timer.insert(
                    now=now,
                    obj=state.user_id,
                    then=state.last_federation_update_ts +
                    FEDERATION_PING_INTERVAL,
                )
            else:
                self.wheel_timer.insert(
                    now=now,
                    obj=state.user_id,
                    then=state.last_federation_update_ts + FEDERATION_TIMEOUT,
                )

        # Set of users who have presence in the `user_to_current_state` that
        # have not yet been persisted
        self.unpersisted_users_changes = set()

        hs.get_reactor().addSystemEventTrigger("before", "shutdown",
                                               self._on_shutdown)

        self.serial_to_user = {}
        self._next_serial = 1

        # Keeps track of the number of *ongoing* syncs on this process. While
        # this is non zero a user will never go offline.
        self.user_to_num_current_syncs = {}

        # Keeps track of the number of *ongoing* syncs on other processes.
        # While any sync is ongoing on another process the user will never
        # go offline.
        # Each process has a unique identifier and an update frequency. If
        # no update is received from that process within the update period then
        # we assume that all the sync requests on that process have stopped.
        # Stored as a dict from process_id to set of user_id, and a dict of
        # process_id to millisecond timestamp last updated.
        self.external_process_to_current_syncs = {}
        self.external_process_last_updated_ms = {}
        self.external_sync_linearizer = Linearizer(
            name="external_sync_linearizer")

        # Start a LoopingCall in 30s that fires every 5s.
        # The initial delay is to allow disconnected clients a chance to
        # reconnect before we treat them as offline.
        self.clock.call_later(
            30,
            self.clock.looping_call,
            self._handle_timeouts,
            5000,
        )

        self.clock.call_later(
            60,
            self.clock.looping_call,
            self._persist_unpersisted_changes,
            60 * 1000,
        )

        LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [],
                   lambda: len(self.wheel_timer))

        # Used to handle sending of presence to newly joined users/servers
        if hs.config.use_presence:
            self.notifier.add_replication_callback(self.notify_new_event)

        # Presence is best effort and quickly heals itself, so lets just always
        # stream from the current state when we restart.
        self._event_pos = self.store.get_current_events_token()
        self._event_processing = False

    @defer.inlineCallbacks
    def _on_shutdown(self):
        """Gets called when shutting down. This lets us persist any updates that
        we haven't yet persisted, e.g. updates that only changes some internal
        timers. This allows changes to persist across startup without having to
        persist every single change.

        If this does not run it simply means that some of the timers will fire
        earlier than they should when synapse is restarted. This affect of this
        is some spurious presence changes that will self-correct.
        """
        # If the DB pool has already terminated, don't try updating
        if not self.hs.get_db_pool().running:
            return

        logger.info(
            "Performing _on_shutdown. Persisting %d unpersisted changes",
            len(self.user_to_current_state))

        if self.unpersisted_users_changes:
            yield self.store.update_presence([
                self.user_to_current_state[user_id]
                for user_id in self.unpersisted_users_changes
            ])
        logger.info("Finished _on_shutdown")

    @defer.inlineCallbacks
    def _persist_unpersisted_changes(self):
        """We periodically persist the unpersisted changes, as otherwise they
        may stack up and slow down shutdown times.
        """
        logger.info(
            "Performing _persist_unpersisted_changes. Persisting %d unpersisted changes",
            len(self.unpersisted_users_changes))

        unpersisted = self.unpersisted_users_changes
        self.unpersisted_users_changes = set()

        if unpersisted:
            yield self.store.update_presence([
                self.user_to_current_state[user_id] for user_id in unpersisted
            ])

        logger.info("Finished _persist_unpersisted_changes")

    @defer.inlineCallbacks
    def _update_states_and_catch_exception(self, new_states):
        try:
            res = yield self._update_states(new_states)
            defer.returnValue(res)
        except Exception:
            logger.exception("Error updating presence")

    @defer.inlineCallbacks
    def _update_states(self, new_states):
        """Updates presence of users. Sets the appropriate timeouts. Pokes
        the notifier and federation if and only if the changed presence state
        should be sent to clients/servers.
        """
        now = self.clock.time_msec()

        with Measure(self.clock, "presence_update_states"):

            # NOTE: We purposefully don't yield between now and when we've
            # calculated what we want to do with the new states, to avoid races.

            to_notify = {}  # Changes we want to notify everyone about
            to_federation_ping = {}  # These need sending keep-alives

            # Only bother handling the last presence change for each user
            new_states_dict = {}
            for new_state in new_states:
                new_states_dict[new_state.user_id] = new_state
            new_state = new_states_dict.values()

            for new_state in new_states:
                user_id = new_state.user_id

                # Its fine to not hit the database here, as the only thing not in
                # the current state cache are OFFLINE states, where the only field
                # of interest is last_active which is safe enough to assume is 0
                # here.
                prev_state = self.user_to_current_state.get(
                    user_id, UserPresenceState.default(user_id))

                new_state, should_notify, should_ping = handle_update(
                    prev_state,
                    new_state,
                    is_mine=self.is_mine_id(user_id),
                    wheel_timer=self.wheel_timer,
                    now=now)

                self.user_to_current_state[user_id] = new_state

                if should_notify:
                    to_notify[user_id] = new_state
                elif should_ping:
                    to_federation_ping[user_id] = new_state

            # TODO: We should probably ensure there are no races hereafter

            presence_updates_counter.inc(len(new_states))

            if to_notify:
                notified_presence_counter.inc(len(to_notify))
                yield self._persist_and_notify(list(to_notify.values()))

            self.unpersisted_users_changes |= set(s.user_id
                                                  for s in new_states)
            self.unpersisted_users_changes -= set(to_notify.keys())

            to_federation_ping = {
                user_id: state
                for user_id, state in to_federation_ping.items()
                if user_id not in to_notify
            }
            if to_federation_ping:
                federation_presence_out_counter.inc(len(to_federation_ping))

                self._push_to_remotes(to_federation_ping.values())

    def _handle_timeouts(self):
        """Checks the presence of users that have timed out and updates as
        appropriate.
        """
        logger.info("Handling presence timeouts")
        now = self.clock.time_msec()

        try:
            with Measure(self.clock, "presence_handle_timeouts"):
                # Fetch the list of users that *may* have timed out. Things may have
                # changed since the timeout was set, so we won't necessarily have to
                # take any action.
                users_to_check = set(self.wheel_timer.fetch(now))

                # Check whether the lists of syncing processes from an external
                # process have expired.
                expired_process_ids = [
                    process_id for process_id, last_update in
                    self.external_process_last_updated_ms.items()
                    if now - last_update > EXTERNAL_PROCESS_EXPIRY
                ]
                for process_id in expired_process_ids:
                    users_to_check.update(
                        self.external_process_last_updated_ms.pop(
                            process_id, ()))
                    self.external_process_last_update.pop(process_id)

                states = [
                    self.user_to_current_state.get(
                        user_id, UserPresenceState.default(user_id))
                    for user_id in users_to_check
                ]

                timers_fired_counter.inc(len(states))

                changes = handle_timeouts(
                    states,
                    is_mine_fn=self.is_mine_id,
                    syncing_user_ids=self.get_currently_syncing_users(),
                    now=now,
                )

            run_in_background(self._update_states_and_catch_exception, changes)
        except Exception:
            logger.exception("Exception in _handle_timeouts loop")

    @defer.inlineCallbacks
    def bump_presence_active_time(self, user):
        """We've seen the user do something that indicates they're interacting
        with the app.
        """
        # If presence is disabled, no-op
        if not self.hs.config.use_presence:
            return

        user_id = user.to_string()

        bump_active_time_counter.inc()

        prev_state = yield self.current_state_for_user(user_id)

        new_fields = {
            "last_active_ts": self.clock.time_msec(),
        }
        if prev_state.state == PresenceState.UNAVAILABLE:
            new_fields["state"] = PresenceState.ONLINE

        yield self._update_states([prev_state.copy_and_replace(**new_fields)])

    @defer.inlineCallbacks
    def user_syncing(self, user_id, affect_presence=True):
        """Returns a context manager that should surround any stream requests
        from the user.

        This allows us to keep track of who is currently streaming and who isn't
        without having to have timers outside of this module to avoid flickering
        when users disconnect/reconnect.

        Args:
            user_id (str)
            affect_presence (bool): If false this function will be a no-op.
                Useful for streams that are not associated with an actual
                client that is being used by a user.
        """
        # Override if it should affect the user's presence, if presence is
        # disabled.
        if not self.hs.config.use_presence:
            affect_presence = False

        if affect_presence:
            curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
            self.user_to_num_current_syncs[user_id] = curr_sync + 1

            prev_state = yield self.current_state_for_user(user_id)
            if prev_state.state == PresenceState.OFFLINE:
                # If they're currently offline then bring them online, otherwise
                # just update the last sync times.
                yield self._update_states([
                    prev_state.copy_and_replace(
                        state=PresenceState.ONLINE,
                        last_active_ts=self.clock.time_msec(),
                        last_user_sync_ts=self.clock.time_msec(),
                    )
                ])
            else:
                yield self._update_states([
                    prev_state.copy_and_replace(
                        last_user_sync_ts=self.clock.time_msec(), )
                ])

        @defer.inlineCallbacks
        def _end():
            try:
                self.user_to_num_current_syncs[user_id] -= 1

                prev_state = yield self.current_state_for_user(user_id)
                yield self._update_states([
                    prev_state.copy_and_replace(
                        last_user_sync_ts=self.clock.time_msec(), )
                ])
            except Exception:
                logger.exception("Error updating presence after sync")

        @contextmanager
        def _user_syncing():
            try:
                yield
            finally:
                if affect_presence:
                    run_in_background(_end)

        defer.returnValue(_user_syncing())

    def get_currently_syncing_users(self):
        """Get the set of user ids that are currently syncing on this HS.
        Returns:
            set(str): A set of user_id strings.
        """
        if self.hs.config.use_presence:
            syncing_user_ids = {
                user_id
                for user_id, count in self.user_to_num_current_syncs.items()
                if count
            }
            for user_ids in self.external_process_to_current_syncs.values():
                syncing_user_ids.update(user_ids)
            return syncing_user_ids
        else:
            return set()

    @defer.inlineCallbacks
    def update_external_syncs_row(self, process_id, user_id, is_syncing,
                                  sync_time_msec):
        """Update the syncing users for an external process as a delta.

        Args:
            process_id (str): An identifier for the process the users are
                syncing against. This allows synapse to process updates
                as user start and stop syncing against a given process.
            user_id (str): The user who has started or stopped syncing
            is_syncing (bool): Whether or not the user is now syncing
            sync_time_msec(int): Time in ms when the user was last syncing
        """
        with (yield self.external_sync_linearizer.queue(process_id)):
            prev_state = yield self.current_state_for_user(user_id)

            process_presence = self.external_process_to_current_syncs.setdefault(
                process_id, set())

            updates = []
            if is_syncing and user_id not in process_presence:
                if prev_state.state == PresenceState.OFFLINE:
                    updates.append(
                        prev_state.copy_and_replace(
                            state=PresenceState.ONLINE,
                            last_active_ts=sync_time_msec,
                            last_user_sync_ts=sync_time_msec,
                        ))
                else:
                    updates.append(
                        prev_state.copy_and_replace(
                            last_user_sync_ts=sync_time_msec, ))
                process_presence.add(user_id)
            elif user_id in process_presence:
                updates.append(
                    prev_state.copy_and_replace(
                        last_user_sync_ts=sync_time_msec, ))

            if not is_syncing:
                process_presence.discard(user_id)

            if updates:
                yield self._update_states(updates)

            self.external_process_last_updated_ms[
                process_id] = self.clock.time_msec()

    @defer.inlineCallbacks
    def update_external_syncs_clear(self, process_id):
        """Marks all users that had been marked as syncing by a given process
        as offline.

        Used when the process has stopped/disappeared.
        """
        with (yield self.external_sync_linearizer.queue(process_id)):
            process_presence = self.external_process_to_current_syncs.pop(
                process_id, set())
            prev_states = yield self.current_state_for_users(process_presence)
            time_now_ms = self.clock.time_msec()

            yield self._update_states([
                prev_state.copy_and_replace(last_user_sync_ts=time_now_ms, )
                for prev_state in itervalues(prev_states)
            ])
            self.external_process_last_updated_ms.pop(process_id, None)

    @defer.inlineCallbacks
    def current_state_for_user(self, user_id):
        """Get the current presence state for a user.
        """
        res = yield self.current_state_for_users([user_id])
        defer.returnValue(res[user_id])

    @defer.inlineCallbacks
    def current_state_for_users(self, user_ids):
        """Get the current presence state for multiple users.

        Returns:
            dict: `user_id` -> `UserPresenceState`
        """
        states = {
            user_id: self.user_to_current_state.get(user_id, None)
            for user_id in user_ids
        }

        missing = [
            user_id for user_id, state in iteritems(states) if not state
        ]
        if missing:
            # There are things not in our in memory cache. Lets pull them out of
            # the database.
            res = yield self.store.get_presence_for_users(missing)
            states.update(res)

            missing = [
                user_id for user_id, state in iteritems(states) if not state
            ]
            if missing:
                new = {
                    user_id: UserPresenceState.default(user_id)
                    for user_id in missing
                }
                states.update(new)
                self.user_to_current_state.update(new)

        defer.returnValue(states)

    @defer.inlineCallbacks
    def _persist_and_notify(self, states):
        """Persist states in the database, poke the notifier and send to
        interested remote servers
        """
        stream_id, max_token = yield self.store.update_presence(states)

        parties = yield get_interested_parties(self.store, states)
        room_ids_to_states, users_to_states = parties

        self.notifier.on_new_event(
            "presence_key",
            stream_id,
            rooms=room_ids_to_states.keys(),
            users=[UserID.from_string(u) for u in users_to_states])

        self._push_to_remotes(states)

    @defer.inlineCallbacks
    def notify_for_states(self, state, stream_id):
        parties = yield get_interested_parties(self.store, [state])
        room_ids_to_states, users_to_states = parties

        self.notifier.on_new_event(
            "presence_key",
            stream_id,
            rooms=room_ids_to_states.keys(),
            users=[UserID.from_string(u) for u in users_to_states])

    def _push_to_remotes(self, states):
        """Sends state updates to remote servers.

        Args:
            states (list(UserPresenceState))
        """
        self.federation.send_presence(states)

    @defer.inlineCallbacks
    def incoming_presence(self, origin, content):
        """Called when we receive a `m.presence` EDU from a remote server.
        """
        now = self.clock.time_msec()
        updates = []
        for push in content.get("push", []):
            # A "push" contains a list of presence that we are probably interested
            # in.
            # TODO: Actually check if we're interested, rather than blindly
            # accepting presence updates.
            user_id = push.get("user_id", None)
            if not user_id:
                logger.info(
                    "Got presence update from %r with no 'user_id': %r",
                    origin,
                    push,
                )
                continue

            if get_domain_from_id(user_id) != origin:
                logger.info(
                    "Got presence update from %r with bad 'user_id': %r",
                    origin,
                    user_id,
                )
                continue

            presence_state = push.get("presence", None)
            if not presence_state:
                logger.info(
                    "Got presence update from %r with no 'presence_state': %r",
                    origin,
                    push,
                )
                continue

            new_fields = {
                "state": presence_state,
                "last_federation_update_ts": now,
            }

            last_active_ago = push.get("last_active_ago", None)
            if last_active_ago is not None:
                new_fields["last_active_ts"] = now - last_active_ago

            new_fields["status_msg"] = push.get("status_msg", None)
            new_fields["currently_active"] = push.get("currently_active",
                                                      False)

            prev_state = yield self.current_state_for_user(user_id)
            updates.append(prev_state.copy_and_replace(**new_fields))

        if updates:
            federation_presence_counter.inc(len(updates))
            yield self._update_states(updates)

    @defer.inlineCallbacks
    def get_state(self, target_user, as_event=False):
        results = yield self.get_states(
            [target_user.to_string()],
            as_event=as_event,
        )

        defer.returnValue(results[0])

    @defer.inlineCallbacks
    def get_states(self, target_user_ids, as_event=False):
        """Get the presence state for users.

        Args:
            target_user_ids (list)
            as_event (bool): Whether to format it as a client event or not.

        Returns:
            list
        """

        updates = yield self.current_state_for_users(target_user_ids)
        updates = list(updates.values())

        for user_id in set(target_user_ids) - set(u.user_id for u in updates):
            updates.append(UserPresenceState.default(user_id))

        now = self.clock.time_msec()
        if as_event:
            defer.returnValue([{
                "type":
                "m.presence",
                "content":
                format_user_presence_state(state, now),
            } for state in updates])
        else:
            defer.returnValue(updates)

    @defer.inlineCallbacks
    def set_state(self, target_user, state, ignore_status_msg=False):
        """Set the presence state of the user.
        """
        status_msg = state.get("status_msg", None)
        presence = state["presence"]

        valid_presence = (PresenceState.ONLINE, PresenceState.UNAVAILABLE,
                          PresenceState.OFFLINE)
        if presence not in valid_presence:
            raise SynapseError(400, "Invalid presence state")

        user_id = target_user.to_string()

        prev_state = yield self.current_state_for_user(user_id)

        new_fields = {"state": presence}

        if not ignore_status_msg:
            msg = status_msg if presence != PresenceState.OFFLINE else None
            new_fields["status_msg"] = msg

        if presence == PresenceState.ONLINE:
            new_fields["last_active_ts"] = self.clock.time_msec()

        yield self._update_states([prev_state.copy_and_replace(**new_fields)])

    @defer.inlineCallbacks
    def is_visible(self, observed_user, observer_user):
        """Returns whether a user can see another user's presence.
        """
        observer_room_ids = yield self.store.get_rooms_for_user(
            observer_user.to_string())
        observed_room_ids = yield self.store.get_rooms_for_user(
            observed_user.to_string())

        if observer_room_ids & observed_room_ids:
            defer.returnValue(True)

        defer.returnValue(False)

    @defer.inlineCallbacks
    def get_all_presence_updates(self, last_id, current_id):
        """
        Gets a list of presence update rows from between the given stream ids.
        Each row has:
        - stream_id(str)
        - user_id(str)
        - state(str)
        - last_active_ts(int)
        - last_federation_update_ts(int)
        - last_user_sync_ts(int)
        - status_msg(int)
        - currently_active(int)
        """
        # TODO(markjh): replicate the unpersisted changes.
        # This could use the in-memory stores for recent changes.
        rows = yield self.store.get_all_presence_updates(last_id, current_id)
        defer.returnValue(rows)

    def notify_new_event(self):
        """Called when new events have happened. Handles users and servers
        joining rooms and require being sent presence.
        """

        if self._event_processing:
            return

        @defer.inlineCallbacks
        def _process_presence():
            assert not self._event_processing

            self._event_processing = True
            try:
                yield self._unsafe_process()
            finally:
                self._event_processing = False

        run_as_background_process("presence.notify_new_event",
                                  _process_presence)

    @defer.inlineCallbacks
    def _unsafe_process(self):
        # Loop round handling deltas until we're up to date
        while True:
            with Measure(self.clock, "presence_delta"):
                deltas = yield self.store.get_current_state_deltas(
                    self._event_pos)
                if not deltas:
                    return

                yield self._handle_state_delta(deltas)

                self._event_pos = deltas[-1]["stream_id"]

                # Expose current event processing position to prometheus
                synapse.metrics.event_processing_positions.labels(
                    "presence").set(self._event_pos)

    @defer.inlineCallbacks
    def _handle_state_delta(self, deltas):
        """Process current state deltas to find new joins that need to be
        handled.
        """
        for delta in deltas:
            typ = delta["type"]
            state_key = delta["state_key"]
            room_id = delta["room_id"]
            event_id = delta["event_id"]
            prev_event_id = delta["prev_event_id"]

            logger.debug("Handling: %r %r, %s", typ, state_key, event_id)

            if typ != EventTypes.Member:
                continue

            if event_id is None:
                # state has been deleted, so this is not a join. We only care about
                # joins.
                continue

            event = yield self.store.get_event(event_id)
            if event.content.get("membership") != Membership.JOIN:
                # We only care about joins
                continue

            if prev_event_id:
                prev_event = yield self.store.get_event(prev_event_id)
                if prev_event.content.get("membership") == Membership.JOIN:
                    # Ignore changes to join events.
                    continue

            yield self._on_user_joined_room(room_id, state_key)

    @defer.inlineCallbacks
    def _on_user_joined_room(self, room_id, user_id):
        """Called when we detect a user joining the room via the current state
        delta stream.

        Args:
            room_id (str)
            user_id (str)

        Returns:
            Deferred
        """

        if self.is_mine_id(user_id):
            # If this is a local user then we need to send their presence
            # out to hosts in the room (who don't already have it)

            # TODO: We should be able to filter the hosts down to those that
            # haven't previously seen the user

            state = yield self.current_state_for_user(user_id)
            hosts = yield self.state.get_current_hosts_in_room(room_id)

            # Filter out ourselves.
            hosts = set(host for host in hosts if host != self.server_name)

            self.federation.send_presence_to_destinations(
                states=[state],
                destinations=hosts,
            )
        else:
            # A remote user has joined the room, so we need to:
            #   1. Check if this is a new server in the room
            #   2. If so send any presence they don't already have for
            #      local users in the room.

            # TODO: We should be able to filter the users down to those that
            # the server hasn't previously seen

            # TODO: Check that this is actually a new server joining the
            # room.

            user_ids = yield self.state.get_current_users_in_room(room_id)
            user_ids = list(filter(self.is_mine_id, user_ids))

            states = yield self.current_state_for_users(user_ids)

            # Filter out old presence, i.e. offline presence states where
            # the user hasn't been active for a week. We can change this
            # depending on what we want the UX to be, but at the least we
            # should filter out offline presence where the state is just the
            # default state.
            now = self.clock.time_msec()
            states = [
                state for state in states.values()
                if state.state != PresenceState.OFFLINE or now -
                state.last_active_ts < 7 * 24 * 60 * 60 *
                1000 or state.status_msg is not None
            ]

            if states:
                self.federation.send_presence_to_destinations(
                    states=states,
                    destinations=[get_domain_from_id(user_id)],
                )
Пример #17
0
class DeviceListUpdater(object):
    "Handles incoming device list updates from federation and updates the DB"

    def __init__(self, hs, device_handler):
        self.store = hs.get_datastore()
        self.federation = hs.get_federation_client()
        self.clock = hs.get_clock()
        self.device_handler = device_handler

        self._remote_edu_linearizer = Linearizer(name="remote_device_list")

        # user_id -> list of updates waiting to be handled.
        self._pending_updates = {}

        # Recently seen stream ids. We don't bother keeping these in the DB,
        # but they're useful to have them about to reduce the number of spurious
        # resyncs.
        self._seen_updates = ExpiringCache(
            cache_name="device_update_edu",
            clock=self.clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
            iterable=True,
        )

    @trace
    @defer.inlineCallbacks
    def incoming_device_list_update(self, origin, edu_content):
        """Called on incoming device list update from federation. Responsible
        for parsing the EDU and adding to pending updates list.
        """

        set_tag("origin", origin)
        set_tag("edu_content", edu_content)
        user_id = edu_content.pop("user_id")
        device_id = edu_content.pop("device_id")
        stream_id = str(edu_content.pop("stream_id"))  # They may come as ints
        prev_ids = edu_content.pop("prev_id", [])
        prev_ids = [str(p) for p in prev_ids]  # They may come as ints

        if get_domain_from_id(user_id) != origin:
            # TODO: Raise?
            logger.warning(
                "Got device list update edu for %r/%r from %r",
                user_id,
                device_id,
                origin,
            )

            set_tag("error", True)
            log_kv(
                {
                    "message": "Got a device list update edu from a user and "
                    "device which does not match the origin of the request.",
                    "user_id": user_id,
                    "device_id": device_id,
                }
            )
            return

        room_ids = yield self.store.get_rooms_for_user(user_id)
        if not room_ids:
            # We don't share any rooms with this user. Ignore update, as we
            # probably won't get any further updates.
            set_tag("error", True)
            log_kv(
                {
                    "message": "Got an update from a user for which "
                    "we don't share any rooms",
                    "other user_id": user_id,
                }
            )
            logger.warning(
                "Got device list update edu for %r/%r, but don't share a room",
                user_id,
                device_id,
            )
            return

        logger.debug("Received device list update for %r/%r", user_id, device_id)

        self._pending_updates.setdefault(user_id, []).append(
            (device_id, stream_id, prev_ids, edu_content)
        )

        yield self._handle_device_updates(user_id)

    @measure_func("_incoming_device_list_update")
    @defer.inlineCallbacks
    def _handle_device_updates(self, user_id):
        "Actually handle pending updates."

        with (yield self._remote_edu_linearizer.queue(user_id)):
            pending_updates = self._pending_updates.pop(user_id, [])
            if not pending_updates:
                # This can happen since we batch updates
                return

            for device_id, stream_id, prev_ids, content in pending_updates:
                logger.debug(
                    "Handling update %r/%r, ID: %r, prev: %r ",
                    user_id,
                    device_id,
                    stream_id,
                    prev_ids,
                )

            # Given a list of updates we check if we need to resync. This
            # happens if we've missed updates.
            resync = yield self._need_to_do_resync(user_id, pending_updates)

            if logger.isEnabledFor(logging.INFO):
                logger.info(
                    "Received device list update for %s, requiring resync: %s. Devices: %s",
                    user_id,
                    resync,
                    ", ".join(u[0] for u in pending_updates),
                )

            if resync:
                yield self.user_device_resync(user_id)
            else:
                # Simply update the single device, since we know that is the only
                # change (because of the single prev_id matching the current cache)
                for device_id, stream_id, prev_ids, content in pending_updates:
                    yield self.store.update_remote_device_list_cache_entry(
                        user_id, device_id, content, stream_id
                    )

                yield self.device_handler.notify_device_update(
                    user_id, [device_id for device_id, _, _, _ in pending_updates]
                )

                self._seen_updates.setdefault(user_id, set()).update(
                    stream_id for _, stream_id, _, _ in pending_updates
                )

    @defer.inlineCallbacks
    def _need_to_do_resync(self, user_id, updates):
        """Given a list of updates for a user figure out if we need to do a full
        resync, or whether we have enough data that we can just apply the delta.
        """
        seen_updates = self._seen_updates.get(user_id, set())

        extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id)

        logger.debug("Current extremity for %r: %r", user_id, extremity)

        stream_id_in_updates = set()  # stream_ids in updates list
        for _, stream_id, prev_ids, _ in updates:
            if not prev_ids:
                # We always do a resync if there are no previous IDs
                return True

            for prev_id in prev_ids:
                if prev_id == extremity:
                    continue
                elif prev_id in seen_updates:
                    continue
                elif prev_id in stream_id_in_updates:
                    continue
                else:
                    return True

            stream_id_in_updates.add(stream_id)

        return False

    @defer.inlineCallbacks
    def user_device_resync(self, user_id):
        """Fetches all devices for a user and updates the device cache with them.

        Args:
            user_id (str): The user's id whose device_list will be updated.
        Returns:
            Deferred[dict]: a dict with device info as under the "devices" in the result of this
            request:
            https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
        """
        log_kv({"message": "Doing resync to update device list."})
        # Fetch all devices for the user.
        origin = get_domain_from_id(user_id)
        try:
            result = yield self.federation.query_user_devices(origin, user_id)
        except (NotRetryingDestination, RequestSendFailed, HttpResponseException):
            # TODO: Remember that we are now out of sync and try again
            # later
            logger.warning("Failed to handle device list update for %s", user_id)
            # We abort on exceptions rather than accepting the update
            # as otherwise synapse will 'forget' that its device list
            # is out of date. If we bail then we will retry the resync
            # next time we get a device list update for this user_id.
            # This makes it more likely that the device lists will
            # eventually become consistent.
            return
        except FederationDeniedError as e:
            set_tag("error", True)
            log_kv({"reason": "FederationDeniedError"})
            logger.info(e)
            return
        except Exception as e:
            # TODO: Remember that we are now out of sync and try again
            # later
            set_tag("error", True)
            log_kv(
                {"message": "Exception raised by federation request", "exception": e}
            )
            logger.exception("Failed to handle device list update for %s", user_id)
            return
        log_kv({"result": result})
        stream_id = result["stream_id"]
        devices = result["devices"]

        # If the remote server has more than ~1000 devices for this user
        # we assume that something is going horribly wrong (e.g. a bot
        # that logs in and creates a new device every time it tries to
        # send a message).  Maintaining lots of devices per user in the
        # cache can cause serious performance issues as if this request
        # takes more than 60s to complete, internal replication from the
        # inbound federation worker to the synapse master may time out
        # causing the inbound federation to fail and causing the remote
        # server to retry, causing a DoS.  So in this scenario we give
        # up on storing the total list of devices and only handle the
        # delta instead.
        if len(devices) > 1000:
            logger.warning(
                "Ignoring device list snapshot for %s as it has >1K devs (%d)",
                user_id,
                len(devices),
            )
            devices = []

        for device in devices:
            logger.debug(
                "Handling resync update %r/%r, ID: %r",
                user_id,
                device["device_id"],
                stream_id,
            )

        yield self.store.update_remote_device_list_cache(user_id, devices, stream_id)
        device_ids = [device["device_id"] for device in devices]
        yield self.device_handler.notify_device_update(user_id, device_ids)

        # We clobber the seen updates since we've re-synced from a given
        # point.
        self._seen_updates[user_id] = set([stream_id])

        defer.returnValue(result)
Пример #18
0
class MediaRepository(object):
    def __init__(self, hs):
        self.hs = hs
        self.auth = hs.get_auth()
        self.client = hs.get_http_client()
        self.clock = hs.get_clock()
        self.server_name = hs.hostname
        self.store = hs.get_datastore()
        self.max_upload_size = hs.config.max_upload_size
        self.max_image_pixels = hs.config.max_image_pixels

        self.primary_base_path = hs.config.media_store_path
        self.filepaths = MediaFilePaths(self.primary_base_path)

        self.dynamic_thumbnails = hs.config.dynamic_thumbnails
        self.thumbnail_requirements = hs.config.thumbnail_requirements

        self.remote_media_linearizer = Linearizer(name="media_remote")

        self.recently_accessed_remotes = set()
        self.recently_accessed_locals = set()

        self.federation_domain_whitelist = hs.config.federation_domain_whitelist

        # List of StorageProviders where we should search for media and
        # potentially upload to.
        storage_providers = []

        for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
            backend = clz(hs, provider_config)
            provider = StorageProviderWrapper(
                backend,
                store_local=wrapper_config.store_local,
                store_remote=wrapper_config.store_remote,
                store_synchronous=wrapper_config.store_synchronous,
            )
            storage_providers.append(provider)

        self.media_storage = MediaStorage(
            self.hs, self.primary_base_path, self.filepaths, storage_providers,
        )

        self.clock.looping_call(
            self._start_update_recently_accessed,
            UPDATE_RECENTLY_ACCESSED_TS,
        )

    def _start_update_recently_accessed(self):
        return run_as_background_process(
            "update_recently_accessed_media", self._update_recently_accessed,
        )

    @defer.inlineCallbacks
    def _update_recently_accessed(self):
        remote_media = self.recently_accessed_remotes
        self.recently_accessed_remotes = set()

        local_media = self.recently_accessed_locals
        self.recently_accessed_locals = set()

        yield self.store.update_cached_last_access_time(
            local_media, remote_media, self.clock.time_msec()
        )

    def mark_recently_accessed(self, server_name, media_id):
        """Mark the given media as recently accessed.

        Args:
            server_name (str|None): Origin server of media, or None if local
            media_id (str): The media ID of the content
        """
        if server_name:
            self.recently_accessed_remotes.add((server_name, media_id))
        else:
            self.recently_accessed_locals.add(media_id)

    @defer.inlineCallbacks
    def create_content(self, media_type, upload_name, content, content_length,
                       auth_user):
        """Store uploaded content for a local user and return the mxc URL

        Args:
            media_type(str): The content type of the file
            upload_name(str): The name of the file
            content: A file like object that is the content to store
            content_length(int): The length of the content
            auth_user(str): The user_id of the uploader

        Returns:
            Deferred[str]: The mxc url of the stored content
        """
        media_id = random_string(24)

        file_info = FileInfo(
            server_name=None,
            file_id=media_id,
        )

        fname = yield self.media_storage.store_file(content, file_info)

        logger.info("Stored local media in file %r", fname)

        yield self.store.store_local_media(
            media_id=media_id,
            media_type=media_type,
            time_now_ms=self.clock.time_msec(),
            upload_name=upload_name,
            media_length=content_length,
            user_id=auth_user,
        )

        yield self._generate_thumbnails(
            None, media_id, media_id, media_type,
        )

        defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))

    @defer.inlineCallbacks
    def get_local_media(self, request, media_id, name):
        """Responds to reqests for local media, if exists, or returns 404.

        Args:
            request(twisted.web.http.Request)
            media_id (str): The media ID of the content. (This is the same as
                the file_id for local content.)
            name (str|None): Optional name that, if specified, will be used as
                the filename in the Content-Disposition header of the response.

        Returns:
            Deferred: Resolves once a response has successfully been written
                to request
        """
        media_info = yield self.store.get_local_media(media_id)
        if not media_info or media_info["quarantined_by"]:
            respond_404(request)
            return

        self.mark_recently_accessed(None, media_id)

        media_type = media_info["media_type"]
        media_length = media_info["media_length"]
        upload_name = name if name else media_info["upload_name"]
        url_cache = media_info["url_cache"]

        file_info = FileInfo(
            None, media_id,
            url_cache=url_cache,
        )

        responder = yield self.media_storage.fetch_media(file_info)
        yield respond_with_responder(
            request, responder, media_type, media_length, upload_name,
        )

    @defer.inlineCallbacks
    def get_remote_media(self, request, server_name, media_id, name):
        """Respond to requests for remote media.

        Args:
            request(twisted.web.http.Request)
            server_name (str): Remote server_name where the media originated.
            media_id (str): The media ID of the content (as defined by the
                remote server).
            name (str|None): Optional name that, if specified, will be used as
                the filename in the Content-Disposition header of the response.

        Returns:
            Deferred: Resolves once a response has successfully been written
                to request
        """
        if (
            self.federation_domain_whitelist is not None and
            server_name not in self.federation_domain_whitelist
        ):
            raise FederationDeniedError(server_name)

        self.mark_recently_accessed(server_name, media_id)

        # We linearize here to ensure that we don't try and download remote
        # media multiple times concurrently
        key = (server_name, media_id)
        with (yield self.remote_media_linearizer.queue(key)):
            responder, media_info = yield self._get_remote_media_impl(
                server_name, media_id,
            )

        # We deliberately stream the file outside the lock
        if responder:
            media_type = media_info["media_type"]
            media_length = media_info["media_length"]
            upload_name = name if name else media_info["upload_name"]
            yield respond_with_responder(
                request, responder, media_type, media_length, upload_name,
            )
        else:
            respond_404(request)

    @defer.inlineCallbacks
    def get_remote_media_info(self, server_name, media_id):
        """Gets the media info associated with the remote file, downloading
        if necessary.

        Args:
            server_name (str): Remote server_name where the media originated.
            media_id (str): The media ID of the content (as defined by the
                remote server).

        Returns:
            Deferred[dict]: The media_info of the file
        """
        if (
            self.federation_domain_whitelist is not None and
            server_name not in self.federation_domain_whitelist
        ):
            raise FederationDeniedError(server_name)

        # We linearize here to ensure that we don't try and download remote
        # media multiple times concurrently
        key = (server_name, media_id)
        with (yield self.remote_media_linearizer.queue(key)):
            responder, media_info = yield self._get_remote_media_impl(
                server_name, media_id,
            )

        # Ensure we actually use the responder so that it releases resources
        if responder:
            with responder:
                pass

        defer.returnValue(media_info)

    @defer.inlineCallbacks
    def _get_remote_media_impl(self, server_name, media_id):
        """Looks for media in local cache, if not there then attempt to
        download from remote server.

        Args:
            server_name (str): Remote server_name where the media originated.
            media_id (str): The media ID of the content (as defined by the
                remote server).

        Returns:
            Deferred[(Responder, media_info)]
        """
        media_info = yield self.store.get_cached_remote_media(
            server_name, media_id
        )

        # file_id is the ID we use to track the file locally. If we've already
        # seen the file then reuse the existing ID, otherwise genereate a new
        # one.
        if media_info:
            file_id = media_info["filesystem_id"]
        else:
            file_id = random_string(24)

        file_info = FileInfo(server_name, file_id)

        # If we have an entry in the DB, try and look for it
        if media_info:
            if media_info["quarantined_by"]:
                logger.info("Media is quarantined")
                raise NotFoundError()

            responder = yield self.media_storage.fetch_media(file_info)
            if responder:
                defer.returnValue((responder, media_info))

        # Failed to find the file anywhere, lets download it.

        media_info = yield self._download_remote_file(
            server_name, media_id, file_id
        )

        responder = yield self.media_storage.fetch_media(file_info)
        defer.returnValue((responder, media_info))

    @defer.inlineCallbacks
    def _download_remote_file(self, server_name, media_id, file_id):
        """Attempt to download the remote file from the given server name,
        using the given file_id as the local id.

        Args:
            server_name (str): Originating server
            media_id (str): The media ID of the content (as defined by the
                remote server). This is different than the file_id, which is
                locally generated.
            file_id (str): Local file ID

        Returns:
            Deferred[MediaInfo]
        """

        file_info = FileInfo(
            server_name=server_name,
            file_id=file_id,
        )

        with self.media_storage.store_into_file(file_info) as (f, fname, finish):
            request_path = "/".join((
                "/_matrix/media/v1/download", server_name, media_id,
            ))
            try:
                length, headers = yield self.client.get_file(
                    server_name, request_path, output_stream=f,
                    max_size=self.max_upload_size, args={
                        # tell the remote server to 404 if it doesn't
                        # recognise the server_name, to make sure we don't
                        # end up with a routing loop.
                        "allow_remote": "false",
                    }
                )
            except RequestSendFailed as e:
                logger.warn("Request failed fetching remote media %s/%s: %r",
                            server_name, media_id, e)
                raise SynapseError(502, "Failed to fetch remote media")

            except HttpResponseException as e:
                logger.warn("HTTP error fetching remote media %s/%s: %s",
                            server_name, media_id, e.response)
                if e.code == twisted.web.http.NOT_FOUND:
                    raise e.to_synapse_error()
                raise SynapseError(502, "Failed to fetch remote media")

            except SynapseError:
                logger.exception("Failed to fetch remote media %s/%s",
                                 server_name, media_id)
                raise
            except NotRetryingDestination:
                logger.warn("Not retrying destination %r", server_name)
                raise SynapseError(502, "Failed to fetch remote media")
            except Exception:
                logger.exception("Failed to fetch remote media %s/%s",
                                 server_name, media_id)
                raise SynapseError(502, "Failed to fetch remote media")

            yield finish()

        media_type = headers[b"Content-Type"][0].decode('ascii')
        upload_name = get_filename_from_headers(headers)
        time_now_ms = self.clock.time_msec()

        logger.info("Stored remote media in file %r", fname)

        yield self.store.store_cached_remote_media(
            origin=server_name,
            media_id=media_id,
            media_type=media_type,
            time_now_ms=self.clock.time_msec(),
            upload_name=upload_name,
            media_length=length,
            filesystem_id=file_id,
        )

        media_info = {
            "media_type": media_type,
            "media_length": length,
            "upload_name": upload_name,
            "created_ts": time_now_ms,
            "filesystem_id": file_id,
        }

        yield self._generate_thumbnails(
            server_name, media_id, file_id, media_type,
        )

        defer.returnValue(media_info)

    def _get_thumbnail_requirements(self, media_type):
        return self.thumbnail_requirements.get(media_type, ())

    def _generate_thumbnail(self, thumbnailer, t_width, t_height,
                            t_method, t_type):
        m_width = thumbnailer.width
        m_height = thumbnailer.height

        if m_width * m_height >= self.max_image_pixels:
            logger.info(
                "Image too large to thumbnail %r x %r > %r",
                m_width, m_height, self.max_image_pixels
            )
            return

        if t_method == "crop":
            t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
        elif t_method == "scale":
            t_width, t_height = thumbnailer.aspect(t_width, t_height)
            t_width = min(m_width, t_width)
            t_height = min(m_height, t_height)
            t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
        else:
            t_byte_source = None

        return t_byte_source

    @defer.inlineCallbacks
    def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
                                       t_method, t_type, url_cache):
        input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
            None, media_id, url_cache=url_cache,
        ))

        thumbnailer = Thumbnailer(input_path)
        t_byte_source = yield logcontext.defer_to_thread(
            self.hs.get_reactor(),
            self._generate_thumbnail,
            thumbnailer, t_width, t_height, t_method, t_type
        )

        if t_byte_source:
            try:
                file_info = FileInfo(
                    server_name=None,
                    file_id=media_id,
                    url_cache=url_cache,
                    thumbnail=True,
                    thumbnail_width=t_width,
                    thumbnail_height=t_height,
                    thumbnail_method=t_method,
                    thumbnail_type=t_type,
                )

                output_path = yield self.media_storage.store_file(
                    t_byte_source, file_info,
                )
            finally:
                t_byte_source.close()

            logger.info("Stored thumbnail in file %r", output_path)

            t_len = os.path.getsize(output_path)

            yield self.store.store_local_thumbnail(
                media_id, t_width, t_height, t_type, t_method, t_len
            )

            defer.returnValue(output_path)

    @defer.inlineCallbacks
    def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
                                        t_width, t_height, t_method, t_type):
        input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
            server_name, file_id, url_cache=False,
        ))

        thumbnailer = Thumbnailer(input_path)
        t_byte_source = yield logcontext.defer_to_thread(
            self.hs.get_reactor(),
            self._generate_thumbnail,
            thumbnailer, t_width, t_height, t_method, t_type
        )

        if t_byte_source:
            try:
                file_info = FileInfo(
                    server_name=server_name,
                    file_id=media_id,
                    thumbnail=True,
                    thumbnail_width=t_width,
                    thumbnail_height=t_height,
                    thumbnail_method=t_method,
                    thumbnail_type=t_type,
                )

                output_path = yield self.media_storage.store_file(
                    t_byte_source, file_info,
                )
            finally:
                t_byte_source.close()

            logger.info("Stored thumbnail in file %r", output_path)

            t_len = os.path.getsize(output_path)

            yield self.store.store_remote_media_thumbnail(
                server_name, media_id, file_id,
                t_width, t_height, t_type, t_method, t_len
            )

            defer.returnValue(output_path)

    @defer.inlineCallbacks
    def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
                             url_cache=False):
        """Generate and store thumbnails for an image.

        Args:
            server_name (str|None): The server name if remote media, else None if local
            media_id (str): The media ID of the content. (This is the same as
                the file_id for local content)
            file_id (str): Local file ID
            media_type (str): The content type of the file
            url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
                used exclusively by the url previewer

        Returns:
            Deferred[dict]: Dict with "width" and "height" keys of original image
        """
        requirements = self._get_thumbnail_requirements(media_type)
        if not requirements:
            return

        input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
            server_name, file_id, url_cache=url_cache,
        ))

        thumbnailer = Thumbnailer(input_path)
        m_width = thumbnailer.width
        m_height = thumbnailer.height

        if m_width * m_height >= self.max_image_pixels:
            logger.info(
                "Image too large to thumbnail %r x %r > %r",
                m_width, m_height, self.max_image_pixels
            )
            return

        # We deduplicate the thumbnail sizes by ignoring the cropped versions if
        # they have the same dimensions of a scaled one.
        thumbnails = {}
        for r_width, r_height, r_method, r_type in requirements:
            if r_method == "crop":
                thumbnails.setdefault((r_width, r_height, r_type), r_method)
            elif r_method == "scale":
                t_width, t_height = thumbnailer.aspect(r_width, r_height)
                t_width = min(m_width, t_width)
                t_height = min(m_height, t_height)
                thumbnails[(t_width, t_height, r_type)] = r_method

        # Now we generate the thumbnails for each dimension, store it
        for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
            # Generate the thumbnail
            if t_method == "crop":
                t_byte_source = yield logcontext.defer_to_thread(
                    self.hs.get_reactor(),
                    thumbnailer.crop,
                    t_width, t_height, t_type,
                )
            elif t_method == "scale":
                t_byte_source = yield logcontext.defer_to_thread(
                    self.hs.get_reactor(),
                    thumbnailer.scale,
                    t_width, t_height, t_type,
                )
            else:
                logger.error("Unrecognized method: %r", t_method)
                continue

            if not t_byte_source:
                continue

            try:
                file_info = FileInfo(
                    server_name=server_name,
                    file_id=file_id,
                    thumbnail=True,
                    thumbnail_width=t_width,
                    thumbnail_height=t_height,
                    thumbnail_method=t_method,
                    thumbnail_type=t_type,
                    url_cache=url_cache,
                )

                output_path = yield self.media_storage.store_file(
                    t_byte_source, file_info,
                )
            finally:
                t_byte_source.close()

            t_len = os.path.getsize(output_path)

            # Write to database
            if server_name:
                yield self.store.store_remote_media_thumbnail(
                    server_name, media_id, file_id,
                    t_width, t_height, t_type, t_method, t_len
                )
            else:
                yield self.store.store_local_thumbnail(
                    media_id, t_width, t_height, t_type, t_method, t_len
                )

        defer.returnValue({
            "width": m_width,
            "height": m_height,
        })

    @defer.inlineCallbacks
    def delete_old_remote_media(self, before_ts):
        old_media = yield self.store.get_remote_media_before(before_ts)

        deleted = 0

        for media in old_media:
            origin = media["media_origin"]
            media_id = media["media_id"]
            file_id = media["filesystem_id"]
            key = (origin, media_id)

            logger.info("Deleting: %r", key)

            # TODO: Should we delete from the backup store

            with (yield self.remote_media_linearizer.queue(key)):
                full_path = self.filepaths.remote_media_filepath(origin, file_id)
                try:
                    os.remove(full_path)
                except OSError as e:
                    logger.warn("Failed to remove file: %r", full_path)
                    if e.errno == errno.ENOENT:
                        pass
                    else:
                        continue

                thumbnail_dir = self.filepaths.remote_media_thumbnail_dir(
                    origin, file_id
                )
                shutil.rmtree(thumbnail_dir, ignore_errors=True)

                yield self.store.delete_remote_media(origin, media_id)
                deleted += 1

        defer.returnValue({"deleted": deleted})
Пример #19
0
class RoomMemberHandler(object):
    # TODO(paul): This handler currently contains a messy conflation of
    #   low-level API that works on UserID objects and so on, and REST-level
    #   API that takes ID strings and returns pagination chunks. These concerns
    #   ought to be separated out a lot better.

    __metaclass__ = abc.ABCMeta

    def __init__(self, hs):
        """

        Args:
            hs (synapse.server.HomeServer):
        """
        self.hs = hs
        self.store = hs.get_datastore()
        self.auth = hs.get_auth()
        self.state_handler = hs.get_state_handler()
        self.config = hs.config
        self.simple_http_client = hs.get_simple_http_client()

        self.federation_handler = hs.get_handlers().federation_handler
        self.directory_handler = hs.get_handlers().directory_handler
        self.registration_handler = hs.get_handlers().registration_handler
        self.profile_handler = hs.get_profile_handler()
        self.event_creation_handler = hs.get_event_creation_handler()

        self.member_linearizer = Linearizer(name="member")

        self.clock = hs.get_clock()
        self.spam_checker = hs.get_spam_checker()
        self._server_notices_mxid = self.config.server_notices_mxid

    @abc.abstractmethod
    def _remote_join(self, requester, remote_room_hosts, room_id, user,
                     content):
        """Try and join a room that this server is not in

        Args:
            requester (Requester)
            remote_room_hosts (list[str]): List of servers that can be used
                to join via.
            room_id (str): Room that we are trying to join
            user (UserID): User who is trying to join
            content (dict): A dict that should be used as the content of the
                join event.

        Returns:
            Deferred
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def _remote_reject_invite(self, remote_room_hosts, room_id, target):
        """Attempt to reject an invite for a room this server is not in. If we
        fail to do so we locally mark the invite as rejected.

        Args:
            requester (Requester)
            remote_room_hosts (list[str]): List of servers to use to try and
                reject invite
            room_id (str)
            target (UserID): The user rejecting the invite

        Returns:
            Deferred[dict]: A dictionary to be returned to the client, may
            include event_id etc, or nothing if we locally rejected
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def get_or_register_3pid_guest(self, requester, medium, address,
                                   inviter_user_id):
        """Get a guest access token for a 3PID, creating a guest account if
        one doesn't already exist.

        Args:
            requester (Requester)
            medium (str)
            address (str)
            inviter_user_id (str): The user ID who is trying to invite the
                3PID

        Returns:
            Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
            3PID guest account.
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def _user_joined_room(self, target, room_id):
        """Notifies distributor on master process that the user has joined the
        room.

        Args:
            target (UserID)
            room_id (str)

        Returns:
            Deferred|None
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def _user_left_room(self, target, room_id):
        """Notifies distributor on master process that the user has left the
        room.

        Args:
            target (UserID)
            room_id (str)

        Returns:
            Deferred|None
        """
        raise NotImplementedError()

    @defer.inlineCallbacks
    def _local_membership_update(
        self,
        requester,
        target,
        room_id,
        membership,
        prev_events_and_hashes,
        txn_id=None,
        ratelimit=True,
        content=None,
    ):
        user_id = target.to_string()

        if content is None:
            content = {}

        content["membership"] = membership
        if requester.is_guest:
            content["kind"] = "guest"

        event, context = yield self.event_creation_handler.create_event(
            requester,
            {
                "type": EventTypes.Member,
                "content": content,
                "room_id": room_id,
                "sender": requester.user.to_string(),
                "state_key": user_id,

                # For backwards compatibility:
                "membership": membership,
            },
            token_id=requester.access_token_id,
            txn_id=txn_id,
            prev_events_and_hashes=prev_events_and_hashes,
        )

        # Check if this event matches the previous membership event for the user.
        duplicate = yield self.event_creation_handler.deduplicate_state_event(
            event,
            context,
        )
        if duplicate is not None:
            # Discard the new event since this membership change is a no-op.
            defer.returnValue(duplicate)

        yield self.event_creation_handler.handle_new_client_event(
            requester,
            event,
            context,
            extra_users=[target],
            ratelimit=ratelimit,
        )

        prev_state_ids = yield context.get_prev_state_ids(self.store)

        prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id),
                                                  None)

        if event.membership == Membership.JOIN:
            # Only fire user_joined_room if the user has actually joined the
            # room. Don't bother if the user is just changing their profile
            # info.
            newly_joined = True
            if prev_member_event_id:
                prev_member_event = yield self.store.get_event(
                    prev_member_event_id)
                newly_joined = prev_member_event.membership != Membership.JOIN
            if newly_joined:
                yield self._user_joined_room(target, room_id)

            # Copy over direct message status and room tags if this is a join
            # on an upgraded room

            # Check if this is an upgraded room
            predecessor = yield self.store.get_room_predecessor(room_id)

            if predecessor:
                # It is an upgraded room. Copy over old tags
                self.copy_room_tags_and_direct_to_room(
                    predecessor["room_id"],
                    room_id,
                    user_id,
                )
        elif event.membership == Membership.LEAVE:
            if prev_member_event_id:
                prev_member_event = yield self.store.get_event(
                    prev_member_event_id)
                if prev_member_event.membership == Membership.JOIN:
                    yield self._user_left_room(target, room_id)

        defer.returnValue(event)

    @defer.inlineCallbacks
    def copy_room_tags_and_direct_to_room(
        self,
        old_room_id,
        new_room_id,
        user_id,
    ):
        """Copies the tags and direct room state from one room to another.

        Args:
            old_room_id (str)
            new_room_id (str)
            user_id (str)

        Returns:
            Deferred[None]
        """
        # Retrieve user account data for predecessor room
        user_account_data, _ = yield self.store.get_account_data_for_user(
            user_id, )

        # Copy direct message state if applicable
        direct_rooms = user_account_data.get("m.direct", {})

        # Check which key this room is under
        if isinstance(direct_rooms, dict):
            for key, room_id_list in direct_rooms.items():
                if old_room_id in room_id_list and new_room_id not in room_id_list:
                    # Add new room_id to this key
                    direct_rooms[key].append(new_room_id)

                    # Save back to user's m.direct account data
                    yield self.store.add_account_data_for_user(
                        user_id,
                        "m.direct",
                        direct_rooms,
                    )
                    break

        # Copy room tags if applicable
        room_tags = yield self.store.get_tags_for_room(
            user_id,
            old_room_id,
        )

        # Copy each room tag to the new room
        for tag, tag_content in room_tags.items():
            yield self.store.add_tag_to_room(user_id, new_room_id, tag,
                                             tag_content)

    @defer.inlineCallbacks
    def update_membership(
        self,
        requester,
        target,
        room_id,
        action,
        txn_id=None,
        remote_room_hosts=None,
        third_party_signed=None,
        ratelimit=True,
        content=None,
    ):
        key = (room_id, )

        with (yield self.member_linearizer.queue(key)):
            result = yield self._update_membership(
                requester,
                target,
                room_id,
                action,
                txn_id=txn_id,
                remote_room_hosts=remote_room_hosts,
                third_party_signed=third_party_signed,
                ratelimit=ratelimit,
                content=content,
            )

        defer.returnValue(result)

    @defer.inlineCallbacks
    def _update_membership(
        self,
        requester,
        target,
        room_id,
        action,
        txn_id=None,
        remote_room_hosts=None,
        third_party_signed=None,
        ratelimit=True,
        content=None,
    ):
        content_specified = bool(content)
        if content is None:
            content = {}
        else:
            # We do a copy here as we potentially change some keys
            # later on.
            content = dict(content)

        effective_membership_state = action
        if action in ["kick", "unban"]:
            effective_membership_state = "leave"

        # if this is a join with a 3pid signature, we may need to turn a 3pid
        # invite into a normal invite before we can handle the join.
        if third_party_signed is not None:
            yield self.federation_handler.exchange_third_party_invite(
                third_party_signed["sender"],
                target.to_string(),
                room_id,
                third_party_signed,
            )

        if not remote_room_hosts:
            remote_room_hosts = []

        if effective_membership_state not in (
                "leave",
                "ban",
        ):
            is_blocked = yield self.store.is_room_blocked(room_id)
            if is_blocked:
                raise SynapseError(
                    403, "This room has been blocked on this server")

        if effective_membership_state == Membership.INVITE:
            # block any attempts to invite the server notices mxid
            if target.to_string() == self._server_notices_mxid:
                raise SynapseError(
                    http_client.FORBIDDEN,
                    "Cannot invite this user",
                )

            block_invite = False

            if (self._server_notices_mxid is not None and
                    requester.user.to_string() == self._server_notices_mxid):
                # allow the server notices mxid to send invites
                is_requester_admin = True

            else:
                is_requester_admin = yield self.auth.is_server_admin(
                    requester.user, )

            if not is_requester_admin:
                if self.config.block_non_admin_invites:
                    logger.info(
                        "Blocking invite: user is not admin and non-admin "
                        "invites disabled")
                    block_invite = True

                if not self.spam_checker.user_may_invite(
                        requester.user.to_string(),
                        target.to_string(),
                        room_id,
                ):
                    logger.info("Blocking invite due to spam checker")
                    block_invite = True

            if block_invite:
                raise SynapseError(
                    403,
                    "Invites have been disabled on this server",
                )

        prev_events_and_hashes = yield self.store.get_prev_events_for_room(
            room_id, )
        latest_event_ids = (event_id
                            for (event_id, _, _) in prev_events_and_hashes)

        current_state_ids = yield self.state_handler.get_current_state_ids(
            room_id,
            latest_event_ids=latest_event_ids,
        )

        old_state_id = current_state_ids.get(
            (EventTypes.Member, target.to_string()))
        if old_state_id:
            old_state = yield self.store.get_event(old_state_id,
                                                   allow_none=True)
            old_membership = old_state.content.get(
                "membership") if old_state else None
            if action == "unban" and old_membership != "ban":
                raise SynapseError(403,
                                   "Cannot unban user who was not banned"
                                   " (membership=%s)" % old_membership,
                                   errcode=Codes.BAD_STATE)
            if old_membership == "ban" and action != "unban":
                raise SynapseError(403,
                                   "Cannot %s user who was banned" %
                                   (action, ),
                                   errcode=Codes.BAD_STATE)

            if old_state:
                same_content = content == old_state.content
                same_membership = old_membership == effective_membership_state
                same_sender = requester.user.to_string() == old_state.sender
                if same_sender and same_membership and same_content:
                    defer.returnValue(old_state)

            # we don't allow people to reject invites to the server notice
            # room, but they can leave it once they are joined.
            if (old_membership == Membership.INVITE
                    and effective_membership_state == Membership.LEAVE):
                is_blocked = yield self._is_server_notice_room(room_id)
                if is_blocked:
                    raise SynapseError(
                        http_client.FORBIDDEN,
                        "You cannot reject this invite",
                        errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM,
                    )

        is_host_in_room = yield self._is_host_in_room(current_state_ids)

        if effective_membership_state == Membership.JOIN:
            if requester.is_guest:
                guest_can_join = yield self._can_guest_join(current_state_ids)
                if not guest_can_join:
                    # This should be an auth check, but guests are a local concept,
                    # so don't really fit into the general auth process.
                    raise AuthError(403, "Guest access not allowed")

            if not is_host_in_room:
                inviter = yield self._get_inviter(target.to_string(), room_id)
                if inviter and not self.hs.is_mine(inviter):
                    remote_room_hosts.append(inviter.domain)

                content["membership"] = Membership.JOIN

                profile = self.profile_handler
                if not content_specified:
                    content["displayname"] = yield profile.get_displayname(
                        target)
                    content["avatar_url"] = yield profile.get_avatar_url(
                        target)

                if requester.is_guest:
                    content["kind"] = "guest"

                ret = yield self._remote_join(requester, remote_room_hosts,
                                              room_id, target, content)
                defer.returnValue(ret)

        elif effective_membership_state == Membership.LEAVE:
            if not is_host_in_room:
                # perhaps we've been invited
                inviter = yield self._get_inviter(target.to_string(), room_id)
                if not inviter:
                    raise SynapseError(404, "Not a known room")

                if self.hs.is_mine(inviter):
                    # the inviter was on our server, but has now left. Carry on
                    # with the normal rejection codepath.
                    #
                    # This is a bit of a hack, because the room might still be
                    # active on other servers.
                    pass
                else:
                    # send the rejection to the inviter's HS.
                    remote_room_hosts = remote_room_hosts + [inviter.domain]
                    res = yield self._remote_reject_invite(
                        requester,
                        remote_room_hosts,
                        room_id,
                        target,
                    )
                    defer.returnValue(res)

        res = yield self._local_membership_update(
            requester=requester,
            target=target,
            room_id=room_id,
            membership=effective_membership_state,
            txn_id=txn_id,
            ratelimit=ratelimit,
            prev_events_and_hashes=prev_events_and_hashes,
            content=content,
        )
        defer.returnValue(res)

    @defer.inlineCallbacks
    def send_membership_event(
        self,
        requester,
        event,
        context,
        remote_room_hosts=None,
        ratelimit=True,
    ):
        """
        Change the membership status of a user in a room.

        Args:
            requester (Requester): The local user who requested the membership
                event. If None, certain checks, like whether this homeserver can
                act as the sender, will be skipped.
            event (SynapseEvent): The membership event.
            context: The context of the event.
            is_guest (bool): Whether the sender is a guest.
            room_hosts ([str]): Homeservers which are likely to already be in
                the room, and could be danced with in order to join this
                homeserver for the first time.
            ratelimit (bool): Whether to rate limit this request.
        Raises:
            SynapseError if there was a problem changing the membership.
        """
        remote_room_hosts = remote_room_hosts or []

        target_user = UserID.from_string(event.state_key)
        room_id = event.room_id

        if requester is not None:
            sender = UserID.from_string(event.sender)
            assert sender == requester.user, (
                "Sender (%s) must be same as requester (%s)" %
                (sender, requester.user))
            assert self.hs.is_mine(
                sender), "Sender must be our own: %s" % (sender, )
        else:
            requester = synapse.types.create_requester(target_user)

        prev_event = yield self.event_creation_handler.deduplicate_state_event(
            event,
            context,
        )
        if prev_event is not None:
            return

        prev_state_ids = yield context.get_prev_state_ids(self.store)
        if event.membership == Membership.JOIN:
            if requester.is_guest:
                guest_can_join = yield self._can_guest_join(prev_state_ids)
                if not guest_can_join:
                    # This should be an auth check, but guests are a local concept,
                    # so don't really fit into the general auth process.
                    raise AuthError(403, "Guest access not allowed")

        if event.membership not in (Membership.LEAVE, Membership.BAN):
            is_blocked = yield self.store.is_room_blocked(room_id)
            if is_blocked:
                raise SynapseError(
                    403, "This room has been blocked on this server")

        yield self.event_creation_handler.handle_new_client_event(
            requester,
            event,
            context,
            extra_users=[target_user],
            ratelimit=ratelimit,
        )

        prev_member_event_id = prev_state_ids.get(
            (EventTypes.Member, event.state_key), None)

        if event.membership == Membership.JOIN:
            # Only fire user_joined_room if the user has actually joined the
            # room. Don't bother if the user is just changing their profile
            # info.
            newly_joined = True
            if prev_member_event_id:
                prev_member_event = yield self.store.get_event(
                    prev_member_event_id)
                newly_joined = prev_member_event.membership != Membership.JOIN
            if newly_joined:
                yield self._user_joined_room(target_user, room_id)
        elif event.membership == Membership.LEAVE:
            if prev_member_event_id:
                prev_member_event = yield self.store.get_event(
                    prev_member_event_id)
                if prev_member_event.membership == Membership.JOIN:
                    yield self._user_left_room(target_user, room_id)

    @defer.inlineCallbacks
    def _can_guest_join(self, current_state_ids):
        """
        Returns whether a guest can join a room based on its current state.
        """
        guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""),
                                                None)
        if not guest_access_id:
            defer.returnValue(False)

        guest_access = yield self.store.get_event(guest_access_id)

        defer.returnValue(
            guest_access and guest_access.content
            and "guest_access" in guest_access.content
            and guest_access.content["guest_access"] == "can_join")

    @defer.inlineCallbacks
    def lookup_room_alias(self, room_alias):
        """
        Get the room ID associated with a room alias.

        Args:
            room_alias (RoomAlias): The alias to look up.
        Returns:
            A tuple of:
                The room ID as a RoomID object.
                Hosts likely to be participating in the room ([str]).
        Raises:
            SynapseError if room alias could not be found.
        """
        directory_handler = self.directory_handler
        mapping = yield directory_handler.get_association(room_alias)

        if not mapping:
            raise SynapseError(404, "No such room alias")

        room_id = mapping["room_id"]
        servers = mapping["servers"]

        # put the server which owns the alias at the front of the server list.
        if room_alias.domain in servers:
            servers.remove(room_alias.domain)
        servers.insert(0, room_alias.domain)

        defer.returnValue((RoomID.from_string(room_id), servers))

    @defer.inlineCallbacks
    def _get_inviter(self, user_id, room_id):
        invite = yield self.store.get_invite_for_user_in_room(
            user_id=user_id,
            room_id=room_id,
        )
        if invite:
            defer.returnValue(UserID.from_string(invite.sender))

    @defer.inlineCallbacks
    def do_3pid_invite(self, room_id, inviter, medium, address, id_server,
                       requester, txn_id):
        if self.config.block_non_admin_invites:
            is_requester_admin = yield self.auth.is_server_admin(
                requester.user, )
            if not is_requester_admin:
                raise SynapseError(
                    403,
                    "Invites have been disabled on this server",
                    Codes.FORBIDDEN,
                )

        invitee = yield self._lookup_3pid(id_server, medium, address)

        if invitee:
            yield self.update_membership(
                requester,
                UserID.from_string(invitee),
                room_id,
                "invite",
                txn_id=txn_id,
            )
        else:
            yield self._make_and_store_3pid_invite(requester,
                                                   id_server,
                                                   medium,
                                                   address,
                                                   room_id,
                                                   inviter,
                                                   txn_id=txn_id)

    @defer.inlineCallbacks
    def _lookup_3pid(self, id_server, medium, address):
        """Looks up a 3pid in the passed identity server.

        Args:
            id_server (str): The server name (including port, if required)
                of the identity server to use.
            medium (str): The type of the third party identifier (e.g. "email").
            address (str): The third party identifier (e.g. "*****@*****.**").

        Returns:
            str: the matrix ID of the 3pid, or None if it is not recognized.
        """
        try:
            data = yield self.simple_http_client.get_json(
                "%s%s/_matrix/identity/api/v1/lookup" % (
                    id_server_scheme,
                    id_server,
                ), {
                    "medium": medium,
                    "address": address,
                })

            if "mxid" in data:
                if "signatures" not in data:
                    raise AuthError(401, "No signatures on 3pid binding")
                yield self._verify_any_signature(data, id_server)
                defer.returnValue(data["mxid"])

        except IOError as e:
            logger.warn("Error from identity server lookup: %s" % (e, ))
            defer.returnValue(None)

    @defer.inlineCallbacks
    def _verify_any_signature(self, data, server_hostname):
        if server_hostname not in data["signatures"]:
            raise AuthError(
                401, "No signature from server %s" % (server_hostname, ))
        for key_name, signature in data["signatures"][server_hostname].items():
            key_data = yield self.simple_http_client.get_json(
                "%s%s/_matrix/identity/api/v1/pubkey/%s" % (
                    id_server_scheme,
                    server_hostname,
                    key_name,
                ), )
            if "public_key" not in key_data:
                raise AuthError(
                    401, "No public key named %s from %s" % (
                        key_name,
                        server_hostname,
                    ))
            verify_signed_json(
                data, server_hostname,
                decode_verify_key_bytes(key_name,
                                        decode_base64(key_data["public_key"])))
            return

    @defer.inlineCallbacks
    def _make_and_store_3pid_invite(self, requester, id_server, medium,
                                    address, room_id, user, txn_id):
        room_state = yield self.state_handler.get_current_state(room_id)

        inviter_display_name = ""
        inviter_avatar_url = ""
        member_event = room_state.get((EventTypes.Member, user.to_string()))
        if member_event:
            inviter_display_name = member_event.content.get("displayname", "")
            inviter_avatar_url = member_event.content.get("avatar_url", "")

        # if user has no display name, default to their MXID
        if not inviter_display_name:
            inviter_display_name = user.to_string()

        canonical_room_alias = ""
        canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, ""))
        if canonical_alias_event:
            canonical_room_alias = canonical_alias_event.content.get(
                "alias", "")

        room_name = ""
        room_name_event = room_state.get((EventTypes.Name, ""))
        if room_name_event:
            room_name = room_name_event.content.get("name", "")

        room_join_rules = ""
        join_rules_event = room_state.get((EventTypes.JoinRules, ""))
        if join_rules_event:
            room_join_rules = join_rules_event.content.get("join_rule", "")

        room_avatar_url = ""
        room_avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
        if room_avatar_event:
            room_avatar_url = room_avatar_event.content.get("url", "")

        token, public_keys, fallback_public_key, display_name = (
            yield self._ask_id_server_for_third_party_invite(
                requester=requester,
                id_server=id_server,
                medium=medium,
                address=address,
                room_id=room_id,
                inviter_user_id=user.to_string(),
                room_alias=canonical_room_alias,
                room_avatar_url=room_avatar_url,
                room_join_rules=room_join_rules,
                room_name=room_name,
                inviter_display_name=inviter_display_name,
                inviter_avatar_url=inviter_avatar_url))

        yield self.event_creation_handler.create_and_send_nonmember_event(
            requester,
            {
                "type": EventTypes.ThirdPartyInvite,
                "content": {
                    "display_name": display_name,
                    "public_keys": public_keys,

                    # For backwards compatibility:
                    "key_validity_url":
                    fallback_public_key["key_validity_url"],
                    "public_key": fallback_public_key["public_key"],
                },
                "room_id": room_id,
                "sender": user.to_string(),
                "state_key": token,
            },
            txn_id=txn_id,
        )

    @defer.inlineCallbacks
    def _ask_id_server_for_third_party_invite(self, requester, id_server,
                                              medium, address, room_id,
                                              inviter_user_id, room_alias,
                                              room_avatar_url, room_join_rules,
                                              room_name, inviter_display_name,
                                              inviter_avatar_url):
        """
        Asks an identity server for a third party invite.

        Args:
            requester (Requester)
            id_server (str): hostname + optional port for the identity server.
            medium (str): The literal string "email".
            address (str): The third party address being invited.
            room_id (str): The ID of the room to which the user is invited.
            inviter_user_id (str): The user ID of the inviter.
            room_alias (str): An alias for the room, for cosmetic notifications.
            room_avatar_url (str): The URL of the room's avatar, for cosmetic
                notifications.
            room_join_rules (str): The join rules of the email (e.g. "public").
            room_name (str): The m.room.name of the room.
            inviter_display_name (str): The current display name of the
                inviter.
            inviter_avatar_url (str): The URL of the inviter's avatar.

        Returns:
            A deferred tuple containing:
                token (str): The token which must be signed to prove authenticity.
                public_keys ([{"public_key": str, "key_validity_url": str}]):
                    public_key is a base64-encoded ed25519 public key.
                fallback_public_key: One element from public_keys.
                display_name (str): A user-friendly name to represent the invited
                    user.
        """

        is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
            id_server_scheme,
            id_server,
        )

        invite_config = {
            "medium": medium,
            "address": address,
            "room_id": room_id,
            "room_alias": room_alias,
            "room_avatar_url": room_avatar_url,
            "room_join_rules": room_join_rules,
            "room_name": room_name,
            "sender": inviter_user_id,
            "sender_display_name": inviter_display_name,
            "sender_avatar_url": inviter_avatar_url,
        }

        if self.config.invite_3pid_guest:
            guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest(
                requester=requester,
                medium=medium,
                address=address,
                inviter_user_id=inviter_user_id,
            )

            invite_config.update({
                "guest_access_token": guest_access_token,
                "guest_user_id": guest_user_id,
            })

        data = yield self.simple_http_client.post_urlencoded_get_json(
            is_url, invite_config)
        # TODO: Check for success
        token = data["token"]
        public_keys = data.get("public_keys", [])
        if "public_key" in data:
            fallback_public_key = {
                "public_key":
                data["public_key"],
                "key_validity_url":
                "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
                    id_server_scheme,
                    id_server,
                ),
            }
        else:
            fallback_public_key = public_keys[0]

        if not public_keys:
            public_keys.append(fallback_public_key)
        display_name = data["display_name"]
        defer.returnValue(
            (token, public_keys, fallback_public_key, display_name))

    @defer.inlineCallbacks
    def _is_host_in_room(self, current_state_ids):
        # Have we just created the room, and is this about to be the very
        # first member event?
        create_event_id = current_state_ids.get(("m.room.create", ""))
        if len(current_state_ids) == 1 and create_event_id:
            # We can only get here if we're in the process of creating the room
            defer.returnValue(True)

        for etype, state_key in current_state_ids:
            if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
                continue

            event_id = current_state_ids[(etype, state_key)]
            event = yield self.store.get_event(event_id, allow_none=True)
            if not event:
                continue

            if event.membership == Membership.JOIN:
                defer.returnValue(True)

        defer.returnValue(False)

    @defer.inlineCallbacks
    def _is_server_notice_room(self, room_id):
        if self._server_notices_mxid is None:
            defer.returnValue(False)
        user_ids = yield self.store.get_users_in_room(room_id)
        defer.returnValue(self._server_notices_mxid in user_ids)
Пример #20
0
class E2eRoomKeysHandler(object):
    """
    Implements an optional realtime backup mechanism for encrypted E2E megolm room keys.
    This gives a way for users to store and recover their megolm keys if they lose all
    their clients. It should also extend easily to future room key mechanisms.
    The actual payload of the encrypted keys is completely opaque to the handler.
    """

    def __init__(self, hs):
        self.store = hs.get_datastore()

        # Used to lock whenever a client is uploading key data.  This prevents collisions
        # between clients trying to upload the details of a new session, given all
        # clients belonging to a user will receive and try to upload a new session at
        # roughly the same time.  Also used to lock out uploads when the key is being
        # changed.
        self._upload_linearizer = Linearizer("upload_room_keys_lock")

    @defer.inlineCallbacks
    def get_room_keys(self, user_id, version, room_id=None, session_id=None):
        """Bulk get the E2E room keys for a given backup, optionally filtered to a given
        room, or a given session.
        See EndToEndRoomKeyStore.get_e2e_room_keys for full details.

        Args:
            user_id(str): the user whose keys we're getting
            version(str): the version ID of the backup we're getting keys from
            room_id(string): room ID to get keys for, for None to get keys for all rooms
            session_id(string): session ID to get keys for, for None to get keys for all
                sessions
        Raises:
            NotFoundError: if the backup version does not exist
        Returns:
            A deferred list of dicts giving the session_data and message metadata for
            these room keys.
        """

        # we deliberately take the lock to get keys so that changing the version
        # works atomically
        with (yield self._upload_linearizer.queue(user_id)):
            # make sure the backup version exists
            try:
                yield self.store.get_e2e_room_keys_version_info(user_id, version)
            except StoreError as e:
                if e.code == 404:
                    raise NotFoundError("Unknown backup version")
                else:
                    raise

            results = yield self.store.get_e2e_room_keys(
                user_id, version, room_id, session_id
            )

            defer.returnValue(results)

    @defer.inlineCallbacks
    def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
        """Bulk delete the E2E room keys for a given backup, optionally filtered to a given
        room or a given session.
        See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.

        Args:
            user_id(str): the user whose backup we're deleting
            version(str): the version ID of the backup we're deleting
            room_id(string): room ID to delete keys for, for None to delete keys for all
                rooms
            session_id(string): session ID to delete keys for, for None to delete keys
                for all sessions
        Returns:
            A deferred of the deletion transaction
        """

        # lock for consistency with uploading
        with (yield self._upload_linearizer.queue(user_id)):
            yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)

    @defer.inlineCallbacks
    def upload_room_keys(self, user_id, version, room_keys):
        """Bulk upload a list of room keys into a given backup version, asserting
        that the given version is the current backup version.  room_keys are merged
        into the current backup as described in RoomKeysServlet.on_PUT().

        Args:
            user_id(str): the user whose backup we're setting
            version(str): the version ID of the backup we're updating
            room_keys(dict): a nested dict describing the room_keys we're setting:

        {
            "rooms": {
                "!abc:matrix.org": {
                    "sessions": {
                        "c0ff33": {
                            "first_message_index": 1,
                            "forwarded_count": 1,
                            "is_verified": false,
                            "session_data": "SSBBTSBBIEZJU0gK"
                        }
                    }
                }
            }
        }

        Raises:
            NotFoundError: if there are no versions defined
            RoomKeysVersionError: if the uploaded version is not the current version
        """

        # TODO: Validate the JSON to make sure it has the right keys.

        # XXX: perhaps we should use a finer grained lock here?
        with (yield self._upload_linearizer.queue(user_id)):

            # Check that the version we're trying to upload is the current version
            try:
                version_info = yield self.store.get_e2e_room_keys_version_info(user_id)
            except StoreError as e:
                if e.code == 404:
                    raise NotFoundError("Version '%s' not found" % (version,))
                else:
                    raise

            if version_info['version'] != version:
                # Check that the version we're trying to upload actually exists
                try:
                    version_info = yield self.store.get_e2e_room_keys_version_info(
                        user_id, version,
                    )
                    # if we get this far, the version must exist
                    raise RoomKeysVersionError(current_version=version_info['version'])
                except StoreError as e:
                    if e.code == 404:
                        raise NotFoundError("Version '%s' not found" % (version,))
                    else:
                        raise

            # go through the room_keys.
            # XXX: this should/could be done concurrently, given we're in a lock.
            for room_id, room in iteritems(room_keys['rooms']):
                for session_id, session in iteritems(room['sessions']):
                    yield self._upload_room_key(
                        user_id, version, room_id, session_id, session
                    )

    @defer.inlineCallbacks
    def _upload_room_key(self, user_id, version, room_id, session_id, room_key):
        """Upload a given room_key for a given room and session into a given
        version of the backup.  Merges the key with any which might already exist.

        Args:
            user_id(str): the user whose backup we're setting
            version(str): the version ID of the backup we're updating
            room_id(str): the ID of the room whose keys we're setting
            session_id(str): the session whose room_key we're setting
            room_key(dict): the room_key being set
        """

        # get the room_key for this particular row
        current_room_key = None
        try:
            current_room_key = yield self.store.get_e2e_room_key(
                user_id, version, room_id, session_id
            )
        except StoreError as e:
            if e.code == 404:
                pass
            else:
                raise

        if self._should_replace_room_key(current_room_key, room_key):
            yield self.store.set_e2e_room_key(
                user_id, version, room_id, session_id, room_key
            )

    @staticmethod
    def _should_replace_room_key(current_room_key, room_key):
        """
        Determine whether to replace a given current_room_key (if any)
        with a newly uploaded room_key backup

        Args:
            current_room_key (dict): Optional, the current room_key dict if any
            room_key (dict): The new room_key dict which may or may not be fit to
                replace the current_room_key

        Returns:
            True if current_room_key should be replaced by room_key in the backup
        """

        if current_room_key:
            # spelt out with if/elifs rather than nested boolean expressions
            # purely for legibility.

            if room_key['is_verified'] and not current_room_key['is_verified']:
                return True
            elif (
                room_key['first_message_index'] <
                current_room_key['first_message_index']
            ):
                return True
            elif room_key['forwarded_count'] < current_room_key['forwarded_count']:
                return True
            else:
                return False
        return True

    @defer.inlineCallbacks
    def create_version(self, user_id, version_info):
        """Create a new backup version.  This automatically becomes the new
        backup version for the user's keys; previous backups will no longer be
        writeable to.

        Args:
            user_id(str): the user whose backup version we're creating
            version_info(dict): metadata about the new version being created

        {
            "algorithm": "m.megolm_backup.v1",
            "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
        }

        Returns:
            A deferred of a string that gives the new version number.
        """

        # TODO: Validate the JSON to make sure it has the right keys.

        # lock everyone out until we've switched version
        with (yield self._upload_linearizer.queue(user_id)):
            new_version = yield self.store.create_e2e_room_keys_version(
                user_id, version_info
            )
            defer.returnValue(new_version)

    @defer.inlineCallbacks
    def get_version_info(self, user_id, version=None):
        """Get the info about a given version of the user's backup

        Args:
            user_id(str): the user whose current backup version we're querying
            version(str): Optional; if None gives the most recent version
                otherwise a historical one.
        Raises:
            NotFoundError: if the requested backup version doesn't exist
        Returns:
            A deferred of a info dict that gives the info about the new version.

        {
            "version": "1234",
            "algorithm": "m.megolm_backup.v1",
            "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
        }
        """

        with (yield self._upload_linearizer.queue(user_id)):
            try:
                res = yield self.store.get_e2e_room_keys_version_info(user_id, version)
            except StoreError as e:
                if e.code == 404:
                    raise NotFoundError("Unknown backup version")
                else:
                    raise
            defer.returnValue(res)

    @defer.inlineCallbacks
    def delete_version(self, user_id, version=None):
        """Deletes a given version of the user's e2e_room_keys backup

        Args:
            user_id(str): the user whose current backup version we're deleting
            version(str): the version id of the backup being deleted
        Raises:
            NotFoundError: if this backup version doesn't exist
        """

        with (yield self._upload_linearizer.queue(user_id)):
            try:
                yield self.store.delete_e2e_room_keys_version(user_id, version)
            except StoreError as e:
                if e.code == 404:
                    raise NotFoundError("Unknown backup version")
                else:
                    raise

    @defer.inlineCallbacks
    def update_version(self, user_id, version, version_info):
        """Update the info about a given version of the user's backup

        Args:
            user_id(str): the user whose current backup version we're updating
            version(str): the backup version we're updating
            version_info(dict): the new information about the backup
        Raises:
            NotFoundError: if the requested backup version doesn't exist
        Returns:
            A deferred of an empty dict.
        """
        if "version" not in version_info:
            raise SynapseError(
                400,
                "Missing version in body",
                Codes.MISSING_PARAM
            )
        if version_info["version"] != version:
            raise SynapseError(
                400,
                "Version in body does not match",
                Codes.INVALID_PARAM
            )
        with (yield self._upload_linearizer.queue(user_id)):
            try:
                old_info = yield self.store.get_e2e_room_keys_version_info(
                    user_id, version
                )
            except StoreError as e:
                if e.code == 404:
                    raise NotFoundError("Unknown backup version")
                else:
                    raise
            if old_info["algorithm"] != version_info["algorithm"]:
                raise SynapseError(
                    400,
                    "Algorithm does not match",
                    Codes.INVALID_PARAM
                )

            yield self.store.update_e2e_room_keys_version(user_id, version, version_info)

            defer.returnValue({})
Пример #21
0
class StateResolutionHandler(object):
    """Responsible for doing state conflict resolution.

    Note that the storage layer depends on this handler, so all functions must
    be storage-independent.
    """
    def __init__(self, hs):
        self.clock = hs.get_clock()

        # dict of set of event_ids -> _StateCacheEntry.
        self._state_cache = None
        self.resolve_linearizer = Linearizer(name="state_resolve_lock")

        self._state_cache = ExpiringCache(
            cache_name="state_cache",
            clock=self.clock,
            max_len=SIZE_OF_CACHE,
            expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
            iterable=True,
            reset_expiry_on_get=True,
        )

    @defer.inlineCallbacks
    @log_function
    def resolve_state_groups(
        self, room_id, room_version, state_groups_ids, event_map, state_res_store,
    ):
        """Resolves conflicts between a set of state groups

        Always generates a new state group (unless we hit the cache), so should
        not be called for a single state group

        Args:
            room_id (str): room we are resolving for (used for logging)
            room_version (str): version of the room
            state_groups_ids (dict[int, dict[(str, str), str]]):
                 map from state group id to the state in that state group
                (where 'state' is a map from state key to event id)

            event_map(dict[str,FrozenEvent]|None):
                a dict from event_id to event, for any events that we happen to
                have in flight (eg, those currently being persisted). This will be
                used as a starting point fof finding the state we need; any missing
                events will be requested via state_res_store.

                If None, all events will be fetched via state_res_store.

            state_res_store (StateResolutionStore)

        Returns:
            Deferred[_StateCacheEntry]: resolved state
        """
        logger.debug(
            "resolve_state_groups state_groups %s",
            state_groups_ids.keys()
        )

        group_names = frozenset(state_groups_ids.keys())

        with (yield self.resolve_linearizer.queue(group_names)):
            if self._state_cache is not None:
                cache = self._state_cache.get(group_names, None)
                if cache:
                    defer.returnValue(cache)

            logger.info(
                "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
            )

            # start by assuming we won't have any conflicted state, and build up the new
            # state map by iterating through the state groups. If we discover a conflict,
            # we give up and instead use `resolve_events_with_store`.
            #
            # XXX: is this actually worthwhile, or should we just let
            # resolve_events_with_store do it?
            new_state = {}
            conflicted_state = False
            for st in itervalues(state_groups_ids):
                for key, e_id in iteritems(st):
                    if key in new_state:
                        conflicted_state = True
                        break
                    new_state[key] = e_id
                if conflicted_state:
                    break

            if conflicted_state:
                logger.info("Resolving conflicted state for %r", room_id)
                with Measure(self.clock, "state._resolve_events"):
                    new_state = yield resolve_events_with_store(
                        room_version,
                        list(itervalues(state_groups_ids)),
                        event_map=event_map,
                        state_res_store=state_res_store,
                    )

            # if the new state matches any of the input state groups, we can
            # use that state group again. Otherwise we will generate a state_id
            # which will be used as a cache key for future resolutions, but
            # not get persisted.

            with Measure(self.clock, "state.create_group_ids"):
                cache = _make_state_cache_entry(new_state, state_groups_ids)

            if self._state_cache is not None:
                self._state_cache[group_names] = cache

            defer.returnValue(cache)
Пример #22
0
class SigningKeyEduUpdater(object):
    """Handles incoming signing key updates from federation and updates the DB"""
    def __init__(self, hs, e2e_keys_handler):
        self.store = hs.get_datastore()
        self.federation = hs.get_federation_client()
        self.clock = hs.get_clock()
        self.e2e_keys_handler = e2e_keys_handler

        self._remote_edu_linearizer = Linearizer(name="remote_signing_key")

        # user_id -> list of updates waiting to be handled.
        self._pending_updates = {}

        # Recently seen stream ids. We don't bother keeping these in the DB,
        # but they're useful to have them about to reduce the number of spurious
        # resyncs.
        self._seen_updates = ExpiringCache(
            cache_name="signing_key_update_edu",
            clock=self.clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
            iterable=True,
        )

    @defer.inlineCallbacks
    def incoming_signing_key_update(self, origin, edu_content):
        """Called on incoming signing key update from federation. Responsible for
        parsing the EDU and adding to pending updates list.

        Args:
            origin (string): the server that sent the EDU
            edu_content (dict): the contents of the EDU
        """

        user_id = edu_content.pop("user_id")
        master_key = edu_content.pop("master_key", None)
        self_signing_key = edu_content.pop("self_signing_key", None)

        if get_domain_from_id(user_id) != origin:
            logger.warning("Got signing key update edu for %r from %r",
                           user_id, origin)
            return

        room_ids = yield self.store.get_rooms_for_user(user_id)
        if not room_ids:
            # We don't share any rooms with this user. Ignore update, as we
            # probably won't get any further updates.
            return

        self._pending_updates.setdefault(user_id, []).append(
            (master_key, self_signing_key))

        yield self._handle_signing_key_updates(user_id)

    @defer.inlineCallbacks
    def _handle_signing_key_updates(self, user_id):
        """Actually handle pending updates.

        Args:
            user_id (string): the user whose updates we are processing
        """

        device_handler = self.e2e_keys_handler.device_handler

        with (yield self._remote_edu_linearizer.queue(user_id)):
            pending_updates = self._pending_updates.pop(user_id, [])
            if not pending_updates:
                # This can happen since we batch updates
                return

            device_ids = []

            logger.info("pending updates: %r", pending_updates)

            for master_key, self_signing_key in pending_updates:
                if master_key:
                    yield self.store.set_e2e_cross_signing_key(
                        user_id, "master", master_key)
                    _, verify_key = get_verify_key_from_cross_signing_key(
                        master_key)
                    # verify_key is a VerifyKey from signedjson, which uses
                    # .version to denote the portion of the key ID after the
                    # algorithm and colon, which is the device ID
                    device_ids.append(verify_key.version)
                if self_signing_key:
                    yield self.store.set_e2e_cross_signing_key(
                        user_id, "self_signing", self_signing_key)
                    _, verify_key = get_verify_key_from_cross_signing_key(
                        self_signing_key)
                    device_ids.append(verify_key.version)

            yield device_handler.notify_device_update(user_id, device_ids)
Пример #23
0
class E2eRoomKeysHandler(object):
    """
    Implements an optional realtime backup mechanism for encrypted E2E megolm room keys.
    This gives a way for users to store and recover their megolm keys if they lose all
    their clients. It should also extend easily to future room key mechanisms.
    The actual payload of the encrypted keys is completely opaque to the handler.
    """

    def __init__(self, hs):
        self.store = hs.get_datastore()

        # Used to lock whenever a client is uploading key data.  This prevents collisions
        # between clients trying to upload the details of a new session, given all
        # clients belonging to a user will receive and try to upload a new session at
        # roughly the same time.  Also used to lock out uploads when the key is being
        # changed.
        self._upload_linearizer = Linearizer("upload_room_keys_lock")

    @defer.inlineCallbacks
    def get_room_keys(self, user_id, version, room_id=None, session_id=None):
        """Bulk get the E2E room keys for a given backup, optionally filtered to a given
        room, or a given session.
        See EndToEndRoomKeyStore.get_e2e_room_keys for full details.

        Args:
            user_id(str): the user whose keys we're getting
            version(str): the version ID of the backup we're getting keys from
            room_id(string): room ID to get keys for, for None to get keys for all rooms
            session_id(string): session ID to get keys for, for None to get keys for all
                sessions
        Returns:
            A deferred list of dicts giving the session_data and message metadata for
            these room keys.
        """

        # we deliberately take the lock to get keys so that changing the version
        # works atomically
        with (yield self._upload_linearizer.queue(user_id)):
            results = yield self.store.get_e2e_room_keys(
                user_id, version, room_id, session_id
            )

            if results['rooms'] == {}:
                raise SynapseError(404, "No room_keys found")

            defer.returnValue(results)

    @defer.inlineCallbacks
    def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
        """Bulk delete the E2E room keys for a given backup, optionally filtered to a given
        room or a given session.
        See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.

        Args:
            user_id(str): the user whose backup we're deleting
            version(str): the version ID of the backup we're deleting
            room_id(string): room ID to delete keys for, for None to delete keys for all
                rooms
            session_id(string): session ID to delete keys for, for None to delete keys
                for all sessions
        Returns:
            A deferred of the deletion transaction
        """

        # lock for consistency with uploading
        with (yield self._upload_linearizer.queue(user_id)):
            yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)

    @defer.inlineCallbacks
    def upload_room_keys(self, user_id, version, room_keys):
        """Bulk upload a list of room keys into a given backup version, asserting
        that the given version is the current backup version.  room_keys are merged
        into the current backup as described in RoomKeysServlet.on_PUT().

        Args:
            user_id(str): the user whose backup we're setting
            version(str): the version ID of the backup we're updating
            room_keys(dict): a nested dict describing the room_keys we're setting:

        {
            "rooms": {
                "!abc:matrix.org": {
                    "sessions": {
                        "c0ff33": {
                            "first_message_index": 1,
                            "forwarded_count": 1,
                            "is_verified": false,
                            "session_data": "SSBBTSBBIEZJU0gK"
                        }
                    }
                }
            }
        }

        Raises:
            SynapseError: with code 404 if there are no versions defined
            RoomKeysVersionError: if the uploaded version is not the current version
        """

        # TODO: Validate the JSON to make sure it has the right keys.

        # XXX: perhaps we should use a finer grained lock here?
        with (yield self._upload_linearizer.queue(user_id)):

            # Check that the version we're trying to upload is the current version
            try:
                version_info = yield self.store.get_e2e_room_keys_version_info(user_id)
            except StoreError as e:
                if e.code == 404:
                    raise SynapseError(404, "Version '%s' not found" % (version,))
                else:
                    raise

            if version_info['version'] != version:
                # Check that the version we're trying to upload actually exists
                try:
                    version_info = yield self.store.get_e2e_room_keys_version_info(
                        user_id, version,
                    )
                    # if we get this far, the version must exist
                    raise RoomKeysVersionError(current_version=version_info['version'])
                except StoreError as e:
                    if e.code == 404:
                        raise SynapseError(404, "Version '%s' not found" % (version,))
                    else:
                        raise

            # go through the room_keys.
            # XXX: this should/could be done concurrently, given we're in a lock.
            for room_id, room in iteritems(room_keys['rooms']):
                for session_id, session in iteritems(room['sessions']):
                    yield self._upload_room_key(
                        user_id, version, room_id, session_id, session
                    )

    @defer.inlineCallbacks
    def _upload_room_key(self, user_id, version, room_id, session_id, room_key):
        """Upload a given room_key for a given room and session into a given
        version of the backup.  Merges the key with any which might already exist.

        Args:
            user_id(str): the user whose backup we're setting
            version(str): the version ID of the backup we're updating
            room_id(str): the ID of the room whose keys we're setting
            session_id(str): the session whose room_key we're setting
            room_key(dict): the room_key being set
        """

        # get the room_key for this particular row
        current_room_key = None
        try:
            current_room_key = yield self.store.get_e2e_room_key(
                user_id, version, room_id, session_id
            )
        except StoreError as e:
            if e.code == 404:
                pass
            else:
                raise

        if self._should_replace_room_key(current_room_key, room_key):
            yield self.store.set_e2e_room_key(
                user_id, version, room_id, session_id, room_key
            )

    @staticmethod
    def _should_replace_room_key(current_room_key, room_key):
        """
        Determine whether to replace a given current_room_key (if any)
        with a newly uploaded room_key backup

        Args:
            current_room_key (dict): Optional, the current room_key dict if any
            room_key (dict): The new room_key dict which may or may not be fit to
                replace the current_room_key

        Returns:
            True if current_room_key should be replaced by room_key in the backup
        """

        if current_room_key:
            # spelt out with if/elifs rather than nested boolean expressions
            # purely for legibility.

            if room_key['is_verified'] and not current_room_key['is_verified']:
                return True
            elif (
                room_key['first_message_index'] <
                current_room_key['first_message_index']
            ):
                return True
            elif room_key['forwarded_count'] < current_room_key['forwarded_count']:
                return True
            else:
                return False
        return True

    @defer.inlineCallbacks
    def create_version(self, user_id, version_info):
        """Create a new backup version.  This automatically becomes the new
        backup version for the user's keys; previous backups will no longer be
        writeable to.

        Args:
            user_id(str): the user whose backup version we're creating
            version_info(dict): metadata about the new version being created

        {
            "algorithm": "m.megolm_backup.v1",
            "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
        }

        Returns:
            A deferred of a string that gives the new version number.
        """

        # TODO: Validate the JSON to make sure it has the right keys.

        # lock everyone out until we've switched version
        with (yield self._upload_linearizer.queue(user_id)):
            new_version = yield self.store.create_e2e_room_keys_version(
                user_id, version_info
            )
            defer.returnValue(new_version)

    @defer.inlineCallbacks
    def get_version_info(self, user_id, version=None):
        """Get the info about a given version of the user's backup

        Args:
            user_id(str): the user whose current backup version we're querying
            version(str): Optional; if None gives the most recent version
                otherwise a historical one.
        Raises:
            StoreError: code 404 if the requested backup version doesn't exist
        Returns:
            A deferred of a info dict that gives the info about the new version.

        {
            "version": "1234",
            "algorithm": "m.megolm_backup.v1",
            "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
        }
        """

        with (yield self._upload_linearizer.queue(user_id)):
            res = yield self.store.get_e2e_room_keys_version_info(user_id, version)
            defer.returnValue(res)

    @defer.inlineCallbacks
    def delete_version(self, user_id, version=None):
        """Deletes a given version of the user's e2e_room_keys backup

        Args:
            user_id(str): the user whose current backup version we're deleting
            version(str): the version id of the backup being deleted
        Raises:
            StoreError: code 404 if this backup version doesn't exist
        """

        with (yield self._upload_linearizer.queue(user_id)):
            yield self.store.delete_e2e_room_keys_version(user_id, version)
Пример #24
0
class RegistrationHandler(BaseHandler):

    def __init__(self, hs):
        """

        Args:
            hs (synapse.server.HomeServer):
        """
        super(RegistrationHandler, self).__init__(hs)
        self.hs = hs
        self.auth = hs.get_auth()
        self._auth_handler = hs.get_auth_handler()
        self.profile_handler = hs.get_profile_handler()
        self.user_directory_handler = hs.get_user_directory_handler()
        self.captcha_client = CaptchaServerHttpClient(hs)
        self.identity_handler = self.hs.get_handlers().identity_handler
        self.ratelimiter = hs.get_registration_ratelimiter()

        self._next_generated_user_id = None

        self.macaroon_gen = hs.get_macaroon_generator()

        self._generate_user_id_linearizer = Linearizer(
            name="_generate_user_id_linearizer",
        )
        self._server_notices_mxid = hs.config.server_notices_mxid

        if hs.config.worker_app:
            self._register_client = ReplicationRegisterServlet.make_client(hs)
            self._register_device_client = (
                RegisterDeviceReplicationServlet.make_client(hs)
            )
            self._post_registration_client = (
                ReplicationPostRegisterActionsServlet.make_client(hs)
            )
        else:
            self.device_handler = hs.get_device_handler()
            self.pusher_pool = hs.get_pusherpool()

    @defer.inlineCallbacks
    def check_username(self, localpart, guest_access_token=None,
                       assigned_user_id=None):
        if types.contains_invalid_mxid_characters(localpart):
            raise SynapseError(
                400,
                "User ID can only contain characters a-z, 0-9, or '=_-./'",
                Codes.INVALID_USERNAME
            )

        if not localpart:
            raise SynapseError(
                400,
                "User ID cannot be empty",
                Codes.INVALID_USERNAME
            )

        if localpart[0] == '_':
            raise SynapseError(
                400,
                "User ID may not begin with _",
                Codes.INVALID_USERNAME
            )

        user = UserID(localpart, self.hs.hostname)
        user_id = user.to_string()

        if assigned_user_id:
            if user_id == assigned_user_id:
                return
            else:
                raise SynapseError(
                    400,
                    "A different user ID has already been registered for this session",
                )

        self.check_user_id_not_appservice_exclusive(user_id)

        if len(user_id) > MAX_USERID_LENGTH:
            raise SynapseError(
                400,
                "User ID may not be longer than %s characters" % (
                    MAX_USERID_LENGTH,
                ),
                Codes.INVALID_USERNAME
            )

        users = yield self.store.get_users_by_id_case_insensitive(user_id)
        if users:
            if not guest_access_token:
                raise SynapseError(
                    400,
                    "User ID already taken.",
                    errcode=Codes.USER_IN_USE,
                )
            user_data = yield self.auth.get_user_by_access_token(guest_access_token)
            if not user_data["is_guest"] or user_data["user"].localpart != localpart:
                raise AuthError(
                    403,
                    "Cannot register taken user ID without valid guest "
                    "credentials for that user.",
                    errcode=Codes.FORBIDDEN,
                )

    @defer.inlineCallbacks
    def register(
        self,
        localpart=None,
        password=None,
        generate_token=True,
        guest_access_token=None,
        make_guest=False,
        admin=False,
        threepid=None,
        user_type=None,
        default_display_name=None,
        address=None,
        bind_emails=[],
    ):
        """Registers a new client on the server.

        Args:
            localpart : The local part of the user ID to register. If None,
              one will be generated.
            password (unicode) : The password to assign to this user so they can
              login again. This can be None which means they cannot login again
              via a password (e.g. the user is an application service user).
            generate_token (bool): Whether a new access token should be
              generated. Having this be True should be considered deprecated,
              since it offers no means of associating a device_id with the
              access_token. Instead you should call auth_handler.issue_access_token
              after registration.
            user_type (str|None): type of user. One of the values from
              api.constants.UserTypes, or None for a normal user.
            default_display_name (unicode|None): if set, the new user's displayname
              will be set to this. Defaults to 'localpart'.
            address (str|None): the IP address used to perform the registration.
            bind_emails (List[str]): list of emails to bind to this account.
        Returns:
            A tuple of (user_id, access_token).
        Raises:
            RegistrationError if there was a problem registering.
        """

        yield self.auth.check_auth_blocking(threepid=threepid)
        password_hash = None
        if password:
            password_hash = yield self._auth_handler.hash(password)

        if localpart:
            yield self.check_username(localpart, guest_access_token=guest_access_token)

            was_guest = guest_access_token is not None

            if not was_guest:
                try:
                    int(localpart)
                    raise RegistrationError(
                        400,
                        "Numeric user IDs are reserved for guest users."
                    )
                except ValueError:
                    pass

            user = UserID(localpart, self.hs.hostname)
            user_id = user.to_string()

            if was_guest:
                # If the user was a guest then they already have a profile
                default_display_name = None

            elif default_display_name is None:
                default_display_name = localpart

            token = None
            if generate_token:
                token = self.macaroon_gen.generate_access_token(user_id)
            yield self.register_with_store(
                user_id=user_id,
                token=token,
                password_hash=password_hash,
                was_guest=was_guest,
                make_guest=make_guest,
                create_profile_with_displayname=default_display_name,
                admin=admin,
                user_type=user_type,
                address=address,
            )

            if self.hs.config.user_directory_search_all_users:
                profile = yield self.store.get_profileinfo(localpart)
                yield self.user_directory_handler.handle_local_profile_change(
                    user_id, profile
                )

        else:
            # autogen a sequential user ID
            attempts = 0
            token = None
            user = None
            while not user:
                localpart = yield self._generate_user_id(attempts > 0)
                user = UserID(localpart, self.hs.hostname)
                user_id = user.to_string()
                yield self.check_user_id_not_appservice_exclusive(user_id)
                if generate_token:
                    token = self.macaroon_gen.generate_access_token(user_id)
                if default_display_name is None:
                    default_display_name = localpart
                try:
                    yield self.register_with_store(
                        user_id=user_id,
                        token=token,
                        password_hash=password_hash,
                        make_guest=make_guest,
                        create_profile_with_displayname=default_display_name,
                        address=address,
                    )
                except SynapseError:
                    # if user id is taken, just generate another
                    user = None
                    user_id = None
                    token = None
                    attempts += 1
        if not self.hs.config.user_consent_at_registration:
            yield self._auto_join_rooms(user_id)

        # Bind any specified emails to this account
        current_time = self.hs.get_clock().time_msec()
        for email in bind_emails:
            # generate threepid dict
            threepid_dict = {
                "medium": "email",
                "address": email,
                "validated_at": current_time,
            }

            # Bind email to new account
            yield self._register_email_threepid(
                user_id, threepid_dict, None, False,
            )

        defer.returnValue((user_id, token))

    @defer.inlineCallbacks
    def _auto_join_rooms(self, user_id):
        """Automatically joins users to auto join rooms - creating the room in the first place
        if the user is the first to be created.

        Args:
            user_id(str): The user to join
        """
        # auto-join the user to any rooms we're supposed to dump them into
        fake_requester = create_requester(user_id)

        # try to create the room if we're the first real user on the server. Note
        # that an auto-generated support user is not a real user and will never be
        # the user to create the room
        should_auto_create_rooms = False
        is_support = yield self.store.is_support_user(user_id)
        # There is an edge case where the first user is the support user, then
        # the room is never created, though this seems unlikely and
        # recoverable from given the support user being involved in the first
        # place.
        if self.hs.config.autocreate_auto_join_rooms and not is_support:
            count = yield self.store.count_all_users()
            should_auto_create_rooms = count == 1
        for r in self.hs.config.auto_join_rooms:
            try:
                if should_auto_create_rooms:
                    room_alias = RoomAlias.from_string(r)
                    if self.hs.hostname != room_alias.domain:
                        logger.warning(
                            'Cannot create room alias %s, '
                            'it does not match server domain',
                            r,
                        )
                    else:
                        # create room expects the localpart of the room alias
                        room_alias_localpart = room_alias.localpart

                        # getting the RoomCreationHandler during init gives a dependency
                        # loop
                        yield self.hs.get_room_creation_handler().create_room(
                            fake_requester,
                            config={
                                "preset": "public_chat",
                                "room_alias_name": room_alias_localpart
                            },
                            ratelimit=False,
                        )
                else:
                    yield self._join_user_to_room(fake_requester, r)
            except ConsentNotGivenError as e:
                # Technically not necessary to pull out this error though
                # moving away from bare excepts is a good thing to do.
                logger.error("Failed to join new user to %r: %r", r, e)
            except Exception as e:
                logger.error("Failed to join new user to %r: %r", r, e)

    @defer.inlineCallbacks
    def post_consent_actions(self, user_id):
        """A series of registration actions that can only be carried out once consent
        has been granted

        Args:
            user_id (str): The user to join
        """
        yield self._auto_join_rooms(user_id)

    @defer.inlineCallbacks
    def appservice_register(self, user_localpart, as_token):
        user = UserID(user_localpart, self.hs.hostname)
        user_id = user.to_string()
        service = self.store.get_app_service_by_token(as_token)
        if not service:
            raise AuthError(403, "Invalid application service token.")
        if not service.is_interested_in_user(user_id):
            raise SynapseError(
                400, "Invalid user localpart for this application service.",
                errcode=Codes.EXCLUSIVE
            )

        service_id = service.id if service.is_exclusive_user(user_id) else None

        yield self.check_user_id_not_appservice_exclusive(
            user_id, allowed_appservice=service
        )

        yield self.register_with_store(
            user_id=user_id,
            password_hash="",
            appservice_id=service_id,
            create_profile_with_displayname=user.localpart,
        )
        defer.returnValue(user_id)

    @defer.inlineCallbacks
    def check_recaptcha(self, ip, private_key, challenge, response):
        """
        Checks a recaptcha is correct.

        Used only by c/s api v1
        """

        captcha_response = yield self._validate_captcha(
            ip,
            private_key,
            challenge,
            response
        )
        if not captcha_response["valid"]:
            logger.info("Invalid captcha entered from %s. Error: %s",
                        ip, captcha_response["error_url"])
            raise InvalidCaptchaError(
                error_url=captcha_response["error_url"]
            )
        else:
            logger.info("Valid captcha entered from %s", ip)

    @defer.inlineCallbacks
    def register_email(self, threepidCreds):
        """
        Registers emails with an identity server.

        Used only by c/s api v1
        """

        for c in threepidCreds:
            logger.info("validating threepidcred sid %s on id server %s",
                        c['sid'], c['idServer'])
            try:
                threepid = yield self.identity_handler.threepid_from_creds(c)
            except Exception:
                logger.exception("Couldn't validate 3pid")
                raise RegistrationError(400, "Couldn't validate 3pid")

            if not threepid:
                raise RegistrationError(400, "Couldn't validate 3pid")
            logger.info("got threepid with medium '%s' and address '%s'",
                        threepid['medium'], threepid['address'])

            if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
                raise RegistrationError(
                    403, "Third party identifier is not allowed"
                )

    @defer.inlineCallbacks
    def bind_emails(self, user_id, threepidCreds):
        """Links emails with a user ID and informs an identity server.

        Used only by c/s api v1
        """

        # Now we have a matrix ID, bind it to the threepids we were given
        for c in threepidCreds:
            # XXX: This should be a deferred list, shouldn't it?
            yield self.identity_handler.bind_threepid(c, user_id)

    def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
        # don't allow people to register the server notices mxid
        if self._server_notices_mxid is not None:
            if user_id == self._server_notices_mxid:
                raise SynapseError(
                    400, "This user ID is reserved.",
                    errcode=Codes.EXCLUSIVE
                )

        # valid user IDs must not clash with any user ID namespaces claimed by
        # application services.
        services = self.store.get_app_services()
        interested_services = [
            s for s in services
            if s.is_interested_in_user(user_id)
            and s != allowed_appservice
        ]
        for service in interested_services:
            if service.is_exclusive_user(user_id):
                raise SynapseError(
                    400, "This user ID is reserved by an application service.",
                    errcode=Codes.EXCLUSIVE
                )

    @defer.inlineCallbacks
    def _generate_user_id(self, reseed=False):
        if reseed or self._next_generated_user_id is None:
            with (yield self._generate_user_id_linearizer.queue(())):
                if reseed or self._next_generated_user_id is None:
                    self._next_generated_user_id = (
                        yield self.store.find_next_generated_user_id_localpart()
                    )

        id = self._next_generated_user_id
        self._next_generated_user_id += 1
        defer.returnValue(str(id))

    @defer.inlineCallbacks
    def _validate_captcha(self, ip_addr, private_key, challenge, response):
        """Validates the captcha provided.

        Used only by c/s api v1

        Returns:
            dict: Containing 'valid'(bool) and 'error_url'(str) if invalid.

        """
        response = yield self._submit_captcha(ip_addr, private_key, challenge,
                                              response)
        # parse Google's response. Lovely format..
        lines = response.split('\n')
        json = {
            "valid": lines[0] == 'true',
            "error_url": "http://www.recaptcha.net/recaptcha/api/challenge?" +
                         "error=%s" % lines[1]
        }
        defer.returnValue(json)

    @defer.inlineCallbacks
    def _submit_captcha(self, ip_addr, private_key, challenge, response):
        """
        Used only by c/s api v1
        """
        data = yield self.captcha_client.post_urlencoded_get_raw(
            "http://www.recaptcha.net:80/recaptcha/api/verify",
            args={
                'privatekey': private_key,
                'remoteip': ip_addr,
                'challenge': challenge,
                'response': response
            }
        )
        defer.returnValue(data)

    @defer.inlineCallbacks
    def get_or_create_user(self, requester, localpart, displayname,
                           password_hash=None):
        """Creates a new user if the user does not exist,
        else revokes all previous access tokens and generates a new one.

        Args:
            localpart : The local part of the user ID to register. If None,
              one will be randomly generated.
        Returns:
            A tuple of (user_id, access_token).
        Raises:
            RegistrationError if there was a problem registering.
        """
        if localpart is None:
            raise SynapseError(400, "Request must include user id")
        yield self.auth.check_auth_blocking()
        need_register = True

        try:
            yield self.check_username(localpart)
        except SynapseError as e:
            if e.errcode == Codes.USER_IN_USE:
                need_register = False
            else:
                raise

        user = UserID(localpart, self.hs.hostname)
        user_id = user.to_string()
        token = self.macaroon_gen.generate_access_token(user_id)

        if need_register:
            yield self.register_with_store(
                user_id=user_id,
                token=token,
                password_hash=password_hash,
                create_profile_with_displayname=user.localpart,
            )
        else:
            yield self._auth_handler.delete_access_tokens_for_user(user_id)
            yield self.store.add_access_token_to_user(user_id=user_id, token=token)

        if displayname is not None:
            logger.info("setting user display name: %s -> %s", user_id, displayname)
            yield self.profile_handler.set_displayname(
                user, requester, displayname, by_admin=True,
            )

        defer.returnValue((user_id, token))

    @defer.inlineCallbacks
    def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
        """Get a guest access token for a 3PID, creating a guest account if
        one doesn't already exist.

        Args:
            medium (str)
            address (str)
            inviter_user_id (str): The user ID who is trying to invite the
                3PID

        Returns:
            Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
            3PID guest account.
        """
        access_token = yield self.store.get_3pid_guest_access_token(medium, address)
        if access_token:
            user_info = yield self.auth.get_user_by_access_token(
                access_token
            )

            defer.returnValue((user_info["user"].to_string(), access_token))

        user_id, access_token = yield self.register(
            generate_token=True,
            make_guest=True
        )
        access_token = yield self.store.save_or_get_3pid_guest_access_token(
            medium, address, access_token, inviter_user_id
        )

        defer.returnValue((user_id, access_token))

    @defer.inlineCallbacks
    def _join_user_to_room(self, requester, room_identifier):
        room_id = None
        room_member_handler = self.hs.get_room_member_handler()
        if RoomID.is_valid(room_identifier):
            room_id = room_identifier
        elif RoomAlias.is_valid(room_identifier):
            room_alias = RoomAlias.from_string(room_identifier)
            room_id, remote_room_hosts = (
                yield room_member_handler.lookup_room_alias(room_alias)
            )
            room_id = room_id.to_string()
        else:
            raise SynapseError(400, "%s was not legal room ID or room alias" % (
                room_identifier,
            ))

        yield room_member_handler.update_membership(
            requester=requester,
            target=requester.user,
            room_id=room_id,
            remote_room_hosts=remote_room_hosts,
            action="join",
            ratelimit=False,
        )

    def register_with_store(self, user_id, token=None, password_hash=None,
                            was_guest=False, make_guest=False, appservice_id=None,
                            create_profile_with_displayname=None, admin=False,
                            user_type=None, address=None):
        """Register user in the datastore.

        Args:
            user_id (str): The desired user ID to register.
            token (str): The desired access token to use for this user. If this
                is not None, the given access token is associated with the user
                id.
            password_hash (str|None): Optional. The password hash for this user.
            was_guest (bool): Optional. Whether this is a guest account being
                upgraded to a non-guest account.
            make_guest (boolean): True if the the new user should be guest,
                false to add a regular user account.
            appservice_id (str|None): The ID of the appservice registering the user.
            create_profile_with_displayname (unicode|None): Optionally create a
                profile for the user, setting their displayname to the given value
            admin (boolean): is an admin user?
            user_type (str|None): type of user. One of the values from
                api.constants.UserTypes, or None for a normal user.
            address (str|None): the IP address used to perform the registration.

        Returns:
            Deferred
        """
        # Don't rate limit for app services
        if appservice_id is None and address is not None:
            time_now = self.clock.time()

            allowed, time_allowed = self.ratelimiter.can_do_action(
                address, time_now_s=time_now,
                rate_hz=self.hs.config.rc_registration.per_second,
                burst_count=self.hs.config.rc_registration.burst_count,
            )

            if not allowed:
                raise LimitExceededError(
                    retry_after_ms=int(1000 * (time_allowed - time_now)),
                )

        if self.hs.config.worker_app:
            return self._register_client(
                user_id=user_id,
                token=token,
                password_hash=password_hash,
                was_guest=was_guest,
                make_guest=make_guest,
                appservice_id=appservice_id,
                create_profile_with_displayname=create_profile_with_displayname,
                admin=admin,
                user_type=user_type,
                address=address,
            )
        else:
            return self.store.register(
                user_id=user_id,
                token=token,
                password_hash=password_hash,
                was_guest=was_guest,
                make_guest=make_guest,
                appservice_id=appservice_id,
                create_profile_with_displayname=create_profile_with_displayname,
                admin=admin,
                user_type=user_type,
            )

    @defer.inlineCallbacks
    def register_device(self, user_id, device_id, initial_display_name,
                        is_guest=False):
        """Register a device for a user and generate an access token.

        Args:
            user_id (str): full canonical @user:id
            device_id (str|None): The device ID to check, or None to generate
                a new one.
            initial_display_name (str|None): An optional display name for the
                device.
            is_guest (bool): Whether this is a guest account

        Returns:
            defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
        """

        if self.hs.config.worker_app:
            r = yield self._register_device_client(
                user_id=user_id,
                device_id=device_id,
                initial_display_name=initial_display_name,
                is_guest=is_guest,
            )
            defer.returnValue((r["device_id"], r["access_token"]))
        else:
            device_id = yield self.device_handler.check_device_registered(
                user_id, device_id, initial_display_name
            )
            if is_guest:
                access_token = self.macaroon_gen.generate_access_token(
                    user_id, ["guest = true"]
                )
            else:
                access_token = yield self._auth_handler.get_access_token_for_user_id(
                    user_id, device_id=device_id,
                )

            defer.returnValue((device_id, access_token))

    @defer.inlineCallbacks
    def post_registration_actions(self, user_id, auth_result, access_token,
                                  bind_email, bind_msisdn):
        """A user has completed registration

        Args:
            user_id (str): The user ID that consented
            auth_result (dict): The authenticated credentials of the newly
                registered user.
            access_token (str|None): The access token of the newly logged in
                device, or None if `inhibit_login` enabled.
            bind_email (bool): Whether to bind the email with the identity
                server.
            bind_msisdn (bool): Whether to bind the msisdn with the identity
                server.
        """
        if self.hs.config.worker_app:
            yield self._post_registration_client(
                user_id=user_id,
                auth_result=auth_result,
                access_token=access_token,
                bind_email=bind_email,
                bind_msisdn=bind_msisdn,
            )
            return

        if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
            threepid = auth_result[LoginType.EMAIL_IDENTITY]
            # Necessary due to auth checks prior to the threepid being
            # written to the db
            if is_threepid_reserved(
                self.hs.config.mau_limits_reserved_threepids, threepid
            ):
                yield self.store.upsert_monthly_active_user(user_id)

            yield self._register_email_threepid(
                user_id, threepid, access_token,
                bind_email,
            )

        if auth_result and LoginType.MSISDN in auth_result:
            threepid = auth_result[LoginType.MSISDN]
            yield self._register_msisdn_threepid(
                user_id, threepid, bind_msisdn,
            )

        if auth_result and LoginType.TERMS in auth_result:
            yield self._on_user_consented(
                user_id, self.hs.config.user_consent_version,
            )

    @defer.inlineCallbacks
    def _on_user_consented(self, user_id, consent_version):
        """A user consented to the terms on registration

        Args:
            user_id (str): The user ID that consented.
            consent_version (str): version of the policy the user has
                consented to.
        """
        logger.info("%s has consented to the privacy policy", user_id)
        yield self.store.user_set_consent_version(
            user_id, consent_version,
        )
        yield self.post_consent_actions(user_id)

    @defer.inlineCallbacks
    def _register_email_threepid(self, user_id, threepid, token, bind_email):
        """Add an email address as a 3pid identifier

        Also adds an email pusher for the email address, if configured in the
        HS config

        Also optionally binds emails to the given user_id on the identity server

        Must be called on master.

        Args:
            user_id (str): id of user
            threepid (object): m.login.email.identity auth response
            token (str|None): access_token for the user, or None if not logged
                in.
            bind_email (bool): true if the client requested the email to be
                bound at the identity server
        Returns:
            defer.Deferred:
        """
        reqd = ('medium', 'address', 'validated_at')
        if any(x not in threepid for x in reqd):
            # This will only happen if the ID server returns a malformed response
            logger.info("Can't add incomplete 3pid")
            return

        yield self._auth_handler.add_threepid(
            user_id,
            threepid['medium'],
            threepid['address'],
            threepid['validated_at'],
        )

        # And we add an email pusher for them by default, but only
        # if email notifications are enabled (so people don't start
        # getting mail spam where they weren't before if email
        # notifs are set up on a home server)
        if (self.hs.config.email_enable_notifs and
                self.hs.config.email_notif_for_new_users
                and token):
            # Pull the ID of the access token back out of the db
            # It would really make more sense for this to be passed
            # up when the access token is saved, but that's quite an
            # invasive change I'd rather do separately.
            user_tuple = yield self.store.get_user_by_access_token(
                token
            )
            token_id = user_tuple["token_id"]

            yield self.pusher_pool.add_pusher(
                user_id=user_id,
                access_token=token_id,
                kind="email",
                app_id="m.email",
                app_display_name="Email Notifications",
                device_display_name=threepid["address"],
                pushkey=threepid["address"],
                lang=None,  # We don't know a user's language here
                data={},
            )

        if bind_email:
            logger.info("bind_email specified: binding")
            logger.debug("Binding emails %s to %s" % (
                threepid, user_id
            ))
            yield self.identity_handler.bind_threepid(
                threepid['threepid_creds'], user_id
            )
        else:
            logger.info("bind_email not specified: not binding email")

    @defer.inlineCallbacks
    def _register_msisdn_threepid(self, user_id, threepid, bind_msisdn):
        """Add a phone number as a 3pid identifier

        Also optionally binds msisdn to the given user_id on the identity server

        Must be called on master.

        Args:
            user_id (str): id of user
            threepid (object): m.login.msisdn auth response
            token (str): access_token for the user
            bind_email (bool): true if the client requested the email to be
                bound at the identity server
        Returns:
            defer.Deferred:
        """
        try:
            assert_params_in_dict(threepid, ['medium', 'address', 'validated_at'])
        except SynapseError as ex:
            if ex.errcode == Codes.MISSING_PARAM:
                # This will only happen if the ID server returns a malformed response
                logger.info("Can't add incomplete 3pid")
                defer.returnValue(None)
            raise

        yield self._auth_handler.add_threepid(
            user_id,
            threepid['medium'],
            threepid['address'],
            threepid['validated_at'],
        )

        if bind_msisdn:
            logger.info("bind_msisdn specified: binding")
            logger.debug("Binding msisdn %s to %s", threepid, user_id)
            yield self.identity_handler.bind_threepid(
                threepid['threepid_creds'], user_id
            )
        else:
            logger.info("bind_msisdn not specified: not binding msisdn")
Пример #25
0
class FederationSenderHandler(object):
    """Processes the replication stream and forwards the appropriate entries
    to the federation sender.
    """
    def __init__(self, hs, replication_client):
        self.store = hs.get_datastore()
        self._is_mine_id = hs.is_mine_id
        self.federation_sender = hs.get_federation_sender()
        self.replication_client = replication_client

        self.federation_position = self.store.federation_out_pos_startup
        self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")

        self._last_ack = self.federation_position

        self._room_serials = {}
        self._room_typing = {}

    def on_start(self):
        # There may be some events that are persisted but haven't been sent,
        # so send them now.
        self.federation_sender.notify_new_events(
            self.store.get_room_max_stream_ordering()
        )

    def stream_positions(self):
        return {"federation": self.federation_position}

    def process_replication_rows(self, stream_name, token, rows):
        # The federation stream contains things that we want to send out, e.g.
        # presence, typing, etc.
        if stream_name == "federation":
            send_queue.process_rows_for_federation(self.federation_sender, rows)
            run_in_background(self.update_token, token)

        # We also need to poke the federation sender when new events happen
        elif stream_name == "events":
            self.federation_sender.notify_new_events(token)

        # ... and when new receipts happen
        elif stream_name == ReceiptsStream.NAME:
            run_as_background_process(
                "process_receipts_for_federation", self._on_new_receipts, rows,
            )

    @defer.inlineCallbacks
    def _on_new_receipts(self, rows):
        """
        Args:
            rows (iterable[synapse.replication.tcp.streams.ReceiptsStreamRow]):
                new receipts to be processed
        """
        for receipt in rows:
            # we only want to send on receipts for our own users
            if not self._is_mine_id(receipt.user_id):
                continue
            receipt_info = ReadReceipt(
                receipt.room_id,
                receipt.receipt_type,
                receipt.user_id,
                [receipt.event_id],
                receipt.data,
            )
            yield self.federation_sender.send_read_receipt(receipt_info)

    @defer.inlineCallbacks
    def update_token(self, token):
        try:
            self.federation_position = token

            # We linearize here to ensure we don't have races updating the token
            with (yield self._fed_position_linearizer.queue(None)):
                if self._last_ack < self.federation_position:
                    yield self.store.update_federation_out_pos(
                        "federation", self.federation_position
                    )

                    # We ACK this token over replication so that the master can drop
                    # its in memory queues
                    self.replication_client.send_federation_ack(self.federation_position)
                    self._last_ack = self.federation_position
        except Exception:
            logger.exception("Error updating federation stream position")
Пример #26
0
class _JoinedHostsCache(object):
    """Cache for joined hosts in a room that is optimised to handle updates
    via state deltas.
    """

    def __init__(self, store, room_id):
        self.store = store
        self.room_id = room_id

        self.hosts_to_joined_users = {}

        self.state_group = object()

        self.linearizer = Linearizer("_JoinedHostsCache")

        self._len = 0

    @defer.inlineCallbacks
    def get_destinations(self, state_entry):
        """Get set of destinations for a state entry

        Args:
            state_entry(synapse.state._StateCacheEntry)
        """
        if state_entry.state_group == self.state_group:
            defer.returnValue(frozenset(self.hosts_to_joined_users))

        with (yield self.linearizer.queue(())):
            if state_entry.state_group == self.state_group:
                pass
            elif state_entry.prev_group == self.state_group:
                for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
                    if typ != EventTypes.Member:
                        continue

                    host = intern_string(get_domain_from_id(state_key))
                    user_id = state_key
                    known_joins = self.hosts_to_joined_users.setdefault(host, set())

                    event = yield self.store.get_event(event_id)
                    if event.membership == Membership.JOIN:
                        known_joins.add(user_id)
                    else:
                        known_joins.discard(user_id)

                        if not known_joins:
                            self.hosts_to_joined_users.pop(host, None)
            else:
                joined_users = yield self.store.get_joined_users_from_state(
                    self.room_id, state_entry
                )

                self.hosts_to_joined_users = {}
                for user_id in joined_users:
                    host = intern_string(get_domain_from_id(user_id))
                    self.hosts_to_joined_users.setdefault(host, set()).add(user_id)

            if state_entry.state_group:
                self.state_group = state_entry.state_group
            else:
                self.state_group = object()
            self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
        defer.returnValue(frozenset(self.hosts_to_joined_users))

    def __len__(self):
        return self._len
Пример #27
0
class E2eRoomKeysHandler:
    """
    Implements an optional realtime backup mechanism for encrypted E2E megolm room keys.
    This gives a way for users to store and recover their megolm keys if they lose all
    their clients. It should also extend easily to future room key mechanisms.
    The actual payload of the encrypted keys is completely opaque to the handler.
    """
    def __init__(self, hs: "HomeServer"):
        self.store = hs.get_datastores().main

        # Used to lock whenever a client is uploading key data.  This prevents collisions
        # between clients trying to upload the details of a new session, given all
        # clients belonging to a user will receive and try to upload a new session at
        # roughly the same time.  Also used to lock out uploads when the key is being
        # changed.
        self._upload_linearizer = Linearizer("upload_room_keys_lock")

    @trace
    async def get_room_keys(
        self,
        user_id: str,
        version: str,
        room_id: Optional[str] = None,
        session_id: Optional[str] = None,
    ) -> Dict[Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[
            str, RoomKey]]]]:
        """Bulk get the E2E room keys for a given backup, optionally filtered to a given
        room, or a given session.
        See EndToEndRoomKeyStore.get_e2e_room_keys for full details.

        Args:
            user_id: the user whose keys we're getting
            version: the version ID of the backup we're getting keys from
            room_id: room ID to get keys for, for None to get keys for all rooms
            session_id: session ID to get keys for, for None to get keys for all
                sessions
        Raises:
            NotFoundError: if the backup version does not exist
        Returns:
            A dict giving the session_data and message metadata for these room keys.
            `{"rooms": {room_id: {"sessions": {session_id: room_key}}}}`
        """

        # we deliberately take the lock to get keys so that changing the version
        # works atomically
        async with self._upload_linearizer.queue(user_id):
            # make sure the backup version exists
            try:
                await self.store.get_e2e_room_keys_version_info(
                    user_id, version)
            except StoreError as e:
                if e.code == 404:
                    raise NotFoundError("Unknown backup version")
                else:
                    raise

            results = await self.store.get_e2e_room_keys(
                user_id, version, room_id, session_id)

            log_kv(results)
            return results

    @trace
    async def delete_room_keys(
        self,
        user_id: str,
        version: str,
        room_id: Optional[str] = None,
        session_id: Optional[str] = None,
    ) -> JsonDict:
        """Bulk delete the E2E room keys for a given backup, optionally filtered to a given
        room or a given session.
        See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.

        Args:
            user_id: the user whose backup we're deleting
            version: the version ID of the backup we're deleting
            room_id: room ID to delete keys for, for None to delete keys for all
                rooms
            session_id: session ID to delete keys for, for None to delete keys
                for all sessions
        Raises:
            NotFoundError: if the backup version does not exist
        Returns:
            A dict containing the count and etag for the backup version
        """

        # lock for consistency with uploading
        async with self._upload_linearizer.queue(user_id):
            # make sure the backup version exists
            try:
                version_info = await self.store.get_e2e_room_keys_version_info(
                    user_id, version)
            except StoreError as e:
                if e.code == 404:
                    raise NotFoundError("Unknown backup version")
                else:
                    raise

            await self.store.delete_e2e_room_keys(user_id, version, room_id,
                                                  session_id)

            version_etag = version_info["etag"] + 1
            await self.store.update_e2e_room_keys_version(
                user_id, version, None, version_etag)

            count = await self.store.count_e2e_room_keys(user_id, version)
            return {"etag": str(version_etag), "count": count}

    @trace
    async def upload_room_keys(self, user_id: str, version: str,
                               room_keys: JsonDict) -> JsonDict:
        """Bulk upload a list of room keys into a given backup version, asserting
        that the given version is the current backup version.  room_keys are merged
        into the current backup as described in RoomKeysServlet.on_PUT().

        Args:
            user_id: the user whose backup we're setting
            version: the version ID of the backup we're updating
            room_keys: a nested dict describing the room_keys we're setting:

        {
            "rooms": {
                "!abc:matrix.org": {
                    "sessions": {
                        "c0ff33": {
                            "first_message_index": 1,
                            "forwarded_count": 1,
                            "is_verified": false,
                            "session_data": "SSBBTSBBIEZJU0gK"
                        }
                    }
                }
            }
        }

        Returns:
            A dict containing the count and etag for the backup version

        Raises:
            NotFoundError: if there are no versions defined
            RoomKeysVersionError: if the uploaded version is not the current version
        """

        # TODO: Validate the JSON to make sure it has the right keys.

        # XXX: perhaps we should use a finer grained lock here?
        async with self._upload_linearizer.queue(user_id):

            # Check that the version we're trying to upload is the current version
            try:
                version_info = await self.store.get_e2e_room_keys_version_info(
                    user_id)
            except StoreError as e:
                if e.code == 404:
                    raise NotFoundError("Version '%s' not found" % (version, ))
                else:
                    raise

            if version_info["version"] != version:
                # Check that the version we're trying to upload actually exists
                try:
                    version_info = await self.store.get_e2e_room_keys_version_info(
                        user_id, version)
                    # if we get this far, the version must exist
                    raise RoomKeysVersionError(
                        current_version=version_info["version"])
                except StoreError as e:
                    if e.code == 404:
                        raise NotFoundError("Version '%s' not found" %
                                            (version, ))
                    else:
                        raise

            # Fetch any existing room keys for the sessions that have been
            # submitted.  Then compare them with the submitted keys.  If the
            # key is new, insert it; if the key should be updated, then update
            # it; otherwise, drop it.
            existing_keys = await self.store.get_e2e_room_keys_multi(
                user_id, version, room_keys["rooms"])
            to_insert = []  # batch the inserts together
            changed = False  # if anything has changed, we need to update the etag
            for room_id, room in room_keys["rooms"].items():
                for session_id, room_key in room["sessions"].items():
                    if not isinstance(room_key["is_verified"], bool):
                        msg = (
                            "is_verified must be a boolean in keys for session %s in"
                            "room %s" % (session_id, room_id))
                        raise SynapseError(400, msg, Codes.INVALID_PARAM)

                    log_kv({
                        "message": "Trying to upload room key",
                        "room_id": room_id,
                        "session_id": session_id,
                        "user_id": user_id,
                    })
                    current_room_key = existing_keys.get(room_id,
                                                         {}).get(session_id)
                    if current_room_key:
                        if self._should_replace_room_key(
                                current_room_key, room_key):
                            log_kv({"message": "Replacing room key."})
                            # updates are done one at a time in the DB, so send
                            # updates right away rather than batching them up,
                            # like we do with the inserts
                            await self.store.update_e2e_room_key(
                                user_id, version, room_id, session_id,
                                room_key)
                            changed = True
                        else:
                            log_kv({"message": "Not replacing room_key."})
                    else:
                        log_kv({
                            "message": "Room key not found.",
                            "room_id": room_id,
                            "user_id": user_id,
                        })
                        log_kv({"message": "Replacing room key."})
                        to_insert.append((room_id, session_id, room_key))
                        changed = True

            if len(to_insert):
                await self.store.add_e2e_room_keys(user_id, version, to_insert)

            version_etag = version_info["etag"]
            if changed:
                version_etag = version_etag + 1
                await self.store.update_e2e_room_keys_version(
                    user_id, version, None, version_etag)

            count = await self.store.count_e2e_room_keys(user_id, version)
            return {"etag": str(version_etag), "count": count}

    @staticmethod
    def _should_replace_room_key(current_room_key: Optional[RoomKey],
                                 room_key: RoomKey) -> bool:
        """
        Determine whether to replace a given current_room_key (if any)
        with a newly uploaded room_key backup

        Args:
            current_room_key: Optional, the current room_key dict if any
            room_key : The new room_key dict which may or may not be fit to
                replace the current_room_key

        Returns:
            True if current_room_key should be replaced by room_key in the backup
        """

        if current_room_key:
            # spelt out with if/elifs rather than nested boolean expressions
            # purely for legibility.

            if room_key["is_verified"] and not current_room_key["is_verified"]:
                return True
            elif (room_key["first_message_index"] <
                  current_room_key["first_message_index"]):
                return True
            elif room_key["forwarded_count"] < current_room_key[
                    "forwarded_count"]:
                return True
            else:
                return False
        return True

    @trace
    async def create_version(self, user_id: str,
                             version_info: JsonDict) -> str:
        """Create a new backup version.  This automatically becomes the new
        backup version for the user's keys; previous backups will no longer be
        writeable to.

        Args:
            user_id: the user whose backup version we're creating
            version_info: metadata about the new version being created

        {
            "algorithm": "m.megolm_backup.v1",
            "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
        }

        Returns:
            The new version number.
        """

        # TODO: Validate the JSON to make sure it has the right keys.

        # lock everyone out until we've switched version
        async with self._upload_linearizer.queue(user_id):
            new_version = await self.store.create_e2e_room_keys_version(
                user_id, version_info)
            return new_version

    async def get_version_info(self,
                               user_id: str,
                               version: Optional[str] = None) -> JsonDict:
        """Get the info about a given version of the user's backup

        Args:
            user_id: the user whose current backup version we're querying
            version: Optional; if None gives the most recent version
                otherwise a historical one.
        Raises:
            NotFoundError: if the requested backup version doesn't exist
        Returns:
            A info dict that gives the info about the new version.

        {
            "version": "1234",
            "algorithm": "m.megolm_backup.v1",
            "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
        }
        """

        async with self._upload_linearizer.queue(user_id):
            try:
                res = await self.store.get_e2e_room_keys_version_info(
                    user_id, version)
            except StoreError as e:
                if e.code == 404:
                    raise NotFoundError("Unknown backup version")
                else:
                    raise

            res["count"] = await self.store.count_e2e_room_keys(
                user_id, res["version"])
            res["etag"] = str(res["etag"])
            return res

    @trace
    async def delete_version(self,
                             user_id: str,
                             version: Optional[str] = None) -> None:
        """Deletes a given version of the user's e2e_room_keys backup

        Args:
            user_id(str): the user whose current backup version we're deleting
            version(str): the version id of the backup being deleted
        Raises:
            NotFoundError: if this backup version doesn't exist
        """

        async with self._upload_linearizer.queue(user_id):
            try:
                await self.store.delete_e2e_room_keys_version(user_id, version)
            except StoreError as e:
                if e.code == 404:
                    raise NotFoundError("Unknown backup version")
                else:
                    raise

    @trace
    async def update_version(self, user_id: str, version: str,
                             version_info: JsonDict) -> JsonDict:
        """Update the info about a given version of the user's backup

        Args:
            user_id: the user whose current backup version we're updating
            version: the backup version we're updating
            version_info: the new information about the backup
        Raises:
            NotFoundError: if the requested backup version doesn't exist
        Returns:
            An empty dict.
        """
        if "version" not in version_info:
            version_info["version"] = version
        elif version_info["version"] != version:
            raise SynapseError(400, "Version in body does not match",
                               Codes.INVALID_PARAM)
        async with self._upload_linearizer.queue(user_id):
            try:
                old_info = await self.store.get_e2e_room_keys_version_info(
                    user_id, version)
            except StoreError as e:
                if e.code == 404:
                    raise NotFoundError("Unknown backup version")
                else:
                    raise
            if old_info["algorithm"] != version_info["algorithm"]:
                raise SynapseError(400, "Algorithm does not match",
                                   Codes.INVALID_PARAM)

            await self.store.update_e2e_room_keys_version(
                user_id, version, version_info)

            return {}
Пример #28
0
class FederationServer(FederationBase):

    def __init__(self, hs):
        super(FederationServer, self).__init__(hs)

        self.auth = hs.get_auth()
        self.handler = hs.get_handlers().federation_handler

        self._server_linearizer = Linearizer("fed_server")
        self._transaction_linearizer = Linearizer("fed_txn_handler")

        self.transaction_actions = TransactionActions(self.store)

        self.registry = hs.get_federation_registry()

        # We cache responses to state queries, as they take a while and often
        # come in waves.
        self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)

    @defer.inlineCallbacks
    @log_function
    def on_backfill_request(self, origin, room_id, versions, limit):
        with (yield self._server_linearizer.queue((origin, room_id))):
            origin_host, _ = parse_server_name(origin)
            yield self.check_server_matches_acl(origin_host, room_id)

            pdus = yield self.handler.on_backfill_request(
                origin, room_id, versions, limit
            )

            res = self._transaction_from_pdus(pdus).get_dict()

        defer.returnValue((200, res))

    @defer.inlineCallbacks
    @log_function
    def on_incoming_transaction(self, origin, transaction_data):
        # keep this as early as possible to make the calculated origin ts as
        # accurate as possible.
        request_time = self._clock.time_msec()

        transaction = Transaction(**transaction_data)

        if not transaction.transaction_id:
            raise Exception("Transaction missing transaction_id")

        logger.debug("[%s] Got transaction", transaction.transaction_id)

        # use a linearizer to ensure that we don't process the same transaction
        # multiple times in parallel.
        with (yield self._transaction_linearizer.queue(
                (origin, transaction.transaction_id),
        )):
            result = yield self._handle_incoming_transaction(
                origin, transaction, request_time,
            )

        defer.returnValue(result)

    @defer.inlineCallbacks
    def _handle_incoming_transaction(self, origin, transaction, request_time):
        """ Process an incoming transaction and return the HTTP response

        Args:
            origin (unicode): the server making the request
            transaction (Transaction): incoming transaction
            request_time (int): timestamp that the HTTP request arrived at

        Returns:
            Deferred[(int, object)]: http response code and body
        """
        response = yield self.transaction_actions.have_responded(origin, transaction)

        if response:
            logger.debug(
                "[%s] We've already responded to this request",
                transaction.transaction_id
            )
            defer.returnValue(response)
            return

        logger.debug("[%s] Transaction is new", transaction.transaction_id)

        # Reject if PDU count > 50 and EDU count > 100
        if (len(transaction.pdus) > 50
                or (hasattr(transaction, "edus") and len(transaction.edus) > 100)):

            logger.info(
                "Transaction PDU or EDU count too large. Returning 400",
            )

            response = {}
            yield self.transaction_actions.set_response(
                origin,
                transaction,
                400, response
            )
            defer.returnValue((400, response))

        received_pdus_counter.inc(len(transaction.pdus))

        origin_host, _ = parse_server_name(origin)

        pdus_by_room = {}

        for p in transaction.pdus:
            if "unsigned" in p:
                unsigned = p["unsigned"]
                if "age" in unsigned:
                    p["age"] = unsigned["age"]
            if "age" in p:
                p["age_ts"] = request_time - int(p["age"])
                del p["age"]

            # We try and pull out an event ID so that if later checks fail we
            # can log something sensible. We don't mandate an event ID here in
            # case future event formats get rid of the key.
            possible_event_id = p.get("event_id", "<Unknown>")

            # Now we get the room ID so that we can check that we know the
            # version of the room.
            room_id = p.get("room_id")
            if not room_id:
                logger.info(
                    "Ignoring PDU as does not have a room_id. Event ID: %s",
                    possible_event_id,
                )
                continue

            try:
                room_version = yield self.store.get_room_version(room_id)
            except NotFoundError:
                logger.info("Ignoring PDU for unknown room_id: %s", room_id)
                continue

            try:
                format_ver = room_version_to_event_format(room_version)
            except UnsupportedRoomVersionError:
                # this can happen if support for a given room version is withdrawn,
                # so that we still get events for said room.
                logger.info(
                    "Ignoring PDU for room %s with unknown version %s",
                    room_id,
                    room_version,
                )
                continue

            event = event_from_pdu_json(p, format_ver)
            pdus_by_room.setdefault(room_id, []).append(event)

        pdu_results = {}

        # we can process different rooms in parallel (which is useful if they
        # require callouts to other servers to fetch missing events), but
        # impose a limit to avoid going too crazy with ram/cpu.

        @defer.inlineCallbacks
        def process_pdus_for_room(room_id):
            logger.debug("Processing PDUs for %s", room_id)
            try:
                yield self.check_server_matches_acl(origin_host, room_id)
            except AuthError as e:
                logger.warn(
                    "Ignoring PDUs for room %s from banned server", room_id,
                )
                for pdu in pdus_by_room[room_id]:
                    event_id = pdu.event_id
                    pdu_results[event_id] = e.error_dict()
                return

            for pdu in pdus_by_room[room_id]:
                event_id = pdu.event_id
                with nested_logging_context(event_id):
                    try:
                        yield self._handle_received_pdu(
                            origin, pdu
                        )
                        pdu_results[event_id] = {}
                    except FederationError as e:
                        logger.warn("Error handling PDU %s: %s", event_id, e)
                        pdu_results[event_id] = {"error": str(e)}
                    except Exception as e:
                        f = failure.Failure()
                        pdu_results[event_id] = {"error": str(e)}
                        logger.error(
                            "Failed to handle PDU %s",
                            event_id,
                            exc_info=(f.type, f.value, f.getTracebackObject()),
                        )

        yield concurrently_execute(
            process_pdus_for_room, pdus_by_room.keys(),
            TRANSACTION_CONCURRENCY_LIMIT,
        )

        if hasattr(transaction, "edus"):
            for edu in (Edu(**x) for x in transaction.edus):
                yield self.received_edu(
                    origin,
                    edu.edu_type,
                    edu.content
                )

        response = {
            "pdus": pdu_results,
        }

        logger.debug("Returning: %s", str(response))

        yield self.transaction_actions.set_response(
            origin,
            transaction,
            200, response
        )
        defer.returnValue((200, response))

    @defer.inlineCallbacks
    def received_edu(self, origin, edu_type, content):
        received_edus_counter.inc()
        yield self.registry.on_edu(edu_type, origin, content)

    @defer.inlineCallbacks
    @log_function
    def on_context_state_request(self, origin, room_id, event_id):
        if not event_id:
            raise NotImplementedError("Specify an event")

        origin_host, _ = parse_server_name(origin)
        yield self.check_server_matches_acl(origin_host, room_id)

        in_room = yield self.auth.check_host_in_room(room_id, origin)
        if not in_room:
            raise AuthError(403, "Host not in room.")

        # we grab the linearizer to protect ourselves from servers which hammer
        # us. In theory we might already have the response to this query
        # in the cache so we could return it without waiting for the linearizer
        # - but that's non-trivial to get right, and anyway somewhat defeats
        # the point of the linearizer.
        with (yield self._server_linearizer.queue((origin, room_id))):
            resp = yield self._state_resp_cache.wrap(
                (room_id, event_id),
                self._on_context_state_request_compute,
                room_id, event_id,
            )

        defer.returnValue((200, resp))

    @defer.inlineCallbacks
    def on_state_ids_request(self, origin, room_id, event_id):
        if not event_id:
            raise NotImplementedError("Specify an event")

        origin_host, _ = parse_server_name(origin)
        yield self.check_server_matches_acl(origin_host, room_id)

        in_room = yield self.auth.check_host_in_room(room_id, origin)
        if not in_room:
            raise AuthError(403, "Host not in room.")

        state_ids = yield self.handler.get_state_ids_for_pdu(
            room_id, event_id,
        )
        auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)

        defer.returnValue((200, {
            "pdu_ids": state_ids,
            "auth_chain_ids": auth_chain_ids,
        }))

    @defer.inlineCallbacks
    def _on_context_state_request_compute(self, room_id, event_id):
        pdus = yield self.handler.get_state_for_pdu(
            room_id, event_id,
        )
        auth_chain = yield self.store.get_auth_chain(
            [pdu.event_id for pdu in pdus]
        )

        for event in auth_chain:
            # We sign these again because there was a bug where we
            # incorrectly signed things the first time round
            if self.hs.is_mine_id(event.event_id):
                event.signatures.update(
                    compute_event_signature(
                        event.get_pdu_json(),
                        self.hs.hostname,
                        self.hs.config.signing_key[0]
                    )
                )

        defer.returnValue({
            "pdus": [pdu.get_pdu_json() for pdu in pdus],
            "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
        })

    @defer.inlineCallbacks
    @log_function
    def on_pdu_request(self, origin, event_id):
        pdu = yield self.handler.get_persisted_pdu(origin, event_id)

        if pdu:
            defer.returnValue(
                (200, self._transaction_from_pdus([pdu]).get_dict())
            )
        else:
            defer.returnValue((404, ""))

    @defer.inlineCallbacks
    def on_query_request(self, query_type, args):
        received_queries_counter.labels(query_type).inc()
        resp = yield self.registry.on_query(query_type, args)
        defer.returnValue((200, resp))

    @defer.inlineCallbacks
    def on_make_join_request(self, origin, room_id, user_id, supported_versions):
        origin_host, _ = parse_server_name(origin)
        yield self.check_server_matches_acl(origin_host, room_id)

        room_version = yield self.store.get_room_version(room_id)
        if room_version not in supported_versions:
            logger.warn("Room version %s not in %s", room_version, supported_versions)
            raise IncompatibleRoomVersionError(room_version=room_version)

        pdu = yield self.handler.on_make_join_request(room_id, user_id)
        time_now = self._clock.time_msec()
        defer.returnValue({
            "event": pdu.get_pdu_json(time_now),
            "room_version": room_version,
        })

    @defer.inlineCallbacks
    def on_invite_request(self, origin, content, room_version):
        if room_version not in KNOWN_ROOM_VERSIONS:
            raise SynapseError(
                400,
                "Homeserver does not support this room version",
                Codes.UNSUPPORTED_ROOM_VERSION,
            )

        format_ver = room_version_to_event_format(room_version)

        pdu = event_from_pdu_json(content, format_ver)
        origin_host, _ = parse_server_name(origin)
        yield self.check_server_matches_acl(origin_host, pdu.room_id)
        ret_pdu = yield self.handler.on_invite_request(origin, pdu)
        time_now = self._clock.time_msec()
        defer.returnValue({"event": ret_pdu.get_pdu_json(time_now)})

    @defer.inlineCallbacks
    def on_send_join_request(self, origin, content, room_id):
        logger.debug("on_send_join_request: content: %s", content)

        room_version = yield self.store.get_room_version(room_id)
        format_ver = room_version_to_event_format(room_version)
        pdu = event_from_pdu_json(content, format_ver)

        origin_host, _ = parse_server_name(origin)
        yield self.check_server_matches_acl(origin_host, pdu.room_id)

        logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
        res_pdus = yield self.handler.on_send_join_request(origin, pdu)
        time_now = self._clock.time_msec()
        defer.returnValue((200, {
            "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
            "auth_chain": [
                p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
            ],
        }))

    @defer.inlineCallbacks
    def on_make_leave_request(self, origin, room_id, user_id):
        origin_host, _ = parse_server_name(origin)
        yield self.check_server_matches_acl(origin_host, room_id)
        pdu = yield self.handler.on_make_leave_request(room_id, user_id)

        room_version = yield self.store.get_room_version(room_id)

        time_now = self._clock.time_msec()
        defer.returnValue({
            "event": pdu.get_pdu_json(time_now),
            "room_version": room_version,
        })

    @defer.inlineCallbacks
    def on_send_leave_request(self, origin, content, room_id):
        logger.debug("on_send_leave_request: content: %s", content)

        room_version = yield self.store.get_room_version(room_id)
        format_ver = room_version_to_event_format(room_version)
        pdu = event_from_pdu_json(content, format_ver)

        origin_host, _ = parse_server_name(origin)
        yield self.check_server_matches_acl(origin_host, pdu.room_id)

        logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
        yield self.handler.on_send_leave_request(origin, pdu)
        defer.returnValue((200, {}))

    @defer.inlineCallbacks
    def on_event_auth(self, origin, room_id, event_id):
        with (yield self._server_linearizer.queue((origin, room_id))):
            origin_host, _ = parse_server_name(origin)
            yield self.check_server_matches_acl(origin_host, room_id)

            time_now = self._clock.time_msec()
            auth_pdus = yield self.handler.on_event_auth(event_id)
            res = {
                "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
            }
        defer.returnValue((200, res))

    @defer.inlineCallbacks
    def on_query_auth_request(self, origin, content, room_id, event_id):
        """
        Content is a dict with keys::
            auth_chain (list): A list of events that give the auth chain.
            missing (list): A list of event_ids indicating what the other
              side (`origin`) think we're missing.
            rejects (dict): A mapping from event_id to a 2-tuple of reason
              string and a proof (or None) of why the event was rejected.
              The keys of this dict give the list of events the `origin` has
              rejected.

        Args:
            origin (str)
            content (dict)
            event_id (str)

        Returns:
            Deferred: Results in `dict` with the same format as `content`
        """
        with (yield self._server_linearizer.queue((origin, room_id))):
            origin_host, _ = parse_server_name(origin)
            yield self.check_server_matches_acl(origin_host, room_id)

            room_version = yield self.store.get_room_version(room_id)
            format_ver = room_version_to_event_format(room_version)

            auth_chain = [
                event_from_pdu_json(e, format_ver)
                for e in content["auth_chain"]
            ]

            signed_auth = yield self._check_sigs_and_hash_and_fetch(
                origin, auth_chain, outlier=True, room_version=room_version,
            )

            ret = yield self.handler.on_query_auth(
                origin,
                event_id,
                room_id,
                signed_auth,
                content.get("rejects", []),
                content.get("missing", []),
            )

            time_now = self._clock.time_msec()
            send_content = {
                "auth_chain": [
                    e.get_pdu_json(time_now)
                    for e in ret["auth_chain"]
                ],
                "rejects": ret.get("rejects", []),
                "missing": ret.get("missing", []),
            }

        defer.returnValue(
            (200, send_content)
        )

    @log_function
    def on_query_client_keys(self, origin, content):
        return self.on_query_request("client_keys", content)

    def on_query_user_devices(self, origin, user_id):
        return self.on_query_request("user_devices", user_id)

    @defer.inlineCallbacks
    @log_function
    def on_claim_client_keys(self, origin, content):
        query = []
        for user_id, device_keys in content.get("one_time_keys", {}).items():
            for device_id, algorithm in device_keys.items():
                query.append((user_id, device_id, algorithm))

        results = yield self.store.claim_e2e_one_time_keys(query)

        json_result = {}
        for user_id, device_keys in results.items():
            for device_id, keys in device_keys.items():
                for key_id, json_bytes in keys.items():
                    json_result.setdefault(user_id, {})[device_id] = {
                        key_id: json.loads(json_bytes)
                    }

        logger.info(
            "Claimed one-time-keys: %s",
            ",".join((
                "%s for %s:%s" % (key_id, user_id, device_id)
                for user_id, user_keys in iteritems(json_result)
                for device_id, device_keys in iteritems(user_keys)
                for key_id, _ in iteritems(device_keys)
            )),
        )

        defer.returnValue({"one_time_keys": json_result})

    @defer.inlineCallbacks
    @log_function
    def on_get_missing_events(self, origin, room_id, earliest_events,
                              latest_events, limit):
        with (yield self._server_linearizer.queue((origin, room_id))):
            origin_host, _ = parse_server_name(origin)
            yield self.check_server_matches_acl(origin_host, room_id)

            logger.info(
                "on_get_missing_events: earliest_events: %r, latest_events: %r,"
                " limit: %d",
                earliest_events, latest_events, limit,
            )

            missing_events = yield self.handler.on_get_missing_events(
                origin, room_id, earliest_events, latest_events, limit,
            )

            if len(missing_events) < 5:
                logger.info(
                    "Returning %d events: %r", len(missing_events), missing_events
                )
            else:
                logger.info("Returning %d events", len(missing_events))

            time_now = self._clock.time_msec()

        defer.returnValue({
            "events": [ev.get_pdu_json(time_now) for ev in missing_events],
        })

    @log_function
    def on_openid_userinfo(self, token):
        ts_now_ms = self._clock.time_msec()
        return self.store.get_user_id_for_open_id_token(token, ts_now_ms)

    def _transaction_from_pdus(self, pdu_list):
        """Returns a new Transaction containing the given PDUs suitable for
        transmission.
        """
        time_now = self._clock.time_msec()
        pdus = [p.get_pdu_json(time_now) for p in pdu_list]
        return Transaction(
            origin=self.server_name,
            pdus=pdus,
            origin_server_ts=int(time_now),
            destination=None,
        )

    @defer.inlineCallbacks
    def _handle_received_pdu(self, origin, pdu):
        """ Process a PDU received in a federation /send/ transaction.

        If the event is invalid, then this method throws a FederationError.
        (The error will then be logged and sent back to the sender (which
        probably won't do anything with it), and other events in the
        transaction will be processed as normal).

        It is likely that we'll then receive other events which refer to
        this rejected_event in their prev_events, etc.  When that happens,
        we'll attempt to fetch the rejected event again, which will presumably
        fail, so those second-generation events will also get rejected.

        Eventually, we get to the point where there are more than 10 events
        between any new events and the original rejected event. Since we
        only try to backfill 10 events deep on received pdu, we then accept the
        new event, possibly introducing a discontinuity in the DAG, with new
        forward extremities, so normal service is approximately returned,
        until we try to backfill across the discontinuity.

        Args:
            origin (str): server which sent the pdu
            pdu (FrozenEvent): received pdu

        Returns (Deferred): completes with None

        Raises: FederationError if the signatures / hash do not match, or
            if the event was unacceptable for any other reason (eg, too large,
            too many prev_events, couldn't find the prev_events)
        """
        # check that it's actually being sent from a valid destination to
        # workaround bug #1753 in 0.18.5 and 0.18.6
        if origin != get_domain_from_id(pdu.sender):
            # We continue to accept join events from any server; this is
            # necessary for the federation join dance to work correctly.
            # (When we join over federation, the "helper" server is
            # responsible for sending out the join event, rather than the
            # origin. See bug #1893. This is also true for some third party
            # invites).
            if not (
                pdu.type == 'm.room.member' and
                pdu.content and
                pdu.content.get("membership", None) in (
                    Membership.JOIN, Membership.INVITE,
                )
            ):
                logger.info(
                    "Discarding PDU %s from invalid origin %s",
                    pdu.event_id, origin
                )
                return
            else:
                logger.info(
                    "Accepting join PDU %s from %s",
                    pdu.event_id, origin
                )

        # We've already checked that we know the room version by this point
        room_version = yield self.store.get_room_version(pdu.room_id)

        # Check signature.
        try:
            pdu = yield self._check_sigs_and_hash(room_version, pdu)
        except SynapseError as e:
            raise FederationError(
                "ERROR",
                e.code,
                e.msg,
                affected=pdu.event_id,
            )

        yield self.handler.on_receive_pdu(
            origin, pdu, sent_to_us_directly=True,
        )

    def __str__(self):
        return "<ReplicationLayer(%s)>" % self.server_name

    @defer.inlineCallbacks
    def exchange_third_party_invite(
            self,
            sender_user_id,
            target_user_id,
            room_id,
            signed,
    ):
        ret = yield self.handler.exchange_third_party_invite(
            sender_user_id,
            target_user_id,
            room_id,
            signed,
        )
        defer.returnValue(ret)

    @defer.inlineCallbacks
    def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
        ret = yield self.handler.on_exchange_third_party_invite_request(
            origin, room_id, event_dict
        )
        defer.returnValue(ret)

    @defer.inlineCallbacks
    def check_server_matches_acl(self, server_name, room_id):
        """Check if the given server is allowed by the server ACLs in the room

        Args:
            server_name (str): name of server, *without any port part*
            room_id (str): ID of the room to check

        Raises:
            AuthError if the server does not match the ACL
        """
        state_ids = yield self.store.get_current_state_ids(room_id)
        acl_event_id = state_ids.get((EventTypes.ServerACL, ""))

        if not acl_event_id:
            return

        acl_event = yield self.store.get_event(acl_event_id)
        if server_matches_acl_event(server_name, acl_event):
            return

        raise AuthError(code=403, msg="Server is banned from room")
Пример #29
0
class RegistrationHandler(BaseHandler):
    def __init__(self, hs):
        """

        Args:
            hs (synapse.server.HomeServer):
        """
        super(RegistrationHandler, self).__init__(hs)
        self.hs = hs
        self.auth = hs.get_auth()
        self._auth_handler = hs.get_auth_handler()
        self.profile_handler = hs.get_profile_handler()
        self.user_directory_handler = hs.get_user_directory_handler()
        self.identity_handler = self.hs.get_handlers().identity_handler
        self.ratelimiter = hs.get_registration_ratelimiter()

        self._next_generated_user_id = None

        self.macaroon_gen = hs.get_macaroon_generator()

        self._generate_user_id_linearizer = Linearizer(
            name="_generate_user_id_linearizer")
        self._server_notices_mxid = hs.config.server_notices_mxid

        if hs.config.worker_app:
            self._register_client = ReplicationRegisterServlet.make_client(hs)
            self._register_device_client = RegisterDeviceReplicationServlet.make_client(
                hs)
            self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client(
                hs)
        else:
            self.device_handler = hs.get_device_handler()
            self.pusher_pool = hs.get_pusherpool()

        self.session_lifetime = hs.config.session_lifetime

    @defer.inlineCallbacks
    def check_username(self,
                       localpart,
                       guest_access_token=None,
                       assigned_user_id=None):
        if types.contains_invalid_mxid_characters(localpart):
            raise SynapseError(
                400,
                "User ID can only contain characters a-z, 0-9, or '=_-./'",
                Codes.INVALID_USERNAME,
            )

        if not localpart:
            raise SynapseError(400, "User ID cannot be empty",
                               Codes.INVALID_USERNAME)

        if localpart[0] == "_":
            raise SynapseError(400, "User ID may not begin with _",
                               Codes.INVALID_USERNAME)

        user = UserID(localpart, self.hs.hostname)
        user_id = user.to_string()

        if assigned_user_id:
            if user_id == assigned_user_id:
                return
            else:
                raise SynapseError(
                    400,
                    "A different user ID has already been registered for this session",
                )

        self.check_user_id_not_appservice_exclusive(user_id)

        if len(user_id) > MAX_USERID_LENGTH:
            raise SynapseError(
                400,
                "User ID may not be longer than %s characters" %
                (MAX_USERID_LENGTH, ),
                Codes.INVALID_USERNAME,
            )

        users = yield self.store.get_users_by_id_case_insensitive(user_id)
        if users:
            if not guest_access_token:
                raise SynapseError(400,
                                   "User ID already taken.",
                                   errcode=Codes.USER_IN_USE)
            user_data = yield self.auth.get_user_by_access_token(
                guest_access_token)
            if not user_data[
                    "is_guest"] or user_data["user"].localpart != localpart:
                raise AuthError(
                    403,
                    "Cannot register taken user ID without valid guest "
                    "credentials for that user.",
                    errcode=Codes.FORBIDDEN,
                )

    @defer.inlineCallbacks
    def register_user(
        self,
        localpart=None,
        password=None,
        guest_access_token=None,
        make_guest=False,
        admin=False,
        threepid=None,
        user_type=None,
        default_display_name=None,
        address=None,
        bind_emails=[],
    ):
        """Registers a new client on the server.

        Args:
            localpart : The local part of the user ID to register. If None,
              one will be generated.
            password (unicode) : The password to assign to this user so they can
              login again. This can be None which means they cannot login again
              via a password (e.g. the user is an application service user).
            user_type (str|None): type of user. One of the values from
              api.constants.UserTypes, or None for a normal user.
            default_display_name (unicode|None): if set, the new user's displayname
              will be set to this. Defaults to 'localpart'.
            address (str|None): the IP address used to perform the registration.
            bind_emails (List[str]): list of emails to bind to this account.
        Returns:
            Deferred[str]: user_id
        Raises:
            RegistrationError if there was a problem registering.
        """
        yield self.check_registration_ratelimit(address)

        yield self.auth.check_auth_blocking(threepid=threepid)
        password_hash = None
        if password:
            password_hash = yield self._auth_handler.hash(password)

        if localpart:
            yield self.check_username(localpart,
                                      guest_access_token=guest_access_token)

            was_guest = guest_access_token is not None

            if not was_guest:
                try:
                    int(localpart)
                    raise RegistrationError(
                        400, "Numeric user IDs are reserved for guest users.")
                except ValueError:
                    pass

            user = UserID(localpart, self.hs.hostname)
            user_id = user.to_string()

            if was_guest:
                # If the user was a guest then they already have a profile
                default_display_name = None

            elif default_display_name is None:
                default_display_name = localpart

            yield self.register_with_store(
                user_id=user_id,
                password_hash=password_hash,
                was_guest=was_guest,
                make_guest=make_guest,
                create_profile_with_displayname=default_display_name,
                admin=admin,
                user_type=user_type,
                address=address,
            )

            if self.hs.config.user_directory_search_all_users:
                profile = yield self.store.get_profileinfo(localpart)
                yield self.user_directory_handler.handle_local_profile_change(
                    user_id, profile)

        else:
            # autogen a sequential user ID
            fail_count = 0
            user = None
            while not user:
                # Fail after being unable to find a suitable ID a few times
                if fail_count > 10:
                    raise SynapseError(
                        500, "Unable to find a suitable guest user ID")

                localpart = yield self._generate_user_id()
                user = UserID(localpart, self.hs.hostname)
                user_id = user.to_string()
                yield self.check_user_id_not_appservice_exclusive(user_id)
                if default_display_name is None:
                    default_display_name = localpart
                try:
                    yield self.register_with_store(
                        user_id=user_id,
                        password_hash=password_hash,
                        make_guest=make_guest,
                        create_profile_with_displayname=default_display_name,
                        address=address,
                    )

                    # Successfully registered
                    break
                except SynapseError:
                    # if user id is taken, just generate another
                    user = None
                    user_id = None
                    fail_count += 1

        if not self.hs.config.user_consent_at_registration:
            yield self._auto_join_rooms(user_id)
        else:
            logger.info(
                "Skipping auto-join for %s because consent is required at registration",
                user_id,
            )

        # Bind any specified emails to this account
        current_time = self.hs.get_clock().time_msec()
        for email in bind_emails:
            # generate threepid dict
            threepid_dict = {
                "medium": "email",
                "address": email,
                "validated_at": current_time,
            }

            # Bind email to new account
            yield self._register_email_threepid(user_id, threepid_dict, None)

        return user_id

    @defer.inlineCallbacks
    def _auto_join_rooms(self, user_id):
        """Automatically joins users to auto join rooms - creating the room in the first place
        if the user is the first to be created.

        Args:
            user_id(str): The user to join
        """
        # auto-join the user to any rooms we're supposed to dump them into
        fake_requester = create_requester(user_id)

        # try to create the room if we're the first real user on the server. Note
        # that an auto-generated support or bot user is not a real user and will never be
        # the user to create the room
        should_auto_create_rooms = False
        is_real_user = yield self.store.is_real_user(user_id)
        if self.hs.config.autocreate_auto_join_rooms and is_real_user:
            count = yield self.store.count_real_users()
            should_auto_create_rooms = count == 1
        for r in self.hs.config.auto_join_rooms:
            logger.info("Auto-joining %s to %s", user_id, r)
            try:
                if should_auto_create_rooms:
                    room_alias = RoomAlias.from_string(r)
                    if self.hs.hostname != room_alias.domain:
                        logger.warning(
                            "Cannot create room alias %s, "
                            "it does not match server domain",
                            r,
                        )
                    else:
                        # create room expects the localpart of the room alias
                        room_alias_localpart = room_alias.localpart

                        # getting the RoomCreationHandler during init gives a dependency
                        # loop
                        yield self.hs.get_room_creation_handler().create_room(
                            fake_requester,
                            config={
                                "preset": "public_chat",
                                "room_alias_name": room_alias_localpart,
                            },
                            ratelimit=False,
                        )
                else:
                    yield self._join_user_to_room(fake_requester, r)
            except ConsentNotGivenError as e:
                # Technically not necessary to pull out this error though
                # moving away from bare excepts is a good thing to do.
                logger.error("Failed to join new user to %r: %r", r, e)
            except Exception as e:
                logger.error("Failed to join new user to %r: %r", r, e)

    @defer.inlineCallbacks
    def post_consent_actions(self, user_id):
        """A series of registration actions that can only be carried out once consent
        has been granted

        Args:
            user_id (str): The user to join
        """
        yield self._auto_join_rooms(user_id)

    @defer.inlineCallbacks
    def appservice_register(self, user_localpart, as_token):
        user = UserID(user_localpart, self.hs.hostname)
        user_id = user.to_string()
        service = self.store.get_app_service_by_token(as_token)
        if not service:
            raise AuthError(403, "Invalid application service token.")
        if not service.is_interested_in_user(user_id):
            raise SynapseError(
                400,
                "Invalid user localpart for this application service.",
                errcode=Codes.EXCLUSIVE,
            )

        service_id = service.id if service.is_exclusive_user(user_id) else None

        yield self.check_user_id_not_appservice_exclusive(
            user_id, allowed_appservice=service)

        yield self.register_with_store(
            user_id=user_id,
            password_hash="",
            appservice_id=service_id,
            create_profile_with_displayname=user.localpart,
        )
        return user_id

    def check_user_id_not_appservice_exclusive(self,
                                               user_id,
                                               allowed_appservice=None):
        # don't allow people to register the server notices mxid
        if self._server_notices_mxid is not None:
            if user_id == self._server_notices_mxid:
                raise SynapseError(400,
                                   "This user ID is reserved.",
                                   errcode=Codes.EXCLUSIVE)

        # valid user IDs must not clash with any user ID namespaces claimed by
        # application services.
        services = self.store.get_app_services()
        interested_services = [
            s for s in services
            if s.is_interested_in_user(user_id) and s != allowed_appservice
        ]
        for service in interested_services:
            if service.is_exclusive_user(user_id):
                raise SynapseError(
                    400,
                    "This user ID is reserved by an application service.",
                    errcode=Codes.EXCLUSIVE,
                )

    @defer.inlineCallbacks
    def _generate_user_id(self):
        if self._next_generated_user_id is None:
            with (yield self._generate_user_id_linearizer.queue(())):
                if self._next_generated_user_id is None:
                    self._next_generated_user_id = (
                        yield
                        self.store.find_next_generated_user_id_localpart())

        id = self._next_generated_user_id
        self._next_generated_user_id += 1
        return str(id)

    @defer.inlineCallbacks
    def _join_user_to_room(self, requester, room_identifier):
        room_member_handler = self.hs.get_room_member_handler()
        if RoomID.is_valid(room_identifier):
            room_id = room_identifier
        elif RoomAlias.is_valid(room_identifier):
            room_alias = RoomAlias.from_string(room_identifier)
            room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias(
                room_alias)
            room_id = room_id.to_string()
        else:
            raise SynapseError(
                400,
                "%s was not legal room ID or room alias" % (room_identifier, ))

        yield room_member_handler.update_membership(
            requester=requester,
            target=requester.user,
            room_id=room_id,
            remote_room_hosts=remote_room_hosts,
            action="join",
            ratelimit=False,
        )

    def check_registration_ratelimit(self, address):
        """A simple helper method to check whether the registration rate limit has been hit
        for a given IP address

        Args:
            address (str|None): the IP address used to perform the registration. If this is
                None, no ratelimiting will be performed.

        Raises:
            LimitExceededError: If the rate limit has been exceeded.
        """
        if not address:
            return

        time_now = self.clock.time()

        self.ratelimiter.ratelimit(
            address,
            time_now_s=time_now,
            rate_hz=self.hs.config.rc_registration.per_second,
            burst_count=self.hs.config.rc_registration.burst_count,
        )

    def register_with_store(
        self,
        user_id,
        password_hash=None,
        was_guest=False,
        make_guest=False,
        appservice_id=None,
        create_profile_with_displayname=None,
        admin=False,
        user_type=None,
        address=None,
    ):
        """Register user in the datastore.

        Args:
            user_id (str): The desired user ID to register.
            password_hash (str|None): Optional. The password hash for this user.
            was_guest (bool): Optional. Whether this is a guest account being
                upgraded to a non-guest account.
            make_guest (boolean): True if the the new user should be guest,
                false to add a regular user account.
            appservice_id (str|None): The ID of the appservice registering the user.
            create_profile_with_displayname (unicode|None): Optionally create a
                profile for the user, setting their displayname to the given value
            admin (boolean): is an admin user?
            user_type (str|None): type of user. One of the values from
                api.constants.UserTypes, or None for a normal user.
            address (str|None): the IP address used to perform the registration.

        Returns:
            Deferred
        """
        if self.hs.config.worker_app:
            return self._register_client(
                user_id=user_id,
                password_hash=password_hash,
                was_guest=was_guest,
                make_guest=make_guest,
                appservice_id=appservice_id,
                create_profile_with_displayname=create_profile_with_displayname,
                admin=admin,
                user_type=user_type,
                address=address,
            )
        else:
            return self.store.register_user(
                user_id=user_id,
                password_hash=password_hash,
                was_guest=was_guest,
                make_guest=make_guest,
                appservice_id=appservice_id,
                create_profile_with_displayname=create_profile_with_displayname,
                admin=admin,
                user_type=user_type,
            )

    @defer.inlineCallbacks
    def register_device(self,
                        user_id,
                        device_id,
                        initial_display_name,
                        is_guest=False):
        """Register a device for a user and generate an access token.

        The access token will be limited by the homeserver's session_lifetime config.

        Args:
            user_id (str): full canonical @user:id
            device_id (str|None): The device ID to check, or None to generate
                a new one.
            initial_display_name (str|None): An optional display name for the
                device.
            is_guest (bool): Whether this is a guest account

        Returns:
            defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
        """

        if self.hs.config.worker_app:
            r = yield self._register_device_client(
                user_id=user_id,
                device_id=device_id,
                initial_display_name=initial_display_name,
                is_guest=is_guest,
            )
            return r["device_id"], r["access_token"]

        valid_until_ms = None
        if self.session_lifetime is not None:
            if is_guest:
                raise Exception(
                    "session_lifetime is not currently implemented for guest access"
                )
            valid_until_ms = self.clock.time_msec() + self.session_lifetime

        device_id = yield self.device_handler.check_device_registered(
            user_id, device_id, initial_display_name)
        if is_guest:
            assert valid_until_ms is None
            access_token = self.macaroon_gen.generate_access_token(
                user_id, ["guest = true"])
        else:
            access_token = yield self._auth_handler.get_access_token_for_user_id(
                user_id, device_id=device_id, valid_until_ms=valid_until_ms)

        return (device_id, access_token)

    @defer.inlineCallbacks
    def post_registration_actions(self, user_id, auth_result, access_token):
        """A user has completed registration

        Args:
            user_id (str): The user ID that consented
            auth_result (dict): The authenticated credentials of the newly
                registered user.
            access_token (str|None): The access token of the newly logged in
                device, or None if `inhibit_login` enabled.
        """
        if self.hs.config.worker_app:
            yield self._post_registration_client(user_id=user_id,
                                                 auth_result=auth_result,
                                                 access_token=access_token)
            return

        if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
            threepid = auth_result[LoginType.EMAIL_IDENTITY]
            # Necessary due to auth checks prior to the threepid being
            # written to the db
            if is_threepid_reserved(
                    self.hs.config.mau_limits_reserved_threepids, threepid):
                yield self.store.upsert_monthly_active_user(user_id)

            yield self._register_email_threepid(user_id, threepid,
                                                access_token)

        if auth_result and LoginType.MSISDN in auth_result:
            threepid = auth_result[LoginType.MSISDN]
            yield self._register_msisdn_threepid(user_id, threepid)

        if auth_result and LoginType.TERMS in auth_result:
            yield self._on_user_consented(user_id,
                                          self.hs.config.user_consent_version)

    @defer.inlineCallbacks
    def _on_user_consented(self, user_id, consent_version):
        """A user consented to the terms on registration

        Args:
            user_id (str): The user ID that consented.
            consent_version (str): version of the policy the user has
                consented to.
        """
        logger.info("%s has consented to the privacy policy", user_id)
        yield self.store.user_set_consent_version(user_id, consent_version)
        yield self.post_consent_actions(user_id)

    @defer.inlineCallbacks
    def _register_email_threepid(self, user_id, threepid, token):
        """Add an email address as a 3pid identifier

        Also adds an email pusher for the email address, if configured in the
        HS config

        Must be called on master.

        Args:
            user_id (str): id of user
            threepid (object): m.login.email.identity auth response
            token (str|None): access_token for the user, or None if not logged
                in.
        Returns:
            defer.Deferred:
        """
        reqd = ("medium", "address", "validated_at")
        if any(x not in threepid for x in reqd):
            # This will only happen if the ID server returns a malformed response
            logger.info("Can't add incomplete 3pid")
            return

        yield self._auth_handler.add_threepid(user_id, threepid["medium"],
                                              threepid["address"],
                                              threepid["validated_at"])

        # And we add an email pusher for them by default, but only
        # if email notifications are enabled (so people don't start
        # getting mail spam where they weren't before if email
        # notifs are set up on a homeserver)
        if (self.hs.config.email_enable_notifs
                and self.hs.config.email_notif_for_new_users and token):
            # Pull the ID of the access token back out of the db
            # It would really make more sense for this to be passed
            # up when the access token is saved, but that's quite an
            # invasive change I'd rather do separately.
            user_tuple = yield self.store.get_user_by_access_token(token)
            token_id = user_tuple["token_id"]

            yield self.pusher_pool.add_pusher(
                user_id=user_id,
                access_token=token_id,
                kind="email",
                app_id="m.email",
                app_display_name="Email Notifications",
                device_display_name=threepid["address"],
                pushkey=threepid["address"],
                lang=None,  # We don't know a user's language here
                data={},
            )

    @defer.inlineCallbacks
    def _register_msisdn_threepid(self, user_id, threepid):
        """Add a phone number as a 3pid identifier

        Must be called on master.

        Args:
            user_id (str): id of user
            threepid (object): m.login.msisdn auth response
        Returns:
            defer.Deferred:
        """
        try:
            assert_params_in_dict(threepid,
                                  ["medium", "address", "validated_at"])
        except SynapseError as ex:
            if ex.errcode == Codes.MISSING_PARAM:
                # This will only happen if the ID server returns a malformed response
                logger.info("Can't add incomplete 3pid")
                return None
            raise

        yield self._auth_handler.add_threepid(user_id, threepid["medium"],
                                              threepid["address"],
                                              threepid["validated_at"])
Пример #30
0
class RoomMemberHandler(object):
    # TODO(paul): This handler currently contains a messy conflation of
    #   low-level API that works on UserID objects and so on, and REST-level
    #   API that takes ID strings and returns pagination chunks. These concerns
    #   ought to be separated out a lot better.

    __metaclass__ = abc.ABCMeta

    def __init__(self, hs):
        """

        Args:
            hs (synapse.server.HomeServer):
        """
        self.hs = hs
        self.store = hs.get_datastore()
        self.auth = hs.get_auth()
        self.state_handler = hs.get_state_handler()
        self.config = hs.config
        self.simple_http_client = hs.get_simple_http_client()

        self.federation_handler = hs.get_handlers().federation_handler
        self.directory_handler = hs.get_handlers().directory_handler
        self.registration_handler = hs.get_registration_handler()
        self.profile_handler = hs.get_profile_handler()
        self.event_creation_handler = hs.get_event_creation_handler()

        self.member_linearizer = Linearizer(name="member")

        self.clock = hs.get_clock()
        self.spam_checker = hs.get_spam_checker()
        self._server_notices_mxid = self.config.server_notices_mxid
        self._enable_lookup = hs.config.enable_3pid_lookup
        self.allow_per_room_profiles = self.config.allow_per_room_profiles

        # This is only used to get at ratelimit function, and
        # maybe_kick_guest_users. It's fine there are multiple of these as
        # it doesn't store state.
        self.base_handler = BaseHandler(hs)

    @abc.abstractmethod
    def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
        """Try and join a room that this server is not in

        Args:
            requester (Requester)
            remote_room_hosts (list[str]): List of servers that can be used
                to join via.
            room_id (str): Room that we are trying to join
            user (UserID): User who is trying to join
            content (dict): A dict that should be used as the content of the
                join event.

        Returns:
            Deferred
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def _remote_reject_invite(self, remote_room_hosts, room_id, target):
        """Attempt to reject an invite for a room this server is not in. If we
        fail to do so we locally mark the invite as rejected.

        Args:
            requester (Requester)
            remote_room_hosts (list[str]): List of servers to use to try and
                reject invite
            room_id (str)
            target (UserID): The user rejecting the invite

        Returns:
            Deferred[dict]: A dictionary to be returned to the client, may
            include event_id etc, or nothing if we locally rejected
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
        """Get a guest access token for a 3PID, creating a guest account if
        one doesn't already exist.

        Args:
            requester (Requester)
            medium (str)
            address (str)
            inviter_user_id (str): The user ID who is trying to invite the
                3PID

        Returns:
            Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
            3PID guest account.
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def _user_joined_room(self, target, room_id):
        """Notifies distributor on master process that the user has joined the
        room.

        Args:
            target (UserID)
            room_id (str)

        Returns:
            Deferred|None
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def _user_left_room(self, target, room_id):
        """Notifies distributor on master process that the user has left the
        room.

        Args:
            target (UserID)
            room_id (str)

        Returns:
            Deferred|None
        """
        raise NotImplementedError()

    @defer.inlineCallbacks
    def _local_membership_update(
        self, requester, target, room_id, membership,
        prev_events_and_hashes,
        txn_id=None,
        ratelimit=True,
        content=None,
        require_consent=True,
    ):
        user_id = target.to_string()

        if content is None:
            content = {}

        content["membership"] = membership
        if requester.is_guest:
            content["kind"] = "guest"

        event, context = yield self.event_creation_handler.create_event(
            requester,
            {
                "type": EventTypes.Member,
                "content": content,
                "room_id": room_id,
                "sender": requester.user.to_string(),
                "state_key": user_id,

                # For backwards compatibility:
                "membership": membership,
            },
            token_id=requester.access_token_id,
            txn_id=txn_id,
            prev_events_and_hashes=prev_events_and_hashes,
            require_consent=require_consent,
        )

        # Check if this event matches the previous membership event for the user.
        duplicate = yield self.event_creation_handler.deduplicate_state_event(
            event, context,
        )
        if duplicate is not None:
            # Discard the new event since this membership change is a no-op.
            defer.returnValue(duplicate)

        yield self.event_creation_handler.handle_new_client_event(
            requester,
            event,
            context,
            extra_users=[target],
            ratelimit=ratelimit,
        )

        prev_state_ids = yield context.get_prev_state_ids(self.store)

        prev_member_event_id = prev_state_ids.get(
            (EventTypes.Member, user_id),
            None
        )

        if event.membership == Membership.JOIN:
            # Only fire user_joined_room if the user has actually joined the
            # room. Don't bother if the user is just changing their profile
            # info.
            newly_joined = True
            if prev_member_event_id:
                prev_member_event = yield self.store.get_event(prev_member_event_id)
                newly_joined = prev_member_event.membership != Membership.JOIN
            if newly_joined:
                yield self._user_joined_room(target, room_id)

            # Copy over direct message status and room tags if this is a join
            # on an upgraded room

            # Check if this is an upgraded room
            predecessor = yield self.store.get_room_predecessor(room_id)

            if predecessor:
                # It is an upgraded room. Copy over old tags
                self.copy_room_tags_and_direct_to_room(
                    predecessor["room_id"], room_id, user_id,
                )
                # Move over old push rules
                self.store.move_push_rules_from_room_to_room_for_user(
                    predecessor["room_id"], room_id, user_id,
                )
        elif event.membership == Membership.LEAVE:
            if prev_member_event_id:
                prev_member_event = yield self.store.get_event(prev_member_event_id)
                if prev_member_event.membership == Membership.JOIN:
                    yield self._user_left_room(target, room_id)

        defer.returnValue(event)

    @defer.inlineCallbacks
    def copy_room_tags_and_direct_to_room(
        self,
        old_room_id,
        new_room_id,
        user_id,
    ):
        """Copies the tags and direct room state from one room to another.

        Args:
            old_room_id (str)
            new_room_id (str)
            user_id (str)

        Returns:
            Deferred[None]
        """
        # Retrieve user account data for predecessor room
        user_account_data, _ = yield self.store.get_account_data_for_user(
            user_id,
        )

        # Copy direct message state if applicable
        direct_rooms = user_account_data.get("m.direct", {})

        # Check which key this room is under
        if isinstance(direct_rooms, dict):
            for key, room_id_list in direct_rooms.items():
                if old_room_id in room_id_list and new_room_id not in room_id_list:
                    # Add new room_id to this key
                    direct_rooms[key].append(new_room_id)

                    # Save back to user's m.direct account data
                    yield self.store.add_account_data_for_user(
                        user_id, "m.direct", direct_rooms,
                    )
                    break

        # Copy room tags if applicable
        room_tags = yield self.store.get_tags_for_room(
            user_id, old_room_id,
        )

        # Copy each room tag to the new room
        for tag, tag_content in room_tags.items():
            yield self.store.add_tag_to_room(
                user_id, new_room_id, tag, tag_content
            )

    @defer.inlineCallbacks
    def update_membership(
            self,
            requester,
            target,
            room_id,
            action,
            txn_id=None,
            remote_room_hosts=None,
            third_party_signed=None,
            ratelimit=True,
            content=None,
            require_consent=True,
    ):
        key = (room_id,)

        with (yield self.member_linearizer.queue(key)):
            result = yield self._update_membership(
                requester,
                target,
                room_id,
                action,
                txn_id=txn_id,
                remote_room_hosts=remote_room_hosts,
                third_party_signed=third_party_signed,
                ratelimit=ratelimit,
                content=content,
                require_consent=require_consent,
            )

        defer.returnValue(result)

    @defer.inlineCallbacks
    def _update_membership(
            self,
            requester,
            target,
            room_id,
            action,
            txn_id=None,
            remote_room_hosts=None,
            third_party_signed=None,
            ratelimit=True,
            content=None,
            require_consent=True,
    ):
        content_specified = bool(content)
        if content is None:
            content = {}
        else:
            # We do a copy here as we potentially change some keys
            # later on.
            content = dict(content)

        if not self.allow_per_room_profiles:
            # Strip profile data, knowing that new profile data will be added to the
            # event's content in event_creation_handler.create_event() using the target's
            # global profile.
            content.pop("displayname", None)
            content.pop("avatar_url", None)

        effective_membership_state = action
        if action in ["kick", "unban"]:
            effective_membership_state = "leave"

        # if this is a join with a 3pid signature, we may need to turn a 3pid
        # invite into a normal invite before we can handle the join.
        if third_party_signed is not None:
            yield self.federation_handler.exchange_third_party_invite(
                third_party_signed["sender"],
                target.to_string(),
                room_id,
                third_party_signed,
            )

        if not remote_room_hosts:
            remote_room_hosts = []

        if effective_membership_state not in ("leave", "ban",):
            is_blocked = yield self.store.is_room_blocked(room_id)
            if is_blocked:
                raise SynapseError(403, "This room has been blocked on this server")

        if effective_membership_state == Membership.INVITE:
            # block any attempts to invite the server notices mxid
            if target.to_string() == self._server_notices_mxid:
                raise SynapseError(
                    http_client.FORBIDDEN,
                    "Cannot invite this user",
                )

            block_invite = False

            if (self._server_notices_mxid is not None and
                    requester.user.to_string() == self._server_notices_mxid):
                # allow the server notices mxid to send invites
                is_requester_admin = True

            else:
                is_requester_admin = yield self.auth.is_server_admin(
                    requester.user,
                )

            if not is_requester_admin:
                if self.config.block_non_admin_invites:
                    logger.info(
                        "Blocking invite: user is not admin and non-admin "
                        "invites disabled"
                    )
                    block_invite = True

                if not self.spam_checker.user_may_invite(
                    requester.user.to_string(), target.to_string(), room_id,
                ):
                    logger.info("Blocking invite due to spam checker")
                    block_invite = True

            if block_invite:
                raise SynapseError(
                    403, "Invites have been disabled on this server",
                )

        prev_events_and_hashes = yield self.store.get_prev_events_for_room(
            room_id,
        )
        latest_event_ids = (
            event_id for (event_id, _, _) in prev_events_and_hashes
        )

        current_state_ids = yield self.state_handler.get_current_state_ids(
            room_id, latest_event_ids=latest_event_ids,
        )

        # TODO: Refactor into dictionary of explicitly allowed transitions
        # between old and new state, with specific error messages for some
        # transitions and generic otherwise
        old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
        if old_state_id:
            old_state = yield self.store.get_event(old_state_id, allow_none=True)
            old_membership = old_state.content.get("membership") if old_state else None
            if action == "unban" and old_membership != "ban":
                raise SynapseError(
                    403,
                    "Cannot unban user who was not banned"
                    " (membership=%s)" % old_membership,
                    errcode=Codes.BAD_STATE
                )
            if old_membership == "ban" and action != "unban":
                raise SynapseError(
                    403,
                    "Cannot %s user who was banned" % (action,),
                    errcode=Codes.BAD_STATE
                )

            if old_state:
                same_content = content == old_state.content
                same_membership = old_membership == effective_membership_state
                same_sender = requester.user.to_string() == old_state.sender
                if same_sender and same_membership and same_content:
                    defer.returnValue(old_state)

            if old_membership in ["ban", "leave"] and action == "kick":
                raise AuthError(403, "The target user is not in the room")

            # we don't allow people to reject invites to the server notice
            # room, but they can leave it once they are joined.
            if (
                old_membership == Membership.INVITE and
                effective_membership_state == Membership.LEAVE
            ):
                is_blocked = yield self._is_server_notice_room(room_id)
                if is_blocked:
                    raise SynapseError(
                        http_client.FORBIDDEN,
                        "You cannot reject this invite",
                        errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM,
                    )
        else:
            if action == "kick":
                raise AuthError(403, "The target user is not in the room")

        is_host_in_room = yield self._is_host_in_room(current_state_ids)

        if effective_membership_state == Membership.JOIN:
            if requester.is_guest:
                guest_can_join = yield self._can_guest_join(current_state_ids)
                if not guest_can_join:
                    # This should be an auth check, but guests are a local concept,
                    # so don't really fit into the general auth process.
                    raise AuthError(403, "Guest access not allowed")

            if not is_host_in_room:
                inviter = yield self._get_inviter(target.to_string(), room_id)
                if inviter and not self.hs.is_mine(inviter):
                    remote_room_hosts.append(inviter.domain)

                content["membership"] = Membership.JOIN

                profile = self.profile_handler
                if not content_specified:
                    content["displayname"] = yield profile.get_displayname(target)
                    content["avatar_url"] = yield profile.get_avatar_url(target)

                if requester.is_guest:
                    content["kind"] = "guest"

                ret = yield self._remote_join(
                    requester, remote_room_hosts, room_id, target, content
                )
                defer.returnValue(ret)

        elif effective_membership_state == Membership.LEAVE:
            if not is_host_in_room:
                # perhaps we've been invited
                inviter = yield self._get_inviter(target.to_string(), room_id)
                if not inviter:
                    raise SynapseError(404, "Not a known room")

                if self.hs.is_mine(inviter):
                    # the inviter was on our server, but has now left. Carry on
                    # with the normal rejection codepath.
                    #
                    # This is a bit of a hack, because the room might still be
                    # active on other servers.
                    pass
                else:
                    # send the rejection to the inviter's HS.
                    remote_room_hosts = remote_room_hosts + [inviter.domain]
                    res = yield self._remote_reject_invite(
                        requester, remote_room_hosts, room_id, target,
                    )
                    defer.returnValue(res)

        res = yield self._local_membership_update(
            requester=requester,
            target=target,
            room_id=room_id,
            membership=effective_membership_state,
            txn_id=txn_id,
            ratelimit=ratelimit,
            prev_events_and_hashes=prev_events_and_hashes,
            content=content,
            require_consent=require_consent,
        )
        defer.returnValue(res)

    @defer.inlineCallbacks
    def send_membership_event(
            self,
            requester,
            event,
            context,
            remote_room_hosts=None,
            ratelimit=True,
    ):
        """
        Change the membership status of a user in a room.

        Args:
            requester (Requester): The local user who requested the membership
                event. If None, certain checks, like whether this homeserver can
                act as the sender, will be skipped.
            event (SynapseEvent): The membership event.
            context: The context of the event.
            is_guest (bool): Whether the sender is a guest.
            room_hosts ([str]): Homeservers which are likely to already be in
                the room, and could be danced with in order to join this
                homeserver for the first time.
            ratelimit (bool): Whether to rate limit this request.
        Raises:
            SynapseError if there was a problem changing the membership.
        """
        remote_room_hosts = remote_room_hosts or []

        target_user = UserID.from_string(event.state_key)
        room_id = event.room_id

        if requester is not None:
            sender = UserID.from_string(event.sender)
            assert sender == requester.user, (
                "Sender (%s) must be same as requester (%s)" %
                (sender, requester.user)
            )
            assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
        else:
            requester = synapse.types.create_requester(target_user)

        prev_event = yield self.event_creation_handler.deduplicate_state_event(
            event, context,
        )
        if prev_event is not None:
            return

        prev_state_ids = yield context.get_prev_state_ids(self.store)
        if event.membership == Membership.JOIN:
            if requester.is_guest:
                guest_can_join = yield self._can_guest_join(prev_state_ids)
                if not guest_can_join:
                    # This should be an auth check, but guests are a local concept,
                    # so don't really fit into the general auth process.
                    raise AuthError(403, "Guest access not allowed")

        if event.membership not in (Membership.LEAVE, Membership.BAN):
            is_blocked = yield self.store.is_room_blocked(room_id)
            if is_blocked:
                raise SynapseError(403, "This room has been blocked on this server")

        yield self.event_creation_handler.handle_new_client_event(
            requester,
            event,
            context,
            extra_users=[target_user],
            ratelimit=ratelimit,
        )

        prev_member_event_id = prev_state_ids.get(
            (EventTypes.Member, event.state_key),
            None
        )

        if event.membership == Membership.JOIN:
            # Only fire user_joined_room if the user has actually joined the
            # room. Don't bother if the user is just changing their profile
            # info.
            newly_joined = True
            if prev_member_event_id:
                prev_member_event = yield self.store.get_event(prev_member_event_id)
                newly_joined = prev_member_event.membership != Membership.JOIN
            if newly_joined:
                yield self._user_joined_room(target_user, room_id)
        elif event.membership == Membership.LEAVE:
            if prev_member_event_id:
                prev_member_event = yield self.store.get_event(prev_member_event_id)
                if prev_member_event.membership == Membership.JOIN:
                    yield self._user_left_room(target_user, room_id)

    @defer.inlineCallbacks
    def _can_guest_join(self, current_state_ids):
        """
        Returns whether a guest can join a room based on its current state.
        """
        guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
        if not guest_access_id:
            defer.returnValue(False)

        guest_access = yield self.store.get_event(guest_access_id)

        defer.returnValue(
            guest_access
            and guest_access.content
            and "guest_access" in guest_access.content
            and guest_access.content["guest_access"] == "can_join"
        )

    @defer.inlineCallbacks
    def lookup_room_alias(self, room_alias):
        """
        Get the room ID associated with a room alias.

        Args:
            room_alias (RoomAlias): The alias to look up.
        Returns:
            A tuple of:
                The room ID as a RoomID object.
                Hosts likely to be participating in the room ([str]).
        Raises:
            SynapseError if room alias could not be found.
        """
        directory_handler = self.directory_handler
        mapping = yield directory_handler.get_association(room_alias)

        if not mapping:
            raise SynapseError(404, "No such room alias")

        room_id = mapping["room_id"]
        servers = mapping["servers"]

        # put the server which owns the alias at the front of the server list.
        if room_alias.domain in servers:
            servers.remove(room_alias.domain)
        servers.insert(0, room_alias.domain)

        defer.returnValue((RoomID.from_string(room_id), servers))

    @defer.inlineCallbacks
    def _get_inviter(self, user_id, room_id):
        invite = yield self.store.get_invite_for_user_in_room(
            user_id=user_id,
            room_id=room_id,
        )
        if invite:
            defer.returnValue(UserID.from_string(invite.sender))

    @defer.inlineCallbacks
    def do_3pid_invite(
            self,
            room_id,
            inviter,
            medium,
            address,
            id_server,
            requester,
            txn_id
    ):
        if self.config.block_non_admin_invites:
            is_requester_admin = yield self.auth.is_server_admin(
                requester.user,
            )
            if not is_requester_admin:
                raise SynapseError(
                    403, "Invites have been disabled on this server",
                    Codes.FORBIDDEN,
                )

        # We need to rate limit *before* we send out any 3PID invites, so we
        # can't just rely on the standard ratelimiting of events.
        yield self.base_handler.ratelimit(requester)

        invitee = yield self._lookup_3pid(
            id_server, medium, address
        )

        if invitee:
            yield self.update_membership(
                requester,
                UserID.from_string(invitee),
                room_id,
                "invite",
                txn_id=txn_id,
            )
        else:
            yield self._make_and_store_3pid_invite(
                requester,
                id_server,
                medium,
                address,
                room_id,
                inviter,
                txn_id=txn_id
            )

    @defer.inlineCallbacks
    def _lookup_3pid(self, id_server, medium, address):
        """Looks up a 3pid in the passed identity server.

        Args:
            id_server (str): The server name (including port, if required)
                of the identity server to use.
            medium (str): The type of the third party identifier (e.g. "email").
            address (str): The third party identifier (e.g. "*****@*****.**").

        Returns:
            str: the matrix ID of the 3pid, or None if it is not recognized.
        """
        if not self._enable_lookup:
            raise SynapseError(
                403, "Looking up third-party identifiers is denied from this server",
            )
        try:
            data = yield self.simple_http_client.get_json(
                "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
                {
                    "medium": medium,
                    "address": address,
                }
            )

            if "mxid" in data:
                if "signatures" not in data:
                    raise AuthError(401, "No signatures on 3pid binding")
                yield self._verify_any_signature(data, id_server)
                defer.returnValue(data["mxid"])

        except IOError as e:
            logger.warn("Error from identity server lookup: %s" % (e,))
            defer.returnValue(None)

    @defer.inlineCallbacks
    def _verify_any_signature(self, data, server_hostname):
        if server_hostname not in data["signatures"]:
            raise AuthError(401, "No signature from server %s" % (server_hostname,))
        for key_name, signature in data["signatures"][server_hostname].items():
            key_data = yield self.simple_http_client.get_json(
                "%s%s/_matrix/identity/api/v1/pubkey/%s" %
                (id_server_scheme, server_hostname, key_name,),
            )
            if "public_key" not in key_data:
                raise AuthError(401, "No public key named %s from %s" %
                                (key_name, server_hostname,))
            verify_signed_json(
                data,
                server_hostname,
                decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"]))
            )
            return

    @defer.inlineCallbacks
    def _make_and_store_3pid_invite(
            self,
            requester,
            id_server,
            medium,
            address,
            room_id,
            user,
            txn_id
    ):
        room_state = yield self.state_handler.get_current_state(room_id)

        inviter_display_name = ""
        inviter_avatar_url = ""
        member_event = room_state.get((EventTypes.Member, user.to_string()))
        if member_event:
            inviter_display_name = member_event.content.get("displayname", "")
            inviter_avatar_url = member_event.content.get("avatar_url", "")

        # if user has no display name, default to their MXID
        if not inviter_display_name:
            inviter_display_name = user.to_string()

        canonical_room_alias = ""
        canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, ""))
        if canonical_alias_event:
            canonical_room_alias = canonical_alias_event.content.get("alias", "")

        room_name = ""
        room_name_event = room_state.get((EventTypes.Name, ""))
        if room_name_event:
            room_name = room_name_event.content.get("name", "")

        room_join_rules = ""
        join_rules_event = room_state.get((EventTypes.JoinRules, ""))
        if join_rules_event:
            room_join_rules = join_rules_event.content.get("join_rule", "")

        room_avatar_url = ""
        room_avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
        if room_avatar_event:
            room_avatar_url = room_avatar_event.content.get("url", "")

        token, public_keys, fallback_public_key, display_name = (
            yield self._ask_id_server_for_third_party_invite(
                requester=requester,
                id_server=id_server,
                medium=medium,
                address=address,
                room_id=room_id,
                inviter_user_id=user.to_string(),
                room_alias=canonical_room_alias,
                room_avatar_url=room_avatar_url,
                room_join_rules=room_join_rules,
                room_name=room_name,
                inviter_display_name=inviter_display_name,
                inviter_avatar_url=inviter_avatar_url
            )
        )

        yield self.event_creation_handler.create_and_send_nonmember_event(
            requester,
            {
                "type": EventTypes.ThirdPartyInvite,
                "content": {
                    "display_name": display_name,
                    "public_keys": public_keys,

                    # For backwards compatibility:
                    "key_validity_url": fallback_public_key["key_validity_url"],
                    "public_key": fallback_public_key["public_key"],
                },
                "room_id": room_id,
                "sender": user.to_string(),
                "state_key": token,
            },
            txn_id=txn_id,
        )

    @defer.inlineCallbacks
    def _ask_id_server_for_third_party_invite(
            self,
            requester,
            id_server,
            medium,
            address,
            room_id,
            inviter_user_id,
            room_alias,
            room_avatar_url,
            room_join_rules,
            room_name,
            inviter_display_name,
            inviter_avatar_url
    ):
        """
        Asks an identity server for a third party invite.

        Args:
            requester (Requester)
            id_server (str): hostname + optional port for the identity server.
            medium (str): The literal string "email".
            address (str): The third party address being invited.
            room_id (str): The ID of the room to which the user is invited.
            inviter_user_id (str): The user ID of the inviter.
            room_alias (str): An alias for the room, for cosmetic notifications.
            room_avatar_url (str): The URL of the room's avatar, for cosmetic
                notifications.
            room_join_rules (str): The join rules of the email (e.g. "public").
            room_name (str): The m.room.name of the room.
            inviter_display_name (str): The current display name of the
                inviter.
            inviter_avatar_url (str): The URL of the inviter's avatar.

        Returns:
            A deferred tuple containing:
                token (str): The token which must be signed to prove authenticity.
                public_keys ([{"public_key": str, "key_validity_url": str}]):
                    public_key is a base64-encoded ed25519 public key.
                fallback_public_key: One element from public_keys.
                display_name (str): A user-friendly name to represent the invited
                    user.
        """

        is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
            id_server_scheme, id_server,
        )

        invite_config = {
            "medium": medium,
            "address": address,
            "room_id": room_id,
            "room_alias": room_alias,
            "room_avatar_url": room_avatar_url,
            "room_join_rules": room_join_rules,
            "room_name": room_name,
            "sender": inviter_user_id,
            "sender_display_name": inviter_display_name,
            "sender_avatar_url": inviter_avatar_url,
        }

        if self.config.invite_3pid_guest:
            guest_user_id, guest_access_token = yield self.get_or_register_3pid_guest(
                requester=requester,
                medium=medium,
                address=address,
                inviter_user_id=inviter_user_id,
            )

            invite_config.update({
                "guest_access_token": guest_access_token,
                "guest_user_id": guest_user_id,
            })

        data = yield self.simple_http_client.post_urlencoded_get_json(
            is_url,
            invite_config
        )
        # TODO: Check for success
        token = data["token"]
        public_keys = data.get("public_keys", [])
        if "public_key" in data:
            fallback_public_key = {
                "public_key": data["public_key"],
                "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
                    id_server_scheme, id_server,
                ),
            }
        else:
            fallback_public_key = public_keys[0]

        if not public_keys:
            public_keys.append(fallback_public_key)
        display_name = data["display_name"]
        defer.returnValue((token, public_keys, fallback_public_key, display_name))

    @defer.inlineCallbacks
    def _is_host_in_room(self, current_state_ids):
        # Have we just created the room, and is this about to be the very
        # first member event?
        create_event_id = current_state_ids.get(("m.room.create", ""))
        if len(current_state_ids) == 1 and create_event_id:
            # We can only get here if we're in the process of creating the room
            defer.returnValue(True)

        for etype, state_key in current_state_ids:
            if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
                continue

            event_id = current_state_ids[(etype, state_key)]
            event = yield self.store.get_event(event_id, allow_none=True)
            if not event:
                continue

            if event.membership == Membership.JOIN:
                defer.returnValue(True)

        defer.returnValue(False)

    @defer.inlineCallbacks
    def _is_server_notice_room(self, room_id):
        if self._server_notices_mxid is None:
            defer.returnValue(False)
        user_ids = yield self.store.get_users_in_room(room_id)
        defer.returnValue(self._server_notices_mxid in user_ids)
Пример #31
0
class RegistrationHandler(BaseHandler):
    def __init__(self, hs):
        """

        Args:
            hs (synapse.server.HomeServer):
        """
        super(RegistrationHandler, self).__init__(hs)
        self.hs = hs
        self.auth = hs.get_auth()
        self._auth_handler = hs.get_auth_handler()
        self.profile_handler = hs.get_profile_handler()
        self.user_directory_handler = hs.get_user_directory_handler()
        self.captcha_client = CaptchaServerHttpClient(hs)
        self.identity_handler = self.hs.get_handlers().identity_handler
        self.ratelimiter = hs.get_registration_ratelimiter()

        self._next_generated_user_id = None

        self.macaroon_gen = hs.get_macaroon_generator()

        self._generate_user_id_linearizer = Linearizer(
            name="_generate_user_id_linearizer", )
        self._server_notices_mxid = hs.config.server_notices_mxid

        if hs.config.worker_app:
            self._register_client = ReplicationRegisterServlet.make_client(hs)
            self._register_device_client = (
                RegisterDeviceReplicationServlet.make_client(hs))
            self._post_registration_client = (
                ReplicationPostRegisterActionsServlet.make_client(hs))
        else:
            self.device_handler = hs.get_device_handler()
            self.pusher_pool = hs.get_pusherpool()

    @defer.inlineCallbacks
    def check_username(self,
                       localpart,
                       guest_access_token=None,
                       assigned_user_id=None):
        if types.contains_invalid_mxid_characters(localpart):
            raise SynapseError(
                400,
                "User ID can only contain characters a-z, 0-9, or '=_-./'",
                Codes.INVALID_USERNAME)

        if not localpart:
            raise SynapseError(400, "User ID cannot be empty",
                               Codes.INVALID_USERNAME)

        if localpart[0] == '_':
            raise SynapseError(400, "User ID may not begin with _",
                               Codes.INVALID_USERNAME)

        user = UserID(localpart, self.hs.hostname)
        user_id = user.to_string()

        if assigned_user_id:
            if user_id == assigned_user_id:
                return
            else:
                raise SynapseError(
                    400,
                    "A different user ID has already been registered for this session",
                )

        self.check_user_id_not_appservice_exclusive(user_id)

        users = yield self.store.get_users_by_id_case_insensitive(user_id)
        if users:
            if not guest_access_token:
                raise SynapseError(
                    400,
                    "User ID already taken.",
                    errcode=Codes.USER_IN_USE,
                )
            user_data = yield self.auth.get_user_by_access_token(
                guest_access_token)
            if not user_data[
                    "is_guest"] or user_data["user"].localpart != localpart:
                raise AuthError(
                    403,
                    "Cannot register taken user ID without valid guest "
                    "credentials for that user.",
                    errcode=Codes.FORBIDDEN,
                )

    @defer.inlineCallbacks
    def register(
        self,
        localpart=None,
        password=None,
        generate_token=True,
        guest_access_token=None,
        make_guest=False,
        admin=False,
        threepid=None,
        user_type=None,
        default_display_name=None,
        address=None,
    ):
        """Registers a new client on the server.

        Args:
            localpart : The local part of the user ID to register. If None,
              one will be generated.
            password (unicode) : The password to assign to this user so they can
              login again. This can be None which means they cannot login again
              via a password (e.g. the user is an application service user).
            generate_token (bool): Whether a new access token should be
              generated. Having this be True should be considered deprecated,
              since it offers no means of associating a device_id with the
              access_token. Instead you should call auth_handler.issue_access_token
              after registration.
            user_type (str|None): type of user. One of the values from
              api.constants.UserTypes, or None for a normal user.
            default_display_name (unicode|None): if set, the new user's displayname
              will be set to this. Defaults to 'localpart'.
            address (str|None): the IP address used to perform the registration.
        Returns:
            A tuple of (user_id, access_token).
        Raises:
            RegistrationError if there was a problem registering.
        """

        yield self.auth.check_auth_blocking(threepid=threepid)
        password_hash = None
        if password:
            password_hash = yield self._auth_handler.hash(password)

        if localpart:
            yield self.check_username(localpart,
                                      guest_access_token=guest_access_token)

            was_guest = guest_access_token is not None

            if not was_guest:
                try:
                    int(localpart)
                    raise RegistrationError(
                        400, "Numeric user IDs are reserved for guest users.")
                except ValueError:
                    pass

            user = UserID(localpart, self.hs.hostname)
            user_id = user.to_string()

            if was_guest:
                # If the user was a guest then they already have a profile
                default_display_name = None

            elif default_display_name is None:
                default_display_name = localpart

            token = None
            if generate_token:
                token = self.macaroon_gen.generate_access_token(user_id)
            yield self.register_with_store(
                user_id=user_id,
                token=token,
                password_hash=password_hash,
                was_guest=was_guest,
                make_guest=make_guest,
                create_profile_with_displayname=default_display_name,
                admin=admin,
                user_type=user_type,
                address=address,
            )

            if self.hs.config.user_directory_search_all_users:
                profile = yield self.store.get_profileinfo(localpart)
                yield self.user_directory_handler.handle_local_profile_change(
                    user_id, profile)

        else:
            # autogen a sequential user ID
            attempts = 0
            token = None
            user = None
            while not user:
                localpart = yield self._generate_user_id(attempts > 0)
                user = UserID(localpart, self.hs.hostname)
                user_id = user.to_string()
                yield self.check_user_id_not_appservice_exclusive(user_id)
                if generate_token:
                    token = self.macaroon_gen.generate_access_token(user_id)
                if default_display_name is None:
                    default_display_name = localpart
                try:
                    yield self.register_with_store(
                        user_id=user_id,
                        token=token,
                        password_hash=password_hash,
                        make_guest=make_guest,
                        create_profile_with_displayname=default_display_name,
                        address=address,
                    )
                except SynapseError:
                    # if user id is taken, just generate another
                    user = None
                    user_id = None
                    token = None
                    attempts += 1
        if not self.hs.config.user_consent_at_registration:
            yield self._auto_join_rooms(user_id)

        defer.returnValue((user_id, token))

    @defer.inlineCallbacks
    def _auto_join_rooms(self, user_id):
        """Automatically joins users to auto join rooms - creating the room in the first place
        if the user is the first to be created.

        Args:
            user_id(str): The user to join
        """
        # auto-join the user to any rooms we're supposed to dump them into
        fake_requester = create_requester(user_id)

        # try to create the room if we're the first real user on the server. Note
        # that an auto-generated support user is not a real user and will never be
        # the user to create the room
        should_auto_create_rooms = False
        is_support = yield self.store.is_support_user(user_id)
        # There is an edge case where the first user is the support user, then
        # the room is never created, though this seems unlikely and
        # recoverable from given the support user being involved in the first
        # place.
        if self.hs.config.autocreate_auto_join_rooms and not is_support:
            count = yield self.store.count_all_users()
            should_auto_create_rooms = count == 1
        for r in self.hs.config.auto_join_rooms:
            try:
                if should_auto_create_rooms:
                    room_alias = RoomAlias.from_string(r)
                    if self.hs.hostname != room_alias.domain:
                        logger.warning(
                            'Cannot create room alias %s, '
                            'it does not match server domain',
                            r,
                        )
                    else:
                        # create room expects the localpart of the room alias
                        room_alias_localpart = room_alias.localpart

                        # getting the RoomCreationHandler during init gives a dependency
                        # loop
                        yield self.hs.get_room_creation_handler().create_room(
                            fake_requester,
                            config={
                                "preset": "public_chat",
                                "room_alias_name": room_alias_localpart
                            },
                            ratelimit=False,
                        )
                else:
                    yield self._join_user_to_room(fake_requester, r)
            except ConsentNotGivenError as e:
                # Technically not necessary to pull out this error though
                # moving away from bare excepts is a good thing to do.
                logger.error("Failed to join new user to %r: %r", r, e)
            except Exception as e:
                logger.error("Failed to join new user to %r: %r", r, e)

    @defer.inlineCallbacks
    def post_consent_actions(self, user_id):
        """A series of registration actions that can only be carried out once consent
        has been granted

        Args:
            user_id (str): The user to join
        """
        yield self._auto_join_rooms(user_id)

    @defer.inlineCallbacks
    def appservice_register(self, user_localpart, as_token):
        user = UserID(user_localpart, self.hs.hostname)
        user_id = user.to_string()
        service = self.store.get_app_service_by_token(as_token)
        if not service:
            raise AuthError(403, "Invalid application service token.")
        if not service.is_interested_in_user(user_id):
            raise SynapseError(
                400,
                "Invalid user localpart for this application service.",
                errcode=Codes.EXCLUSIVE)

        service_id = service.id if service.is_exclusive_user(user_id) else None

        yield self.check_user_id_not_appservice_exclusive(
            user_id, allowed_appservice=service)

        yield self.register_with_store(
            user_id=user_id,
            password_hash="",
            appservice_id=service_id,
            create_profile_with_displayname=user.localpart,
        )
        defer.returnValue(user_id)

    @defer.inlineCallbacks
    def check_recaptcha(self, ip, private_key, challenge, response):
        """
        Checks a recaptcha is correct.

        Used only by c/s api v1
        """

        captcha_response = yield self._validate_captcha(
            ip, private_key, challenge, response)
        if not captcha_response["valid"]:
            logger.info("Invalid captcha entered from %s. Error: %s", ip,
                        captcha_response["error_url"])
            raise InvalidCaptchaError(error_url=captcha_response["error_url"])
        else:
            logger.info("Valid captcha entered from %s", ip)

    @defer.inlineCallbacks
    def register_email(self, threepidCreds):
        """
        Registers emails with an identity server.

        Used only by c/s api v1
        """

        for c in threepidCreds:
            logger.info("validating threepidcred sid %s on id server %s",
                        c['sid'], c['idServer'])
            try:
                threepid = yield self.identity_handler.threepid_from_creds(c)
            except Exception:
                logger.exception("Couldn't validate 3pid")
                raise RegistrationError(400, "Couldn't validate 3pid")

            if not threepid:
                raise RegistrationError(400, "Couldn't validate 3pid")
            logger.info("got threepid with medium '%s' and address '%s'",
                        threepid['medium'], threepid['address'])

            if not check_3pid_allowed(self.hs, threepid['medium'],
                                      threepid['address']):
                raise RegistrationError(
                    403, "Third party identifier is not allowed")

    @defer.inlineCallbacks
    def bind_emails(self, user_id, threepidCreds):
        """Links emails with a user ID and informs an identity server.

        Used only by c/s api v1
        """

        # Now we have a matrix ID, bind it to the threepids we were given
        for c in threepidCreds:
            # XXX: This should be a deferred list, shouldn't it?
            yield self.identity_handler.bind_threepid(c, user_id)

    def check_user_id_not_appservice_exclusive(self,
                                               user_id,
                                               allowed_appservice=None):
        # don't allow people to register the server notices mxid
        if self._server_notices_mxid is not None:
            if user_id == self._server_notices_mxid:
                raise SynapseError(400,
                                   "This user ID is reserved.",
                                   errcode=Codes.EXCLUSIVE)

        # valid user IDs must not clash with any user ID namespaces claimed by
        # application services.
        services = self.store.get_app_services()
        interested_services = [
            s for s in services
            if s.is_interested_in_user(user_id) and s != allowed_appservice
        ]
        for service in interested_services:
            if service.is_exclusive_user(user_id):
                raise SynapseError(
                    400,
                    "This user ID is reserved by an application service.",
                    errcode=Codes.EXCLUSIVE)

    @defer.inlineCallbacks
    def _generate_user_id(self, reseed=False):
        if reseed or self._next_generated_user_id is None:
            with (yield self._generate_user_id_linearizer.queue(())):
                if reseed or self._next_generated_user_id is None:
                    self._next_generated_user_id = (
                        yield
                        self.store.find_next_generated_user_id_localpart())

        id = self._next_generated_user_id
        self._next_generated_user_id += 1
        defer.returnValue(str(id))

    @defer.inlineCallbacks
    def _validate_captcha(self, ip_addr, private_key, challenge, response):
        """Validates the captcha provided.

        Used only by c/s api v1

        Returns:
            dict: Containing 'valid'(bool) and 'error_url'(str) if invalid.

        """
        response = yield self._submit_captcha(ip_addr, private_key, challenge,
                                              response)
        # parse Google's response. Lovely format..
        lines = response.split('\n')
        json = {
            "valid":
            lines[0] == 'true',
            "error_url":
            "http://www.recaptcha.net/recaptcha/api/challenge?" +
            "error=%s" % lines[1]
        }
        defer.returnValue(json)

    @defer.inlineCallbacks
    def _submit_captcha(self, ip_addr, private_key, challenge, response):
        """
        Used only by c/s api v1
        """
        data = yield self.captcha_client.post_urlencoded_get_raw(
            "http://www.recaptcha.net:80/recaptcha/api/verify",
            args={
                'privatekey': private_key,
                'remoteip': ip_addr,
                'challenge': challenge,
                'response': response
            })
        defer.returnValue(data)

    @defer.inlineCallbacks
    def get_or_create_user(self,
                           requester,
                           localpart,
                           displayname,
                           password_hash=None):
        """Creates a new user if the user does not exist,
        else revokes all previous access tokens and generates a new one.

        Args:
            localpart : The local part of the user ID to register. If None,
              one will be randomly generated.
        Returns:
            A tuple of (user_id, access_token).
        Raises:
            RegistrationError if there was a problem registering.
        """
        if localpart is None:
            raise SynapseError(400, "Request must include user id")
        yield self.auth.check_auth_blocking()
        need_register = True

        try:
            yield self.check_username(localpart)
        except SynapseError as e:
            if e.errcode == Codes.USER_IN_USE:
                need_register = False
            else:
                raise

        user = UserID(localpart, self.hs.hostname)
        user_id = user.to_string()
        token = self.macaroon_gen.generate_access_token(user_id)

        if need_register:
            yield self.register_with_store(
                user_id=user_id,
                token=token,
                password_hash=password_hash,
                create_profile_with_displayname=user.localpart,
            )
        else:
            yield self._auth_handler.delete_access_tokens_for_user(user_id)
            yield self.store.add_access_token_to_user(user_id=user_id,
                                                      token=token)

        if displayname is not None:
            logger.info("setting user display name: %s -> %s", user_id,
                        displayname)
            yield self.profile_handler.set_displayname(
                user,
                requester,
                displayname,
                by_admin=True,
            )

        defer.returnValue((user_id, token))

    @defer.inlineCallbacks
    def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
        """Get a guest access token for a 3PID, creating a guest account if
        one doesn't already exist.

        Args:
            medium (str)
            address (str)
            inviter_user_id (str): The user ID who is trying to invite the
                3PID

        Returns:
            Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
            3PID guest account.
        """
        access_token = yield self.store.get_3pid_guest_access_token(
            medium, address)
        if access_token:
            user_info = yield self.auth.get_user_by_access_token(access_token)

            defer.returnValue((user_info["user"].to_string(), access_token))

        user_id, access_token = yield self.register(generate_token=True,
                                                    make_guest=True)
        access_token = yield self.store.save_or_get_3pid_guest_access_token(
            medium, address, access_token, inviter_user_id)

        defer.returnValue((user_id, access_token))

    @defer.inlineCallbacks
    def _join_user_to_room(self, requester, room_identifier):
        room_id = None
        room_member_handler = self.hs.get_room_member_handler()
        if RoomID.is_valid(room_identifier):
            room_id = room_identifier
        elif RoomAlias.is_valid(room_identifier):
            room_alias = RoomAlias.from_string(room_identifier)
            room_id, remote_room_hosts = (
                yield room_member_handler.lookup_room_alias(room_alias))
            room_id = room_id.to_string()
        else:
            raise SynapseError(
                400,
                "%s was not legal room ID or room alias" % (room_identifier, ))

        yield room_member_handler.update_membership(
            requester=requester,
            target=requester.user,
            room_id=room_id,
            remote_room_hosts=remote_room_hosts,
            action="join",
            ratelimit=False,
        )

    def register_with_store(self,
                            user_id,
                            token=None,
                            password_hash=None,
                            was_guest=False,
                            make_guest=False,
                            appservice_id=None,
                            create_profile_with_displayname=None,
                            admin=False,
                            user_type=None,
                            address=None):
        """Register user in the datastore.

        Args:
            user_id (str): The desired user ID to register.
            token (str): The desired access token to use for this user. If this
                is not None, the given access token is associated with the user
                id.
            password_hash (str|None): Optional. The password hash for this user.
            was_guest (bool): Optional. Whether this is a guest account being
                upgraded to a non-guest account.
            make_guest (boolean): True if the the new user should be guest,
                false to add a regular user account.
            appservice_id (str|None): The ID of the appservice registering the user.
            create_profile_with_displayname (unicode|None): Optionally create a
                profile for the user, setting their displayname to the given value
            admin (boolean): is an admin user?
            user_type (str|None): type of user. One of the values from
                api.constants.UserTypes, or None for a normal user.
            address (str|None): the IP address used to perform the registration.

        Returns:
            Deferred
        """
        # Don't rate limit for app services
        if appservice_id is None and address is not None:
            time_now = self.clock.time()

            allowed, time_allowed = self.ratelimiter.can_do_action(
                address,
                time_now_s=time_now,
                rate_hz=self.hs.config.rc_registration.per_second,
                burst_count=self.hs.config.rc_registration.burst_count,
            )

            if not allowed:
                raise LimitExceededError(retry_after_ms=int(
                    1000 * (time_allowed - time_now)), )

        if self.hs.config.worker_app:
            return self._register_client(
                user_id=user_id,
                token=token,
                password_hash=password_hash,
                was_guest=was_guest,
                make_guest=make_guest,
                appservice_id=appservice_id,
                create_profile_with_displayname=create_profile_with_displayname,
                admin=admin,
                user_type=user_type,
                address=address,
            )
        else:
            return self.store.register(
                user_id=user_id,
                token=token,
                password_hash=password_hash,
                was_guest=was_guest,
                make_guest=make_guest,
                appservice_id=appservice_id,
                create_profile_with_displayname=create_profile_with_displayname,
                admin=admin,
                user_type=user_type,
            )

    @defer.inlineCallbacks
    def register_device(self,
                        user_id,
                        device_id,
                        initial_display_name,
                        is_guest=False):
        """Register a device for a user and generate an access token.

        Args:
            user_id (str): full canonical @user:id
            device_id (str|None): The device ID to check, or None to generate
                a new one.
            initial_display_name (str|None): An optional display name for the
                device.
            is_guest (bool): Whether this is a guest account

        Returns:
            defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
        """

        if self.hs.config.worker_app:
            r = yield self._register_device_client(
                user_id=user_id,
                device_id=device_id,
                initial_display_name=initial_display_name,
                is_guest=is_guest,
            )
            defer.returnValue((r["device_id"], r["access_token"]))
        else:
            device_id = yield self.device_handler.check_device_registered(
                user_id, device_id, initial_display_name)
            if is_guest:
                access_token = self.macaroon_gen.generate_access_token(
                    user_id, ["guest = true"])
            else:
                access_token = yield self._auth_handler.get_access_token_for_user_id(
                    user_id,
                    device_id=device_id,
                )

            defer.returnValue((device_id, access_token))

    @defer.inlineCallbacks
    def post_registration_actions(self, user_id, auth_result, access_token,
                                  bind_email, bind_msisdn):
        """A user has completed registration

        Args:
            user_id (str): The user ID that consented
            auth_result (dict): The authenticated credentials of the newly
                registered user.
            access_token (str|None): The access token of the newly logged in
                device, or None if `inhibit_login` enabled.
            bind_email (bool): Whether to bind the email with the identity
                server.
            bind_msisdn (bool): Whether to bind the msisdn with the identity
                server.
        """
        if self.hs.config.worker_app:
            yield self._post_registration_client(
                user_id=user_id,
                auth_result=auth_result,
                access_token=access_token,
                bind_email=bind_email,
                bind_msisdn=bind_msisdn,
            )
            return

        if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
            threepid = auth_result[LoginType.EMAIL_IDENTITY]
            # Necessary due to auth checks prior to the threepid being
            # written to the db
            if is_threepid_reserved(
                    self.hs.config.mau_limits_reserved_threepids, threepid):
                yield self.store.upsert_monthly_active_user(user_id)

            yield self._register_email_threepid(
                user_id,
                threepid,
                access_token,
                bind_email,
            )

        if auth_result and LoginType.MSISDN in auth_result:
            threepid = auth_result[LoginType.MSISDN]
            yield self._register_msisdn_threepid(
                user_id,
                threepid,
                bind_msisdn,
            )

        if auth_result and LoginType.TERMS in auth_result:
            yield self._on_user_consented(
                user_id,
                self.hs.config.user_consent_version,
            )

    @defer.inlineCallbacks
    def _on_user_consented(self, user_id, consent_version):
        """A user consented to the terms on registration

        Args:
            user_id (str): The user ID that consented.
            consent_version (str): version of the policy the user has
                consented to.
        """
        logger.info("%s has consented to the privacy policy", user_id)
        yield self.store.user_set_consent_version(
            user_id,
            consent_version,
        )
        yield self.post_consent_actions(user_id)

    @defer.inlineCallbacks
    def _register_email_threepid(self, user_id, threepid, token, bind_email):
        """Add an email address as a 3pid identifier

        Also adds an email pusher for the email address, if configured in the
        HS config

        Also optionally binds emails to the given user_id on the identity server

        Must be called on master.

        Args:
            user_id (str): id of user
            threepid (object): m.login.email.identity auth response
            token (str|None): access_token for the user, or None if not logged
                in.
            bind_email (bool): true if the client requested the email to be
                bound at the identity server
        Returns:
            defer.Deferred:
        """
        reqd = ('medium', 'address', 'validated_at')
        if any(x not in threepid for x in reqd):
            # This will only happen if the ID server returns a malformed response
            logger.info("Can't add incomplete 3pid")
            return

        yield self._auth_handler.add_threepid(
            user_id,
            threepid['medium'],
            threepid['address'],
            threepid['validated_at'],
        )

        # And we add an email pusher for them by default, but only
        # if email notifications are enabled (so people don't start
        # getting mail spam where they weren't before if email
        # notifs are set up on a home server)
        if (self.hs.config.email_enable_notifs
                and self.hs.config.email_notif_for_new_users and token):
            # Pull the ID of the access token back out of the db
            # It would really make more sense for this to be passed
            # up when the access token is saved, but that's quite an
            # invasive change I'd rather do separately.
            user_tuple = yield self.store.get_user_by_access_token(token)
            token_id = user_tuple["token_id"]

            yield self.pusher_pool.add_pusher(
                user_id=user_id,
                access_token=token_id,
                kind="email",
                app_id="m.email",
                app_display_name="Email Notifications",
                device_display_name=threepid["address"],
                pushkey=threepid["address"],
                lang=None,  # We don't know a user's language here
                data={},
            )

        if bind_email:
            logger.info("bind_email specified: binding")
            logger.debug("Binding emails %s to %s" % (threepid, user_id))
            yield self.identity_handler.bind_threepid(
                threepid['threepid_creds'], user_id)
        else:
            logger.info("bind_email not specified: not binding email")

    @defer.inlineCallbacks
    def _register_msisdn_threepid(self, user_id, threepid, bind_msisdn):
        """Add a phone number as a 3pid identifier

        Also optionally binds msisdn to the given user_id on the identity server

        Must be called on master.

        Args:
            user_id (str): id of user
            threepid (object): m.login.msisdn auth response
            token (str): access_token for the user
            bind_email (bool): true if the client requested the email to be
                bound at the identity server
        Returns:
            defer.Deferred:
        """
        try:
            assert_params_in_dict(threepid,
                                  ['medium', 'address', 'validated_at'])
        except SynapseError as ex:
            if ex.errcode == Codes.MISSING_PARAM:
                # This will only happen if the ID server returns a malformed response
                logger.info("Can't add incomplete 3pid")
                defer.returnValue(None)
            raise

        yield self._auth_handler.add_threepid(
            user_id,
            threepid['medium'],
            threepid['address'],
            threepid['validated_at'],
        )

        if bind_msisdn:
            logger.info("bind_msisdn specified: binding")
            logger.debug("Binding msisdn %s to %s", threepid, user_id)
            yield self.identity_handler.bind_threepid(
                threepid['threepid_creds'], user_id)
        else:
            logger.info("bind_msisdn not specified: not binding msisdn")
Пример #32
0
class MediaRepository(object):
    def __init__(self, hs):
        self.hs = hs
        self.auth = hs.get_auth()
        self.client = hs.get_http_client()
        self.clock = hs.get_clock()
        self.server_name = hs.hostname
        self.store = hs.get_datastore()
        self.max_upload_size = hs.config.max_upload_size
        self.max_image_pixels = hs.config.max_image_pixels

        self.primary_base_path = hs.config.media_store_path
        self.filepaths = MediaFilePaths(self.primary_base_path)

        self.dynamic_thumbnails = hs.config.dynamic_thumbnails
        self.thumbnail_requirements = hs.config.thumbnail_requirements

        self.remote_media_linearizer = Linearizer(name="media_remote")

        self.recently_accessed_remotes = set()
        self.recently_accessed_locals = set()

        self.federation_domain_whitelist = hs.config.federation_domain_whitelist

        # List of StorageProviders where we should search for media and
        # potentially upload to.
        storage_providers = []

        for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
            backend = clz(hs, provider_config)
            provider = StorageProviderWrapper(
                backend,
                store_local=wrapper_config.store_local,
                store_remote=wrapper_config.store_remote,
                store_synchronous=wrapper_config.store_synchronous,
            )
            storage_providers.append(provider)

        self.media_storage = MediaStorage(
            self.hs,
            self.primary_base_path,
            self.filepaths,
            storage_providers,
        )

        self.clock.looping_call(
            self._start_update_recently_accessed,
            UPDATE_RECENTLY_ACCESSED_TS,
        )

    def _start_update_recently_accessed(self):
        return run_as_background_process(
            "update_recently_accessed_media",
            self._update_recently_accessed,
        )

    @defer.inlineCallbacks
    def _update_recently_accessed(self):
        remote_media = self.recently_accessed_remotes
        self.recently_accessed_remotes = set()

        local_media = self.recently_accessed_locals
        self.recently_accessed_locals = set()

        yield self.store.update_cached_last_access_time(
            local_media, remote_media, self.clock.time_msec())

    def mark_recently_accessed(self, server_name, media_id):
        """Mark the given media as recently accessed.

        Args:
            server_name (str|None): Origin server of media, or None if local
            media_id (str): The media ID of the content
        """
        if server_name:
            self.recently_accessed_remotes.add((server_name, media_id))
        else:
            self.recently_accessed_locals.add(media_id)

    @defer.inlineCallbacks
    def create_content(self, media_type, upload_name, content, content_length,
                       auth_user):
        """Store uploaded content for a local user and return the mxc URL

        Args:
            media_type(str): The content type of the file
            upload_name(str): The name of the file
            content: A file like object that is the content to store
            content_length(int): The length of the content
            auth_user(str): The user_id of the uploader

        Returns:
            Deferred[str]: The mxc url of the stored content
        """
        media_id = random_string(24)

        file_info = FileInfo(
            server_name=None,
            file_id=media_id,
        )

        fname = yield self.media_storage.store_file(content, file_info)

        logger.info("Stored local media in file %r", fname)

        yield self.store.store_local_media(
            media_id=media_id,
            media_type=media_type,
            time_now_ms=self.clock.time_msec(),
            upload_name=upload_name,
            media_length=content_length,
            user_id=auth_user,
        )

        yield self._generate_thumbnails(
            None,
            media_id,
            media_id,
            media_type,
        )

        defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))

    @defer.inlineCallbacks
    def get_local_media(self, request, media_id, name):
        """Responds to reqests for local media, if exists, or returns 404.

        Args:
            request(twisted.web.http.Request)
            media_id (str): The media ID of the content. (This is the same as
                the file_id for local content.)
            name (str|None): Optional name that, if specified, will be used as
                the filename in the Content-Disposition header of the response.

        Returns:
            Deferred: Resolves once a response has successfully been written
                to request
        """
        media_info = yield self.store.get_local_media(media_id)
        if not media_info or media_info["quarantined_by"]:
            respond_404(request)
            return

        self.mark_recently_accessed(None, media_id)

        media_type = media_info["media_type"]
        media_length = media_info["media_length"]
        upload_name = name if name else media_info["upload_name"]
        url_cache = media_info["url_cache"]

        file_info = FileInfo(
            None,
            media_id,
            url_cache=url_cache,
        )

        responder = yield self.media_storage.fetch_media(file_info)
        yield respond_with_responder(
            request,
            responder,
            media_type,
            media_length,
            upload_name,
        )

    @defer.inlineCallbacks
    def get_remote_media(self, request, server_name, media_id, name):
        """Respond to requests for remote media.

        Args:
            request(twisted.web.http.Request)
            server_name (str): Remote server_name where the media originated.
            media_id (str): The media ID of the content (as defined by the
                remote server).
            name (str|None): Optional name that, if specified, will be used as
                the filename in the Content-Disposition header of the response.

        Returns:
            Deferred: Resolves once a response has successfully been written
                to request
        """
        if (self.federation_domain_whitelist is not None
                and server_name not in self.federation_domain_whitelist):
            raise FederationDeniedError(server_name)

        self.mark_recently_accessed(server_name, media_id)

        # We linearize here to ensure that we don't try and download remote
        # media multiple times concurrently
        key = (server_name, media_id)
        with (yield self.remote_media_linearizer.queue(key)):
            responder, media_info = yield self._get_remote_media_impl(
                server_name,
                media_id,
            )

        # We deliberately stream the file outside the lock
        if responder:
            media_type = media_info["media_type"]
            media_length = media_info["media_length"]
            upload_name = name if name else media_info["upload_name"]
            yield respond_with_responder(
                request,
                responder,
                media_type,
                media_length,
                upload_name,
            )
        else:
            respond_404(request)

    @defer.inlineCallbacks
    def get_remote_media_info(self, server_name, media_id):
        """Gets the media info associated with the remote file, downloading
        if necessary.

        Args:
            server_name (str): Remote server_name where the media originated.
            media_id (str): The media ID of the content (as defined by the
                remote server).

        Returns:
            Deferred[dict]: The media_info of the file
        """
        if (self.federation_domain_whitelist is not None
                and server_name not in self.federation_domain_whitelist):
            raise FederationDeniedError(server_name)

        # We linearize here to ensure that we don't try and download remote
        # media multiple times concurrently
        key = (server_name, media_id)
        with (yield self.remote_media_linearizer.queue(key)):
            responder, media_info = yield self._get_remote_media_impl(
                server_name,
                media_id,
            )

        # Ensure we actually use the responder so that it releases resources
        if responder:
            with responder:
                pass

        defer.returnValue(media_info)

    @defer.inlineCallbacks
    def _get_remote_media_impl(self, server_name, media_id):
        """Looks for media in local cache, if not there then attempt to
        download from remote server.

        Args:
            server_name (str): Remote server_name where the media originated.
            media_id (str): The media ID of the content (as defined by the
                remote server).

        Returns:
            Deferred[(Responder, media_info)]
        """
        media_info = yield self.store.get_cached_remote_media(
            server_name, media_id)

        # file_id is the ID we use to track the file locally. If we've already
        # seen the file then reuse the existing ID, otherwise genereate a new
        # one.
        if media_info:
            file_id = media_info["filesystem_id"]
        else:
            file_id = random_string(24)

        file_info = FileInfo(server_name, file_id)

        # If we have an entry in the DB, try and look for it
        if media_info:
            if media_info["quarantined_by"]:
                logger.info("Media is quarantined")
                raise NotFoundError()

            responder = yield self.media_storage.fetch_media(file_info)
            if responder:
                defer.returnValue((responder, media_info))

        # Failed to find the file anywhere, lets download it.

        media_info = yield self._download_remote_file(server_name, media_id,
                                                      file_id)

        responder = yield self.media_storage.fetch_media(file_info)
        defer.returnValue((responder, media_info))

    @defer.inlineCallbacks
    def _download_remote_file(self, server_name, media_id, file_id):
        """Attempt to download the remote file from the given server name,
        using the given file_id as the local id.

        Args:
            server_name (str): Originating server
            media_id (str): The media ID of the content (as defined by the
                remote server). This is different than the file_id, which is
                locally generated.
            file_id (str): Local file ID

        Returns:
            Deferred[MediaInfo]
        """

        file_info = FileInfo(
            server_name=server_name,
            file_id=file_id,
        )

        with self.media_storage.store_into_file(file_info) as (f, fname,
                                                               finish):
            request_path = "/".join((
                "/_matrix/media/v1/download",
                server_name,
                media_id,
            ))
            try:
                length, headers = yield self.client.get_file(
                    server_name,
                    request_path,
                    output_stream=f,
                    max_size=self.max_upload_size,
                    args={
                        # tell the remote server to 404 if it doesn't
                        # recognise the server_name, to make sure we don't
                        # end up with a routing loop.
                        "allow_remote": "false",
                    })
            except RequestSendFailed as e:
                logger.warn("Request failed fetching remote media %s/%s: %r",
                            server_name, media_id, e)
                raise SynapseError(502, "Failed to fetch remote media")

            except HttpResponseException as e:
                logger.warn("HTTP error fetching remote media %s/%s: %s",
                            server_name, media_id, e.response)
                if e.code == twisted.web.http.NOT_FOUND:
                    raise e.to_synapse_error()
                raise SynapseError(502, "Failed to fetch remote media")

            except SynapseError:
                logger.exception("Failed to fetch remote media %s/%s",
                                 server_name, media_id)
                raise
            except NotRetryingDestination:
                logger.warn("Not retrying destination %r", server_name)
                raise SynapseError(502, "Failed to fetch remote media")
            except Exception:
                logger.exception("Failed to fetch remote media %s/%s",
                                 server_name, media_id)
                raise SynapseError(502, "Failed to fetch remote media")

            yield finish()

        media_type = headers[b"Content-Type"][0].decode('ascii')
        upload_name = get_filename_from_headers(headers)
        time_now_ms = self.clock.time_msec()

        logger.info("Stored remote media in file %r", fname)

        yield self.store.store_cached_remote_media(
            origin=server_name,
            media_id=media_id,
            media_type=media_type,
            time_now_ms=self.clock.time_msec(),
            upload_name=upload_name,
            media_length=length,
            filesystem_id=file_id,
        )

        media_info = {
            "media_type": media_type,
            "media_length": length,
            "upload_name": upload_name,
            "created_ts": time_now_ms,
            "filesystem_id": file_id,
        }

        yield self._generate_thumbnails(
            server_name,
            media_id,
            file_id,
            media_type,
        )

        defer.returnValue(media_info)

    def _get_thumbnail_requirements(self, media_type):
        return self.thumbnail_requirements.get(media_type, ())

    def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method,
                            t_type):
        m_width = thumbnailer.width
        m_height = thumbnailer.height

        if m_width * m_height >= self.max_image_pixels:
            logger.info("Image too large to thumbnail %r x %r > %r", m_width,
                        m_height, self.max_image_pixels)
            return

        if thumbnailer.transpose_method is not None:
            m_width, m_height = thumbnailer.transpose()

        if t_method == "crop":
            t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
        elif t_method == "scale":
            t_width, t_height = thumbnailer.aspect(t_width, t_height)
            t_width = min(m_width, t_width)
            t_height = min(m_height, t_height)
            t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
        else:
            t_byte_source = None

        return t_byte_source

    @defer.inlineCallbacks
    def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
                                       t_method, t_type, url_cache):
        input_path = yield self.media_storage.ensure_media_is_in_local_cache(
            FileInfo(
                None,
                media_id,
                url_cache=url_cache,
            ))

        thumbnailer = Thumbnailer(input_path)
        t_byte_source = yield logcontext.defer_to_thread(
            self.hs.get_reactor(), self._generate_thumbnail, thumbnailer,
            t_width, t_height, t_method, t_type)

        if t_byte_source:
            try:
                file_info = FileInfo(
                    server_name=None,
                    file_id=media_id,
                    url_cache=url_cache,
                    thumbnail=True,
                    thumbnail_width=t_width,
                    thumbnail_height=t_height,
                    thumbnail_method=t_method,
                    thumbnail_type=t_type,
                )

                output_path = yield self.media_storage.store_file(
                    t_byte_source,
                    file_info,
                )
            finally:
                t_byte_source.close()

            logger.info("Stored thumbnail in file %r", output_path)

            t_len = os.path.getsize(output_path)

            yield self.store.store_local_thumbnail(media_id, t_width, t_height,
                                                   t_type, t_method, t_len)

            defer.returnValue(output_path)

    @defer.inlineCallbacks
    def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
                                        t_width, t_height, t_method, t_type):
        input_path = yield self.media_storage.ensure_media_is_in_local_cache(
            FileInfo(
                server_name,
                file_id,
                url_cache=False,
            ))

        thumbnailer = Thumbnailer(input_path)
        t_byte_source = yield logcontext.defer_to_thread(
            self.hs.get_reactor(), self._generate_thumbnail, thumbnailer,
            t_width, t_height, t_method, t_type)

        if t_byte_source:
            try:
                file_info = FileInfo(
                    server_name=server_name,
                    file_id=media_id,
                    thumbnail=True,
                    thumbnail_width=t_width,
                    thumbnail_height=t_height,
                    thumbnail_method=t_method,
                    thumbnail_type=t_type,
                )

                output_path = yield self.media_storage.store_file(
                    t_byte_source,
                    file_info,
                )
            finally:
                t_byte_source.close()

            logger.info("Stored thumbnail in file %r", output_path)

            t_len = os.path.getsize(output_path)

            yield self.store.store_remote_media_thumbnail(
                server_name, media_id, file_id, t_width, t_height, t_type,
                t_method, t_len)

            defer.returnValue(output_path)

    @defer.inlineCallbacks
    def _generate_thumbnails(self,
                             server_name,
                             media_id,
                             file_id,
                             media_type,
                             url_cache=False):
        """Generate and store thumbnails for an image.

        Args:
            server_name (str|None): The server name if remote media, else None if local
            media_id (str): The media ID of the content. (This is the same as
                the file_id for local content)
            file_id (str): Local file ID
            media_type (str): The content type of the file
            url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
                used exclusively by the url previewer

        Returns:
            Deferred[dict]: Dict with "width" and "height" keys of original image
        """
        requirements = self._get_thumbnail_requirements(media_type)
        if not requirements:
            return

        input_path = yield self.media_storage.ensure_media_is_in_local_cache(
            FileInfo(
                server_name,
                file_id,
                url_cache=url_cache,
            ))

        thumbnailer = Thumbnailer(input_path)
        m_width = thumbnailer.width
        m_height = thumbnailer.height

        if m_width * m_height >= self.max_image_pixels:
            logger.info("Image too large to thumbnail %r x %r > %r", m_width,
                        m_height, self.max_image_pixels)
            return

        if thumbnailer.transpose_method is not None:
            m_width, m_height = yield logcontext.defer_to_thread(
                self.hs.get_reactor(), thumbnailer.transpose)

        # We deduplicate the thumbnail sizes by ignoring the cropped versions if
        # they have the same dimensions of a scaled one.
        thumbnails = {}
        for r_width, r_height, r_method, r_type in requirements:
            if r_method == "crop":
                thumbnails.setdefault((r_width, r_height, r_type), r_method)
            elif r_method == "scale":
                t_width, t_height = thumbnailer.aspect(r_width, r_height)
                t_width = min(m_width, t_width)
                t_height = min(m_height, t_height)
                thumbnails[(t_width, t_height, r_type)] = r_method

        # Now we generate the thumbnails for each dimension, store it
        for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
            # Generate the thumbnail
            if t_method == "crop":
                t_byte_source = yield logcontext.defer_to_thread(
                    self.hs.get_reactor(),
                    thumbnailer.crop,
                    t_width,
                    t_height,
                    t_type,
                )
            elif t_method == "scale":
                t_byte_source = yield logcontext.defer_to_thread(
                    self.hs.get_reactor(),
                    thumbnailer.scale,
                    t_width,
                    t_height,
                    t_type,
                )
            else:
                logger.error("Unrecognized method: %r", t_method)
                continue

            if not t_byte_source:
                continue

            try:
                file_info = FileInfo(
                    server_name=server_name,
                    file_id=file_id,
                    thumbnail=True,
                    thumbnail_width=t_width,
                    thumbnail_height=t_height,
                    thumbnail_method=t_method,
                    thumbnail_type=t_type,
                    url_cache=url_cache,
                )

                output_path = yield self.media_storage.store_file(
                    t_byte_source,
                    file_info,
                )
            finally:
                t_byte_source.close()

            t_len = os.path.getsize(output_path)

            # Write to database
            if server_name:
                yield self.store.store_remote_media_thumbnail(
                    server_name, media_id, file_id, t_width, t_height, t_type,
                    t_method, t_len)
            else:
                yield self.store.store_local_thumbnail(media_id, t_width,
                                                       t_height, t_type,
                                                       t_method, t_len)

        defer.returnValue({
            "width": m_width,
            "height": m_height,
        })

    @defer.inlineCallbacks
    def delete_old_remote_media(self, before_ts):
        old_media = yield self.store.get_remote_media_before(before_ts)

        deleted = 0

        for media in old_media:
            origin = media["media_origin"]
            media_id = media["media_id"]
            file_id = media["filesystem_id"]
            key = (origin, media_id)

            logger.info("Deleting: %r", key)

            # TODO: Should we delete from the backup store

            with (yield self.remote_media_linearizer.queue(key)):
                full_path = self.filepaths.remote_media_filepath(
                    origin, file_id)
                try:
                    os.remove(full_path)
                except OSError as e:
                    logger.warn("Failed to remove file: %r", full_path)
                    if e.errno == errno.ENOENT:
                        pass
                    else:
                        continue

                thumbnail_dir = self.filepaths.remote_media_thumbnail_dir(
                    origin, file_id)
                shutil.rmtree(thumbnail_dir, ignore_errors=True)

                yield self.store.delete_remote_media(origin, media_id)
                deleted += 1

        defer.returnValue({"deleted": deleted})
Пример #33
0
class RegistrationHandler(BaseHandler):

    def __init__(self, hs):
        """

        Args:
            hs (synapse.server.HomeServer):
        """
        super(RegistrationHandler, self).__init__(hs)
        self.hs = hs
        self.auth = hs.get_auth()
        self._auth_handler = hs.get_auth_handler()
        self.profile_handler = hs.get_profile_handler()
        self.user_directory_handler = hs.get_user_directory_handler()
        self.captcha_client = CaptchaServerHttpClient(hs)

        self._next_generated_user_id = None

        self.macaroon_gen = hs.get_macaroon_generator()

        self._generate_user_id_linearizer = Linearizer(
            name="_generate_user_id_linearizer",
        )
        self._server_notices_mxid = hs.config.server_notices_mxid

    @defer.inlineCallbacks
    def check_username(self, localpart, guest_access_token=None,
                       assigned_user_id=None):
        if types.contains_invalid_mxid_characters(localpart):
            raise SynapseError(
                400,
                "User ID can only contain characters a-z, 0-9, or '=_-./'",
                Codes.INVALID_USERNAME
            )

        if not localpart:
            raise SynapseError(
                400,
                "User ID cannot be empty",
                Codes.INVALID_USERNAME
            )

        if localpart[0] == '_':
            raise SynapseError(
                400,
                "User ID may not begin with _",
                Codes.INVALID_USERNAME
            )

        user = UserID(localpart, self.hs.hostname)
        user_id = user.to_string()

        if assigned_user_id:
            if user_id == assigned_user_id:
                return
            else:
                raise SynapseError(
                    400,
                    "A different user ID has already been registered for this session",
                )

        self.check_user_id_not_appservice_exclusive(user_id)

        users = yield self.store.get_users_by_id_case_insensitive(user_id)
        if users:
            if not guest_access_token:
                raise SynapseError(
                    400,
                    "User ID already taken.",
                    errcode=Codes.USER_IN_USE,
                )
            user_data = yield self.auth.get_user_by_access_token(guest_access_token)
            if not user_data["is_guest"] or user_data["user"].localpart != localpart:
                raise AuthError(
                    403,
                    "Cannot register taken user ID without valid guest "
                    "credentials for that user.",
                    errcode=Codes.FORBIDDEN,
                )

    @defer.inlineCallbacks
    def register(
        self,
        localpart=None,
        password=None,
        generate_token=True,
        guest_access_token=None,
        make_guest=False,
        admin=False,
        threepid=None,
        user_type=None,
        default_display_name=None,
    ):
        """Registers a new client on the server.

        Args:
            localpart : The local part of the user ID to register. If None,
              one will be generated.
            password (unicode) : The password to assign to this user so they can
              login again. This can be None which means they cannot login again
              via a password (e.g. the user is an application service user).
            generate_token (bool): Whether a new access token should be
              generated. Having this be True should be considered deprecated,
              since it offers no means of associating a device_id with the
              access_token. Instead you should call auth_handler.issue_access_token
              after registration.
            user_type (str|None): type of user. One of the values from
              api.constants.UserTypes, or None for a normal user.
            default_display_name (unicode|None): if set, the new user's displayname
              will be set to this. Defaults to 'localpart'.
        Returns:
            A tuple of (user_id, access_token).
        Raises:
            RegistrationError if there was a problem registering.
        """

        yield self.auth.check_auth_blocking(threepid=threepid)
        password_hash = None
        if password:
            password_hash = yield self.auth_handler().hash(password)

        if localpart:
            yield self.check_username(localpart, guest_access_token=guest_access_token)

            was_guest = guest_access_token is not None

            if not was_guest:
                try:
                    int(localpart)
                    raise RegistrationError(
                        400,
                        "Numeric user IDs are reserved for guest users."
                    )
                except ValueError:
                    pass

            user = UserID(localpart, self.hs.hostname)
            user_id = user.to_string()

            if was_guest:
                # If the user was a guest then they already have a profile
                default_display_name = None

            elif default_display_name is None:
                default_display_name = localpart

            token = None
            if generate_token:
                token = self.macaroon_gen.generate_access_token(user_id)
            yield self.store.register(
                user_id=user_id,
                token=token,
                password_hash=password_hash,
                was_guest=was_guest,
                make_guest=make_guest,
                create_profile_with_displayname=default_display_name,
                admin=admin,
                user_type=user_type,
            )

            if self.hs.config.user_directory_search_all_users:
                profile = yield self.store.get_profileinfo(localpart)
                yield self.user_directory_handler.handle_local_profile_change(
                    user_id, profile
                )

        else:
            # autogen a sequential user ID
            attempts = 0
            token = None
            user = None
            while not user:
                localpart = yield self._generate_user_id(attempts > 0)
                user = UserID(localpart, self.hs.hostname)
                user_id = user.to_string()
                yield self.check_user_id_not_appservice_exclusive(user_id)
                if generate_token:
                    token = self.macaroon_gen.generate_access_token(user_id)
                if default_display_name is None:
                    default_display_name = localpart
                try:
                    yield self.store.register(
                        user_id=user_id,
                        token=token,
                        password_hash=password_hash,
                        make_guest=make_guest,
                        create_profile_with_displayname=default_display_name,
                    )
                except SynapseError:
                    # if user id is taken, just generate another
                    user = None
                    user_id = None
                    token = None
                    attempts += 1
        if not self.hs.config.user_consent_at_registration:
            yield self._auto_join_rooms(user_id)

        defer.returnValue((user_id, token))

    @defer.inlineCallbacks
    def _auto_join_rooms(self, user_id):
        """Automatically joins users to auto join rooms - creating the room in the first place
        if the user is the first to be created.

        Args:
            user_id(str): The user to join
        """
        # auto-join the user to any rooms we're supposed to dump them into
        fake_requester = create_requester(user_id)

        # try to create the room if we're the first real user on the server. Note
        # that an auto-generated support user is not a real user and will never be
        # the user to create the room
        should_auto_create_rooms = False
        is_support = yield self.store.is_support_user(user_id)
        # There is an edge case where the first user is the support user, then
        # the room is never created, though this seems unlikely and
        # recoverable from given the support user being involved in the first
        # place.
        if self.hs.config.autocreate_auto_join_rooms and not is_support:
            count = yield self.store.count_all_users()
            should_auto_create_rooms = count == 1
        for r in self.hs.config.auto_join_rooms:
            try:
                if should_auto_create_rooms:
                    room_alias = RoomAlias.from_string(r)
                    if self.hs.hostname != room_alias.domain:
                        logger.warning(
                            'Cannot create room alias %s, '
                            'it does not match server domain',
                            r,
                        )
                    else:
                        # create room expects the localpart of the room alias
                        room_alias_localpart = room_alias.localpart

                        # getting the RoomCreationHandler during init gives a dependency
                        # loop
                        yield self.hs.get_room_creation_handler().create_room(
                            fake_requester,
                            config={
                                "preset": "public_chat",
                                "room_alias_name": room_alias_localpart
                            },
                            ratelimit=False,
                        )
                else:
                    yield self._join_user_to_room(fake_requester, r)
            except Exception as e:
                logger.error("Failed to join new user to %r: %r", r, e)

    @defer.inlineCallbacks
    def post_consent_actions(self, user_id):
        """A series of registration actions that can only be carried out once consent
        has been granted

        Args:
            user_id (str): The user to join
        """
        yield self._auto_join_rooms(user_id)

    @defer.inlineCallbacks
    def appservice_register(self, user_localpart, as_token):
        user = UserID(user_localpart, self.hs.hostname)
        user_id = user.to_string()
        service = self.store.get_app_service_by_token(as_token)
        if not service:
            raise AuthError(403, "Invalid application service token.")
        if not service.is_interested_in_user(user_id):
            raise SynapseError(
                400, "Invalid user localpart for this application service.",
                errcode=Codes.EXCLUSIVE
            )

        service_id = service.id if service.is_exclusive_user(user_id) else None

        yield self.check_user_id_not_appservice_exclusive(
            user_id, allowed_appservice=service
        )

        yield self.store.register(
            user_id=user_id,
            password_hash="",
            appservice_id=service_id,
            create_profile_with_displayname=user.localpart,
        )
        defer.returnValue(user_id)

    @defer.inlineCallbacks
    def check_recaptcha(self, ip, private_key, challenge, response):
        """
        Checks a recaptcha is correct.

        Used only by c/s api v1
        """

        captcha_response = yield self._validate_captcha(
            ip,
            private_key,
            challenge,
            response
        )
        if not captcha_response["valid"]:
            logger.info("Invalid captcha entered from %s. Error: %s",
                        ip, captcha_response["error_url"])
            raise InvalidCaptchaError(
                error_url=captcha_response["error_url"]
            )
        else:
            logger.info("Valid captcha entered from %s", ip)

    @defer.inlineCallbacks
    def register_email(self, threepidCreds):
        """
        Registers emails with an identity server.

        Used only by c/s api v1
        """

        for c in threepidCreds:
            logger.info("validating threepidcred sid %s on id server %s",
                        c['sid'], c['idServer'])
            try:
                identity_handler = self.hs.get_handlers().identity_handler
                threepid = yield identity_handler.threepid_from_creds(c)
            except Exception:
                logger.exception("Couldn't validate 3pid")
                raise RegistrationError(400, "Couldn't validate 3pid")

            if not threepid:
                raise RegistrationError(400, "Couldn't validate 3pid")
            logger.info("got threepid with medium '%s' and address '%s'",
                        threepid['medium'], threepid['address'])

            if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
                raise RegistrationError(
                    403, "Third party identifier is not allowed"
                )

    @defer.inlineCallbacks
    def bind_emails(self, user_id, threepidCreds):
        """Links emails with a user ID and informs an identity server.

        Used only by c/s api v1
        """

        # Now we have a matrix ID, bind it to the threepids we were given
        for c in threepidCreds:
            identity_handler = self.hs.get_handlers().identity_handler
            # XXX: This should be a deferred list, shouldn't it?
            yield identity_handler.bind_threepid(c, user_id)

    def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
        # don't allow people to register the server notices mxid
        if self._server_notices_mxid is not None:
            if user_id == self._server_notices_mxid:
                raise SynapseError(
                    400, "This user ID is reserved.",
                    errcode=Codes.EXCLUSIVE
                )

        # valid user IDs must not clash with any user ID namespaces claimed by
        # application services.
        services = self.store.get_app_services()
        interested_services = [
            s for s in services
            if s.is_interested_in_user(user_id)
            and s != allowed_appservice
        ]
        for service in interested_services:
            if service.is_exclusive_user(user_id):
                raise SynapseError(
                    400, "This user ID is reserved by an application service.",
                    errcode=Codes.EXCLUSIVE
                )

    @defer.inlineCallbacks
    def _generate_user_id(self, reseed=False):
        if reseed or self._next_generated_user_id is None:
            with (yield self._generate_user_id_linearizer.queue(())):
                if reseed or self._next_generated_user_id is None:
                    self._next_generated_user_id = (
                        yield self.store.find_next_generated_user_id_localpart()
                    )

        id = self._next_generated_user_id
        self._next_generated_user_id += 1
        defer.returnValue(str(id))

    @defer.inlineCallbacks
    def _validate_captcha(self, ip_addr, private_key, challenge, response):
        """Validates the captcha provided.

        Used only by c/s api v1

        Returns:
            dict: Containing 'valid'(bool) and 'error_url'(str) if invalid.

        """
        response = yield self._submit_captcha(ip_addr, private_key, challenge,
                                              response)
        # parse Google's response. Lovely format..
        lines = response.split('\n')
        json = {
            "valid": lines[0] == 'true',
            "error_url": "http://www.google.com/recaptcha/api/challenge?" +
                         "error=%s" % lines[1]
        }
        defer.returnValue(json)

    @defer.inlineCallbacks
    def _submit_captcha(self, ip_addr, private_key, challenge, response):
        """
        Used only by c/s api v1
        """
        data = yield self.captcha_client.post_urlencoded_get_raw(
            "http://www.google.com:80/recaptcha/api/verify",
            args={
                'privatekey': private_key,
                'remoteip': ip_addr,
                'challenge': challenge,
                'response': response
            }
        )
        defer.returnValue(data)

    @defer.inlineCallbacks
    def get_or_create_user(self, requester, localpart, displayname,
                           password_hash=None):
        """Creates a new user if the user does not exist,
        else revokes all previous access tokens and generates a new one.

        Args:
            localpart : The local part of the user ID to register. If None,
              one will be randomly generated.
        Returns:
            A tuple of (user_id, access_token).
        Raises:
            RegistrationError if there was a problem registering.
        """
        if localpart is None:
            raise SynapseError(400, "Request must include user id")
        yield self.auth.check_auth_blocking()
        need_register = True

        try:
            yield self.check_username(localpart)
        except SynapseError as e:
            if e.errcode == Codes.USER_IN_USE:
                need_register = False
            else:
                raise

        user = UserID(localpart, self.hs.hostname)
        user_id = user.to_string()
        token = self.macaroon_gen.generate_access_token(user_id)

        if need_register:
            yield self.store.register(
                user_id=user_id,
                token=token,
                password_hash=password_hash,
                create_profile_with_displayname=user.localpart,
            )
        else:
            yield self._auth_handler.delete_access_tokens_for_user(user_id)
            yield self.store.add_access_token_to_user(user_id=user_id, token=token)

        if displayname is not None:
            logger.info("setting user display name: %s -> %s", user_id, displayname)
            yield self.profile_handler.set_displayname(
                user, requester, displayname, by_admin=True,
            )

        defer.returnValue((user_id, token))

    def auth_handler(self):
        return self.hs.get_auth_handler()

    @defer.inlineCallbacks
    def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
        """Get a guest access token for a 3PID, creating a guest account if
        one doesn't already exist.

        Args:
            medium (str)
            address (str)
            inviter_user_id (str): The user ID who is trying to invite the
                3PID

        Returns:
            Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
            3PID guest account.
        """
        access_token = yield self.store.get_3pid_guest_access_token(medium, address)
        if access_token:
            user_info = yield self.auth.get_user_by_access_token(
                access_token
            )

            defer.returnValue((user_info["user"].to_string(), access_token))

        user_id, access_token = yield self.register(
            generate_token=True,
            make_guest=True
        )
        access_token = yield self.store.save_or_get_3pid_guest_access_token(
            medium, address, access_token, inviter_user_id
        )

        defer.returnValue((user_id, access_token))

    @defer.inlineCallbacks
    def _join_user_to_room(self, requester, room_identifier):
        room_id = None
        room_member_handler = self.hs.get_room_member_handler()
        if RoomID.is_valid(room_identifier):
            room_id = room_identifier
        elif RoomAlias.is_valid(room_identifier):
            room_alias = RoomAlias.from_string(room_identifier)
            room_id, remote_room_hosts = (
                yield room_member_handler.lookup_room_alias(room_alias)
            )
            room_id = room_id.to_string()
        else:
            raise SynapseError(400, "%s was not legal room ID or room alias" % (
                room_identifier,
            ))

        yield room_member_handler.update_membership(
            requester=requester,
            target=requester.user,
            room_id=room_id,
            remote_room_hosts=remote_room_hosts,
            action="join",
            ratelimit=False,
        )
Пример #34
0
class FederationSenderHandler(object):
    """Processes the replication stream and forwards the appropriate entries
    to the federation sender.
    """
    def __init__(self, hs, replication_client):
        self.store = hs.get_datastore()
        self._is_mine_id = hs.is_mine_id
        self.federation_sender = hs.get_federation_sender()
        self.replication_client = replication_client

        self.federation_position = self.store.federation_out_pos_startup
        self._fed_position_linearizer = Linearizer(
            name="_fed_position_linearizer")

        self._last_ack = self.federation_position

        self._room_serials = {}
        self._room_typing = {}

    def on_start(self):
        # There may be some events that are persisted but haven't been sent,
        # so send them now.
        self.federation_sender.notify_new_events(
            self.store.get_room_max_stream_ordering())

    def stream_positions(self):
        return {"federation": self.federation_position}

    def process_replication_rows(self, stream_name, token, rows):
        # The federation stream contains things that we want to send out, e.g.
        # presence, typing, etc.
        if stream_name == "federation":
            send_queue.process_rows_for_federation(self.federation_sender,
                                                   rows)
            run_in_background(self.update_token, token)

        # We also need to poke the federation sender when new events happen
        elif stream_name == "events":
            self.federation_sender.notify_new_events(token)

        # ... and when new receipts happen
        elif stream_name == ReceiptsStream.NAME:
            run_as_background_process(
                "process_receipts_for_federation",
                self._on_new_receipts,
                rows,
            )

    @defer.inlineCallbacks
    def _on_new_receipts(self, rows):
        """
        Args:
            rows (iterable[synapse.replication.tcp.streams.ReceiptsStreamRow]):
                new receipts to be processed
        """
        for receipt in rows:
            # we only want to send on receipts for our own users
            if not self._is_mine_id(receipt.user_id):
                continue
            receipt_info = ReadReceipt(
                receipt.room_id,
                receipt.receipt_type,
                receipt.user_id,
                [receipt.event_id],
                receipt.data,
            )
            yield self.federation_sender.send_read_receipt(receipt_info)

    @defer.inlineCallbacks
    def update_token(self, token):
        try:
            self.federation_position = token

            # We linearize here to ensure we don't have races updating the token
            with (yield self._fed_position_linearizer.queue(None)):
                if self._last_ack < self.federation_position:
                    yield self.store.update_federation_out_pos(
                        "federation", self.federation_position)

                    # We ACK this token over replication so that the master can drop
                    # its in memory queues
                    self.replication_client.send_federation_ack(
                        self.federation_position)
                    self._last_ack = self.federation_position
        except Exception:
            logger.exception("Error updating federation stream position")
Пример #35
0
class DeviceListEduUpdater(object):
    "Handles incoming device list updates from federation and updates the DB"

    def __init__(self, hs, device_handler):
        self.store = hs.get_datastore()
        self.federation = hs.get_federation_client()
        self.clock = hs.get_clock()
        self.device_handler = device_handler

        self._remote_edu_linearizer = Linearizer(name="remote_device_list")

        # user_id -> list of updates waiting to be handled.
        self._pending_updates = {}

        # Recently seen stream ids. We don't bother keeping these in the DB,
        # but they're useful to have them about to reduce the number of spurious
        # resyncs.
        self._seen_updates = ExpiringCache(
            cache_name="device_update_edu",
            clock=self.clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
            iterable=True,
        )

    @defer.inlineCallbacks
    def incoming_device_list_update(self, origin, edu_content):
        """Called on incoming device list update from federation. Responsible
        for parsing the EDU and adding to pending updates list.
        """

        user_id = edu_content.pop("user_id")
        device_id = edu_content.pop("device_id")
        stream_id = str(edu_content.pop("stream_id"))  # They may come as ints
        prev_ids = edu_content.pop("prev_id", [])
        prev_ids = [str(p) for p in prev_ids]   # They may come as ints

        if get_domain_from_id(user_id) != origin:
            # TODO: Raise?
            logger.warning("Got device list update edu for %r from %r", user_id, origin)
            return

        room_ids = yield self.store.get_rooms_for_user(user_id)
        if not room_ids:
            # We don't share any rooms with this user. Ignore update, as we
            # probably won't get any further updates.
            return

        self._pending_updates.setdefault(user_id, []).append(
            (device_id, stream_id, prev_ids, edu_content)
        )

        yield self._handle_device_updates(user_id)

    @measure_func("_incoming_device_list_update")
    @defer.inlineCallbacks
    def _handle_device_updates(self, user_id):
        "Actually handle pending updates."

        with (yield self._remote_edu_linearizer.queue(user_id)):
            pending_updates = self._pending_updates.pop(user_id, [])
            if not pending_updates:
                # This can happen since we batch updates
                return

            # Given a list of updates we check if we need to resync. This
            # happens if we've missed updates.
            resync = yield self._need_to_do_resync(user_id, pending_updates)

            if resync:
                # Fetch all devices for the user.
                origin = get_domain_from_id(user_id)
                try:
                    result = yield self.federation.query_user_devices(origin, user_id)
                except NotRetryingDestination:
                    # TODO: Remember that we are now out of sync and try again
                    # later
                    logger.warn(
                        "Failed to handle device list update for %s,"
                        " we're not retrying the remote",
                        user_id,
                    )
                    # We abort on exceptions rather than accepting the update
                    # as otherwise synapse will 'forget' that its device list
                    # is out of date. If we bail then we will retry the resync
                    # next time we get a device list update for this user_id.
                    # This makes it more likely that the device lists will
                    # eventually become consistent.
                    return
                except FederationDeniedError as e:
                    logger.info(e)
                    return
                except Exception:
                    # TODO: Remember that we are now out of sync and try again
                    # later
                    logger.exception(
                        "Failed to handle device list update for %s", user_id
                    )
                    return

                stream_id = result["stream_id"]
                devices = result["devices"]
                yield self.store.update_remote_device_list_cache(
                    user_id, devices, stream_id,
                )
                device_ids = [device["device_id"] for device in devices]
                yield self.device_handler.notify_device_update(user_id, device_ids)
            else:
                # Simply update the single device, since we know that is the only
                # change (because of the single prev_id matching the current cache)
                for device_id, stream_id, prev_ids, content in pending_updates:
                    yield self.store.update_remote_device_list_cache_entry(
                        user_id, device_id, content, stream_id,
                    )

                yield self.device_handler.notify_device_update(
                    user_id, [device_id for device_id, _, _, _ in pending_updates]
                )

            self._seen_updates.setdefault(user_id, set()).update(
                stream_id for _, stream_id, _, _ in pending_updates
            )

    @defer.inlineCallbacks
    def _need_to_do_resync(self, user_id, updates):
        """Given a list of updates for a user figure out if we need to do a full
        resync, or whether we have enough data that we can just apply the delta.
        """
        seen_updates = self._seen_updates.get(user_id, set())

        extremity = yield self.store.get_device_list_last_stream_id_for_remote(
            user_id
        )

        stream_id_in_updates = set()  # stream_ids in updates list
        for _, stream_id, prev_ids, _ in updates:
            if not prev_ids:
                # We always do a resync if there are no previous IDs
                defer.returnValue(True)

            for prev_id in prev_ids:
                if prev_id == extremity:
                    continue
                elif prev_id in seen_updates:
                    continue
                elif prev_id in stream_id_in_updates:
                    continue
                else:
                    defer.returnValue(True)

            stream_id_in_updates.add(stream_id)

        defer.returnValue(False)
Пример #36
0
class StateResolutionHandler(object):
    """Responsible for doing state conflict resolution.

    Note that the storage layer depends on this handler, so all functions must
    be storage-independent.
    """
    def __init__(self, hs):
        self.clock = hs.get_clock()

        # dict of set of event_ids -> _StateCacheEntry.
        self._state_cache = None
        self.resolve_linearizer = Linearizer(name="state_resolve_lock")

        self._state_cache = ExpiringCache(
            cache_name="state_cache",
            clock=self.clock,
            max_len=100000,
            expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
            iterable=True,
            reset_expiry_on_get=True,
        )

    @defer.inlineCallbacks
    @log_function
    def resolve_state_groups(self, room_id, room_version, state_groups_ids,
                             event_map, state_res_store):
        """Resolves conflicts between a set of state groups

        Always generates a new state group (unless we hit the cache), so should
        not be called for a single state group

        Args:
            room_id (str): room we are resolving for (used for logging and sanity checks)
            room_version (str): version of the room
            state_groups_ids (dict[int, dict[(str, str), str]]):
                 map from state group id to the state in that state group
                (where 'state' is a map from state key to event id)

            event_map(dict[str,FrozenEvent]|None):
                a dict from event_id to event, for any events that we happen to
                have in flight (eg, those currently being persisted). This will be
                used as a starting point fof finding the state we need; any missing
                events will be requested via state_res_store.

                If None, all events will be fetched via state_res_store.

            state_res_store (StateResolutionStore)

        Returns:
            Deferred[_StateCacheEntry]: resolved state
        """
        logger.debug("resolve_state_groups state_groups %s",
                     state_groups_ids.keys())

        group_names = frozenset(state_groups_ids.keys())

        with (yield self.resolve_linearizer.queue(group_names)):
            if self._state_cache is not None:
                cache = self._state_cache.get(group_names, None)
                if cache:
                    return cache

            logger.info("Resolving state for %s with %d groups", room_id,
                        len(state_groups_ids))

            state_groups_histogram.observe(len(state_groups_ids))

            # start by assuming we won't have any conflicted state, and build up the new
            # state map by iterating through the state groups. If we discover a conflict,
            # we give up and instead use `resolve_events_with_store`.
            #
            # XXX: is this actually worthwhile, or should we just let
            # resolve_events_with_store do it?
            new_state = {}
            conflicted_state = False
            for st in state_groups_ids.values():
                for key, e_id in st.items():
                    if key in new_state:
                        conflicted_state = True
                        break
                    new_state[key] = e_id
                if conflicted_state:
                    break

            if conflicted_state:
                logger.info("Resolving conflicted state for %r", room_id)
                with Measure(self.clock, "state._resolve_events"):
                    new_state = yield resolve_events_with_store(
                        self.clock,
                        room_id,
                        room_version,
                        list(state_groups_ids.values()),
                        event_map=event_map,
                        state_res_store=state_res_store,
                    )

            # if the new state matches any of the input state groups, we can
            # use that state group again. Otherwise we will generate a state_id
            # which will be used as a cache key for future resolutions, but
            # not get persisted.

            with Measure(self.clock, "state.create_group_ids"):
                cache = _make_state_cache_entry(new_state, state_groups_ids)

            if self._state_cache is not None:
                self._state_cache[group_names] = cache

            return cache
Пример #37
0
class EventCreationHandler(object):
    def __init__(self, hs):
        self.hs = hs
        self.auth = hs.get_auth()
        self.store = hs.get_datastore()
        self.storage = hs.get_storage()
        self.state = hs.get_state_handler()
        self.clock = hs.get_clock()
        self.validator = EventValidator()
        self.profile_handler = hs.get_profile_handler()
        self.event_builder_factory = hs.get_event_builder_factory()
        self.server_name = hs.hostname
        self.ratelimiter = hs.get_ratelimiter()
        self.notifier = hs.get_notifier()
        self.config = hs.config
        self.require_membership_for_aliases = hs.config.require_membership_for_aliases

        self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs)

        # This is only used to get at ratelimit function, and maybe_kick_guest_users
        self.base_handler = BaseHandler(hs)

        self.pusher_pool = hs.get_pusherpool()

        # We arbitrarily limit concurrent event creation for a room to 5.
        # This is to stop us from diverging history *too* much.
        self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")

        self.action_generator = hs.get_action_generator()

        self.spam_checker = hs.get_spam_checker()
        self.third_party_event_rules = hs.get_third_party_event_rules()

        self._block_events_without_consent_error = (
            self.config.block_events_without_consent_error
        )

        # Rooms which should be excluded from dummy insertion. (For instance,
        # those without local users who can send events into the room).
        #
        # map from room id to time-of-last-attempt.
        #
        self._rooms_to_exclude_from_dummy_event_insertion = {}  # type: dict[str, int]

        # we need to construct a ConsentURIBuilder here, as it checks that the necessary
        # config options, but *only* if we have a configuration for which we are
        # going to need it.
        if self._block_events_without_consent_error:
            self._consent_uri_builder = ConsentURIBuilder(self.config)

        if (
            not self.config.worker_app
            and self.config.cleanup_extremities_with_dummy_events
        ):
            self.clock.looping_call(
                lambda: run_as_background_process(
                    "send_dummy_events_to_fill_extremities",
                    self._send_dummy_events_to_fill_extremities,
                ),
                5 * 60 * 1000,
            )

    @defer.inlineCallbacks
    def create_event(
        self,
        requester,
        event_dict,
        token_id=None,
        txn_id=None,
        prev_events_and_hashes=None,
        require_consent=True,
    ):
        """
        Given a dict from a client, create a new event.

        Creates an FrozenEvent object, filling out auth_events, prev_events,
        etc.

        Adds display names to Join membership events.

        Args:
            requester
            event_dict (dict): An entire event
            token_id (str)
            txn_id (str)

            prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
                the forward extremities to use as the prev_events for the
                new event. For each event, a tuple of (event_id, hashes, depth)
                where *hashes* is a map from algorithm to hash.

                If None, they will be requested from the database.

            require_consent (bool): Whether to check if the requester has
                consented to privacy policy.
        Raises:
            ResourceLimitError if server is blocked to some resource being
            exceeded
        Returns:
            Tuple of created event (FrozenEvent), Context
        """
        yield self.auth.check_auth_blocking(requester.user.to_string())

        if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
            room_version = event_dict["content"]["room_version"]
        else:
            try:
                room_version = yield self.store.get_room_version(event_dict["room_id"])
            except NotFoundError:
                raise AuthError(403, "Unknown room")

        builder = self.event_builder_factory.new(room_version, event_dict)

        self.validator.validate_builder(builder)

        if builder.type == EventTypes.Member:
            membership = builder.content.get("membership", None)
            target = UserID.from_string(builder.state_key)

            if membership in {Membership.JOIN, Membership.INVITE}:
                # If event doesn't include a display name, add one.
                profile = self.profile_handler
                content = builder.content

                try:
                    if "displayname" not in content:
                        content["displayname"] = yield profile.get_displayname(target)
                    if "avatar_url" not in content:
                        content["avatar_url"] = yield profile.get_avatar_url(target)
                except Exception as e:
                    logger.info(
                        "Failed to get profile information for %r: %s", target, e
                    )

        is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester)
        if require_consent and not is_exempt:
            yield self.assert_accepted_privacy_policy(requester)

        if token_id is not None:
            builder.internal_metadata.token_id = token_id

        if txn_id is not None:
            builder.internal_metadata.txn_id = txn_id

        event, context = yield self.create_new_client_event(
            builder=builder,
            requester=requester,
            prev_events_and_hashes=prev_events_and_hashes,
        )

        # In an ideal world we wouldn't need the second part of this condition. However,
        # this behaviour isn't spec'd yet, meaning we should be able to deactivate this
        # behaviour. Another reason is that this code is also evaluated each time a new
        # m.room.aliases event is created, which includes hitting a /directory route.
        # Therefore not including this condition here would render the similar one in
        # synapse.handlers.directory pointless.
        if builder.type == EventTypes.Aliases and self.require_membership_for_aliases:
            # Ideally we'd do the membership check in event_auth.check(), which
            # describes a spec'd algorithm for authenticating events received over
            # federation as well as those created locally. As of room v3, aliases events
            # can be created by users that are not in the room, therefore we have to
            # tolerate them in event_auth.check().
            prev_state_ids = yield context.get_prev_state_ids(self.store)
            prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
            prev_event = (
                yield self.store.get_event(prev_event_id, allow_none=True)
                if prev_event_id
                else None
            )
            if not prev_event or prev_event.membership != Membership.JOIN:
                logger.warning(
                    (
                        "Attempt to send `m.room.aliases` in room %s by user %s but"
                        " membership is %s"
                    ),
                    event.room_id,
                    event.sender,
                    prev_event.membership if prev_event else None,
                )

                raise AuthError(
                    403, "You must be in the room to create an alias for it"
                )

        self.validator.validate_new(event)

        return (event, context)

    def _is_exempt_from_privacy_policy(self, builder, requester):
        """"Determine if an event to be sent is exempt from having to consent
        to the privacy policy

        Args:
            builder (synapse.events.builder.EventBuilder): event being created
            requester (Requster): user requesting this event

        Returns:
            Deferred[bool]: true if the event can be sent without the user
                consenting
        """
        # the only thing the user can do is join the server notices room.
        if builder.type == EventTypes.Member:
            membership = builder.content.get("membership", None)
            if membership == Membership.JOIN:
                return self._is_server_notices_room(builder.room_id)
            elif membership == Membership.LEAVE:
                # the user is always allowed to leave (but not kick people)
                return builder.state_key == requester.user.to_string()
        return succeed(False)

    @defer.inlineCallbacks
    def _is_server_notices_room(self, room_id):
        if self.config.server_notices_mxid is None:
            return False
        user_ids = yield self.store.get_users_in_room(room_id)
        return self.config.server_notices_mxid in user_ids

    @defer.inlineCallbacks
    def assert_accepted_privacy_policy(self, requester):
        """Check if a user has accepted the privacy policy

        Called when the given user is about to do something that requires
        privacy consent. We see if the user is exempt and otherwise check that
        they have given consent. If they have not, a ConsentNotGiven error is
        raised.

        Args:
            requester (synapse.types.Requester):
                The user making the request

        Returns:
            Deferred[None]: returns normally if the user has consented or is
                exempt

        Raises:
            ConsentNotGivenError: if the user has not given consent yet
        """
        if self._block_events_without_consent_error is None:
            return

        # exempt AS users from needing consent
        if requester.app_service is not None:
            return

        user_id = requester.user.to_string()

        # exempt the system notices user
        if (
            self.config.server_notices_mxid is not None
            and user_id == self.config.server_notices_mxid
        ):
            return

        u = yield self.store.get_user_by_id(user_id)
        assert u is not None
        if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
            # support and bot users are not required to consent
            return
        if u["appservice_id"] is not None:
            # users registered by an appservice are exempt
            return
        if u["consent_version"] == self.config.user_consent_version:
            return

        consent_uri = self._consent_uri_builder.build_user_consent_uri(
            requester.user.localpart
        )
        msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
        raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)

    @defer.inlineCallbacks
    def send_nonmember_event(self, requester, event, context, ratelimit=True):
        """
        Persists and notifies local clients and federation of an event.

        Args:
            event (FrozenEvent) the event to send.
            context (Context) the context of the event.
            ratelimit (bool): Whether to rate limit this send.
            is_guest (bool): Whether the sender is a guest.
        """
        if event.type == EventTypes.Member:
            raise SynapseError(
                500, "Tried to send member event through non-member codepath"
            )

        user = UserID.from_string(event.sender)

        assert self.hs.is_mine(user), "User must be our own: %s" % (user,)

        if event.is_state():
            prev_state = yield self.deduplicate_state_event(event, context)
            if prev_state is not None:
                logger.info(
                    "Not bothering to persist state event %s duplicated by %s",
                    event.event_id,
                    prev_state.event_id,
                )
                return prev_state

        yield self.handle_new_client_event(
            requester=requester, event=event, context=context, ratelimit=ratelimit
        )

    @defer.inlineCallbacks
    def deduplicate_state_event(self, event, context):
        """
        Checks whether event is in the latest resolved state in context.

        If so, returns the version of the event in context.
        Otherwise, returns None.
        """
        prev_state_ids = yield context.get_prev_state_ids(self.store)
        prev_event_id = prev_state_ids.get((event.type, event.state_key))
        if not prev_event_id:
            return
        prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
        if not prev_event:
            return

        if prev_event and event.user_id == prev_event.user_id:
            prev_content = encode_canonical_json(prev_event.content)
            next_content = encode_canonical_json(event.content)
            if prev_content == next_content:
                return prev_event
        return

    @defer.inlineCallbacks
    def create_and_send_nonmember_event(
        self, requester, event_dict, ratelimit=True, txn_id=None
    ):
        """
        Creates an event, then sends it.

        See self.create_event and self.send_nonmember_event.
        """

        # We limit the number of concurrent event sends in a room so that we
        # don't fork the DAG too much. If we don't limit then we can end up in
        # a situation where event persistence can't keep up, causing
        # extremities to pile up, which in turn leads to state resolution
        # taking longer.
        with (yield self.limiter.queue(event_dict["room_id"])):
            event, context = yield self.create_event(
                requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
            )

            spam_error = self.spam_checker.check_event_for_spam(event)
            if spam_error:
                if not isinstance(spam_error, string_types):
                    spam_error = "Spam is not permitted here"
                raise SynapseError(403, spam_error, Codes.FORBIDDEN)

            yield self.send_nonmember_event(
                requester, event, context, ratelimit=ratelimit
            )
        return event

    @measure_func("create_new_client_event")
    @defer.inlineCallbacks
    def create_new_client_event(
        self, builder, requester=None, prev_events_and_hashes=None
    ):
        """Create a new event for a local client

        Args:
            builder (EventBuilder):

            requester (synapse.types.Requester|None):

            prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
                the forward extremities to use as the prev_events for the
                new event. For each event, a tuple of (event_id, hashes, depth)
                where *hashes* is a map from algorithm to hash.

                If None, they will be requested from the database.

        Returns:
            Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)]
        """

        if prev_events_and_hashes is not None:
            assert len(prev_events_and_hashes) <= 10, (
                "Attempting to create an event with %i prev_events"
                % (len(prev_events_and_hashes),)
            )
        else:
            prev_events_and_hashes = yield self.store.get_prev_events_for_room(
                builder.room_id
            )

        prev_events = [
            (event_id, prev_hashes)
            for event_id, prev_hashes, _ in prev_events_and_hashes
        ]

        event = yield builder.build(prev_event_ids=[p for p, _ in prev_events])
        context = yield self.state.compute_event_context(event)
        if requester:
            context.app_service = requester.app_service

        self.validator.validate_new(event)

        # If this event is an annotation then we check that that the sender
        # can't annotate the same way twice (e.g. stops users from liking an
        # event multiple times).
        relation = event.content.get("m.relates_to", {})
        if relation.get("rel_type") == RelationTypes.ANNOTATION:
            relates_to = relation["event_id"]
            aggregation_key = relation["key"]

            already_exists = yield self.store.has_user_annotated_event(
                relates_to, event.type, aggregation_key, event.sender
            )
            if already_exists:
                raise SynapseError(400, "Can't send same reaction twice")

        logger.debug("Created event %s", event.event_id)

        return (event, context)

    @measure_func("handle_new_client_event")
    @defer.inlineCallbacks
    def handle_new_client_event(
        self, requester, event, context, ratelimit=True, extra_users=[]
    ):
        """Processes a new event. This includes checking auth, persisting it,
        notifying users, sending to remote servers, etc.

        If called from a worker will hit out to the master process for final
        processing.

        Args:
            requester (Requester)
            event (FrozenEvent)
            context (EventContext)
            ratelimit (bool)
            extra_users (list(UserID)): Any extra users to notify about event
        """

        if event.is_state() and (event.type, event.state_key) == (
            EventTypes.Create,
            "",
        ):
            room_version = event.content.get("room_version", RoomVersions.V1.identifier)
        else:
            room_version = yield self.store.get_room_version(event.room_id)

        event_allowed = yield self.third_party_event_rules.check_event_allowed(
            event, context
        )
        if not event_allowed:
            raise SynapseError(
                403, "This event is not allowed in this context", Codes.FORBIDDEN
            )

        try:
            yield self.auth.check_from_context(room_version, event, context)
        except AuthError as err:
            logger.warning("Denying new event %r because %s", event, err)
            raise err

        # Ensure that we can round trip before trying to persist in db
        try:
            dump = frozendict_json_encoder.encode(event.content)
            json.loads(dump)
        except Exception:
            logger.exception("Failed to encode content: %r", event.content)
            raise

        yield self.action_generator.handle_push_actions_for_event(event, context)

        # reraise does not allow inlineCallbacks to preserve the stacktrace, so we
        # hack around with a try/finally instead.
        success = False
        try:
            # If we're a worker we need to hit out to the master.
            if self.config.worker_app:
                yield self.send_event_to_master(
                    event_id=event.event_id,
                    store=self.store,
                    requester=requester,
                    event=event,
                    context=context,
                    ratelimit=ratelimit,
                    extra_users=extra_users,
                )
                success = True
                return

            yield self.persist_and_notify_client_event(
                requester, event, context, ratelimit=ratelimit, extra_users=extra_users
            )

            success = True
        finally:
            if not success:
                # Ensure that we actually remove the entries in the push actions
                # staging area, if we calculated them.
                run_in_background(
                    self.store.remove_push_actions_from_staging, event.event_id
                )

    @defer.inlineCallbacks
    def persist_and_notify_client_event(
        self, requester, event, context, ratelimit=True, extra_users=[]
    ):
        """Called when we have fully built the event, have already
        calculated the push actions for the event, and checked auth.

        This should only be run on master.
        """
        assert not self.config.worker_app

        if ratelimit:
            # We check if this is a room admin redacting an event so that we
            # can apply different ratelimiting. We do this by simply checking
            # it's not a self-redaction (to avoid having to look up whether the
            # user is actually admin or not).
            is_admin_redaction = False
            if event.type == EventTypes.Redaction:
                original_event = yield self.store.get_event(
                    event.redacts,
                    check_redacted=False,
                    get_prev_content=False,
                    allow_rejected=False,
                    allow_none=True,
                )

                is_admin_redaction = (
                    original_event and event.sender != original_event.sender
                )

            yield self.base_handler.ratelimit(
                requester, is_admin_redaction=is_admin_redaction
            )

        yield self.base_handler.maybe_kick_guest_users(event, context)

        if event.type == EventTypes.CanonicalAlias:
            # Check the alias is acually valid (at this time at least)
            room_alias_str = event.content.get("alias", None)
            if room_alias_str:
                room_alias = RoomAlias.from_string(room_alias_str)
                directory_handler = self.hs.get_handlers().directory_handler
                mapping = yield directory_handler.get_association(room_alias)

                if mapping["room_id"] != event.room_id:
                    raise SynapseError(
                        400,
                        "Room alias %s does not point to the room" % (room_alias_str,),
                    )

        federation_handler = self.hs.get_handlers().federation_handler

        if event.type == EventTypes.Member:
            if event.content["membership"] == Membership.INVITE:

                def is_inviter_member_event(e):
                    return e.type == EventTypes.Member and e.sender == event.sender

                current_state_ids = yield context.get_current_state_ids(self.store)

                state_to_include_ids = [
                    e_id
                    for k, e_id in iteritems(current_state_ids)
                    if k[0] in self.hs.config.room_invite_state_types
                    or k == (EventTypes.Member, event.sender)
                ]

                state_to_include = yield self.store.get_events(state_to_include_ids)

                event.unsigned["invite_room_state"] = [
                    {
                        "type": e.type,
                        "state_key": e.state_key,
                        "content": e.content,
                        "sender": e.sender,
                    }
                    for e in itervalues(state_to_include)
                ]

                invitee = UserID.from_string(event.state_key)
                if not self.hs.is_mine(invitee):
                    # TODO: Can we add signature from remote server in a nicer
                    # way? If we have been invited by a remote server, we need
                    # to get them to sign the event.

                    returned_invite = yield federation_handler.send_invite(
                        invitee.domain, event
                    )

                    event.unsigned.pop("room_state", None)

                    # TODO: Make sure the signatures actually are correct.
                    event.signatures.update(returned_invite.signatures)

        if event.type == EventTypes.Redaction:
            original_event = yield self.store.get_event(
                event.redacts,
                check_redacted=False,
                get_prev_content=False,
                allow_rejected=False,
                allow_none=True,
            )

            # we can make some additional checks now if we have the original event.
            if original_event:
                if original_event.type == EventTypes.Create:
                    raise AuthError(403, "Redacting create events is not permitted")

                if original_event.room_id != event.room_id:
                    raise SynapseError(400, "Cannot redact event from a different room")

            prev_state_ids = yield context.get_prev_state_ids(self.store)
            auth_events_ids = yield self.auth.compute_auth_events(
                event, prev_state_ids, for_verification=True
            )
            auth_events = yield self.store.get_events(auth_events_ids)
            auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
            room_version = yield self.store.get_room_version(event.room_id)

            if event_auth.check_redaction(room_version, event, auth_events=auth_events):
                # this user doesn't have 'redact' rights, so we need to do some more
                # checks on the original event. Let's start by checking the original
                # event exists.
                if not original_event:
                    raise NotFoundError("Could not find event %s" % (event.redacts,))

                if event.user_id != original_event.user_id:
                    raise AuthError(403, "You don't have permission to redact events")

                # all the checks are done.
                event.internal_metadata.recheck_redaction = False

        if event.type == EventTypes.Create:
            prev_state_ids = yield context.get_prev_state_ids(self.store)
            if prev_state_ids:
                raise AuthError(403, "Changing the room create event is forbidden")

        event_stream_id, max_stream_id = yield self.storage.persistence.persist_event(
            event, context=context
        )

        yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)

        def _notify():
            try:
                self.notifier.on_new_room_event(
                    event, event_stream_id, max_stream_id, extra_users=extra_users
                )
            except Exception:
                logger.exception("Error notifying about new room event")

        run_in_background(_notify)

        if event.type == EventTypes.Message:
            # We don't want to block sending messages on any presence code. This
            # matters as sometimes presence code can take a while.
            run_in_background(self._bump_active_time, requester.user)

    @defer.inlineCallbacks
    def _bump_active_time(self, user):
        try:
            presence = self.hs.get_presence_handler()
            yield presence.bump_presence_active_time(user)
        except Exception:
            logger.exception("Error bumping presence active time")

    @defer.inlineCallbacks
    def _send_dummy_events_to_fill_extremities(self):
        """Background task to send dummy events into rooms that have a large
        number of extremities
        """
        self._expire_rooms_to_exclude_from_dummy_event_insertion()
        room_ids = yield self.store.get_rooms_with_many_extremities(
            min_count=10,
            limit=5,
            room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(),
        )

        for room_id in room_ids:
            # For each room we need to find a joined member we can use to send
            # the dummy event with.

            prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id)

            latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes)

            members = yield self.state.get_current_users_in_room(
                room_id, latest_event_ids=latest_event_ids
            )
            dummy_event_sent = False
            for user_id in members:
                if not self.hs.is_mine_id(user_id):
                    continue
                requester = create_requester(user_id)
                try:
                    event, context = yield self.create_event(
                        requester,
                        {
                            "type": "org.matrix.dummy_event",
                            "content": {},
                            "room_id": room_id,
                            "sender": user_id,
                        },
                        prev_events_and_hashes=prev_events_and_hashes,
                    )

                    event.internal_metadata.proactively_send = False

                    yield self.send_nonmember_event(
                        requester, event, context, ratelimit=False
                    )
                    dummy_event_sent = True
                    break
                except ConsentNotGivenError:
                    logger.info(
                        "Failed to send dummy event into room %s for user %s due to "
                        "lack of consent. Will try another user" % (room_id, user_id)
                    )
                except AuthError:
                    logger.info(
                        "Failed to send dummy event into room %s for user %s due to "
                        "lack of power. Will try another user" % (room_id, user_id)
                    )

            if not dummy_event_sent:
                # Did not find a valid user in the room, so remove from future attempts
                # Exclusion is time limited, so the room will be rechecked in the future
                # dependent on _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
                logger.info(
                    "Failed to send dummy event into room %s. Will exclude it from "
                    "future attempts until cache expires" % (room_id,)
                )
                now = self.clock.time_msec()
                self._rooms_to_exclude_from_dummy_event_insertion[room_id] = now

    def _expire_rooms_to_exclude_from_dummy_event_insertion(self):
        expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
        to_expire = set()
        for room_id, time in self._rooms_to_exclude_from_dummy_event_insertion.items():
            if time < expire_before:
                to_expire.add(room_id)
        for room_id in to_expire:
            logger.debug(
                "Expiring room id %s from dummy event insertion exclusion cache",
                room_id,
            )
            del self._rooms_to_exclude_from_dummy_event_insertion[room_id]
Пример #38
0
class EventCreationHandler(object):
    def __init__(self, hs):
        self.hs = hs
        self.auth = hs.get_auth()
        self.store = hs.get_datastore()
        self.state = hs.get_state_handler()
        self.clock = hs.get_clock()
        self.validator = EventValidator()
        self.profile_handler = hs.get_profile_handler()
        self.event_builder_factory = hs.get_event_builder_factory()
        self.server_name = hs.hostname
        self.ratelimiter = hs.get_ratelimiter()
        self.notifier = hs.get_notifier()
        self.config = hs.config
        self.require_membership_for_aliases = hs.config.require_membership_for_aliases

        self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs)

        # This is only used to get at ratelimit function, and maybe_kick_guest_users
        self.base_handler = BaseHandler(hs)

        self.pusher_pool = hs.get_pusherpool()

        # We arbitrarily limit concurrent event creation for a room to 5.
        # This is to stop us from diverging history *too* much.
        self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")

        self.action_generator = hs.get_action_generator()

        self.spam_checker = hs.get_spam_checker()

        self._block_events_without_consent_error = (
            self.config.block_events_without_consent_error
        )

        # we need to construct a ConsentURIBuilder here, as it checks that the necessary
        # config options, but *only* if we have a configuration for which we are
        # going to need it.
        if self._block_events_without_consent_error:
            self._consent_uri_builder = ConsentURIBuilder(self.config)

    @defer.inlineCallbacks
    def create_event(self, requester, event_dict, token_id=None, txn_id=None,
                     prev_events_and_hashes=None, require_consent=True):
        """
        Given a dict from a client, create a new event.

        Creates an FrozenEvent object, filling out auth_events, prev_events,
        etc.

        Adds display names to Join membership events.

        Args:
            requester
            event_dict (dict): An entire event
            token_id (str)
            txn_id (str)

            prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
                the forward extremities to use as the prev_events for the
                new event. For each event, a tuple of (event_id, hashes, depth)
                where *hashes* is a map from algorithm to hash.

                If None, they will be requested from the database.

            require_consent (bool): Whether to check if the requester has
                consented to privacy policy.
        Raises:
            ResourceLimitError if server is blocked to some resource being
            exceeded
        Returns:
            Tuple of created event (FrozenEvent), Context
        """
        yield self.auth.check_auth_blocking(requester.user.to_string())

        if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
            room_version = event_dict["content"]["room_version"]
        else:
            try:
                room_version = yield self.store.get_room_version(event_dict["room_id"])
            except NotFoundError:
                raise AuthError(403, "Unknown room")

        builder = self.event_builder_factory.new(room_version, event_dict)

        self.validator.validate_builder(builder)

        if builder.type == EventTypes.Member:
            membership = builder.content.get("membership", None)
            target = UserID.from_string(builder.state_key)

            if membership in {Membership.JOIN, Membership.INVITE}:
                # If event doesn't include a display name, add one.
                profile = self.profile_handler
                content = builder.content

                try:
                    if "displayname" not in content:
                        content["displayname"] = yield profile.get_displayname(target)
                    if "avatar_url" not in content:
                        content["avatar_url"] = yield profile.get_avatar_url(target)
                except Exception as e:
                    logger.info(
                        "Failed to get profile information for %r: %s",
                        target, e
                    )

        is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester)
        if require_consent and not is_exempt:
            yield self.assert_accepted_privacy_policy(requester)

        if token_id is not None:
            builder.internal_metadata.token_id = token_id

        if txn_id is not None:
            builder.internal_metadata.txn_id = txn_id

        event, context = yield self.create_new_client_event(
            builder=builder,
            requester=requester,
            prev_events_and_hashes=prev_events_and_hashes,
        )

        # In an ideal world we wouldn't need the second part of this condition. However,
        # this behaviour isn't spec'd yet, meaning we should be able to deactivate this
        # behaviour. Another reason is that this code is also evaluated each time a new
        # m.room.aliases event is created, which includes hitting a /directory route.
        # Therefore not including this condition here would render the similar one in
        # synapse.handlers.directory pointless.
        if builder.type == EventTypes.Aliases and self.require_membership_for_aliases:
            # Ideally we'd do the membership check in event_auth.check(), which
            # describes a spec'd algorithm for authenticating events received over
            # federation as well as those created locally. As of room v3, aliases events
            # can be created by users that are not in the room, therefore we have to
            # tolerate them in event_auth.check().
            prev_state_ids = yield context.get_prev_state_ids(self.store)
            prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
            prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
            if not prev_event or prev_event.membership != Membership.JOIN:
                logger.warning(
                    ("Attempt to send `m.room.aliases` in room %s by user %s but"
                     " membership is %s"),
                    event.room_id,
                    event.sender,
                    prev_event.membership if prev_event else None,
                )

                raise AuthError(
                    403,
                    "You must be in the room to create an alias for it",
                )

        self.validator.validate_new(event)

        defer.returnValue((event, context))

    def _is_exempt_from_privacy_policy(self, builder, requester):
        """"Determine if an event to be sent is exempt from having to consent
        to the privacy policy

        Args:
            builder (synapse.events.builder.EventBuilder): event being created
            requester (Requster): user requesting this event

        Returns:
            Deferred[bool]: true if the event can be sent without the user
                consenting
        """
        # the only thing the user can do is join the server notices room.
        if builder.type == EventTypes.Member:
            membership = builder.content.get("membership", None)
            if membership == Membership.JOIN:
                return self._is_server_notices_room(builder.room_id)
            elif membership == Membership.LEAVE:
                # the user is always allowed to leave (but not kick people)
                return builder.state_key == requester.user.to_string()
        return succeed(False)

    @defer.inlineCallbacks
    def _is_server_notices_room(self, room_id):
        if self.config.server_notices_mxid is None:
            defer.returnValue(False)
        user_ids = yield self.store.get_users_in_room(room_id)
        defer.returnValue(self.config.server_notices_mxid in user_ids)

    @defer.inlineCallbacks
    def assert_accepted_privacy_policy(self, requester):
        """Check if a user has accepted the privacy policy

        Called when the given user is about to do something that requires
        privacy consent. We see if the user is exempt and otherwise check that
        they have given consent. If they have not, a ConsentNotGiven error is
        raised.

        Args:
            requester (synapse.types.Requester):
                The user making the request

        Returns:
            Deferred[None]: returns normally if the user has consented or is
                exempt

        Raises:
            ConsentNotGivenError: if the user has not given consent yet
        """
        if self._block_events_without_consent_error is None:
            return

        # exempt AS users from needing consent
        if requester.app_service is not None:
            return

        user_id = requester.user.to_string()

        # exempt the system notices user
        if (
            self.config.server_notices_mxid is not None and
            user_id == self.config.server_notices_mxid
        ):
            return

        u = yield self.store.get_user_by_id(user_id)
        assert u is not None
        if u["appservice_id"] is not None:
            # users registered by an appservice are exempt
            return
        if u["consent_version"] == self.config.user_consent_version:
            return

        consent_uri = self._consent_uri_builder.build_user_consent_uri(
            requester.user.localpart,
        )
        msg = self._block_events_without_consent_error % {
            'consent_uri': consent_uri,
        }
        raise ConsentNotGivenError(
            msg=msg,
            consent_uri=consent_uri,
        )

    @defer.inlineCallbacks
    def send_nonmember_event(self, requester, event, context, ratelimit=True):
        """
        Persists and notifies local clients and federation of an event.

        Args:
            event (FrozenEvent) the event to send.
            context (Context) the context of the event.
            ratelimit (bool): Whether to rate limit this send.
            is_guest (bool): Whether the sender is a guest.
        """
        if event.type == EventTypes.Member:
            raise SynapseError(
                500,
                "Tried to send member event through non-member codepath"
            )

        user = UserID.from_string(event.sender)

        assert self.hs.is_mine(user), "User must be our own: %s" % (user,)

        if event.is_state():
            prev_state = yield self.deduplicate_state_event(event, context)
            if prev_state is not None:
                logger.info(
                    "Not bothering to persist state event %s duplicated by %s",
                    event.event_id, prev_state.event_id,
                )
                defer.returnValue(prev_state)

        yield self.handle_new_client_event(
            requester=requester,
            event=event,
            context=context,
            ratelimit=ratelimit,
        )

    @defer.inlineCallbacks
    def deduplicate_state_event(self, event, context):
        """
        Checks whether event is in the latest resolved state in context.

        If so, returns the version of the event in context.
        Otherwise, returns None.
        """
        prev_state_ids = yield context.get_prev_state_ids(self.store)
        prev_event_id = prev_state_ids.get((event.type, event.state_key))
        prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
        if not prev_event:
            return

        if prev_event and event.user_id == prev_event.user_id:
            prev_content = encode_canonical_json(prev_event.content)
            next_content = encode_canonical_json(event.content)
            if prev_content == next_content:
                defer.returnValue(prev_event)
        return

    @defer.inlineCallbacks
    def create_and_send_nonmember_event(
        self,
        requester,
        event_dict,
        ratelimit=True,
        txn_id=None
    ):
        """
        Creates an event, then sends it.

        See self.create_event and self.send_nonmember_event.
        """

        # We limit the number of concurrent event sends in a room so that we
        # don't fork the DAG too much. If we don't limit then we can end up in
        # a situation where event persistence can't keep up, causing
        # extremities to pile up, which in turn leads to state resolution
        # taking longer.
        with (yield self.limiter.queue(event_dict["room_id"])):
            event, context = yield self.create_event(
                requester,
                event_dict,
                token_id=requester.access_token_id,
                txn_id=txn_id
            )

            spam_error = self.spam_checker.check_event_for_spam(event)
            if spam_error:
                if not isinstance(spam_error, string_types):
                    spam_error = "Spam is not permitted here"
                raise SynapseError(
                    403, spam_error, Codes.FORBIDDEN
                )

            yield self.send_nonmember_event(
                requester,
                event,
                context,
                ratelimit=ratelimit,
            )
        defer.returnValue(event)

    @measure_func("create_new_client_event")
    @defer.inlineCallbacks
    def create_new_client_event(self, builder, requester=None,
                                prev_events_and_hashes=None):
        """Create a new event for a local client

        Args:
            builder (EventBuilder):

            requester (synapse.types.Requester|None):

            prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
                the forward extremities to use as the prev_events for the
                new event. For each event, a tuple of (event_id, hashes, depth)
                where *hashes* is a map from algorithm to hash.

                If None, they will be requested from the database.

        Returns:
            Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)]
        """

        if prev_events_and_hashes is not None:
            assert len(prev_events_and_hashes) <= 10, \
                "Attempting to create an event with %i prev_events" % (
                    len(prev_events_and_hashes),
            )
        else:
            prev_events_and_hashes = \
                yield self.store.get_prev_events_for_room(builder.room_id)

        prev_events = [
            (event_id, prev_hashes)
            for event_id, prev_hashes, _ in prev_events_and_hashes
        ]

        event = yield builder.build(
            prev_event_ids=[p for p, _ in prev_events],
        )
        context = yield self.state.compute_event_context(event)
        if requester:
            context.app_service = requester.app_service

        self.validator.validate_new(event)

        # If this event is an annotation then we check that that the sender
        # can't annotate the same way twice (e.g. stops users from liking an
        # event multiple times).
        relation = event.content.get("m.relates_to", {})
        if relation.get("rel_type") == RelationTypes.ANNOTATION:
            relates_to = relation["event_id"]
            aggregation_key = relation["key"]

            already_exists = yield self.store.has_user_annotated_event(
                relates_to, event.type, aggregation_key, event.sender,
            )
            if already_exists:
                raise SynapseError(400, "Can't send same reaction twice")

        logger.debug(
            "Created event %s",
            event.event_id,
        )

        defer.returnValue(
            (event, context,)
        )

    @measure_func("handle_new_client_event")
    @defer.inlineCallbacks
    def handle_new_client_event(
        self,
        requester,
        event,
        context,
        ratelimit=True,
        extra_users=[],
    ):
        """Processes a new event. This includes checking auth, persisting it,
        notifying users, sending to remote servers, etc.

        If called from a worker will hit out to the master process for final
        processing.

        Args:
            requester (Requester)
            event (FrozenEvent)
            context (EventContext)
            ratelimit (bool)
            extra_users (list(UserID)): Any extra users to notify about event
        """

        if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""):
            room_version = event.content.get(
                "room_version", RoomVersions.V1.identifier
            )
        else:
            room_version = yield self.store.get_room_version(event.room_id)

        try:
            yield self.auth.check_from_context(room_version, event, context)
        except AuthError as err:
            logger.warn("Denying new event %r because %s", event, err)
            raise err

        # Ensure that we can round trip before trying to persist in db
        try:
            dump = frozendict_json_encoder.encode(event.content)
            json.loads(dump)
        except Exception:
            logger.exception("Failed to encode content: %r", event.content)
            raise

        yield self.action_generator.handle_push_actions_for_event(
            event, context
        )

        # reraise does not allow inlineCallbacks to preserve the stacktrace, so we
        # hack around with a try/finally instead.
        success = False
        try:
            # If we're a worker we need to hit out to the master.
            if self.config.worker_app:
                yield self.send_event_to_master(
                    event_id=event.event_id,
                    store=self.store,
                    requester=requester,
                    event=event,
                    context=context,
                    ratelimit=ratelimit,
                    extra_users=extra_users,
                )
                success = True
                return

            yield self.persist_and_notify_client_event(
                requester,
                event,
                context,
                ratelimit=ratelimit,
                extra_users=extra_users,
            )

            success = True
        finally:
            if not success:
                # Ensure that we actually remove the entries in the push actions
                # staging area, if we calculated them.
                run_in_background(
                    self.store.remove_push_actions_from_staging,
                    event.event_id,
                )

    @defer.inlineCallbacks
    def persist_and_notify_client_event(
        self,
        requester,
        event,
        context,
        ratelimit=True,
        extra_users=[],
    ):
        """Called when we have fully built the event, have already
        calculated the push actions for the event, and checked auth.

        This should only be run on master.
        """
        assert not self.config.worker_app

        if ratelimit:
            yield self.base_handler.ratelimit(requester)

        yield self.base_handler.maybe_kick_guest_users(event, context)

        if event.type == EventTypes.CanonicalAlias:
            # Check the alias is acually valid (at this time at least)
            room_alias_str = event.content.get("alias", None)
            if room_alias_str:
                room_alias = RoomAlias.from_string(room_alias_str)
                directory_handler = self.hs.get_handlers().directory_handler
                mapping = yield directory_handler.get_association(room_alias)

                if mapping["room_id"] != event.room_id:
                    raise SynapseError(
                        400,
                        "Room alias %s does not point to the room" % (
                            room_alias_str,
                        )
                    )

        federation_handler = self.hs.get_handlers().federation_handler

        if event.type == EventTypes.Member:
            if event.content["membership"] == Membership.INVITE:
                def is_inviter_member_event(e):
                    return (
                        e.type == EventTypes.Member and
                        e.sender == event.sender
                    )

                current_state_ids = yield context.get_current_state_ids(self.store)

                state_to_include_ids = [
                    e_id
                    for k, e_id in iteritems(current_state_ids)
                    if k[0] in self.hs.config.room_invite_state_types
                    or k == (EventTypes.Member, event.sender)
                ]

                state_to_include = yield self.store.get_events(state_to_include_ids)

                event.unsigned["invite_room_state"] = [
                    {
                        "type": e.type,
                        "state_key": e.state_key,
                        "content": e.content,
                        "sender": e.sender,
                    }
                    for e in itervalues(state_to_include)
                ]

                invitee = UserID.from_string(event.state_key)
                if not self.hs.is_mine(invitee):
                    # TODO: Can we add signature from remote server in a nicer
                    # way? If we have been invited by a remote server, we need
                    # to get them to sign the event.

                    returned_invite = yield federation_handler.send_invite(
                        invitee.domain,
                        event,
                    )

                    event.unsigned.pop("room_state", None)

                    # TODO: Make sure the signatures actually are correct.
                    event.signatures.update(
                        returned_invite.signatures
                    )

        if event.type == EventTypes.Redaction:
            prev_state_ids = yield context.get_prev_state_ids(self.store)
            auth_events_ids = yield self.auth.compute_auth_events(
                event, prev_state_ids, for_verification=True,
            )
            auth_events = yield self.store.get_events(auth_events_ids)
            auth_events = {
                (e.type, e.state_key): e for e in auth_events.values()
            }
            room_version = yield self.store.get_room_version(event.room_id)
            if self.auth.check_redaction(room_version, event, auth_events=auth_events):
                original_event = yield self.store.get_event(
                    event.redacts,
                    check_redacted=False,
                    get_prev_content=False,
                    allow_rejected=False,
                    allow_none=False
                )
                if event.user_id != original_event.user_id:
                    raise AuthError(
                        403,
                        "You don't have permission to redact events"
                    )

                # We've already checked.
                event.internal_metadata.recheck_redaction = False

        if event.type == EventTypes.Create:
            prev_state_ids = yield context.get_prev_state_ids(self.store)
            if prev_state_ids:
                raise AuthError(
                    403,
                    "Changing the room create event is forbidden",
                )

        (event_stream_id, max_stream_id) = yield self.store.persist_event(
            event, context=context
        )

        yield self.pusher_pool.on_new_notifications(
            event_stream_id, max_stream_id,
        )

        def _notify():
            try:
                self.notifier.on_new_room_event(
                    event, event_stream_id, max_stream_id,
                    extra_users=extra_users
                )
            except Exception:
                logger.exception("Error notifying about new room event")

        run_in_background(_notify)

        if event.type == EventTypes.Message:
            # We don't want to block sending messages on any presence code. This
            # matters as sometimes presence code can take a while.
            run_in_background(self._bump_active_time, requester.user)

    @defer.inlineCallbacks
    def _bump_active_time(self, user):
        try:
            presence = self.hs.get_presence_handler()
            yield presence.bump_presence_active_time(user)
        except Exception:
            logger.exception("Error bumping presence active time")
Пример #39
0
class EventCreationHandler(object):
    def __init__(self, hs):
        self.hs = hs
        self.auth = hs.get_auth()
        self.store = hs.get_datastore()
        self.state = hs.get_state_handler()
        self.clock = hs.get_clock()
        self.validator = EventValidator()
        self.profile_handler = hs.get_profile_handler()
        self.event_builder_factory = hs.get_event_builder_factory()
        self.server_name = hs.hostname
        self.ratelimiter = hs.get_ratelimiter()
        self.notifier = hs.get_notifier()
        self.config = hs.config

        self.send_event_to_master = ReplicationSendEventRestServlet.make_client(
            hs)

        # This is only used to get at ratelimit function, and maybe_kick_guest_users
        self.base_handler = BaseHandler(hs)

        self.pusher_pool = hs.get_pusherpool()

        # We arbitrarily limit concurrent event creation for a room to 5.
        # This is to stop us from diverging history *too* much.
        self.limiter = Linearizer(max_count=5,
                                  name="room_event_creation_limit")

        self.action_generator = hs.get_action_generator()

        self.spam_checker = hs.get_spam_checker()

        if self.config.block_events_without_consent_error is not None:
            self._consent_uri_builder = ConsentURIBuilder(self.config)

    @defer.inlineCallbacks
    def create_event(self,
                     requester,
                     event_dict,
                     token_id=None,
                     txn_id=None,
                     prev_events_and_hashes=None):
        """
        Given a dict from a client, create a new event.

        Creates an FrozenEvent object, filling out auth_events, prev_events,
        etc.

        Adds display names to Join membership events.

        Args:
            requester
            event_dict (dict): An entire event
            token_id (str)
            txn_id (str)

            prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
                the forward extremities to use as the prev_events for the
                new event. For each event, a tuple of (event_id, hashes, depth)
                where *hashes* is a map from algorithm to hash.

                If None, they will be requested from the database.
        Raises:
            ResourceLimitError if server is blocked to some resource being
            exceeded
        Returns:
            Tuple of created event (FrozenEvent), Context
        """
        yield self.auth.check_auth_blocking(requester.user.to_string())

        builder = self.event_builder_factory.new(event_dict)

        self.validator.validate_new(builder)

        if builder.type == EventTypes.Member:
            membership = builder.content.get("membership", None)
            target = UserID.from_string(builder.state_key)

            if membership in {Membership.JOIN, Membership.INVITE}:
                # If event doesn't include a display name, add one.
                profile = self.profile_handler
                content = builder.content

                try:
                    if "displayname" not in content:
                        content["displayname"] = yield profile.get_displayname(
                            target)
                    if "avatar_url" not in content:
                        content["avatar_url"] = yield profile.get_avatar_url(
                            target)
                except Exception as e:
                    logger.info("Failed to get profile information for %r: %s",
                                target, e)

        is_exempt = yield self._is_exempt_from_privacy_policy(
            builder, requester)
        if not is_exempt:
            yield self.assert_accepted_privacy_policy(requester)

        if token_id is not None:
            builder.internal_metadata.token_id = token_id

        if txn_id is not None:
            builder.internal_metadata.txn_id = txn_id

        event, context = yield self.create_new_client_event(
            builder=builder,
            requester=requester,
            prev_events_and_hashes=prev_events_and_hashes,
        )

        defer.returnValue((event, context))

    def _is_exempt_from_privacy_policy(self, builder, requester):
        """"Determine if an event to be sent is exempt from having to consent
        to the privacy policy

        Args:
            builder (synapse.events.builder.EventBuilder): event being created
            requester (Requster): user requesting this event

        Returns:
            Deferred[bool]: true if the event can be sent without the user
                consenting
        """
        # the only thing the user can do is join the server notices room.
        if builder.type == EventTypes.Member:
            membership = builder.content.get("membership", None)
            if membership == Membership.JOIN:
                return self._is_server_notices_room(builder.room_id)
            elif membership == Membership.LEAVE:
                # the user is always allowed to leave (but not kick people)
                return builder.state_key == requester.user.to_string()
        return succeed(False)

    @defer.inlineCallbacks
    def _is_server_notices_room(self, room_id):
        if self.config.server_notices_mxid is None:
            defer.returnValue(False)
        user_ids = yield self.store.get_users_in_room(room_id)
        defer.returnValue(self.config.server_notices_mxid in user_ids)

    @defer.inlineCallbacks
    def assert_accepted_privacy_policy(self, requester):
        """Check if a user has accepted the privacy policy

        Called when the given user is about to do something that requires
        privacy consent. We see if the user is exempt and otherwise check that
        they have given consent. If they have not, a ConsentNotGiven error is
        raised.

        Args:
            requester (synapse.types.Requester):
                The user making the request

        Returns:
            Deferred[None]: returns normally if the user has consented or is
                exempt

        Raises:
            ConsentNotGivenError: if the user has not given consent yet
        """
        if self.config.block_events_without_consent_error is None:
            return

        # exempt AS users from needing consent
        if requester.app_service is not None:
            return

        user_id = requester.user.to_string()

        # exempt the system notices user
        if (self.config.server_notices_mxid is not None
                and user_id == self.config.server_notices_mxid):
            return

        u = yield self.store.get_user_by_id(user_id)
        assert u is not None
        if u["appservice_id"] is not None:
            # users registered by an appservice are exempt
            return
        if u["consent_version"] == self.config.user_consent_version:
            return

        consent_uri = self._consent_uri_builder.build_user_consent_uri(
            requester.user.localpart, )
        msg = self.config.block_events_without_consent_error % {
            'consent_uri': consent_uri,
        }
        raise ConsentNotGivenError(
            msg=msg,
            consent_uri=consent_uri,
        )

    @defer.inlineCallbacks
    def send_nonmember_event(self, requester, event, context, ratelimit=True):
        """
        Persists and notifies local clients and federation of an event.

        Args:
            event (FrozenEvent) the event to send.
            context (Context) the context of the event.
            ratelimit (bool): Whether to rate limit this send.
            is_guest (bool): Whether the sender is a guest.
        """
        if event.type == EventTypes.Member:
            raise SynapseError(
                500, "Tried to send member event through non-member codepath")

        user = UserID.from_string(event.sender)

        assert self.hs.is_mine(user), "User must be our own: %s" % (user, )

        if event.is_state():
            prev_state = yield self.deduplicate_state_event(event, context)
            if prev_state is not None:
                defer.returnValue(prev_state)

        yield self.handle_new_client_event(
            requester=requester,
            event=event,
            context=context,
            ratelimit=ratelimit,
        )

    @defer.inlineCallbacks
    def deduplicate_state_event(self, event, context):
        """
        Checks whether event is in the latest resolved state in context.

        If so, returns the version of the event in context.
        Otherwise, returns None.
        """
        prev_state_ids = yield context.get_prev_state_ids(self.store)
        prev_event_id = prev_state_ids.get((event.type, event.state_key))
        prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
        if not prev_event:
            return

        if prev_event and event.user_id == prev_event.user_id:
            prev_content = encode_canonical_json(prev_event.content)
            next_content = encode_canonical_json(event.content)
            if prev_content == next_content:
                defer.returnValue(prev_event)
        return

    @defer.inlineCallbacks
    def create_and_send_nonmember_event(self,
                                        requester,
                                        event_dict,
                                        ratelimit=True,
                                        txn_id=None):
        """
        Creates an event, then sends it.

        See self.create_event and self.send_nonmember_event.
        """

        # We limit the number of concurrent event sends in a room so that we
        # don't fork the DAG too much. If we don't limit then we can end up in
        # a situation where event persistence can't keep up, causing
        # extremities to pile up, which in turn leads to state resolution
        # taking longer.
        with (yield self.limiter.queue(event_dict["room_id"])):
            event, context = yield self.create_event(
                requester,
                event_dict,
                token_id=requester.access_token_id,
                txn_id=txn_id)

            spam_error = self.spam_checker.check_event_for_spam(event)
            if spam_error:
                if not isinstance(spam_error, string_types):
                    spam_error = "Spam is not permitted here"
                raise SynapseError(403, spam_error, Codes.FORBIDDEN)

            yield self.send_nonmember_event(
                requester,
                event,
                context,
                ratelimit=ratelimit,
            )
        defer.returnValue(event)

    @measure_func("create_new_client_event")
    @defer.inlineCallbacks
    def create_new_client_event(self,
                                builder,
                                requester=None,
                                prev_events_and_hashes=None):
        """Create a new event for a local client

        Args:
            builder (EventBuilder):

            requester (synapse.types.Requester|None):

            prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
                the forward extremities to use as the prev_events for the
                new event. For each event, a tuple of (event_id, hashes, depth)
                where *hashes* is a map from algorithm to hash.

                If None, they will be requested from the database.

        Returns:
            Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)]
        """

        if prev_events_and_hashes is not None:
            assert len(prev_events_and_hashes) <= 10, \
                "Attempting to create an event with %i prev_events" % (
                    len(prev_events_and_hashes),
            )
        else:
            prev_events_and_hashes = \
                yield self.store.get_prev_events_for_room(builder.room_id)

        if prev_events_and_hashes:
            depth = max([d for _, _, d in prev_events_and_hashes]) + 1
            # we cap depth of generated events, to ensure that they are not
            # rejected by other servers (and so that they can be persisted in
            # the db)
            depth = min(depth, MAX_DEPTH)
        else:
            depth = 1

        prev_events = [(event_id, prev_hashes)
                       for event_id, prev_hashes, _ in prev_events_and_hashes]

        builder.prev_events = prev_events
        builder.depth = depth

        context = yield self.state.compute_event_context(builder)
        if requester:
            context.app_service = requester.app_service

        if builder.is_state():
            builder.prev_state = yield self.store.add_event_hashes(
                context.prev_state_events)

        yield self.auth.add_auth_events(builder, context)

        signing_key = self.hs.config.signing_key[0]
        add_hashes_and_signatures(builder, self.server_name, signing_key)

        event = builder.build()

        logger.debug(
            "Created event %s",
            event.event_id,
        )

        defer.returnValue((
            event,
            context,
        ))

    @measure_func("handle_new_client_event")
    @defer.inlineCallbacks
    def handle_new_client_event(
        self,
        requester,
        event,
        context,
        ratelimit=True,
        extra_users=[],
    ):
        """Processes a new event. This includes checking auth, persisting it,
        notifying users, sending to remote servers, etc.

        If called from a worker will hit out to the master process for final
        processing.

        Args:
            requester (Requester)
            event (FrozenEvent)
            context (EventContext)
            ratelimit (bool)
            extra_users (list(UserID)): Any extra users to notify about event
        """

        try:
            yield self.auth.check_from_context(event, context)
        except AuthError as err:
            logger.warn("Denying new event %r because %s", event, err)
            raise err

        # Ensure that we can round trip before trying to persist in db
        try:
            dump = frozendict_json_encoder.encode(event.content)
            json.loads(dump)
        except Exception:
            logger.exception("Failed to encode content: %r", event.content)
            raise

        yield self.action_generator.handle_push_actions_for_event(
            event, context)

        try:
            # If we're a worker we need to hit out to the master.
            if self.config.worker_app:
                yield self.send_event_to_master(
                    event_id=event.event_id,
                    store=self.store,
                    requester=requester,
                    event=event,
                    context=context,
                    ratelimit=ratelimit,
                    extra_users=extra_users,
                )
                return

            yield self.persist_and_notify_client_event(
                requester,
                event,
                context,
                ratelimit=ratelimit,
                extra_users=extra_users,
            )
        except:  # noqa: E722, as we reraise the exception this is fine.
            # Ensure that we actually remove the entries in the push actions
            # staging area, if we calculated them.
            tp, value, tb = sys.exc_info()

            run_in_background(
                self.store.remove_push_actions_from_staging,
                event.event_id,
            )

            six.reraise(tp, value, tb)

    @defer.inlineCallbacks
    def persist_and_notify_client_event(
        self,
        requester,
        event,
        context,
        ratelimit=True,
        extra_users=[],
    ):
        """Called when we have fully built the event, have already
        calculated the push actions for the event, and checked auth.

        This should only be run on master.
        """
        assert not self.config.worker_app

        if ratelimit:
            yield self.base_handler.ratelimit(requester)

        yield self.base_handler.maybe_kick_guest_users(event, context)

        if event.type == EventTypes.CanonicalAlias:
            # Check the alias is acually valid (at this time at least)
            room_alias_str = event.content.get("alias", None)
            if room_alias_str:
                room_alias = RoomAlias.from_string(room_alias_str)
                directory_handler = self.hs.get_handlers().directory_handler
                mapping = yield directory_handler.get_association(room_alias)

                if mapping["room_id"] != event.room_id:
                    raise SynapseError(
                        400, "Room alias %s does not point to the room" %
                        (room_alias_str, ))

        federation_handler = self.hs.get_handlers().federation_handler

        if event.type == EventTypes.Member:
            if event.content["membership"] == Membership.INVITE:

                def is_inviter_member_event(e):
                    return (e.type == EventTypes.Member
                            and e.sender == event.sender)

                current_state_ids = yield context.get_current_state_ids(
                    self.store)

                state_to_include_ids = [
                    e_id for k, e_id in iteritems(current_state_ids)
                    if k[0] in self.hs.config.room_invite_state_types or k == (
                        EventTypes.Member, event.sender)
                ]

                state_to_include = yield self.store.get_events(
                    state_to_include_ids)

                event.unsigned["invite_room_state"] = [{
                    "type": e.type,
                    "state_key": e.state_key,
                    "content": e.content,
                    "sender": e.sender,
                } for e in itervalues(state_to_include)]

                invitee = UserID.from_string(event.state_key)
                if not self.hs.is_mine(invitee):
                    # TODO: Can we add signature from remote server in a nicer
                    # way? If we have been invited by a remote server, we need
                    # to get them to sign the event.

                    returned_invite = yield federation_handler.send_invite(
                        invitee.domain,
                        event,
                    )

                    event.unsigned.pop("room_state", None)

                    # TODO: Make sure the signatures actually are correct.
                    event.signatures.update(returned_invite.signatures)

        if event.type == EventTypes.Redaction:
            prev_state_ids = yield context.get_prev_state_ids(self.store)
            auth_events_ids = yield self.auth.compute_auth_events(
                event,
                prev_state_ids,
                for_verification=True,
            )
            auth_events = yield self.store.get_events(auth_events_ids)
            auth_events = {(e.type, e.state_key): e
                           for e in auth_events.values()}
            if self.auth.check_redaction(event, auth_events=auth_events):
                original_event = yield self.store.get_event(
                    event.redacts,
                    check_redacted=False,
                    get_prev_content=False,
                    allow_rejected=False,
                    allow_none=False)
                if event.user_id != original_event.user_id:
                    raise AuthError(
                        403, "You don't have permission to redact events")

        if event.type == EventTypes.Create:
            prev_state_ids = yield context.get_prev_state_ids(self.store)
            if prev_state_ids:
                raise AuthError(
                    403,
                    "Changing the room create event is forbidden",
                )

        (event_stream_id,
         max_stream_id) = yield self.store.persist_event(event,
                                                         context=context)

        self.pusher_pool.on_new_notifications(
            event_stream_id,
            max_stream_id,
        )

        def _notify():
            try:
                self.notifier.on_new_room_event(event,
                                                event_stream_id,
                                                max_stream_id,
                                                extra_users=extra_users)
            except Exception:
                logger.exception("Error notifying about new room event")

        run_in_background(_notify)

        if event.type == EventTypes.Message:
            # We don't want to block sending messages on any presence code. This
            # matters as sometimes presence code can take a while.
            run_in_background(self._bump_active_time, requester.user)

    @defer.inlineCallbacks
    def _bump_active_time(self, user):
        try:
            presence = self.hs.get_presence_handler()
            yield presence.bump_presence_active_time(user)
        except Exception:
            logger.exception("Error bumping presence active time")
Пример #40
0
class PresenceHandler(object):

    def __init__(self, hs):
        """

        Args:
            hs (synapse.server.HomeServer):
        """
        self.hs = hs
        self.is_mine = hs.is_mine
        self.is_mine_id = hs.is_mine_id
        self.clock = hs.get_clock()
        self.store = hs.get_datastore()
        self.wheel_timer = WheelTimer()
        self.notifier = hs.get_notifier()
        self.federation = hs.get_federation_sender()
        self.state = hs.get_state_handler()

        federation_registry = hs.get_federation_registry()

        federation_registry.register_edu_handler(
            "m.presence", self.incoming_presence
        )
        federation_registry.register_edu_handler(
            "m.presence_invite",
            lambda origin, content: self.invite_presence(
                observed_user=UserID.from_string(content["observed_user"]),
                observer_user=UserID.from_string(content["observer_user"]),
            )
        )
        federation_registry.register_edu_handler(
            "m.presence_accept",
            lambda origin, content: self.accept_presence(
                observed_user=UserID.from_string(content["observed_user"]),
                observer_user=UserID.from_string(content["observer_user"]),
            )
        )
        federation_registry.register_edu_handler(
            "m.presence_deny",
            lambda origin, content: self.deny_presence(
                observed_user=UserID.from_string(content["observed_user"]),
                observer_user=UserID.from_string(content["observer_user"]),
            )
        )

        distributor = hs.get_distributor()
        distributor.observe("user_joined_room", self.user_joined_room)

        active_presence = self.store.take_presence_startup_info()

        # A dictionary of the current state of users. This is prefilled with
        # non-offline presence from the DB. We should fetch from the DB if
        # we can't find a users presence in here.
        self.user_to_current_state = {
            state.user_id: state
            for state in active_presence
        }

        LaterGauge(
            "synapse_handlers_presence_user_to_current_state_size", "", [],
            lambda: len(self.user_to_current_state)
        )

        now = self.clock.time_msec()
        for state in active_presence:
            self.wheel_timer.insert(
                now=now,
                obj=state.user_id,
                then=state.last_active_ts + IDLE_TIMER,
            )
            self.wheel_timer.insert(
                now=now,
                obj=state.user_id,
                then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
            )
            if self.is_mine_id(state.user_id):
                self.wheel_timer.insert(
                    now=now,
                    obj=state.user_id,
                    then=state.last_federation_update_ts + FEDERATION_PING_INTERVAL,
                )
            else:
                self.wheel_timer.insert(
                    now=now,
                    obj=state.user_id,
                    then=state.last_federation_update_ts + FEDERATION_TIMEOUT,
                )

        # Set of users who have presence in the `user_to_current_state` that
        # have not yet been persisted
        self.unpersisted_users_changes = set()

        hs.get_reactor().addSystemEventTrigger("before", "shutdown", self._on_shutdown)

        self.serial_to_user = {}
        self._next_serial = 1

        # Keeps track of the number of *ongoing* syncs on this process. While
        # this is non zero a user will never go offline.
        self.user_to_num_current_syncs = {}

        # Keeps track of the number of *ongoing* syncs on other processes.
        # While any sync is ongoing on another process the user will never
        # go offline.
        # Each process has a unique identifier and an update frequency. If
        # no update is received from that process within the update period then
        # we assume that all the sync requests on that process have stopped.
        # Stored as a dict from process_id to set of user_id, and a dict of
        # process_id to millisecond timestamp last updated.
        self.external_process_to_current_syncs = {}
        self.external_process_last_updated_ms = {}
        self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")

        # Start a LoopingCall in 30s that fires every 5s.
        # The initial delay is to allow disconnected clients a chance to
        # reconnect before we treat them as offline.
        self.clock.call_later(
            30,
            self.clock.looping_call,
            self._handle_timeouts,
            5000,
        )

        self.clock.call_later(
            60,
            self.clock.looping_call,
            self._persist_unpersisted_changes,
            60 * 1000,
        )

        LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [],
                   lambda: len(self.wheel_timer))

    @defer.inlineCallbacks
    def _on_shutdown(self):
        """Gets called when shutting down. This lets us persist any updates that
        we haven't yet persisted, e.g. updates that only changes some internal
        timers. This allows changes to persist across startup without having to
        persist every single change.

        If this does not run it simply means that some of the timers will fire
        earlier than they should when synapse is restarted. This affect of this
        is some spurious presence changes that will self-correct.
        """
        # If the DB pool has already terminated, don't try updating
        if not self.hs.get_db_pool().running:
            return

        logger.info(
            "Performing _on_shutdown. Persisting %d unpersisted changes",
            len(self.user_to_current_state)
        )

        if self.unpersisted_users_changes:
            yield self.store.update_presence([
                self.user_to_current_state[user_id]
                for user_id in self.unpersisted_users_changes
            ])
        logger.info("Finished _on_shutdown")

    @defer.inlineCallbacks
    def _persist_unpersisted_changes(self):
        """We periodically persist the unpersisted changes, as otherwise they
        may stack up and slow down shutdown times.
        """
        logger.info(
            "Performing _persist_unpersisted_changes. Persisting %d unpersisted changes",
            len(self.unpersisted_users_changes)
        )

        unpersisted = self.unpersisted_users_changes
        self.unpersisted_users_changes = set()

        if unpersisted:
            yield self.store.update_presence([
                self.user_to_current_state[user_id]
                for user_id in unpersisted
            ])

        logger.info("Finished _persist_unpersisted_changes")

    @defer.inlineCallbacks
    def _update_states_and_catch_exception(self, new_states):
        try:
            res = yield self._update_states(new_states)
            defer.returnValue(res)
        except Exception:
            logger.exception("Error updating presence")

    @defer.inlineCallbacks
    def _update_states(self, new_states):
        """Updates presence of users. Sets the appropriate timeouts. Pokes
        the notifier and federation if and only if the changed presence state
        should be sent to clients/servers.
        """
        now = self.clock.time_msec()

        with Measure(self.clock, "presence_update_states"):

            # NOTE: We purposefully don't yield between now and when we've
            # calculated what we want to do with the new states, to avoid races.

            to_notify = {}  # Changes we want to notify everyone about
            to_federation_ping = {}  # These need sending keep-alives

            # Only bother handling the last presence change for each user
            new_states_dict = {}
            for new_state in new_states:
                new_states_dict[new_state.user_id] = new_state
            new_state = new_states_dict.values()

            for new_state in new_states:
                user_id = new_state.user_id

                # Its fine to not hit the database here, as the only thing not in
                # the current state cache are OFFLINE states, where the only field
                # of interest is last_active which is safe enough to assume is 0
                # here.
                prev_state = self.user_to_current_state.get(
                    user_id, UserPresenceState.default(user_id)
                )

                new_state, should_notify, should_ping = handle_update(
                    prev_state, new_state,
                    is_mine=self.is_mine_id(user_id),
                    wheel_timer=self.wheel_timer,
                    now=now
                )

                self.user_to_current_state[user_id] = new_state

                if should_notify:
                    to_notify[user_id] = new_state
                elif should_ping:
                    to_federation_ping[user_id] = new_state

            # TODO: We should probably ensure there are no races hereafter

            presence_updates_counter.inc(len(new_states))

            if to_notify:
                notified_presence_counter.inc(len(to_notify))
                yield self._persist_and_notify(list(to_notify.values()))

            self.unpersisted_users_changes |= set(s.user_id for s in new_states)
            self.unpersisted_users_changes -= set(to_notify.keys())

            to_federation_ping = {
                user_id: state for user_id, state in to_federation_ping.items()
                if user_id not in to_notify
            }
            if to_federation_ping:
                federation_presence_out_counter.inc(len(to_federation_ping))

                self._push_to_remotes(to_federation_ping.values())

    def _handle_timeouts(self):
        """Checks the presence of users that have timed out and updates as
        appropriate.
        """
        logger.info("Handling presence timeouts")
        now = self.clock.time_msec()

        try:
            with Measure(self.clock, "presence_handle_timeouts"):
                # Fetch the list of users that *may* have timed out. Things may have
                # changed since the timeout was set, so we won't necessarily have to
                # take any action.
                users_to_check = set(self.wheel_timer.fetch(now))

                # Check whether the lists of syncing processes from an external
                # process have expired.
                expired_process_ids = [
                    process_id for process_id, last_update
                    in self.external_process_last_updated_ms.items()
                    if now - last_update > EXTERNAL_PROCESS_EXPIRY
                ]
                for process_id in expired_process_ids:
                    users_to_check.update(
                        self.external_process_last_updated_ms.pop(process_id, ())
                    )
                    self.external_process_last_update.pop(process_id)

                states = [
                    self.user_to_current_state.get(
                        user_id, UserPresenceState.default(user_id)
                    )
                    for user_id in users_to_check
                ]

                timers_fired_counter.inc(len(states))

                changes = handle_timeouts(
                    states,
                    is_mine_fn=self.is_mine_id,
                    syncing_user_ids=self.get_currently_syncing_users(),
                    now=now,
                )

            run_in_background(self._update_states_and_catch_exception, changes)
        except Exception:
            logger.exception("Exception in _handle_timeouts loop")

    @defer.inlineCallbacks
    def bump_presence_active_time(self, user):
        """We've seen the user do something that indicates they're interacting
        with the app.
        """
        # If presence is disabled, no-op
        if not self.hs.config.use_presence:
            return

        user_id = user.to_string()

        bump_active_time_counter.inc()

        prev_state = yield self.current_state_for_user(user_id)

        new_fields = {
            "last_active_ts": self.clock.time_msec(),
        }
        if prev_state.state == PresenceState.UNAVAILABLE:
            new_fields["state"] = PresenceState.ONLINE

        yield self._update_states([prev_state.copy_and_replace(**new_fields)])

    @defer.inlineCallbacks
    def user_syncing(self, user_id, affect_presence=True):
        """Returns a context manager that should surround any stream requests
        from the user.

        This allows us to keep track of who is currently streaming and who isn't
        without having to have timers outside of this module to avoid flickering
        when users disconnect/reconnect.

        Args:
            user_id (str)
            affect_presence (bool): If false this function will be a no-op.
                Useful for streams that are not associated with an actual
                client that is being used by a user.
        """
        # Override if it should affect the user's presence, if presence is
        # disabled.
        if not self.hs.config.use_presence:
            affect_presence = False

        if affect_presence:
            curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
            self.user_to_num_current_syncs[user_id] = curr_sync + 1

            prev_state = yield self.current_state_for_user(user_id)
            if prev_state.state == PresenceState.OFFLINE:
                # If they're currently offline then bring them online, otherwise
                # just update the last sync times.
                yield self._update_states([prev_state.copy_and_replace(
                    state=PresenceState.ONLINE,
                    last_active_ts=self.clock.time_msec(),
                    last_user_sync_ts=self.clock.time_msec(),
                )])
            else:
                yield self._update_states([prev_state.copy_and_replace(
                    last_user_sync_ts=self.clock.time_msec(),
                )])

        @defer.inlineCallbacks
        def _end():
            try:
                self.user_to_num_current_syncs[user_id] -= 1

                prev_state = yield self.current_state_for_user(user_id)
                yield self._update_states([prev_state.copy_and_replace(
                    last_user_sync_ts=self.clock.time_msec(),
                )])
            except Exception:
                logger.exception("Error updating presence after sync")

        @contextmanager
        def _user_syncing():
            try:
                yield
            finally:
                if affect_presence:
                    run_in_background(_end)

        defer.returnValue(_user_syncing())

    def get_currently_syncing_users(self):
        """Get the set of user ids that are currently syncing on this HS.
        Returns:
            set(str): A set of user_id strings.
        """
        if self.hs.config.use_presence:
            syncing_user_ids = {
                user_id for user_id, count in self.user_to_num_current_syncs.items()
                if count
            }
            for user_ids in self.external_process_to_current_syncs.values():
                syncing_user_ids.update(user_ids)
            return syncing_user_ids
        else:
            return set()

    @defer.inlineCallbacks
    def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec):
        """Update the syncing users for an external process as a delta.

        Args:
            process_id (str): An identifier for the process the users are
                syncing against. This allows synapse to process updates
                as user start and stop syncing against a given process.
            user_id (str): The user who has started or stopped syncing
            is_syncing (bool): Whether or not the user is now syncing
            sync_time_msec(int): Time in ms when the user was last syncing
        """
        with (yield self.external_sync_linearizer.queue(process_id)):
            prev_state = yield self.current_state_for_user(user_id)

            process_presence = self.external_process_to_current_syncs.setdefault(
                process_id, set()
            )

            updates = []
            if is_syncing and user_id not in process_presence:
                if prev_state.state == PresenceState.OFFLINE:
                    updates.append(prev_state.copy_and_replace(
                        state=PresenceState.ONLINE,
                        last_active_ts=sync_time_msec,
                        last_user_sync_ts=sync_time_msec,
                    ))
                else:
                    updates.append(prev_state.copy_and_replace(
                        last_user_sync_ts=sync_time_msec,
                    ))
                process_presence.add(user_id)
            elif user_id in process_presence:
                updates.append(prev_state.copy_and_replace(
                    last_user_sync_ts=sync_time_msec,
                ))

            if not is_syncing:
                process_presence.discard(user_id)

            if updates:
                yield self._update_states(updates)

            self.external_process_last_updated_ms[process_id] = self.clock.time_msec()

    @defer.inlineCallbacks
    def update_external_syncs_clear(self, process_id):
        """Marks all users that had been marked as syncing by a given process
        as offline.

        Used when the process has stopped/disappeared.
        """
        with (yield self.external_sync_linearizer.queue(process_id)):
            process_presence = self.external_process_to_current_syncs.pop(
                process_id, set()
            )
            prev_states = yield self.current_state_for_users(process_presence)
            time_now_ms = self.clock.time_msec()

            yield self._update_states([
                prev_state.copy_and_replace(
                    last_user_sync_ts=time_now_ms,
                )
                for prev_state in itervalues(prev_states)
            ])
            self.external_process_last_updated_ms.pop(process_id, None)

    @defer.inlineCallbacks
    def current_state_for_user(self, user_id):
        """Get the current presence state for a user.
        """
        res = yield self.current_state_for_users([user_id])
        defer.returnValue(res[user_id])

    @defer.inlineCallbacks
    def current_state_for_users(self, user_ids):
        """Get the current presence state for multiple users.

        Returns:
            dict: `user_id` -> `UserPresenceState`
        """
        states = {
            user_id: self.user_to_current_state.get(user_id, None)
            for user_id in user_ids
        }

        missing = [user_id for user_id, state in iteritems(states) if not state]
        if missing:
            # There are things not in our in memory cache. Lets pull them out of
            # the database.
            res = yield self.store.get_presence_for_users(missing)
            states.update(res)

            missing = [user_id for user_id, state in iteritems(states) if not state]
            if missing:
                new = {
                    user_id: UserPresenceState.default(user_id)
                    for user_id in missing
                }
                states.update(new)
                self.user_to_current_state.update(new)

        defer.returnValue(states)

    @defer.inlineCallbacks
    def _persist_and_notify(self, states):
        """Persist states in the database, poke the notifier and send to
        interested remote servers
        """
        stream_id, max_token = yield self.store.update_presence(states)

        parties = yield get_interested_parties(self.store, states)
        room_ids_to_states, users_to_states = parties

        self.notifier.on_new_event(
            "presence_key", stream_id, rooms=room_ids_to_states.keys(),
            users=[UserID.from_string(u) for u in users_to_states]
        )

        self._push_to_remotes(states)

    @defer.inlineCallbacks
    def notify_for_states(self, state, stream_id):
        parties = yield get_interested_parties(self.store, [state])
        room_ids_to_states, users_to_states = parties

        self.notifier.on_new_event(
            "presence_key", stream_id, rooms=room_ids_to_states.keys(),
            users=[UserID.from_string(u) for u in users_to_states]
        )

    def _push_to_remotes(self, states):
        """Sends state updates to remote servers.

        Args:
            states (list(UserPresenceState))
        """
        self.federation.send_presence(states)

    @defer.inlineCallbacks
    def incoming_presence(self, origin, content):
        """Called when we receive a `m.presence` EDU from a remote server.
        """
        now = self.clock.time_msec()
        updates = []
        for push in content.get("push", []):
            # A "push" contains a list of presence that we are probably interested
            # in.
            # TODO: Actually check if we're interested, rather than blindly
            # accepting presence updates.
            user_id = push.get("user_id", None)
            if not user_id:
                logger.info(
                    "Got presence update from %r with no 'user_id': %r",
                    origin, push,
                )
                continue

            if get_domain_from_id(user_id) != origin:
                logger.info(
                    "Got presence update from %r with bad 'user_id': %r",
                    origin, user_id,
                )
                continue

            presence_state = push.get("presence", None)
            if not presence_state:
                logger.info(
                    "Got presence update from %r with no 'presence_state': %r",
                    origin, push,
                )
                continue

            new_fields = {
                "state": presence_state,
                "last_federation_update_ts": now,
            }

            last_active_ago = push.get("last_active_ago", None)
            if last_active_ago is not None:
                new_fields["last_active_ts"] = now - last_active_ago

            new_fields["status_msg"] = push.get("status_msg", None)
            new_fields["currently_active"] = push.get("currently_active", False)

            prev_state = yield self.current_state_for_user(user_id)
            updates.append(prev_state.copy_and_replace(**new_fields))

        if updates:
            federation_presence_counter.inc(len(updates))
            yield self._update_states(updates)

    @defer.inlineCallbacks
    def get_state(self, target_user, as_event=False):
        results = yield self.get_states(
            [target_user.to_string()],
            as_event=as_event,
        )

        defer.returnValue(results[0])

    @defer.inlineCallbacks
    def get_states(self, target_user_ids, as_event=False):
        """Get the presence state for users.

        Args:
            target_user_ids (list)
            as_event (bool): Whether to format it as a client event or not.

        Returns:
            list
        """

        updates = yield self.current_state_for_users(target_user_ids)
        updates = list(updates.values())

        for user_id in set(target_user_ids) - set(u.user_id for u in updates):
            updates.append(UserPresenceState.default(user_id))

        now = self.clock.time_msec()
        if as_event:
            defer.returnValue([
                {
                    "type": "m.presence",
                    "content": format_user_presence_state(state, now),
                }
                for state in updates
            ])
        else:
            defer.returnValue(updates)

    @defer.inlineCallbacks
    def set_state(self, target_user, state, ignore_status_msg=False):
        """Set the presence state of the user.
        """
        status_msg = state.get("status_msg", None)
        presence = state["presence"]

        valid_presence = (
            PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE
        )
        if presence not in valid_presence:
            raise SynapseError(400, "Invalid presence state")

        user_id = target_user.to_string()

        prev_state = yield self.current_state_for_user(user_id)

        new_fields = {
            "state": presence
        }

        if not ignore_status_msg:
            msg = status_msg if presence != PresenceState.OFFLINE else None
            new_fields["status_msg"] = msg

        if presence == PresenceState.ONLINE:
            new_fields["last_active_ts"] = self.clock.time_msec()

        yield self._update_states([prev_state.copy_and_replace(**new_fields)])

    @defer.inlineCallbacks
    def user_joined_room(self, user, room_id):
        """Called (via the distributor) when a user joins a room. This funciton
        sends presence updates to servers, either:
            1. the joining user is a local user and we send their presence to
               all servers in the room.
            2. the joining user is a remote user and so we send presence for all
               local users in the room.
        """
        # We only need to send presence to servers that don't have it yet. We
        # don't need to send to local clients here, as that is done as part
        # of the event stream/sync.
        # TODO: Only send to servers not already in the room.
        if self.is_mine(user):
            state = yield self.current_state_for_user(user.to_string())

            self._push_to_remotes([state])
        else:
            user_ids = yield self.store.get_users_in_room(room_id)
            user_ids = list(filter(self.is_mine_id, user_ids))

            states = yield self.current_state_for_users(user_ids)

            self._push_to_remotes(list(states.values()))

    @defer.inlineCallbacks
    def get_presence_list(self, observer_user, accepted=None):
        """Returns the presence for all users in their presence list.
        """
        if not self.is_mine(observer_user):
            raise SynapseError(400, "User is not hosted on this Home Server")

        presence_list = yield self.store.get_presence_list(
            observer_user.localpart, accepted=accepted
        )

        results = yield self.get_states(
            target_user_ids=[row["observed_user_id"] for row in presence_list],
            as_event=False,
        )

        now = self.clock.time_msec()
        results[:] = [format_user_presence_state(r, now) for r in results]

        is_accepted = {
            row["observed_user_id"]: row["accepted"] for row in presence_list
        }

        for result in results:
            result.update({
                "accepted": is_accepted,
            })

        defer.returnValue(results)

    @defer.inlineCallbacks
    def send_presence_invite(self, observer_user, observed_user):
        """Sends a presence invite.
        """
        yield self.store.add_presence_list_pending(
            observer_user.localpart, observed_user.to_string()
        )

        if self.is_mine(observed_user):
            yield self.invite_presence(observed_user, observer_user)
        else:
            yield self.federation.send_edu(
                destination=observed_user.domain,
                edu_type="m.presence_invite",
                content={
                    "observed_user": observed_user.to_string(),
                    "observer_user": observer_user.to_string(),
                }
            )

    @defer.inlineCallbacks
    def invite_presence(self, observed_user, observer_user):
        """Handles new presence invites.
        """
        if not self.is_mine(observed_user):
            raise SynapseError(400, "User is not hosted on this Home Server")

        # TODO: Don't auto accept
        if self.is_mine(observer_user):
            yield self.accept_presence(observed_user, observer_user)
        else:
            self.federation.send_edu(
                destination=observer_user.domain,
                edu_type="m.presence_accept",
                content={
                    "observed_user": observed_user.to_string(),
                    "observer_user": observer_user.to_string(),
                }
            )

            state_dict = yield self.get_state(observed_user, as_event=False)
            state_dict = format_user_presence_state(state_dict, self.clock.time_msec())

            self.federation.send_edu(
                destination=observer_user.domain,
                edu_type="m.presence",
                content={
                    "push": [state_dict]
                }
            )

    @defer.inlineCallbacks
    def accept_presence(self, observed_user, observer_user):
        """Handles a m.presence_accept EDU. Mark a presence invite from a
        local or remote user as accepted in a local user's presence list.
        Starts polling for presence updates from the local or remote user.
        Args:
            observed_user(UserID): The user to update in the presence list.
            observer_user(UserID): The owner of the presence list to update.
        """
        yield self.store.set_presence_list_accepted(
            observer_user.localpart, observed_user.to_string()
        )

    @defer.inlineCallbacks
    def deny_presence(self, observed_user, observer_user):
        """Handle a m.presence_deny EDU. Removes a local or remote user from a
        local user's presence list.
        Args:
            observed_user(UserID): The local or remote user to remove from the
                list.
            observer_user(UserID): The local owner of the presence list.
        Returns:
            A Deferred.
        """
        yield self.store.del_presence_list(
            observer_user.localpart, observed_user.to_string()
        )

        # TODO(paul): Inform the user somehow?

    @defer.inlineCallbacks
    def drop(self, observed_user, observer_user):
        """Remove a local or remote user from a local user's presence list and
        unsubscribe the local user from updates that user.
        Args:
            observed_user(UserId): The local or remote user to remove from the
                list.
            observer_user(UserId): The local owner of the presence list.
        Returns:
            A Deferred.
        """
        if not self.is_mine(observer_user):
            raise SynapseError(400, "User is not hosted on this Home Server")

        yield self.store.del_presence_list(
            observer_user.localpart, observed_user.to_string()
        )

        # TODO: Inform the remote that we've dropped the presence list.

    @defer.inlineCallbacks
    def is_visible(self, observed_user, observer_user):
        """Returns whether a user can see another user's presence.
        """
        observer_room_ids = yield self.store.get_rooms_for_user(
            observer_user.to_string()
        )
        observed_room_ids = yield self.store.get_rooms_for_user(
            observed_user.to_string()
        )

        if observer_room_ids & observed_room_ids:
            defer.returnValue(True)

        accepted_observers = yield self.store.get_presence_list_observers_accepted(
            observed_user.to_string()
        )

        defer.returnValue(observer_user.to_string() in accepted_observers)

    @defer.inlineCallbacks
    def get_all_presence_updates(self, last_id, current_id):
        """
        Gets a list of presence update rows from between the given stream ids.
        Each row has:
        - stream_id(str)
        - user_id(str)
        - state(str)
        - last_active_ts(int)
        - last_federation_update_ts(int)
        - last_user_sync_ts(int)
        - status_msg(int)
        - currently_active(int)
        """
        # TODO(markjh): replicate the unpersisted changes.
        # This could use the in-memory stores for recent changes.
        rows = yield self.store.get_all_presence_updates(last_id, current_id)
        defer.returnValue(rows)
Пример #41
0
class MediaRepository:
    def __init__(self, hs: "HomeServer"):
        self.hs = hs
        self.auth = hs.get_auth()
        self.client = hs.get_federation_http_client()
        self.clock = hs.get_clock()
        self.server_name = hs.hostname
        self.store = hs.get_datastores().main
        self.max_upload_size = hs.config.media.max_upload_size
        self.max_image_pixels = hs.config.media.max_image_pixels

        Thumbnailer.set_limits(self.max_image_pixels)

        self.primary_base_path: str = hs.config.media.media_store_path
        self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path)

        self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
        self.thumbnail_requirements = hs.config.media.thumbnail_requirements

        self.remote_media_linearizer = Linearizer(name="media_remote")

        self.recently_accessed_remotes: Set[Tuple[str, str]] = set()
        self.recently_accessed_locals: Set[str] = set()

        self.federation_domain_whitelist = (
            hs.config.federation.federation_domain_whitelist)

        # List of StorageProviders where we should search for media and
        # potentially upload to.
        storage_providers = []

        for (
                clz,
                provider_config,
                wrapper_config,
        ) in hs.config.media.media_storage_providers:
            backend = clz(hs, provider_config)
            provider = StorageProviderWrapper(
                backend,
                store_local=wrapper_config.store_local,
                store_remote=wrapper_config.store_remote,
                store_synchronous=wrapper_config.store_synchronous,
            )
            storage_providers.append(provider)

        self.media_storage = MediaStorage(self.hs, self.primary_base_path,
                                          self.filepaths, storage_providers)

        self.clock.looping_call(self._start_update_recently_accessed,
                                UPDATE_RECENTLY_ACCESSED_TS)

        # Media retention configuration options
        self._media_retention_local_media_lifetime_ms = (
            hs.config.media.media_retention_local_media_lifetime_ms)
        self._media_retention_remote_media_lifetime_ms = (
            hs.config.media.media_retention_remote_media_lifetime_ms)

        # Check whether local or remote media retention is configured
        if (hs.config.media.media_retention_local_media_lifetime_ms is not None
                or hs.config.media.media_retention_remote_media_lifetime_ms
                is not None):
            # Run the background job to apply media retention rules routinely,
            # with the duration between runs dictated by the homeserver config.
            self.clock.looping_call(
                self._start_apply_media_retention_rules,
                MEDIA_RETENTION_CHECK_PERIOD_MS,
            )

    def _start_update_recently_accessed(self) -> Deferred:
        return run_as_background_process("update_recently_accessed_media",
                                         self._update_recently_accessed)

    def _start_apply_media_retention_rules(self) -> Deferred:
        return run_as_background_process("apply_media_retention_rules",
                                         self._apply_media_retention_rules)

    async def _update_recently_accessed(self) -> None:
        remote_media = self.recently_accessed_remotes
        self.recently_accessed_remotes = set()

        local_media = self.recently_accessed_locals
        self.recently_accessed_locals = set()

        await self.store.update_cached_last_access_time(
            local_media, remote_media, self.clock.time_msec())

    def mark_recently_accessed(self, server_name: Optional[str],
                               media_id: str) -> None:
        """Mark the given media as recently accessed.

        Args:
            server_name: Origin server of media, or None if local
            media_id: The media ID of the content
        """
        if server_name:
            self.recently_accessed_remotes.add((server_name, media_id))
        else:
            self.recently_accessed_locals.add(media_id)

    async def create_content(
        self,
        media_type: str,
        upload_name: Optional[str],
        content: IO,
        content_length: int,
        auth_user: UserID,
    ) -> str:
        """Store uploaded content for a local user and return the mxc URL

        Args:
            media_type: The content type of the file.
            upload_name: The name of the file, if provided.
            content: A file like object that is the content to store
            content_length: The length of the content
            auth_user: The user_id of the uploader

        Returns:
            The mxc url of the stored content
        """

        media_id = random_string(24)

        file_info = FileInfo(server_name=None, file_id=media_id)

        fname = await self.media_storage.store_file(content, file_info)

        logger.info("Stored local media in file %r", fname)

        await self.store.store_local_media(
            media_id=media_id,
            media_type=media_type,
            time_now_ms=self.clock.time_msec(),
            upload_name=upload_name,
            media_length=content_length,
            user_id=auth_user,
        )

        await self._generate_thumbnails(None, media_id, media_id, media_type)

        return "mxc://%s/%s" % (self.server_name, media_id)

    async def get_local_media(self, request: SynapseRequest, media_id: str,
                              name: Optional[str]) -> None:
        """Responds to requests for local media, if exists, or returns 404.

        Args:
            request: The incoming request.
            media_id: The media ID of the content. (This is the same as
                the file_id for local content.)
            name: Optional name that, if specified, will be used as
                the filename in the Content-Disposition header of the response.

        Returns:
            Resolves once a response has successfully been written to request
        """
        media_info = await self.store.get_local_media(media_id)
        if not media_info or media_info["quarantined_by"]:
            respond_404(request)
            return

        self.mark_recently_accessed(None, media_id)

        media_type = media_info["media_type"]
        if not media_type:
            media_type = "application/octet-stream"
        media_length = media_info["media_length"]
        upload_name = name if name else media_info["upload_name"]
        url_cache = media_info["url_cache"]

        file_info = FileInfo(None, media_id, url_cache=bool(url_cache))

        responder = await self.media_storage.fetch_media(file_info)
        await respond_with_responder(request, responder, media_type,
                                     media_length, upload_name)

    async def get_remote_media(
        self,
        request: SynapseRequest,
        server_name: str,
        media_id: str,
        name: Optional[str],
    ) -> None:
        """Respond to requests for remote media.

        Args:
            request: The incoming request.
            server_name: Remote server_name where the media originated.
            media_id: The media ID of the content (as defined by the remote server).
            name: Optional name that, if specified, will be used as
                the filename in the Content-Disposition header of the response.

        Returns:
            Resolves once a response has successfully been written to request
        """
        if (self.federation_domain_whitelist is not None
                and server_name not in self.federation_domain_whitelist):
            raise FederationDeniedError(server_name)

        self.mark_recently_accessed(server_name, media_id)

        # We linearize here to ensure that we don't try and download remote
        # media multiple times concurrently
        key = (server_name, media_id)
        async with self.remote_media_linearizer.queue(key):
            responder, media_info = await self._get_remote_media_impl(
                server_name, media_id)

        # We deliberately stream the file outside the lock
        if responder:
            media_type = media_info["media_type"]
            media_length = media_info["media_length"]
            upload_name = name if name else media_info["upload_name"]
            await respond_with_responder(request, responder, media_type,
                                         media_length, upload_name)
        else:
            respond_404(request)

    async def get_remote_media_info(self, server_name: str,
                                    media_id: str) -> dict:
        """Gets the media info associated with the remote file, downloading
        if necessary.

        Args:
            server_name: Remote server_name where the media originated.
            media_id: The media ID of the content (as defined by the remote server).

        Returns:
            The media info of the file
        """
        if (self.federation_domain_whitelist is not None
                and server_name not in self.federation_domain_whitelist):
            raise FederationDeniedError(server_name)

        # We linearize here to ensure that we don't try and download remote
        # media multiple times concurrently
        key = (server_name, media_id)
        async with self.remote_media_linearizer.queue(key):
            responder, media_info = await self._get_remote_media_impl(
                server_name, media_id)

        # Ensure we actually use the responder so that it releases resources
        if responder:
            with responder:
                pass

        return media_info

    async def _get_remote_media_impl(
            self, server_name: str,
            media_id: str) -> Tuple[Optional[Responder], dict]:
        """Looks for media in local cache, if not there then attempt to
        download from remote server.

        Args:
            server_name (str): Remote server_name where the media originated.
            media_id (str): The media ID of the content (as defined by the
                remote server).

        Returns:
            A tuple of responder and the media info of the file.
        """
        media_info = await self.store.get_cached_remote_media(
            server_name, media_id)

        # file_id is the ID we use to track the file locally. If we've already
        # seen the file then reuse the existing ID, otherwise generate a new
        # one.

        # If we have an entry in the DB, try and look for it
        if media_info:
            file_id = media_info["filesystem_id"]
            file_info = FileInfo(server_name, file_id)

            if media_info["quarantined_by"]:
                logger.info("Media is quarantined")
                raise NotFoundError()

            if not media_info["media_type"]:
                media_info["media_type"] = "application/octet-stream"

            responder = await self.media_storage.fetch_media(file_info)
            if responder:
                return responder, media_info

        # Failed to find the file anywhere, lets download it.

        try:
            media_info = await self._download_remote_file(
                server_name,
                media_id,
            )
        except SynapseError:
            raise
        except Exception as e:
            # An exception may be because we downloaded media in another
            # process, so let's check if we magically have the media.
            media_info = await self.store.get_cached_remote_media(
                server_name, media_id)
            if not media_info:
                raise e

        file_id = media_info["filesystem_id"]
        if not media_info["media_type"]:
            media_info["media_type"] = "application/octet-stream"
        file_info = FileInfo(server_name, file_id)

        # We generate thumbnails even if another process downloaded the media
        # as a) it's conceivable that the other download request dies before it
        # generates thumbnails, but mainly b) we want to be sure the thumbnails
        # have finished being generated before responding to the client,
        # otherwise they'll request thumbnails and get a 404 if they're not
        # ready yet.
        await self._generate_thumbnails(server_name, media_id, file_id,
                                        media_info["media_type"])

        responder = await self.media_storage.fetch_media(file_info)
        return responder, media_info

    async def _download_remote_file(
        self,
        server_name: str,
        media_id: str,
    ) -> dict:
        """Attempt to download the remote file from the given server name,
        using the given file_id as the local id.

        Args:
            server_name: Originating server
            media_id: The media ID of the content (as defined by the
                remote server). This is different than the file_id, which is
                locally generated.
            file_id: Local file ID

        Returns:
            The media info of the file.
        """

        file_id = random_string(24)

        file_info = FileInfo(server_name=server_name, file_id=file_id)

        with self.media_storage.store_into_file(file_info) as (f, fname,
                                                               finish):
            request_path = "/".join(
                ("/_matrix/media/r0/download", server_name, media_id))
            try:
                length, headers = await self.client.get_file(
                    server_name,
                    request_path,
                    output_stream=f,
                    max_size=self.max_upload_size,
                    args={
                        # tell the remote server to 404 if it doesn't
                        # recognise the server_name, to make sure we don't
                        # end up with a routing loop.
                        "allow_remote": "false"
                    },
                )
            except RequestSendFailed as e:
                logger.warning(
                    "Request failed fetching remote media %s/%s: %r",
                    server_name,
                    media_id,
                    e,
                )
                raise SynapseError(502, "Failed to fetch remote media")

            except HttpResponseException as e:
                logger.warning(
                    "HTTP error fetching remote media %s/%s: %s",
                    server_name,
                    media_id,
                    e.response,
                )
                if e.code == twisted.web.http.NOT_FOUND:
                    raise e.to_synapse_error()
                raise SynapseError(502, "Failed to fetch remote media")

            except SynapseError:
                logger.warning("Failed to fetch remote media %s/%s",
                               server_name, media_id)
                raise
            except NotRetryingDestination:
                logger.warning("Not retrying destination %r", server_name)
                raise SynapseError(502, "Failed to fetch remote media")
            except Exception:
                logger.exception("Failed to fetch remote media %s/%s",
                                 server_name, media_id)
                raise SynapseError(502, "Failed to fetch remote media")

            await finish()

            if b"Content-Type" in headers:
                media_type = headers[b"Content-Type"][0].decode("ascii")
            else:
                media_type = "application/octet-stream"
            upload_name = get_filename_from_headers(headers)
            time_now_ms = self.clock.time_msec()

            # Multiple remote media download requests can race (when using
            # multiple media repos), so this may throw a violation constraint
            # exception. If it does we'll delete the newly downloaded file from
            # disk (as we're in the ctx manager).
            #
            # However: we've already called `finish()` so we may have also
            # written to the storage providers. This is preferable to the
            # alternative where we call `finish()` *after* this, where we could
            # end up having an entry in the DB but fail to write the files to
            # the storage providers.
            await self.store.store_cached_remote_media(
                origin=server_name,
                media_id=media_id,
                media_type=media_type,
                time_now_ms=self.clock.time_msec(),
                upload_name=upload_name,
                media_length=length,
                filesystem_id=file_id,
            )

        logger.info("Stored remote media in file %r", fname)

        media_info = {
            "media_type": media_type,
            "media_length": length,
            "upload_name": upload_name,
            "created_ts": time_now_ms,
            "filesystem_id": file_id,
        }

        return media_info

    def _get_thumbnail_requirements(
            self, media_type: str) -> Tuple[ThumbnailRequirement, ...]:
        scpos = media_type.find(";")
        if scpos > 0:
            media_type = media_type[:scpos]
        return self.thumbnail_requirements.get(media_type, ())

    def _generate_thumbnail(
        self,
        thumbnailer: Thumbnailer,
        t_width: int,
        t_height: int,
        t_method: str,
        t_type: str,
    ) -> Optional[BytesIO]:
        m_width = thumbnailer.width
        m_height = thumbnailer.height

        if m_width * m_height >= self.max_image_pixels:
            logger.info(
                "Image too large to thumbnail %r x %r > %r",
                m_width,
                m_height,
                self.max_image_pixels,
            )
            return None

        if thumbnailer.transpose_method is not None:
            m_width, m_height = thumbnailer.transpose()

        if t_method == "crop":
            return thumbnailer.crop(t_width, t_height, t_type)
        elif t_method == "scale":
            t_width, t_height = thumbnailer.aspect(t_width, t_height)
            t_width = min(m_width, t_width)
            t_height = min(m_height, t_height)
            return thumbnailer.scale(t_width, t_height, t_type)

        return None

    async def generate_local_exact_thumbnail(
        self,
        media_id: str,
        t_width: int,
        t_height: int,
        t_method: str,
        t_type: str,
        url_cache: bool,
    ) -> Optional[str]:
        input_path = await self.media_storage.ensure_media_is_in_local_cache(
            FileInfo(None, media_id, url_cache=url_cache))

        try:
            thumbnailer = Thumbnailer(input_path)
        except ThumbnailError as e:
            logger.warning(
                "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s",
                media_id,
                t_method,
                t_type,
                e,
            )
            return None

        with thumbnailer:
            t_byte_source = await defer_to_thread(
                self.hs.get_reactor(),
                self._generate_thumbnail,
                thumbnailer,
                t_width,
                t_height,
                t_method,
                t_type,
            )

        if t_byte_source:
            try:
                file_info = FileInfo(
                    server_name=None,
                    file_id=media_id,
                    url_cache=url_cache,
                    thumbnail=ThumbnailInfo(
                        width=t_width,
                        height=t_height,
                        method=t_method,
                        type=t_type,
                    ),
                )

                output_path = await self.media_storage.store_file(
                    t_byte_source, file_info)
            finally:
                t_byte_source.close()

            logger.info("Stored thumbnail in file %r", output_path)

            t_len = os.path.getsize(output_path)

            await self.store.store_local_thumbnail(media_id, t_width, t_height,
                                                   t_type, t_method, t_len)

            return output_path

        # Could not generate thumbnail.
        return None

    async def generate_remote_exact_thumbnail(
        self,
        server_name: str,
        file_id: str,
        media_id: str,
        t_width: int,
        t_height: int,
        t_method: str,
        t_type: str,
    ) -> Optional[str]:
        input_path = await self.media_storage.ensure_media_is_in_local_cache(
            FileInfo(server_name, file_id))

        try:
            thumbnailer = Thumbnailer(input_path)
        except ThumbnailError as e:
            logger.warning(
                "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s",
                media_id,
                server_name,
                t_method,
                t_type,
                e,
            )
            return None

        with thumbnailer:
            t_byte_source = await defer_to_thread(
                self.hs.get_reactor(),
                self._generate_thumbnail,
                thumbnailer,
                t_width,
                t_height,
                t_method,
                t_type,
            )

        if t_byte_source:
            try:
                file_info = FileInfo(
                    server_name=server_name,
                    file_id=file_id,
                    thumbnail=ThumbnailInfo(
                        width=t_width,
                        height=t_height,
                        method=t_method,
                        type=t_type,
                    ),
                )

                output_path = await self.media_storage.store_file(
                    t_byte_source, file_info)
            finally:
                t_byte_source.close()

            logger.info("Stored thumbnail in file %r", output_path)

            t_len = os.path.getsize(output_path)

            await self.store.store_remote_media_thumbnail(
                server_name,
                media_id,
                file_id,
                t_width,
                t_height,
                t_type,
                t_method,
                t_len,
            )

            return output_path

        # Could not generate thumbnail.
        return None

    async def _generate_thumbnails(
        self,
        server_name: Optional[str],
        media_id: str,
        file_id: str,
        media_type: str,
        url_cache: bool = False,
    ) -> Optional[dict]:
        """Generate and store thumbnails for an image.

        Args:
            server_name: The server name if remote media, else None if local
            media_id: The media ID of the content. (This is the same as
                the file_id for local content)
            file_id: Local file ID
            media_type: The content type of the file
            url_cache: If we are thumbnailing images downloaded for the URL cache,
                used exclusively by the url previewer

        Returns:
            Dict with "width" and "height" keys of original image or None if the
            media cannot be thumbnailed.
        """
        requirements = self._get_thumbnail_requirements(media_type)
        if not requirements:
            return None

        input_path = await self.media_storage.ensure_media_is_in_local_cache(
            FileInfo(server_name, file_id, url_cache=url_cache))

        try:
            thumbnailer = Thumbnailer(input_path)
        except ThumbnailError as e:
            logger.warning(
                "Unable to generate thumbnails for remote media %s from %s of type %s: %s",
                media_id,
                server_name,
                media_type,
                e,
            )
            return None

        with thumbnailer:
            m_width = thumbnailer.width
            m_height = thumbnailer.height

            if m_width * m_height >= self.max_image_pixels:
                logger.info(
                    "Image too large to thumbnail %r x %r > %r",
                    m_width,
                    m_height,
                    self.max_image_pixels,
                )
                return None

            if thumbnailer.transpose_method is not None:
                m_width, m_height = await defer_to_thread(
                    self.hs.get_reactor(), thumbnailer.transpose)

            # We deduplicate the thumbnail sizes by ignoring the cropped versions if
            # they have the same dimensions of a scaled one.
            thumbnails: Dict[Tuple[int, int, str], str] = {}
            for requirement in requirements:
                if requirement.method == "crop":
                    thumbnails.setdefault(
                        (requirement.width, requirement.height,
                         requirement.media_type),
                        requirement.method,
                    )
                elif requirement.method == "scale":
                    t_width, t_height = thumbnailer.aspect(
                        requirement.width, requirement.height)
                    t_width = min(m_width, t_width)
                    t_height = min(m_height, t_height)
                    thumbnails[(t_width, t_height,
                                requirement.media_type)] = requirement.method

            # Now we generate the thumbnails for each dimension, store it
            for (t_width, t_height, t_type), t_method in thumbnails.items():
                # Generate the thumbnail
                if t_method == "crop":
                    t_byte_source = await defer_to_thread(
                        self.hs.get_reactor(),
                        thumbnailer.crop,
                        t_width,
                        t_height,
                        t_type,
                    )
                elif t_method == "scale":
                    t_byte_source = await defer_to_thread(
                        self.hs.get_reactor(),
                        thumbnailer.scale,
                        t_width,
                        t_height,
                        t_type,
                    )
                else:
                    logger.error("Unrecognized method: %r", t_method)
                    continue

                if not t_byte_source:
                    continue

                file_info = FileInfo(
                    server_name=server_name,
                    file_id=file_id,
                    url_cache=url_cache,
                    thumbnail=ThumbnailInfo(
                        width=t_width,
                        height=t_height,
                        method=t_method,
                        type=t_type,
                    ),
                )

                with self.media_storage.store_into_file(file_info) as (
                        f,
                        fname,
                        finish,
                ):
                    try:
                        await self.media_storage.write_to_file(
                            t_byte_source, f)
                        await finish()
                    finally:
                        t_byte_source.close()

                    t_len = os.path.getsize(fname)

                    # Write to database
                    if server_name:
                        # Multiple remote media download requests can race (when
                        # using multiple media repos), so this may throw a violation
                        # constraint exception. If it does we'll delete the newly
                        # generated thumbnail from disk (as we're in the ctx
                        # manager).
                        #
                        # However: we've already called `finish()` so we may have
                        # also written to the storage providers. This is preferable
                        # to the alternative where we call `finish()` *after* this,
                        # where we could end up having an entry in the DB but fail
                        # to write the files to the storage providers.
                        try:
                            await self.store.store_remote_media_thumbnail(
                                server_name,
                                media_id,
                                file_id,
                                t_width,
                                t_height,
                                t_type,
                                t_method,
                                t_len,
                            )
                        except Exception as e:
                            thumbnail_exists = (
                                await self.store.get_remote_media_thumbnail(
                                    server_name,
                                    media_id,
                                    t_width,
                                    t_height,
                                    t_type,
                                ))
                            if not thumbnail_exists:
                                raise e
                    else:
                        await self.store.store_local_thumbnail(
                            media_id, t_width, t_height, t_type, t_method,
                            t_len)

        return {"width": m_width, "height": m_height}

    async def _apply_media_retention_rules(self) -> None:
        """
        Purge old local and remote media according to the media retention rules
        defined in the homeserver config.
        """
        # Purge remote media
        if self._media_retention_remote_media_lifetime_ms is not None:
            # Calculate a threshold timestamp derived from the configured lifetime. Any
            # media that has not been accessed since this timestamp will be removed.
            remote_media_threshold_timestamp_ms = (
                self.clock.time_msec() -
                self._media_retention_remote_media_lifetime_ms)

            logger.info("Purging remote media last accessed before"
                        f" {remote_media_threshold_timestamp_ms}")

            await self.delete_old_remote_media(
                before_ts=remote_media_threshold_timestamp_ms)

        # And now do the same for local media
        if self._media_retention_local_media_lifetime_ms is not None:
            # This works the same as the remote media threshold
            local_media_threshold_timestamp_ms = (
                self.clock.time_msec() -
                self._media_retention_local_media_lifetime_ms)

            logger.info("Purging local media last accessed before"
                        f" {local_media_threshold_timestamp_ms}")

            await self.delete_old_local_media(
                before_ts=local_media_threshold_timestamp_ms,
                keep_profiles=True,
                delete_quarantined_media=False,
                delete_protected_media=False,
            )

    async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
        old_media = await self.store.get_remote_media_ids(
            before_ts, include_quarantined_media=False)

        deleted = 0

        for media in old_media:
            origin = media["media_origin"]
            media_id = media["media_id"]
            file_id = media["filesystem_id"]
            key = (origin, media_id)

            logger.info("Deleting: %r", key)

            # TODO: Should we delete from the backup store
            async with self.remote_media_linearizer.queue(key):
                full_path = self.filepaths.remote_media_filepath(
                    origin, file_id)
                try:
                    os.remove(full_path)
                except OSError as e:
                    logger.warning("Failed to remove file: %r", full_path)
                    if e.errno == errno.ENOENT:
                        pass
                    else:
                        continue

                thumbnail_dir = self.filepaths.remote_media_thumbnail_dir(
                    origin, file_id)
                shutil.rmtree(thumbnail_dir, ignore_errors=True)

                await self.store.delete_remote_media(origin, media_id)
                deleted += 1

        return {"deleted": deleted}

    async def delete_local_media_ids(
            self, media_ids: List[str]) -> Tuple[List[str], int]:
        """
        Delete the given local or remote media ID from this server

        Args:
            media_id: The media ID to delete.
        Returns:
            A tuple of (list of deleted media IDs, total deleted media IDs).
        """
        return await self._remove_local_media_from_disk(media_ids)

    async def delete_old_local_media(
        self,
        before_ts: int,
        size_gt: int = 0,
        keep_profiles: bool = True,
        delete_quarantined_media: bool = False,
        delete_protected_media: bool = False,
    ) -> Tuple[List[str], int]:
        """
        Delete local or remote media from this server by size and timestamp. Removes
        media files, any thumbnails and cached URLs.

        Args:
            before_ts: Unix timestamp in ms.
                Files that were last used before this timestamp will be deleted.
            size_gt: Size of the media in bytes. Files that are larger will be deleted.
            keep_profiles: Switch to delete also files that are still used in image data
                (e.g user profile, room avatar). If false these files will be deleted.
            delete_quarantined_media: If True, media marked as quarantined will be deleted.
            delete_protected_media: If True, media marked as protected will be deleted.

        Returns:
            A tuple of (list of deleted media IDs, total deleted media IDs).
        """
        old_media = await self.store.get_local_media_ids(
            before_ts,
            size_gt,
            keep_profiles,
            include_quarantined_media=delete_quarantined_media,
            include_protected_media=delete_protected_media,
        )
        return await self._remove_local_media_from_disk(old_media)

    async def _remove_local_media_from_disk(
            self, media_ids: List[str]) -> Tuple[List[str], int]:
        """
        Delete local or remote media from this server. Removes media files,
        any thumbnails and cached URLs.

        Args:
            media_ids: List of media_id to delete
        Returns:
            A tuple of (list of deleted media IDs, total deleted media IDs).
        """
        removed_media = []
        for media_id in media_ids:
            logger.info("Deleting media with ID '%s'", media_id)
            full_path = self.filepaths.local_media_filepath(media_id)
            try:
                os.remove(full_path)
            except OSError as e:
                logger.warning("Failed to remove file: %r: %s", full_path, e)
                if e.errno == errno.ENOENT:
                    pass
                else:
                    continue

            thumbnail_dir = self.filepaths.local_media_thumbnail_dir(media_id)
            shutil.rmtree(thumbnail_dir, ignore_errors=True)

            await self.store.delete_remote_media(self.server_name, media_id)

            await self.store.delete_url_cache((media_id, ))
            await self.store.delete_url_cache_media((media_id, ))

            removed_media.append(media_id)

        return removed_media, len(removed_media)
Пример #42
0
class DeviceListUpdater(object):
    "Handles incoming device list updates from federation and updates the DB"

    def __init__(self, hs, device_handler):
        self.store = hs.get_datastore()
        self.federation = hs.get_federation_client()
        self.clock = hs.get_clock()
        self.device_handler = device_handler

        self._remote_edu_linearizer = Linearizer(name="remote_device_list")

        # user_id -> list of updates waiting to be handled.
        self._pending_updates = {}

        # Recently seen stream ids. We don't bother keeping these in the DB,
        # but they're useful to have them about to reduce the number of spurious
        # resyncs.
        self._seen_updates = ExpiringCache(
            cache_name="device_update_edu",
            clock=self.clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
            iterable=True,
        )

        # Attempt to resync out of sync device lists every 30s.
        self._resync_retry_in_progress = False
        self.clock.looping_call(
            run_as_background_process,
            30 * 1000,
            func=self._maybe_retry_device_resync,
            desc="_maybe_retry_device_resync",
        )

    @trace
    @defer.inlineCallbacks
    def incoming_device_list_update(self, origin, edu_content):
        """Called on incoming device list update from federation. Responsible
        for parsing the EDU and adding to pending updates list.
        """

        set_tag("origin", origin)
        set_tag("edu_content", edu_content)
        user_id = edu_content.pop("user_id")
        device_id = edu_content.pop("device_id")
        stream_id = str(edu_content.pop("stream_id"))  # They may come as ints
        prev_ids = edu_content.pop("prev_id", [])
        prev_ids = [str(p) for p in prev_ids]  # They may come as ints

        if get_domain_from_id(user_id) != origin:
            # TODO: Raise?
            logger.warning(
                "Got device list update edu for %r/%r from %r",
                user_id,
                device_id,
                origin,
            )

            set_tag("error", True)
            log_kv(
                {
                    "message": "Got a device list update edu from a user and "
                    "device which does not match the origin of the request.",
                    "user_id": user_id,
                    "device_id": device_id,
                }
            )
            return

        room_ids = yield self.store.get_rooms_for_user(user_id)
        if not room_ids:
            # We don't share any rooms with this user. Ignore update, as we
            # probably won't get any further updates.
            set_tag("error", True)
            log_kv(
                {
                    "message": "Got an update from a user for which "
                    "we don't share any rooms",
                    "other user_id": user_id,
                }
            )
            logger.warning(
                "Got device list update edu for %r/%r, but don't share a room",
                user_id,
                device_id,
            )
            return

        logger.debug("Received device list update for %r/%r", user_id, device_id)

        self._pending_updates.setdefault(user_id, []).append(
            (device_id, stream_id, prev_ids, edu_content)
        )

        yield self._handle_device_updates(user_id)

    @measure_func("_incoming_device_list_update")
    @defer.inlineCallbacks
    def _handle_device_updates(self, user_id):
        "Actually handle pending updates."

        with (yield self._remote_edu_linearizer.queue(user_id)):
            pending_updates = self._pending_updates.pop(user_id, [])
            if not pending_updates:
                # This can happen since we batch updates
                return

            for device_id, stream_id, prev_ids, content in pending_updates:
                logger.debug(
                    "Handling update %r/%r, ID: %r, prev: %r ",
                    user_id,
                    device_id,
                    stream_id,
                    prev_ids,
                )

            # Given a list of updates we check if we need to resync. This
            # happens if we've missed updates.
            resync = yield self._need_to_do_resync(user_id, pending_updates)

            if logger.isEnabledFor(logging.INFO):
                logger.info(
                    "Received device list update for %s, requiring resync: %s. Devices: %s",
                    user_id,
                    resync,
                    ", ".join(u[0] for u in pending_updates),
                )

            if resync:
                yield self.user_device_resync(user_id)
            else:
                # Simply update the single device, since we know that is the only
                # change (because of the single prev_id matching the current cache)
                for device_id, stream_id, prev_ids, content in pending_updates:
                    yield self.store.update_remote_device_list_cache_entry(
                        user_id, device_id, content, stream_id
                    )

                yield self.device_handler.notify_device_update(
                    user_id, [device_id for device_id, _, _, _ in pending_updates]
                )

                self._seen_updates.setdefault(user_id, set()).update(
                    stream_id for _, stream_id, _, _ in pending_updates
                )

    @defer.inlineCallbacks
    def _need_to_do_resync(self, user_id, updates):
        """Given a list of updates for a user figure out if we need to do a full
        resync, or whether we have enough data that we can just apply the delta.
        """
        seen_updates = self._seen_updates.get(user_id, set())

        extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id)

        logger.debug("Current extremity for %r: %r", user_id, extremity)

        stream_id_in_updates = set()  # stream_ids in updates list
        for _, stream_id, prev_ids, _ in updates:
            if not prev_ids:
                # We always do a resync if there are no previous IDs
                return True

            for prev_id in prev_ids:
                if prev_id == extremity:
                    continue
                elif prev_id in seen_updates:
                    continue
                elif prev_id in stream_id_in_updates:
                    continue
                else:
                    return True

            stream_id_in_updates.add(stream_id)

        return False

    @defer.inlineCallbacks
    def _maybe_retry_device_resync(self):
        """Retry to resync device lists that are out of sync, except if another retry is
        in progress.
        """
        if self._resync_retry_in_progress:
            return

        try:
            # Prevent another call of this function to retry resyncing device lists so
            # we don't send too many requests.
            self._resync_retry_in_progress = True
            # Get all of the users that need resyncing.
            need_resync = yield self.store.get_user_ids_requiring_device_list_resync()
            # Iterate over the set of user IDs.
            for user_id in need_resync:
                try:
                    # Try to resync the current user's devices list.
                    result = yield self.user_device_resync(
                        user_id=user_id, mark_failed_as_stale=False,
                    )

                    # user_device_resync only returns a result if it managed to
                    # successfully resync and update the database. Updating the table
                    # of users requiring resync isn't necessary here as
                    # user_device_resync already does it (through
                    # self.store.update_remote_device_list_cache).
                    if result:
                        logger.debug(
                            "Successfully resynced the device list for %s", user_id,
                        )
                except Exception as e:
                    # If there was an issue resyncing this user, e.g. if the remote
                    # server sent a malformed result, just log the error instead of
                    # aborting all the subsequent resyncs.
                    logger.debug(
                        "Could not resync the device list for %s: %s", user_id, e,
                    )
        finally:
            # Allow future calls to retry resyncinc out of sync device lists.
            self._resync_retry_in_progress = False

    @defer.inlineCallbacks
    def user_device_resync(self, user_id, mark_failed_as_stale=True):
        """Fetches all devices for a user and updates the device cache with them.

        Args:
            user_id (str): The user's id whose device_list will be updated.
            mark_failed_as_stale (bool): Whether to mark the user's device list as stale
                if the attempt to resync failed.
        Returns:
            Deferred[dict]: a dict with device info as under the "devices" in the result of this
            request:
            https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
        """
        logger.debug("Attempting to resync the device list for %s", user_id)
        log_kv({"message": "Doing resync to update device list."})
        # Fetch all devices for the user.
        origin = get_domain_from_id(user_id)
        try:
            result = yield self.federation.query_user_devices(origin, user_id)
        except NotRetryingDestination:
            if mark_failed_as_stale:
                # Mark the remote user's device list as stale so we know we need to retry
                # it later.
                yield self.store.mark_remote_user_device_cache_as_stale(user_id)

            return
        except (RequestSendFailed, HttpResponseException) as e:
            logger.warning(
                "Failed to handle device list update for %s: %s", user_id, e,
            )

            if mark_failed_as_stale:
                # Mark the remote user's device list as stale so we know we need to retry
                # it later.
                yield self.store.mark_remote_user_device_cache_as_stale(user_id)

            # We abort on exceptions rather than accepting the update
            # as otherwise synapse will 'forget' that its device list
            # is out of date. If we bail then we will retry the resync
            # next time we get a device list update for this user_id.
            # This makes it more likely that the device lists will
            # eventually become consistent.
            return
        except FederationDeniedError as e:
            set_tag("error", True)
            log_kv({"reason": "FederationDeniedError"})
            logger.info(e)
            return
        except Exception as e:
            set_tag("error", True)
            log_kv(
                {"message": "Exception raised by federation request", "exception": e}
            )
            logger.exception("Failed to handle device list update for %s", user_id)

            if mark_failed_as_stale:
                # Mark the remote user's device list as stale so we know we need to retry
                # it later.
                yield self.store.mark_remote_user_device_cache_as_stale(user_id)

            return
        log_kv({"result": result})
        stream_id = result["stream_id"]
        devices = result["devices"]

        # Get the master key and the self-signing key for this user if provided in the
        # response (None if not in the response).
        # The response will not contain the user signing key, as this key is only used by
        # its owner, thus it doesn't make sense to send it over federation.
        master_key = result.get("master_key")
        self_signing_key = result.get("self_signing_key")

        # If the remote server has more than ~1000 devices for this user
        # we assume that something is going horribly wrong (e.g. a bot
        # that logs in and creates a new device every time it tries to
        # send a message).  Maintaining lots of devices per user in the
        # cache can cause serious performance issues as if this request
        # takes more than 60s to complete, internal replication from the
        # inbound federation worker to the synapse master may time out
        # causing the inbound federation to fail and causing the remote
        # server to retry, causing a DoS.  So in this scenario we give
        # up on storing the total list of devices and only handle the
        # delta instead.
        if len(devices) > 1000:
            logger.warning(
                "Ignoring device list snapshot for %s as it has >1K devs (%d)",
                user_id,
                len(devices),
            )
            devices = []

        for device in devices:
            logger.debug(
                "Handling resync update %r/%r, ID: %r",
                user_id,
                device["device_id"],
                stream_id,
            )

        yield self.store.update_remote_device_list_cache(user_id, devices, stream_id)
        device_ids = [device["device_id"] for device in devices]

        # Handle cross-signing keys.
        cross_signing_device_ids = yield self.process_cross_signing_key_update(
            user_id, master_key, self_signing_key,
        )
        device_ids = device_ids + cross_signing_device_ids

        yield self.device_handler.notify_device_update(user_id, device_ids)

        # We clobber the seen updates since we've re-synced from a given
        # point.
        self._seen_updates[user_id] = {stream_id}

        defer.returnValue(result)

    @defer.inlineCallbacks
    def process_cross_signing_key_update(
        self,
        user_id: str,
        master_key: Optional[Dict[str, Any]],
        self_signing_key: Optional[Dict[str, Any]],
    ) -> list:
        """Process the given new master and self-signing key for the given remote user.

        Args:
            user_id: The ID of the user these keys are for.
            master_key: The dict of the cross-signing master key as returned by the
                remote server.
            self_signing_key: The dict of the cross-signing self-signing key as returned
                by the remote server.

        Return:
            The device IDs for the given keys.
        """
        device_ids = []

        if master_key:
            yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
            _, verify_key = get_verify_key_from_cross_signing_key(master_key)
            # verify_key is a VerifyKey from signedjson, which uses
            # .version to denote the portion of the key ID after the
            # algorithm and colon, which is the device ID
            device_ids.append(verify_key.version)
        if self_signing_key:
            yield self.store.set_e2e_cross_signing_key(
                user_id, "self_signing", self_signing_key
            )
            _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key)
            device_ids.append(verify_key.version)

        return device_ids
Пример #43
0
class FederationServer(FederationBase):
    def __init__(self, hs):
        super(FederationServer, self).__init__(hs)

        self.auth = hs.get_auth()
        self.handler = hs.get_handlers().federation_handler

        self._server_linearizer = Linearizer("fed_server")
        self._transaction_linearizer = Linearizer("fed_txn_handler")

        self.transaction_actions = TransactionActions(self.store)

        self.registry = hs.get_federation_registry()

        # We cache responses to state queries, as they take a while and often
        # come in waves.
        self._state_resp_cache = ResponseCache(hs,
                                               "state_resp",
                                               timeout_ms=30000)

    @defer.inlineCallbacks
    @log_function
    def on_backfill_request(self, origin, room_id, versions, limit):
        with (yield self._server_linearizer.queue((origin, room_id))):
            origin_host, _ = parse_server_name(origin)
            yield self.check_server_matches_acl(origin_host, room_id)

            pdus = yield self.handler.on_backfill_request(
                origin, room_id, versions, limit)

            res = self._transaction_from_pdus(pdus).get_dict()

        defer.returnValue((200, res))

    @defer.inlineCallbacks
    @log_function
    def on_incoming_transaction(self, origin, transaction_data):
        # keep this as early as possible to make the calculated origin ts as
        # accurate as possible.
        request_time = self._clock.time_msec()

        transaction = Transaction(**transaction_data)

        if not transaction.transaction_id:
            raise Exception("Transaction missing transaction_id")

        logger.debug("[%s] Got transaction", transaction.transaction_id)

        # use a linearizer to ensure that we don't process the same transaction
        # multiple times in parallel.
        with (yield self._transaction_linearizer.queue(
            (origin, transaction.transaction_id), )):
            result = yield self._handle_incoming_transaction(
                origin,
                transaction,
                request_time,
            )

        defer.returnValue(result)

    @defer.inlineCallbacks
    def _handle_incoming_transaction(self, origin, transaction, request_time):
        """ Process an incoming transaction and return the HTTP response

        Args:
            origin (unicode): the server making the request
            transaction (Transaction): incoming transaction
            request_time (int): timestamp that the HTTP request arrived at

        Returns:
            Deferred[(int, object)]: http response code and body
        """
        response = yield self.transaction_actions.have_responded(
            origin, transaction)

        if response:
            logger.debug("[%s] We've already responded to this request",
                         transaction.transaction_id)
            defer.returnValue(response)
            return

        logger.debug("[%s] Transaction is new", transaction.transaction_id)

        received_pdus_counter.inc(len(transaction.pdus))

        origin_host, _ = parse_server_name(origin)

        pdus_by_room = {}

        for p in transaction.pdus:
            if "unsigned" in p:
                unsigned = p["unsigned"]
                if "age" in unsigned:
                    p["age"] = unsigned["age"]
            if "age" in p:
                p["age_ts"] = request_time - int(p["age"])
                del p["age"]

            event = event_from_pdu_json(p)
            room_id = event.room_id
            pdus_by_room.setdefault(room_id, []).append(event)

        pdu_results = {}

        # we can process different rooms in parallel (which is useful if they
        # require callouts to other servers to fetch missing events), but
        # impose a limit to avoid going too crazy with ram/cpu.

        @defer.inlineCallbacks
        def process_pdus_for_room(room_id):
            logger.debug("Processing PDUs for %s", room_id)
            try:
                yield self.check_server_matches_acl(origin_host, room_id)
            except AuthError as e:
                logger.warn(
                    "Ignoring PDUs for room %s from banned server",
                    room_id,
                )
                for pdu in pdus_by_room[room_id]:
                    event_id = pdu.event_id
                    pdu_results[event_id] = e.error_dict()
                return

            for pdu in pdus_by_room[room_id]:
                event_id = pdu.event_id
                with nested_logging_context(event_id):
                    try:
                        yield self._handle_received_pdu(origin, pdu)
                        pdu_results[event_id] = {}
                    except FederationError as e:
                        logger.warn("Error handling PDU %s: %s", event_id, e)
                        pdu_results[event_id] = {"error": str(e)}
                    except Exception as e:
                        f = failure.Failure()
                        pdu_results[event_id] = {"error": str(e)}
                        logger.error(
                            "Failed to handle PDU %s: %s",
                            event_id,
                            f.getTraceback().rstrip(),
                        )

        yield concurrently_execute(
            process_pdus_for_room,
            pdus_by_room.keys(),
            TRANSACTION_CONCURRENCY_LIMIT,
        )

        if hasattr(transaction, "edus"):
            for edu in (Edu(**x) for x in transaction.edus):
                yield self.received_edu(origin, edu.edu_type, edu.content)

        response = {
            "pdus": pdu_results,
        }

        logger.debug("Returning: %s", str(response))

        yield self.transaction_actions.set_response(origin, transaction, 200,
                                                    response)
        defer.returnValue((200, response))

    @defer.inlineCallbacks
    def received_edu(self, origin, edu_type, content):
        received_edus_counter.inc()
        yield self.registry.on_edu(edu_type, origin, content)

    @defer.inlineCallbacks
    @log_function
    def on_context_state_request(self, origin, room_id, event_id):
        if not event_id:
            raise NotImplementedError("Specify an event")

        origin_host, _ = parse_server_name(origin)
        yield self.check_server_matches_acl(origin_host, room_id)

        in_room = yield self.auth.check_host_in_room(room_id, origin)
        if not in_room:
            raise AuthError(403, "Host not in room.")

        # we grab the linearizer to protect ourselves from servers which hammer
        # us. In theory we might already have the response to this query
        # in the cache so we could return it without waiting for the linearizer
        # - but that's non-trivial to get right, and anyway somewhat defeats
        # the point of the linearizer.
        with (yield self._server_linearizer.queue((origin, room_id))):
            resp = yield self._state_resp_cache.wrap(
                (room_id, event_id),
                self._on_context_state_request_compute,
                room_id,
                event_id,
            )

        defer.returnValue((200, resp))

    @defer.inlineCallbacks
    def on_state_ids_request(self, origin, room_id, event_id):
        if not event_id:
            raise NotImplementedError("Specify an event")

        origin_host, _ = parse_server_name(origin)
        yield self.check_server_matches_acl(origin_host, room_id)

        in_room = yield self.auth.check_host_in_room(room_id, origin)
        if not in_room:
            raise AuthError(403, "Host not in room.")

        state_ids = yield self.handler.get_state_ids_for_pdu(
            room_id,
            event_id,
        )
        auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)

        defer.returnValue((200, {
            "pdu_ids": state_ids,
            "auth_chain_ids": auth_chain_ids,
        }))

    @defer.inlineCallbacks
    def _on_context_state_request_compute(self, room_id, event_id):
        pdus = yield self.handler.get_state_for_pdu(
            room_id,
            event_id,
        )
        auth_chain = yield self.store.get_auth_chain(
            [pdu.event_id for pdu in pdus])

        for event in auth_chain:
            # We sign these again because there was a bug where we
            # incorrectly signed things the first time round
            if self.hs.is_mine_id(event.event_id):
                event.signatures.update(
                    compute_event_signature(event, self.hs.hostname,
                                            self.hs.config.signing_key[0]))

        defer.returnValue({
            "pdus": [pdu.get_pdu_json() for pdu in pdus],
            "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
        })

    @defer.inlineCallbacks
    @log_function
    def on_pdu_request(self, origin, event_id):
        pdu = yield self.handler.get_persisted_pdu(origin, event_id)

        if pdu:
            defer.returnValue(
                (200, self._transaction_from_pdus([pdu]).get_dict()))
        else:
            defer.returnValue((404, ""))

    @defer.inlineCallbacks
    @log_function
    def on_pull_request(self, origin, versions):
        raise NotImplementedError("Pull transactions not implemented")

    @defer.inlineCallbacks
    def on_query_request(self, query_type, args):
        received_queries_counter.labels(query_type).inc()
        resp = yield self.registry.on_query(query_type, args)
        defer.returnValue((200, resp))

    @defer.inlineCallbacks
    def on_make_join_request(self, origin, room_id, user_id,
                             supported_versions):
        origin_host, _ = parse_server_name(origin)
        yield self.check_server_matches_acl(origin_host, room_id)

        room_version = yield self.store.get_room_version(room_id)
        if room_version not in supported_versions:
            logger.warn("Room version %s not in %s", room_version,
                        supported_versions)
            raise IncompatibleRoomVersionError(room_version=room_version)

        pdu = yield self.handler.on_make_join_request(room_id, user_id)
        time_now = self._clock.time_msec()
        defer.returnValue({
            "event": pdu.get_pdu_json(time_now),
            "room_version": room_version,
        })

    @defer.inlineCallbacks
    def on_invite_request(self, origin, content):
        pdu = event_from_pdu_json(content)
        origin_host, _ = parse_server_name(origin)
        yield self.check_server_matches_acl(origin_host, pdu.room_id)
        ret_pdu = yield self.handler.on_invite_request(origin, pdu)
        time_now = self._clock.time_msec()
        defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))

    @defer.inlineCallbacks
    def on_send_join_request(self, origin, content):
        logger.debug("on_send_join_request: content: %s", content)
        pdu = event_from_pdu_json(content)

        origin_host, _ = parse_server_name(origin)
        yield self.check_server_matches_acl(origin_host, pdu.room_id)

        logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
        res_pdus = yield self.handler.on_send_join_request(origin, pdu)
        time_now = self._clock.time_msec()
        defer.returnValue((200, {
            "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
            "auth_chain":
            [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
        }))

    @defer.inlineCallbacks
    def on_make_leave_request(self, origin, room_id, user_id):
        origin_host, _ = parse_server_name(origin)
        yield self.check_server_matches_acl(origin_host, room_id)
        pdu = yield self.handler.on_make_leave_request(room_id, user_id)
        time_now = self._clock.time_msec()
        defer.returnValue({"event": pdu.get_pdu_json(time_now)})

    @defer.inlineCallbacks
    def on_send_leave_request(self, origin, content):
        logger.debug("on_send_leave_request: content: %s", content)
        pdu = event_from_pdu_json(content)

        origin_host, _ = parse_server_name(origin)
        yield self.check_server_matches_acl(origin_host, pdu.room_id)

        logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
        yield self.handler.on_send_leave_request(origin, pdu)
        defer.returnValue((200, {}))

    @defer.inlineCallbacks
    def on_event_auth(self, origin, room_id, event_id):
        with (yield self._server_linearizer.queue((origin, room_id))):
            origin_host, _ = parse_server_name(origin)
            yield self.check_server_matches_acl(origin_host, room_id)

            time_now = self._clock.time_msec()
            auth_pdus = yield self.handler.on_event_auth(event_id)
            res = {
                "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
            }
        defer.returnValue((200, res))

    @defer.inlineCallbacks
    def on_query_auth_request(self, origin, content, room_id, event_id):
        """
        Content is a dict with keys::
            auth_chain (list): A list of events that give the auth chain.
            missing (list): A list of event_ids indicating what the other
              side (`origin`) think we're missing.
            rejects (dict): A mapping from event_id to a 2-tuple of reason
              string and a proof (or None) of why the event was rejected.
              The keys of this dict give the list of events the `origin` has
              rejected.

        Args:
            origin (str)
            content (dict)
            event_id (str)

        Returns:
            Deferred: Results in `dict` with the same format as `content`
        """
        with (yield self._server_linearizer.queue((origin, room_id))):
            origin_host, _ = parse_server_name(origin)
            yield self.check_server_matches_acl(origin_host, room_id)

            auth_chain = [
                event_from_pdu_json(e) for e in content["auth_chain"]
            ]

            signed_auth = yield self._check_sigs_and_hash_and_fetch(
                origin, auth_chain, outlier=True)

            ret = yield self.handler.on_query_auth(
                origin,
                event_id,
                room_id,
                signed_auth,
                content.get("rejects", []),
                content.get("missing", []),
            )

            time_now = self._clock.time_msec()
            send_content = {
                "auth_chain":
                [e.get_pdu_json(time_now) for e in ret["auth_chain"]],
                "rejects": ret.get("rejects", []),
                "missing": ret.get("missing", []),
            }

        defer.returnValue((200, send_content))

    @log_function
    def on_query_client_keys(self, origin, content):
        return self.on_query_request("client_keys", content)

    def on_query_user_devices(self, origin, user_id):
        return self.on_query_request("user_devices", user_id)

    @defer.inlineCallbacks
    @log_function
    def on_claim_client_keys(self, origin, content):
        query = []
        for user_id, device_keys in content.get("one_time_keys", {}).items():
            for device_id, algorithm in device_keys.items():
                query.append((user_id, device_id, algorithm))

        results = yield self.store.claim_e2e_one_time_keys(query)

        json_result = {}
        for user_id, device_keys in results.items():
            for device_id, keys in device_keys.items():
                for key_id, json_bytes in keys.items():
                    json_result.setdefault(user_id, {})[device_id] = {
                        key_id: json.loads(json_bytes)
                    }

        logger.info(
            "Claimed one-time-keys: %s",
            ",".join(("%s for %s:%s" % (key_id, user_id, device_id)
                      for user_id, user_keys in iteritems(json_result)
                      for device_id, device_keys in iteritems(user_keys)
                      for key_id, _ in iteritems(device_keys))),
        )

        defer.returnValue({"one_time_keys": json_result})

    @defer.inlineCallbacks
    @log_function
    def on_get_missing_events(self, origin, room_id, earliest_events,
                              latest_events, limit):
        with (yield self._server_linearizer.queue((origin, room_id))):
            origin_host, _ = parse_server_name(origin)
            yield self.check_server_matches_acl(origin_host, room_id)

            logger.info(
                "on_get_missing_events: earliest_events: %r, latest_events: %r,"
                " limit: %d",
                earliest_events,
                latest_events,
                limit,
            )

            missing_events = yield self.handler.on_get_missing_events(
                origin,
                room_id,
                earliest_events,
                latest_events,
                limit,
            )

            if len(missing_events) < 5:
                logger.info("Returning %d events: %r", len(missing_events),
                            missing_events)
            else:
                logger.info("Returning %d events", len(missing_events))

            time_now = self._clock.time_msec()

        defer.returnValue({
            "events": [ev.get_pdu_json(time_now) for ev in missing_events],
        })

    @log_function
    def on_openid_userinfo(self, token):
        ts_now_ms = self._clock.time_msec()
        return self.store.get_user_id_for_open_id_token(token, ts_now_ms)

    def _transaction_from_pdus(self, pdu_list):
        """Returns a new Transaction containing the given PDUs suitable for
        transmission.
        """
        time_now = self._clock.time_msec()
        pdus = [p.get_pdu_json(time_now) for p in pdu_list]
        return Transaction(
            origin=self.server_name,
            pdus=pdus,
            origin_server_ts=int(time_now),
            destination=None,
        )

    @defer.inlineCallbacks
    def _handle_received_pdu(self, origin, pdu):
        """ Process a PDU received in a federation /send/ transaction.

        If the event is invalid, then this method throws a FederationError.
        (The error will then be logged and sent back to the sender (which
        probably won't do anything with it), and other events in the
        transaction will be processed as normal).

        It is likely that we'll then receive other events which refer to
        this rejected_event in their prev_events, etc.  When that happens,
        we'll attempt to fetch the rejected event again, which will presumably
        fail, so those second-generation events will also get rejected.

        Eventually, we get to the point where there are more than 10 events
        between any new events and the original rejected event. Since we
        only try to backfill 10 events deep on received pdu, we then accept the
        new event, possibly introducing a discontinuity in the DAG, with new
        forward extremities, so normal service is approximately returned,
        until we try to backfill across the discontinuity.

        Args:
            origin (str): server which sent the pdu
            pdu (FrozenEvent): received pdu

        Returns (Deferred): completes with None

        Raises: FederationError if the signatures / hash do not match, or
            if the event was unacceptable for any other reason (eg, too large,
            too many prev_events, couldn't find the prev_events)
        """
        # check that it's actually being sent from a valid destination to
        # workaround bug #1753 in 0.18.5 and 0.18.6
        if origin != get_domain_from_id(pdu.event_id):
            # We continue to accept join events from any server; this is
            # necessary for the federation join dance to work correctly.
            # (When we join over federation, the "helper" server is
            # responsible for sending out the join event, rather than the
            # origin. See bug #1893).
            if not (pdu.type == 'm.room.member' and pdu.content
                    and pdu.content.get("membership", None) == 'join'):
                logger.info("Discarding PDU %s from invalid origin %s",
                            pdu.event_id, origin)
                return
            else:
                logger.info("Accepting join PDU %s from %s", pdu.event_id,
                            origin)

        # Check signature.
        try:
            pdu = yield self._check_sigs_and_hash(pdu)
        except SynapseError as e:
            raise FederationError(
                "ERROR",
                e.code,
                e.msg,
                affected=pdu.event_id,
            )

        yield self.handler.on_receive_pdu(
            origin,
            pdu,
            sent_to_us_directly=True,
        )

    def __str__(self):
        return "<ReplicationLayer(%s)>" % self.server_name

    @defer.inlineCallbacks
    def exchange_third_party_invite(
        self,
        sender_user_id,
        target_user_id,
        room_id,
        signed,
    ):
        ret = yield self.handler.exchange_third_party_invite(
            sender_user_id,
            target_user_id,
            room_id,
            signed,
        )
        defer.returnValue(ret)

    @defer.inlineCallbacks
    def on_exchange_third_party_invite_request(self, origin, room_id,
                                               event_dict):
        ret = yield self.handler.on_exchange_third_party_invite_request(
            origin, room_id, event_dict)
        defer.returnValue(ret)

    @defer.inlineCallbacks
    def check_server_matches_acl(self, server_name, room_id):
        """Check if the given server is allowed by the server ACLs in the room

        Args:
            server_name (str): name of server, *without any port part*
            room_id (str): ID of the room to check

        Raises:
            AuthError if the server does not match the ACL
        """
        state_ids = yield self.store.get_current_state_ids(room_id)
        acl_event_id = state_ids.get((EventTypes.ServerACL, ""))

        if not acl_event_id:
            return

        acl_event = yield self.store.get_event(acl_event_id)
        if server_matches_acl_event(server_name, acl_event):
            return

        raise AuthError(code=403, msg="Server is banned from room")
Пример #44
0
class RoomCreationHandler(BaseHandler):

    PRESETS_DICT = {
        RoomCreationPreset.PRIVATE_CHAT: {
            "join_rules": JoinRules.INVITE,
            "history_visibility": "shared",
            "original_invitees_have_ops": False,
            "guest_can_join": True,
        },
        RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
            "join_rules": JoinRules.INVITE,
            "history_visibility": "shared",
            "original_invitees_have_ops": True,
            "guest_can_join": True,
        },
        RoomCreationPreset.PUBLIC_CHAT: {
            "join_rules": JoinRules.PUBLIC,
            "history_visibility": "shared",
            "original_invitees_have_ops": False,
            "guest_can_join": False,
        },
    }

    def __init__(self, hs):
        super(RoomCreationHandler, self).__init__(hs)

        self.spam_checker = hs.get_spam_checker()
        self.event_creation_handler = hs.get_event_creation_handler()
        self.room_member_handler = hs.get_room_member_handler()

        # linearizer to stop two upgrades happening at once
        self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")

    @defer.inlineCallbacks
    def upgrade_room(self, requester, old_room_id, new_version):
        """Replace a room with a new room with a different version

        Args:
            requester (synapse.types.Requester): the user requesting the upgrade
            old_room_id (unicode): the id of the room to be replaced
            new_version (unicode): the new room version to use

        Returns:
            Deferred[unicode]: the new room id
        """
        yield self.ratelimit(requester)

        user_id = requester.user.to_string()

        with (yield self._upgrade_linearizer.queue(old_room_id)):
            # start by allocating a new room id
            r = yield self.store.get_room(old_room_id)
            if r is None:
                raise NotFoundError("Unknown room id %s" % (old_room_id,))
            new_room_id = yield self._generate_room_id(
                creator_id=user_id, is_public=r["is_public"],
            )

            logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)

            # we create and auth the tombstone event before properly creating the new
            # room, to check our user has perms in the old room.
            tombstone_event, tombstone_context = (
                yield self.event_creation_handler.create_event(
                    requester, {
                        "type": EventTypes.Tombstone,
                        "state_key": "",
                        "room_id": old_room_id,
                        "sender": user_id,
                        "content": {
                            "body": "This room has been replaced",
                            "replacement_room": new_room_id,
                        }
                    },
                    token_id=requester.access_token_id,
                )
            )
            old_room_version = yield self.store.get_room_version(old_room_id)
            yield self.auth.check_from_context(
                old_room_version, tombstone_event, tombstone_context,
            )

            yield self.clone_existing_room(
                requester,
                old_room_id=old_room_id,
                new_room_id=new_room_id,
                new_room_version=new_version,
                tombstone_event_id=tombstone_event.event_id,
            )

            # now send the tombstone
            yield self.event_creation_handler.send_nonmember_event(
                requester, tombstone_event, tombstone_context,
            )

            old_room_state = yield tombstone_context.get_current_state_ids(self.store)

            # update any aliases
            yield self._move_aliases_to_new_room(
                requester, old_room_id, new_room_id, old_room_state,
            )

            # and finally, shut down the PLs in the old room, and update them in the new
            # room.
            yield self._update_upgraded_room_pls(
                requester, old_room_id, new_room_id, old_room_state,
            )

            defer.returnValue(new_room_id)

    @defer.inlineCallbacks
    def _update_upgraded_room_pls(
            self, requester, old_room_id, new_room_id, old_room_state,
    ):
        """Send updated power levels in both rooms after an upgrade

        Args:
            requester (synapse.types.Requester): the user requesting the upgrade
            old_room_id (unicode): the id of the room to be replaced
            new_room_id (unicode): the id of the replacement room
            old_room_state (dict[tuple[str, str], str]): the state map for the old room

        Returns:
            Deferred
        """
        old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, ""))

        if old_room_pl_event_id is None:
            logger.warning(
                "Not supported: upgrading a room with no PL event. Not setting PLs "
                "in old room.",
            )
            return

        old_room_pl_state = yield self.store.get_event(old_room_pl_event_id)

        # we try to stop regular users from speaking by setting the PL required
        # to send regular events and invites to 'Moderator' level. That's normally
        # 50, but if the default PL in a room is 50 or more, then we set the
        # required PL above that.

        pl_content = dict(old_room_pl_state.content)
        users_default = int(pl_content.get("users_default", 0))
        restricted_level = max(users_default + 1, 50)

        updated = False
        for v in ("invite", "events_default"):
            current = int(pl_content.get(v, 0))
            if current < restricted_level:
                logger.info(
                    "Setting level for %s in %s to %i (was %i)",
                    v, old_room_id, restricted_level, current,
                )
                pl_content[v] = restricted_level
                updated = True
            else:
                logger.info(
                    "Not setting level for %s (already %i)",
                    v, current,
                )

        if updated:
            try:
                yield self.event_creation_handler.create_and_send_nonmember_event(
                    requester, {
                        "type": EventTypes.PowerLevels,
                        "state_key": '',
                        "room_id": old_room_id,
                        "sender": requester.user.to_string(),
                        "content": pl_content,
                    }, ratelimit=False,
                )
            except AuthError as e:
                logger.warning("Unable to update PLs in old room: %s", e)

        logger.info("Setting correct PLs in new room")
        yield self.event_creation_handler.create_and_send_nonmember_event(
            requester, {
                "type": EventTypes.PowerLevels,
                "state_key": '',
                "room_id": new_room_id,
                "sender": requester.user.to_string(),
                "content": old_room_pl_state.content,
            }, ratelimit=False,
        )

    @defer.inlineCallbacks
    def clone_existing_room(
            self, requester, old_room_id, new_room_id, new_room_version,
            tombstone_event_id,
    ):
        """Populate a new room based on an old room

        Args:
            requester (synapse.types.Requester): the user requesting the upgrade
            old_room_id (unicode): the id of the room to be replaced
            new_room_id (unicode): the id to give the new room (should already have been
                created with _gemerate_room_id())
            new_room_version (unicode): the new room version to use
            tombstone_event_id (unicode|str): the ID of the tombstone event in the old
                room.
        Returns:
            Deferred[None]
        """
        user_id = requester.user.to_string()

        if not self.spam_checker.user_may_create_room(user_id):
            raise SynapseError(403, "You are not permitted to create rooms")

        creation_content = {
            "room_version": new_room_version,
            "predecessor": {
                "room_id": old_room_id,
                "event_id": tombstone_event_id,
            }
        }

        # Check if old room was non-federatable

        # Get old room's create event
        old_room_create_event = yield self.store.get_create_event_for_room(old_room_id)

        # Check if the create event specified a non-federatable room
        if not old_room_create_event.content.get("m.federate", True):
            # If so, mark the new room as non-federatable as well
            creation_content["m.federate"] = False

        initial_state = dict()

        # Replicate relevant room events
        types_to_copy = (
            (EventTypes.JoinRules, ""),
            (EventTypes.Name, ""),
            (EventTypes.Topic, ""),
            (EventTypes.RoomHistoryVisibility, ""),
            (EventTypes.GuestAccess, ""),
            (EventTypes.RoomAvatar, ""),
            (EventTypes.Encryption, ""),
            (EventTypes.ServerACL, ""),
            (EventTypes.RelatedGroups, ""),
        )

        old_room_state_ids = yield self.store.get_filtered_current_state_ids(
            old_room_id, StateFilter.from_types(types_to_copy),
        )
        # map from event_id to BaseEvent
        old_room_state_events = yield self.store.get_events(old_room_state_ids.values())

        for k, old_event_id in iteritems(old_room_state_ids):
            old_event = old_room_state_events.get(old_event_id)
            if old_event:
                initial_state[k] = old_event.content

        yield self._send_events_for_new_room(
            requester,
            new_room_id,

            # we expect to override all the presets with initial_state, so this is
            # somewhat arbitrary.
            preset_config=RoomCreationPreset.PRIVATE_CHAT,

            invite_list=[],
            initial_state=initial_state,
            creation_content=creation_content,
        )

        # Transfer membership events
        old_room_member_state_ids = yield self.store.get_filtered_current_state_ids(
            old_room_id, StateFilter.from_types([(EventTypes.Member, None)]),
        )

        # map from event_id to BaseEvent
        old_room_member_state_events = yield self.store.get_events(
            old_room_member_state_ids.values(),
        )
        for k, old_event in iteritems(old_room_member_state_events):
            # Only transfer ban events
            if ("membership" in old_event.content and
                    old_event.content["membership"] == "ban"):
                yield self.room_member_handler.update_membership(
                    requester,
                    UserID.from_string(old_event['state_key']),
                    new_room_id,
                    "ban",
                    ratelimit=False,
                    content=old_event.content,
                )

        # XXX invites/joins
        # XXX 3pid invites

    @defer.inlineCallbacks
    def _move_aliases_to_new_room(
            self, requester, old_room_id, new_room_id, old_room_state,
    ):
        directory_handler = self.hs.get_handlers().directory_handler

        aliases = yield self.store.get_aliases_for_room(old_room_id)

        # check to see if we have a canonical alias.
        canonical_alias = None
        canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
        if canonical_alias_event_id:
            canonical_alias_event = yield self.store.get_event(canonical_alias_event_id)
            if canonical_alias_event:
                canonical_alias = canonical_alias_event.content.get("alias", "")

        # first we try to remove the aliases from the old room (we suppress sending
        # the room_aliases event until the end).
        #
        # Note that we'll only be able to remove aliases that (a) aren't owned by an AS,
        # and (b) unless the user is a server admin, which the user created.
        #
        # This is probably correct - given we don't allow such aliases to be deleted
        # normally, it would be odd to allow it in the case of doing a room upgrade -
        # but it makes the upgrade less effective, and you have to wonder why a room
        # admin can't remove aliases that point to that room anyway.
        # (cf https://github.com/matrix-org/synapse/issues/2360)
        #
        removed_aliases = []
        for alias_str in aliases:
            alias = RoomAlias.from_string(alias_str)
            try:
                yield directory_handler.delete_association(
                    requester, alias, send_event=False,
                )
                removed_aliases.append(alias_str)
            except SynapseError as e:
                logger.warning(
                    "Unable to remove alias %s from old room: %s",
                    alias, e,
                )

        # if we didn't find any aliases, or couldn't remove anyway, we can skip the rest
        # of this.
        if not removed_aliases:
            return

        try:
            # this can fail if, for some reason, our user doesn't have perms to send
            # m.room.aliases events in the old room (note that we've already checked that
            # they have perms to send a tombstone event, so that's not terribly likely).
            #
            # If that happens, it's regrettable, but we should carry on: it's the same
            # as when you remove an alias from the directory normally - it just means that
            # the aliases event gets out of sync with the directory
            # (cf https://github.com/vector-im/riot-web/issues/2369)
            yield directory_handler.send_room_alias_update_event(
                requester, old_room_id,
            )
        except AuthError as e:
            logger.warning(
                "Failed to send updated alias event on old room: %s", e,
            )

        # we can now add any aliases we successfully removed to the new room.
        for alias in removed_aliases:
            try:
                yield directory_handler.create_association(
                    requester, RoomAlias.from_string(alias),
                    new_room_id, servers=(self.hs.hostname, ),
                    send_event=False, check_membership=False,
                )
                logger.info("Moved alias %s to new room", alias)
            except SynapseError as e:
                # I'm not really expecting this to happen, but it could if the spam
                # checking module decides it shouldn't, or similar.
                logger.error(
                    "Error adding alias %s to new room: %s",
                    alias, e,
                )

        try:
            if canonical_alias and (canonical_alias in removed_aliases):
                yield self.event_creation_handler.create_and_send_nonmember_event(
                    requester,
                    {
                        "type": EventTypes.CanonicalAlias,
                        "state_key": "",
                        "room_id": new_room_id,
                        "sender": requester.user.to_string(),
                        "content": {"alias": canonical_alias, },
                    },
                    ratelimit=False
                )

            yield directory_handler.send_room_alias_update_event(
                requester, new_room_id,
            )
        except SynapseError as e:
            # again I'm not really expecting this to fail, but if it does, I'd rather
            # we returned the new room to the client at this point.
            logger.error(
                "Unable to send updated alias events in new room: %s", e,
            )

    @defer.inlineCallbacks
    def create_room(self, requester, config, ratelimit=True,
                    creator_join_profile=None):
        """ Creates a new room.

        Args:
            requester (synapse.types.Requester):
                The user who requested the room creation.
            config (dict) : A dict of configuration options.
            ratelimit (bool): set to False to disable the rate limiter

            creator_join_profile (dict|None):
                Set to override the displayname and avatar for the creating
                user in this room. If unset, displayname and avatar will be
                derived from the user's profile. If set, should contain the
                values to go in the body of the 'join' event (typically
                `avatar_url` and/or `displayname`.

        Returns:
            Deferred[dict]:
                a dict containing the keys `room_id` and, if an alias was
                requested, `room_alias`.
        Raises:
            SynapseError if the room ID couldn't be stored, or something went
            horribly wrong.
            ResourceLimitError if server is blocked to some resource being
            exceeded
        """
        user_id = requester.user.to_string()

        yield self.auth.check_auth_blocking(user_id)

        if not self.spam_checker.user_may_create_room(user_id):
            raise SynapseError(403, "You are not permitted to create rooms")

        if ratelimit:
            yield self.ratelimit(requester)

        room_version = config.get("room_version", DEFAULT_ROOM_VERSION.identifier)
        if not isinstance(room_version, string_types):
            raise SynapseError(
                400,
                "room_version must be a string",
                Codes.BAD_JSON,
            )

        if room_version not in KNOWN_ROOM_VERSIONS:
            raise SynapseError(
                400,
                "Your homeserver does not support this room version",
                Codes.UNSUPPORTED_ROOM_VERSION,
            )

        if "room_alias_name" in config:
            for wchar in string.whitespace:
                if wchar in config["room_alias_name"]:
                    raise SynapseError(400, "Invalid characters in room alias")

            room_alias = RoomAlias(
                config["room_alias_name"],
                self.hs.hostname,
            )
            mapping = yield self.store.get_association_from_room_alias(
                room_alias
            )

            if mapping:
                raise SynapseError(
                    400,
                    "Room alias already taken",
                    Codes.ROOM_IN_USE
                )
        else:
            room_alias = None

        invite_list = config.get("invite", [])
        for i in invite_list:
            try:
                UserID.from_string(i)
            except Exception:
                raise SynapseError(400, "Invalid user_id: %s" % (i,))

        yield self.event_creation_handler.assert_accepted_privacy_policy(
            requester,
        )

        invite_3pid_list = config.get("invite_3pid", [])

        visibility = config.get("visibility", None)
        is_public = visibility == "public"

        room_id = yield self._generate_room_id(creator_id=user_id, is_public=is_public)

        if room_alias:
            directory_handler = self.hs.get_handlers().directory_handler
            yield directory_handler.create_association(
                requester=requester,
                room_id=room_id,
                room_alias=room_alias,
                servers=[self.hs.hostname],
                send_event=False,
                check_membership=False,
            )

        preset_config = config.get(
            "preset",
            RoomCreationPreset.PRIVATE_CHAT
            if visibility == "private"
            else RoomCreationPreset.PUBLIC_CHAT
        )

        raw_initial_state = config.get("initial_state", [])

        initial_state = OrderedDict()
        for val in raw_initial_state:
            initial_state[(val["type"], val.get("state_key", ""))] = val["content"]

        creation_content = config.get("creation_content", {})

        # override any attempt to set room versions via the creation_content
        creation_content["room_version"] = room_version

        yield self._send_events_for_new_room(
            requester,
            room_id,
            preset_config=preset_config,
            invite_list=invite_list,
            initial_state=initial_state,
            creation_content=creation_content,
            room_alias=room_alias,
            power_level_content_override=config.get("power_level_content_override"),
            creator_join_profile=creator_join_profile,
        )

        if "name" in config:
            name = config["name"]
            yield self.event_creation_handler.create_and_send_nonmember_event(
                requester,
                {
                    "type": EventTypes.Name,
                    "room_id": room_id,
                    "sender": user_id,
                    "state_key": "",
                    "content": {"name": name},
                },
                ratelimit=False)

        if "topic" in config:
            topic = config["topic"]
            yield self.event_creation_handler.create_and_send_nonmember_event(
                requester,
                {
                    "type": EventTypes.Topic,
                    "room_id": room_id,
                    "sender": user_id,
                    "state_key": "",
                    "content": {"topic": topic},
                },
                ratelimit=False)

        for invitee in invite_list:
            content = {}
            is_direct = config.get("is_direct", None)
            if is_direct:
                content["is_direct"] = is_direct

            yield self.room_member_handler.update_membership(
                requester,
                UserID.from_string(invitee),
                room_id,
                "invite",
                ratelimit=False,
                content=content,
            )

        for invite_3pid in invite_3pid_list:
            id_server = invite_3pid["id_server"]
            address = invite_3pid["address"]
            medium = invite_3pid["medium"]
            yield self.hs.get_room_member_handler().do_3pid_invite(
                room_id,
                requester.user,
                medium,
                address,
                id_server,
                requester,
                txn_id=None,
            )

        result = {"room_id": room_id}

        if room_alias:
            result["room_alias"] = room_alias.to_string()
            yield directory_handler.send_room_alias_update_event(
                requester, room_id
            )

        defer.returnValue(result)

    @defer.inlineCallbacks
    def _send_events_for_new_room(
            self,
            creator,  # A Requester object.
            room_id,
            preset_config,
            invite_list,
            initial_state,
            creation_content,
            room_alias=None,
            power_level_content_override=None,
            creator_join_profile=None,
    ):
        def create(etype, content, **kwargs):
            e = {
                "type": etype,
                "content": content,
            }

            e.update(event_keys)
            e.update(kwargs)

            return e

        @defer.inlineCallbacks
        def send(etype, content, **kwargs):
            event = create(etype, content, **kwargs)
            logger.info("Sending %s in new room", etype)
            yield self.event_creation_handler.create_and_send_nonmember_event(
                creator,
                event,
                ratelimit=False
            )

        config = RoomCreationHandler.PRESETS_DICT[preset_config]

        creator_id = creator.user.to_string()

        event_keys = {
            "room_id": room_id,
            "sender": creator_id,
            "state_key": "",
        }

        creation_content.update({"creator": creator_id})
        yield send(
            etype=EventTypes.Create,
            content=creation_content,
        )

        logger.info("Sending %s in new room", EventTypes.Member)
        yield self.room_member_handler.update_membership(
            creator,
            creator.user,
            room_id,
            "join",
            ratelimit=False,
            content=creator_join_profile,
        )

        # We treat the power levels override specially as this needs to be one
        # of the first events that get sent into a room.
        pl_content = initial_state.pop((EventTypes.PowerLevels, ''), None)
        if pl_content is not None:
            yield send(
                etype=EventTypes.PowerLevels,
                content=pl_content,
            )
        else:
            power_level_content = {
                "users": {
                    creator_id: 100,
                },
                "users_default": 0,
                "events": {
                    EventTypes.Name: 50,
                    EventTypes.PowerLevels: 100,
                    EventTypes.RoomHistoryVisibility: 100,
                    EventTypes.CanonicalAlias: 50,
                    EventTypes.RoomAvatar: 50,
                },
                "events_default": 0,
                "state_default": 50,
                "ban": 50,
                "kick": 50,
                "redact": 50,
                "invite": 0,
            }

            if config["original_invitees_have_ops"]:
                for invitee in invite_list:
                    power_level_content["users"][invitee] = 100

            if power_level_content_override:
                power_level_content.update(power_level_content_override)

            yield send(
                etype=EventTypes.PowerLevels,
                content=power_level_content,
            )

        if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state:
            yield send(
                etype=EventTypes.CanonicalAlias,
                content={"alias": room_alias.to_string()},
            )

        if (EventTypes.JoinRules, '') not in initial_state:
            yield send(
                etype=EventTypes.JoinRules,
                content={"join_rule": config["join_rules"]},
            )

        if (EventTypes.RoomHistoryVisibility, '') not in initial_state:
            yield send(
                etype=EventTypes.RoomHistoryVisibility,
                content={"history_visibility": config["history_visibility"]}
            )

        if config["guest_can_join"]:
            if (EventTypes.GuestAccess, '') not in initial_state:
                yield send(
                    etype=EventTypes.GuestAccess,
                    content={"guest_access": "can_join"}
                )

        for (etype, state_key), content in initial_state.items():
            yield send(
                etype=etype,
                state_key=state_key,
                content=content,
            )

    @defer.inlineCallbacks
    def _generate_room_id(self, creator_id, is_public):
        # autogen room IDs and try to create it. We may clash, so just
        # try a few times till one goes through, giving up eventually.
        attempts = 0
        while attempts < 5:
            try:
                random_string = stringutils.random_string(18)
                gen_room_id = RoomID(
                    random_string,
                    self.hs.hostname,
                ).to_string()
                if isinstance(gen_room_id, bytes):
                    gen_room_id = gen_room_id.decode('utf-8')
                yield self.store.store_room(
                    room_id=gen_room_id,
                    room_creator_user_id=creator_id,
                    is_public=is_public,
                )
                defer.returnValue(gen_room_id)
            except StoreError:
                attempts += 1
        raise StoreError(500, "Couldn't generate a room ID.")
Пример #45
0
class _JoinedHostsCache(object):
    """Cache for joined hosts in a room that is optimised to handle updates
    via state deltas.
    """
    def __init__(self, store, room_id):
        self.store = store
        self.room_id = room_id

        self.hosts_to_joined_users = {}

        self.state_group = object()

        self.linearizer = Linearizer("_JoinedHostsCache")

        self._len = 0

    @defer.inlineCallbacks
    def get_destinations(self, state_entry):
        """Get set of destinations for a state entry

        Args:
            state_entry(synapse.state._StateCacheEntry)
        """
        if state_entry.state_group == self.state_group:
            return frozenset(self.hosts_to_joined_users)

        with (yield self.linearizer.queue(())):
            if state_entry.state_group == self.state_group:
                pass
            elif state_entry.prev_group == self.state_group:
                for (typ,
                     state_key), event_id in iteritems(state_entry.delta_ids):
                    if typ != EventTypes.Member:
                        continue

                    host = intern_string(get_domain_from_id(state_key))
                    user_id = state_key
                    known_joins = self.hosts_to_joined_users.setdefault(
                        host, set())

                    event = yield self.store.get_event(event_id)
                    if event.membership == Membership.JOIN:
                        known_joins.add(user_id)
                    else:
                        known_joins.discard(user_id)

                        if not known_joins:
                            self.hosts_to_joined_users.pop(host, None)
            else:
                joined_users = yield self.store.get_joined_users_from_state(
                    self.room_id, state_entry)

                self.hosts_to_joined_users = {}
                for user_id in joined_users:
                    host = intern_string(get_domain_from_id(user_id))
                    self.hosts_to_joined_users.setdefault(host,
                                                          set()).add(user_id)

            if state_entry.state_group:
                self.state_group = state_entry.state_group
            else:
                self.state_group = object()
            self._len = sum(
                len(v) for v in itervalues(self.hosts_to_joined_users))
        return frozenset(self.hosts_to_joined_users)

    def __len__(self):
        return self._len
Пример #46
0
class RoomCreationHandler(BaseHandler):

    PRESETS_DICT = {
        RoomCreationPreset.PRIVATE_CHAT: {
            "join_rules": JoinRules.INVITE,
            "history_visibility": "shared",
            "original_invitees_have_ops": False,
            "guest_can_join": True,
        },
        RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
            "join_rules": JoinRules.INVITE,
            "history_visibility": "shared",
            "original_invitees_have_ops": True,
            "guest_can_join": True,
        },
        RoomCreationPreset.PUBLIC_CHAT: {
            "join_rules": JoinRules.PUBLIC,
            "history_visibility": "shared",
            "original_invitees_have_ops": False,
            "guest_can_join": False,
        },
    }

    def __init__(self, hs):
        super(RoomCreationHandler, self).__init__(hs)

        self.spam_checker = hs.get_spam_checker()
        self.event_creation_handler = hs.get_event_creation_handler()
        self.room_member_handler = hs.get_room_member_handler()

        # linearizer to stop two upgrades happening at once
        self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")

    @defer.inlineCallbacks
    def upgrade_room(self, requester, old_room_id, new_version):
        """Replace a room with a new room with a different version

        Args:
            requester (synapse.types.Requester): the user requesting the upgrade
            old_room_id (unicode): the id of the room to be replaced
            new_version (unicode): the new room version to use

        Returns:
            Deferred[unicode]: the new room id
        """
        yield self.ratelimit(requester)

        user_id = requester.user.to_string()

        with (yield self._upgrade_linearizer.queue(old_room_id)):
            # start by allocating a new room id
            r = yield self.store.get_room(old_room_id)
            if r is None:
                raise NotFoundError("Unknown room id %s" % (old_room_id,))
            new_room_id = yield self._generate_room_id(
                creator_id=user_id, is_public=r["is_public"],
            )

            logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)

            # we create and auth the tombstone event before properly creating the new
            # room, to check our user has perms in the old room.
            tombstone_event, tombstone_context = (
                yield self.event_creation_handler.create_event(
                    requester, {
                        "type": EventTypes.Tombstone,
                        "state_key": "",
                        "room_id": old_room_id,
                        "sender": user_id,
                        "content": {
                            "body": "This room has been replaced",
                            "replacement_room": new_room_id,
                        }
                    },
                    token_id=requester.access_token_id,
                )
            )
            old_room_version = yield self.store.get_room_version(old_room_id)
            yield self.auth.check_from_context(
                old_room_version, tombstone_event, tombstone_context,
            )

            yield self.clone_existing_room(
                requester,
                old_room_id=old_room_id,
                new_room_id=new_room_id,
                new_room_version=new_version,
                tombstone_event_id=tombstone_event.event_id,
            )

            # now send the tombstone
            yield self.event_creation_handler.send_nonmember_event(
                requester, tombstone_event, tombstone_context,
            )

            old_room_state = yield tombstone_context.get_current_state_ids(self.store)

            # update any aliases
            yield self._move_aliases_to_new_room(
                requester, old_room_id, new_room_id, old_room_state,
            )

            # and finally, shut down the PLs in the old room, and update them in the new
            # room.
            yield self._update_upgraded_room_pls(
                requester, old_room_id, new_room_id, old_room_state,
            )

            defer.returnValue(new_room_id)

    @defer.inlineCallbacks
    def _update_upgraded_room_pls(
            self, requester, old_room_id, new_room_id, old_room_state,
    ):
        """Send updated power levels in both rooms after an upgrade

        Args:
            requester (synapse.types.Requester): the user requesting the upgrade
            old_room_id (unicode): the id of the room to be replaced
            new_room_id (unicode): the id of the replacement room
            old_room_state (dict[tuple[str, str], str]): the state map for the old room

        Returns:
            Deferred
        """
        old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, ""))

        if old_room_pl_event_id is None:
            logger.warning(
                "Not supported: upgrading a room with no PL event. Not setting PLs "
                "in old room.",
            )
            return

        old_room_pl_state = yield self.store.get_event(old_room_pl_event_id)

        # we try to stop regular users from speaking by setting the PL required
        # to send regular events and invites to 'Moderator' level. That's normally
        # 50, but if the default PL in a room is 50 or more, then we set the
        # required PL above that.

        pl_content = dict(old_room_pl_state.content)
        users_default = int(pl_content.get("users_default", 0))
        restricted_level = max(users_default + 1, 50)

        updated = False
        for v in ("invite", "events_default"):
            current = int(pl_content.get(v, 0))
            if current < restricted_level:
                logger.info(
                    "Setting level for %s in %s to %i (was %i)",
                    v, old_room_id, restricted_level, current,
                )
                pl_content[v] = restricted_level
                updated = True
            else:
                logger.info(
                    "Not setting level for %s (already %i)",
                    v, current,
                )

        if updated:
            try:
                yield self.event_creation_handler.create_and_send_nonmember_event(
                    requester, {
                        "type": EventTypes.PowerLevels,
                        "state_key": '',
                        "room_id": old_room_id,
                        "sender": requester.user.to_string(),
                        "content": pl_content,
                    }, ratelimit=False,
                )
            except AuthError as e:
                logger.warning("Unable to update PLs in old room: %s", e)

        logger.info("Setting correct PLs in new room")
        yield self.event_creation_handler.create_and_send_nonmember_event(
            requester, {
                "type": EventTypes.PowerLevels,
                "state_key": '',
                "room_id": new_room_id,
                "sender": requester.user.to_string(),
                "content": old_room_pl_state.content,
            }, ratelimit=False,
        )

    @defer.inlineCallbacks
    def clone_existing_room(
            self, requester, old_room_id, new_room_id, new_room_version,
            tombstone_event_id,
    ):
        """Populate a new room based on an old room

        Args:
            requester (synapse.types.Requester): the user requesting the upgrade
            old_room_id (unicode): the id of the room to be replaced
            new_room_id (unicode): the id to give the new room (should already have been
                created with _gemerate_room_id())
            new_room_version (unicode): the new room version to use
            tombstone_event_id (unicode|str): the ID of the tombstone event in the old
                room.
        Returns:
            Deferred[None]
        """
        user_id = requester.user.to_string()

        if not self.spam_checker.user_may_create_room(user_id):
            raise SynapseError(403, "You are not permitted to create rooms")

        creation_content = {
            "room_version": new_room_version,
            "predecessor": {
                "room_id": old_room_id,
                "event_id": tombstone_event_id,
            }
        }

        initial_state = dict()

        # Replicate relevant room events
        types_to_copy = (
            (EventTypes.JoinRules, ""),
            (EventTypes.Name, ""),
            (EventTypes.Topic, ""),
            (EventTypes.RoomHistoryVisibility, ""),
            (EventTypes.GuestAccess, ""),
            (EventTypes.RoomAvatar, ""),
            (EventTypes.Encryption, ""),
        )

        old_room_state_ids = yield self.store.get_filtered_current_state_ids(
            old_room_id, StateFilter.from_types(types_to_copy),
        )
        # map from event_id to BaseEvent
        old_room_state_events = yield self.store.get_events(old_room_state_ids.values())

        for k, old_event_id in iteritems(old_room_state_ids):
            old_event = old_room_state_events.get(old_event_id)
            if old_event:
                initial_state[k] = old_event.content

        yield self._send_events_for_new_room(
            requester,
            new_room_id,

            # we expect to override all the presets with initial_state, so this is
            # somewhat arbitrary.
            preset_config=RoomCreationPreset.PRIVATE_CHAT,

            invite_list=[],
            initial_state=initial_state,
            creation_content=creation_content,
        )

        # XXX invites/joins
        # XXX 3pid invites

    @defer.inlineCallbacks
    def _move_aliases_to_new_room(
            self, requester, old_room_id, new_room_id, old_room_state,
    ):
        directory_handler = self.hs.get_handlers().directory_handler

        aliases = yield self.store.get_aliases_for_room(old_room_id)

        # check to see if we have a canonical alias.
        canonical_alias = None
        canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
        if canonical_alias_event_id:
            canonical_alias_event = yield self.store.get_event(canonical_alias_event_id)
            if canonical_alias_event:
                canonical_alias = canonical_alias_event.content.get("alias", "")

        # first we try to remove the aliases from the old room (we suppress sending
        # the room_aliases event until the end).
        #
        # Note that we'll only be able to remove aliases that (a) aren't owned by an AS,
        # and (b) unless the user is a server admin, which the user created.
        #
        # This is probably correct - given we don't allow such aliases to be deleted
        # normally, it would be odd to allow it in the case of doing a room upgrade -
        # but it makes the upgrade less effective, and you have to wonder why a room
        # admin can't remove aliases that point to that room anyway.
        # (cf https://github.com/matrix-org/synapse/issues/2360)
        #
        removed_aliases = []
        for alias_str in aliases:
            alias = RoomAlias.from_string(alias_str)
            try:
                yield directory_handler.delete_association(
                    requester, alias, send_event=False,
                )
                removed_aliases.append(alias_str)
            except SynapseError as e:
                logger.warning(
                    "Unable to remove alias %s from old room: %s",
                    alias, e,
                )

        # if we didn't find any aliases, or couldn't remove anyway, we can skip the rest
        # of this.
        if not removed_aliases:
            return

        try:
            # this can fail if, for some reason, our user doesn't have perms to send
            # m.room.aliases events in the old room (note that we've already checked that
            # they have perms to send a tombstone event, so that's not terribly likely).
            #
            # If that happens, it's regrettable, but we should carry on: it's the same
            # as when you remove an alias from the directory normally - it just means that
            # the aliases event gets out of sync with the directory
            # (cf https://github.com/vector-im/riot-web/issues/2369)
            yield directory_handler.send_room_alias_update_event(
                requester, old_room_id,
            )
        except AuthError as e:
            logger.warning(
                "Failed to send updated alias event on old room: %s", e,
            )

        # we can now add any aliases we successfully removed to the new room.
        for alias in removed_aliases:
            try:
                yield directory_handler.create_association(
                    requester, RoomAlias.from_string(alias),
                    new_room_id, servers=(self.hs.hostname, ),
                    send_event=False,
                )
                logger.info("Moved alias %s to new room", alias)
            except SynapseError as e:
                # I'm not really expecting this to happen, but it could if the spam
                # checking module decides it shouldn't, or similar.
                logger.error(
                    "Error adding alias %s to new room: %s",
                    alias, e,
                )

        try:
            if canonical_alias and (canonical_alias in removed_aliases):
                yield self.event_creation_handler.create_and_send_nonmember_event(
                    requester,
                    {
                        "type": EventTypes.CanonicalAlias,
                        "state_key": "",
                        "room_id": new_room_id,
                        "sender": requester.user.to_string(),
                        "content": {"alias": canonical_alias, },
                    },
                    ratelimit=False
                )

            yield directory_handler.send_room_alias_update_event(
                requester, new_room_id,
            )
        except SynapseError as e:
            # again I'm not really expecting this to fail, but if it does, I'd rather
            # we returned the new room to the client at this point.
            logger.error(
                "Unable to send updated alias events in new room: %s", e,
            )

    @defer.inlineCallbacks
    def create_room(self, requester, config, ratelimit=True,
                    creator_join_profile=None):
        """ Creates a new room.

        Args:
            requester (synapse.types.Requester):
                The user who requested the room creation.
            config (dict) : A dict of configuration options.
            ratelimit (bool): set to False to disable the rate limiter

            creator_join_profile (dict|None):
                Set to override the displayname and avatar for the creating
                user in this room. If unset, displayname and avatar will be
                derived from the user's profile. If set, should contain the
                values to go in the body of the 'join' event (typically
                `avatar_url` and/or `displayname`.

        Returns:
            Deferred[dict]:
                a dict containing the keys `room_id` and, if an alias was
                requested, `room_alias`.
        Raises:
            SynapseError if the room ID couldn't be stored, or something went
            horribly wrong.
            ResourceLimitError if server is blocked to some resource being
            exceeded
        """
        user_id = requester.user.to_string()

        yield self.auth.check_auth_blocking(user_id)

        if not self.spam_checker.user_may_create_room(user_id):
            raise SynapseError(403, "You are not permitted to create rooms")

        if ratelimit:
            yield self.ratelimit(requester)

        room_version = config.get("room_version", DEFAULT_ROOM_VERSION)
        if not isinstance(room_version, string_types):
            raise SynapseError(
                400,
                "room_version must be a string",
                Codes.BAD_JSON,
            )

        if room_version not in KNOWN_ROOM_VERSIONS:
            raise SynapseError(
                400,
                "Your homeserver does not support this room version",
                Codes.UNSUPPORTED_ROOM_VERSION,
            )

        if "room_alias_name" in config:
            for wchar in string.whitespace:
                if wchar in config["room_alias_name"]:
                    raise SynapseError(400, "Invalid characters in room alias")

            room_alias = RoomAlias(
                config["room_alias_name"],
                self.hs.hostname,
            )
            mapping = yield self.store.get_association_from_room_alias(
                room_alias
            )

            if mapping:
                raise SynapseError(
                    400,
                    "Room alias already taken",
                    Codes.ROOM_IN_USE
                )
        else:
            room_alias = None

        invite_list = config.get("invite", [])
        for i in invite_list:
            try:
                UserID.from_string(i)
            except Exception:
                raise SynapseError(400, "Invalid user_id: %s" % (i,))

        yield self.event_creation_handler.assert_accepted_privacy_policy(
            requester,
        )

        invite_3pid_list = config.get("invite_3pid", [])

        visibility = config.get("visibility", None)
        is_public = visibility == "public"

        room_id = yield self._generate_room_id(creator_id=user_id, is_public=is_public)

        if room_alias:
            directory_handler = self.hs.get_handlers().directory_handler
            yield directory_handler.create_association(
                requester=requester,
                room_id=room_id,
                room_alias=room_alias,
                servers=[self.hs.hostname],
                send_event=False,
            )

        preset_config = config.get(
            "preset",
            RoomCreationPreset.PRIVATE_CHAT
            if visibility == "private"
            else RoomCreationPreset.PUBLIC_CHAT
        )

        raw_initial_state = config.get("initial_state", [])

        initial_state = OrderedDict()
        for val in raw_initial_state:
            initial_state[(val["type"], val.get("state_key", ""))] = val["content"]

        creation_content = config.get("creation_content", {})

        # override any attempt to set room versions via the creation_content
        creation_content["room_version"] = room_version

        yield self._send_events_for_new_room(
            requester,
            room_id,
            preset_config=preset_config,
            invite_list=invite_list,
            initial_state=initial_state,
            creation_content=creation_content,
            room_alias=room_alias,
            power_level_content_override=config.get("power_level_content_override"),
            creator_join_profile=creator_join_profile,
        )

        if "name" in config:
            name = config["name"]
            yield self.event_creation_handler.create_and_send_nonmember_event(
                requester,
                {
                    "type": EventTypes.Name,
                    "room_id": room_id,
                    "sender": user_id,
                    "state_key": "",
                    "content": {"name": name},
                },
                ratelimit=False)

        if "topic" in config:
            topic = config["topic"]
            yield self.event_creation_handler.create_and_send_nonmember_event(
                requester,
                {
                    "type": EventTypes.Topic,
                    "room_id": room_id,
                    "sender": user_id,
                    "state_key": "",
                    "content": {"topic": topic},
                },
                ratelimit=False)

        for invitee in invite_list:
            content = {}
            is_direct = config.get("is_direct", None)
            if is_direct:
                content["is_direct"] = is_direct

            yield self.room_member_handler.update_membership(
                requester,
                UserID.from_string(invitee),
                room_id,
                "invite",
                ratelimit=False,
                content=content,
            )

        for invite_3pid in invite_3pid_list:
            id_server = invite_3pid["id_server"]
            address = invite_3pid["address"]
            medium = invite_3pid["medium"]
            yield self.hs.get_room_member_handler().do_3pid_invite(
                room_id,
                requester.user,
                medium,
                address,
                id_server,
                requester,
                txn_id=None,
            )

        result = {"room_id": room_id}

        if room_alias:
            result["room_alias"] = room_alias.to_string()
            yield directory_handler.send_room_alias_update_event(
                requester, room_id
            )

        defer.returnValue(result)

    @defer.inlineCallbacks
    def _send_events_for_new_room(
            self,
            creator,  # A Requester object.
            room_id,
            preset_config,
            invite_list,
            initial_state,
            creation_content,
            room_alias=None,
            power_level_content_override=None,
            creator_join_profile=None,
    ):
        def create(etype, content, **kwargs):
            e = {
                "type": etype,
                "content": content,
            }

            e.update(event_keys)
            e.update(kwargs)

            return e

        @defer.inlineCallbacks
        def send(etype, content, **kwargs):
            event = create(etype, content, **kwargs)
            logger.info("Sending %s in new room", etype)
            yield self.event_creation_handler.create_and_send_nonmember_event(
                creator,
                event,
                ratelimit=False
            )

        config = RoomCreationHandler.PRESETS_DICT[preset_config]

        creator_id = creator.user.to_string()

        event_keys = {
            "room_id": room_id,
            "sender": creator_id,
            "state_key": "",
        }

        creation_content.update({"creator": creator_id})
        yield send(
            etype=EventTypes.Create,
            content=creation_content,
        )

        logger.info("Sending %s in new room", EventTypes.Member)
        yield self.room_member_handler.update_membership(
            creator,
            creator.user,
            room_id,
            "join",
            ratelimit=False,
            content=creator_join_profile,
        )

        # We treat the power levels override specially as this needs to be one
        # of the first events that get sent into a room.
        pl_content = initial_state.pop((EventTypes.PowerLevels, ''), None)
        if pl_content is not None:
            yield send(
                etype=EventTypes.PowerLevels,
                content=pl_content,
            )
        else:
            power_level_content = {
                "users": {
                    creator_id: 100,
                },
                "users_default": 0,
                "events": {
                    EventTypes.Name: 50,
                    EventTypes.PowerLevels: 100,
                    EventTypes.RoomHistoryVisibility: 100,
                    EventTypes.CanonicalAlias: 50,
                    EventTypes.RoomAvatar: 50,
                },
                "events_default": 0,
                "state_default": 50,
                "ban": 50,
                "kick": 50,
                "redact": 50,
                "invite": 0,
            }

            if config["original_invitees_have_ops"]:
                for invitee in invite_list:
                    power_level_content["users"][invitee] = 100

            if power_level_content_override:
                power_level_content.update(power_level_content_override)

            yield send(
                etype=EventTypes.PowerLevels,
                content=power_level_content,
            )

        if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state:
            yield send(
                etype=EventTypes.CanonicalAlias,
                content={"alias": room_alias.to_string()},
            )

        if (EventTypes.JoinRules, '') not in initial_state:
            yield send(
                etype=EventTypes.JoinRules,
                content={"join_rule": config["join_rules"]},
            )

        if (EventTypes.RoomHistoryVisibility, '') not in initial_state:
            yield send(
                etype=EventTypes.RoomHistoryVisibility,
                content={"history_visibility": config["history_visibility"]}
            )

        if config["guest_can_join"]:
            if (EventTypes.GuestAccess, '') not in initial_state:
                yield send(
                    etype=EventTypes.GuestAccess,
                    content={"guest_access": "can_join"}
                )

        for (etype, state_key), content in initial_state.items():
            yield send(
                etype=etype,
                state_key=state_key,
                content=content,
            )

    @defer.inlineCallbacks
    def _generate_room_id(self, creator_id, is_public):
        # autogen room IDs and try to create it. We may clash, so just
        # try a few times till one goes through, giving up eventually.
        attempts = 0
        while attempts < 5:
            try:
                random_string = stringutils.random_string(18)
                gen_room_id = RoomID(
                    random_string,
                    self.hs.hostname,
                ).to_string()
                if isinstance(gen_room_id, bytes):
                    gen_room_id = gen_room_id.decode('utf-8')
                yield self.store.store_room(
                    room_id=gen_room_id,
                    room_creator_user_id=creator_id,
                    is_public=is_public,
                )
                defer.returnValue(gen_room_id)
            except StoreError:
                attempts += 1
        raise StoreError(500, "Couldn't generate a room ID.")
class RulesForRoom(object):
    """Caches push rules for users in a room.

    This efficiently handles users joining/leaving the room by not invalidating
    the entire cache for the room.
    """

    def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics):
        """
        Args:
            hs (HomeServer)
            room_id (str)
            rules_for_room_cache(Cache): The cache object that caches these
                RoomsForUser objects.
            room_push_rule_cache_metrics (CacheMetric)
        """
        self.room_id = room_id
        self.is_mine_id = hs.is_mine_id
        self.store = hs.get_datastore()
        self.room_push_rule_cache_metrics = room_push_rule_cache_metrics

        self.linearizer = Linearizer(name="rules_for_room")

        self.member_map = {}  # event_id -> (user_id, state)
        self.rules_by_user = {}  # user_id -> rules

        # The last state group we updated the caches for. If the state_group of
        # a new event comes along, we know that we can just return the cached
        # result.
        # On invalidation of the rules themselves (if the user changes them),
        # we invalidate everything and set state_group to `object()`
        self.state_group = object()

        # A sequence number to keep track of when we're allowed to update the
        # cache. We bump the sequence number when we invalidate the cache. If
        # the sequence number changes while we're calculating stuff we should
        # not update the cache with it.
        self.sequence = 0

        # A cache of user_ids that we *know* aren't interesting, e.g. user_ids
        # owned by AS's, or remote users, etc. (I.e. users we will never need to
        # calculate push for)
        # These never need to be invalidated as we will never set up push for
        # them.
        self.uninteresting_user_set = set()

        # We need to be clever on the invalidating caches callbacks, as
        # otherwise the invalidation callback holds a reference to the object,
        # potentially causing it to leak.
        # To get around this we pass a function that on invalidations looks ups
        # the RoomsForUser entry in the cache, rather than keeping a reference
        # to self around in the callback.
        self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)

    @defer.inlineCallbacks
    def get_rules(self, event, context):
        """Given an event context return the rules for all users who are
        currently in the room.
        """
        state_group = context.state_group

        if state_group and self.state_group == state_group:
            logger.debug("Using cached rules for %r", self.room_id)
            self.room_push_rule_cache_metrics.inc_hits()
            defer.returnValue(self.rules_by_user)

        with (yield self.linearizer.queue(())):
            if state_group and self.state_group == state_group:
                logger.debug("Using cached rules for %r", self.room_id)
                self.room_push_rule_cache_metrics.inc_hits()
                defer.returnValue(self.rules_by_user)

            self.room_push_rule_cache_metrics.inc_misses()

            ret_rules_by_user = {}
            missing_member_event_ids = {}
            if state_group and self.state_group == context.prev_group:
                # If we have a simple delta then we can reuse most of the previous
                # results.
                ret_rules_by_user = self.rules_by_user
                current_state_ids = context.delta_ids

                push_rules_delta_state_cache_metric.inc_hits()
            else:
                current_state_ids = yield context.get_current_state_ids(self.store)
                push_rules_delta_state_cache_metric.inc_misses()

            push_rules_state_size_counter.inc(len(current_state_ids))

            logger.debug(
                "Looking for member changes in %r %r", state_group, current_state_ids
            )

            # Loop through to see which member events we've seen and have rules
            # for and which we need to fetch
            for key in current_state_ids:
                typ, user_id = key
                if typ != EventTypes.Member:
                    continue

                if user_id in self.uninteresting_user_set:
                    continue

                if not self.is_mine_id(user_id):
                    self.uninteresting_user_set.add(user_id)
                    continue

                if self.store.get_if_app_services_interested_in_user(user_id):
                    self.uninteresting_user_set.add(user_id)
                    continue

                event_id = current_state_ids[key]

                res = self.member_map.get(event_id, None)
                if res:
                    user_id, state = res
                    if state == Membership.JOIN:
                        rules = self.rules_by_user.get(user_id, None)
                        if rules:
                            ret_rules_by_user[user_id] = rules
                    continue

                # If a user has left a room we remove their push rule. If they
                # joined then we readd it later in _update_rules_with_member_event_ids
                ret_rules_by_user.pop(user_id, None)
                missing_member_event_ids[user_id] = event_id

            if missing_member_event_ids:
                # If we have some memebr events we haven't seen, look them up
                # and fetch push rules for them if appropriate.
                logger.debug("Found new member events %r", missing_member_event_ids)
                yield self._update_rules_with_member_event_ids(
                    ret_rules_by_user, missing_member_event_ids, state_group, event
                )
            else:
                # The push rules didn't change but lets update the cache anyway
                self.update_cache(
                    self.sequence,
                    members={},  # There were no membership changes
                    rules_by_user=ret_rules_by_user,
                    state_group=state_group
                )

        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(
                "Returning push rules for %r %r",
                self.room_id, ret_rules_by_user.keys(),
            )
        defer.returnValue(ret_rules_by_user)

    @defer.inlineCallbacks
    def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids,
                                            state_group, event):
        """Update the partially filled rules_by_user dict by fetching rules for
        any newly joined users in the `member_event_ids` list.

        Args:
            ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
                updated with any new rules.
            member_event_ids (list): List of event ids for membership events that
                have happened since the last time we filled rules_by_user
            state_group: The state group we are currently computing push rules
                for. Used when updating the cache.
        """
        sequence = self.sequence

        rows = yield self.store._simple_select_many_batch(
            table="room_memberships",
            column="event_id",
            iterable=member_event_ids.values(),
            retcols=('user_id', 'membership', 'event_id'),
            keyvalues={},
            batch_size=500,
            desc="_get_rules_for_member_event_ids",
        )

        members = {
            row["event_id"]: (row["user_id"], row["membership"])
            for row in rows
        }

        # If the event is a join event then it will be in current state evnts
        # map but not in the DB, so we have to explicitly insert it.
        if event.type == EventTypes.Member:
            for event_id in itervalues(member_event_ids):
                if event_id == event.event_id:
                    members[event_id] = (event.state_key, event.membership)

        if logger.isEnabledFor(logging.DEBUG):
            logger.debug("Found members %r: %r", self.room_id, members.values())

        interested_in_user_ids = set(
            user_id for user_id, membership in itervalues(members)
            if membership == Membership.JOIN
        )

        logger.debug("Joined: %r", interested_in_user_ids)

        if_users_with_pushers = yield self.store.get_if_users_have_pushers(
            interested_in_user_ids,
            on_invalidate=self.invalidate_all_cb,
        )

        user_ids = set(
            uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher
        )

        logger.debug("With pushers: %r", user_ids)

        users_with_receipts = yield self.store.get_users_with_read_receipts_in_room(
            self.room_id, on_invalidate=self.invalidate_all_cb,
        )

        logger.debug("With receipts: %r", users_with_receipts)

        # any users with pushers must be ours: they have pushers
        for uid in users_with_receipts:
            if uid in interested_in_user_ids:
                user_ids.add(uid)

        rules_by_user = yield self.store.bulk_get_push_rules(
            user_ids, on_invalidate=self.invalidate_all_cb,
        )

        ret_rules_by_user.update(
            item for item in iteritems(rules_by_user) if item[0] is not None
        )

        self.update_cache(sequence, members, ret_rules_by_user, state_group)

    def invalidate_all(self):
        # Note: Don't hand this function directly to an invalidation callback
        # as it keeps a reference to self and will stop this instance from being
        # GC'd if it gets dropped from the rules_to_user cache. Instead use
        # `self.invalidate_all_cb`
        logger.debug("Invalidating RulesForRoom for %r", self.room_id)
        self.sequence += 1
        self.state_group = object()
        self.member_map = {}
        self.rules_by_user = {}
        push_rules_invalidation_counter.inc()

    def update_cache(self, sequence, members, rules_by_user, state_group):
        if sequence == self.sequence:
            self.member_map.update(members)
            self.rules_by_user = rules_by_user
            self.state_group = state_group
Пример #48
0
class DeviceListEduUpdater(object):
    "Handles incoming device list updates from federation and updates the DB"

    def __init__(self, hs, device_handler):
        self.store = hs.get_datastore()
        self.federation = hs.get_federation_client()
        self.clock = hs.get_clock()
        self.device_handler = device_handler

        self._remote_edu_linearizer = Linearizer(name="remote_device_list")

        # user_id -> list of updates waiting to be handled.
        self._pending_updates = {}

        # Recently seen stream ids. We don't bother keeping these in the DB,
        # but they're useful to have them about to reduce the number of spurious
        # resyncs.
        self._seen_updates = ExpiringCache(
            cache_name="device_update_edu",
            clock=self.clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
            iterable=True,
        )

    @defer.inlineCallbacks
    def incoming_device_list_update(self, origin, edu_content):
        """Called on incoming device list update from federation. Responsible
        for parsing the EDU and adding to pending updates list.
        """

        user_id = edu_content.pop("user_id")
        device_id = edu_content.pop("device_id")
        stream_id = str(edu_content.pop("stream_id"))  # They may come as ints
        prev_ids = edu_content.pop("prev_id", [])
        prev_ids = [str(p) for p in prev_ids]   # They may come as ints

        if get_domain_from_id(user_id) != origin:
            # TODO: Raise?
            logger.warning(
                "Got device list update edu for %r/%r from %r",
                user_id, device_id, origin,
            )
            return

        room_ids = yield self.store.get_rooms_for_user(user_id)
        if not room_ids:
            # We don't share any rooms with this user. Ignore update, as we
            # probably won't get any further updates.
            logger.warning(
                "Got device list update edu for %r/%r, but don't share a room",
                user_id, device_id,
            )
            return

        logger.debug(
            "Received device list update for %r/%r", user_id, device_id,
        )

        self._pending_updates.setdefault(user_id, []).append(
            (device_id, stream_id, prev_ids, edu_content)
        )

        yield self._handle_device_updates(user_id)

    @measure_func("_incoming_device_list_update")
    @defer.inlineCallbacks
    def _handle_device_updates(self, user_id):
        "Actually handle pending updates."

        with (yield self._remote_edu_linearizer.queue(user_id)):
            pending_updates = self._pending_updates.pop(user_id, [])
            if not pending_updates:
                # This can happen since we batch updates
                return

            for device_id, stream_id, prev_ids, content in pending_updates:
                logger.debug(
                    "Handling update %r/%r, ID: %r, prev: %r ",
                    user_id, device_id, stream_id, prev_ids,
                )

            # Given a list of updates we check if we need to resync. This
            # happens if we've missed updates.
            resync = yield self._need_to_do_resync(user_id, pending_updates)

            logger.debug("Need to re-sync devices for %r? %r", user_id, resync)

            if resync:
                # Fetch all devices for the user.
                origin = get_domain_from_id(user_id)
                try:
                    result = yield self.federation.query_user_devices(origin, user_id)
                except (
                    NotRetryingDestination, RequestSendFailed, HttpResponseException,
                ):
                    # TODO: Remember that we are now out of sync and try again
                    # later
                    logger.warn(
                        "Failed to handle device list update for %s", user_id,
                    )
                    # We abort on exceptions rather than accepting the update
                    # as otherwise synapse will 'forget' that its device list
                    # is out of date. If we bail then we will retry the resync
                    # next time we get a device list update for this user_id.
                    # This makes it more likely that the device lists will
                    # eventually become consistent.
                    return
                except FederationDeniedError as e:
                    logger.info(e)
                    return
                except Exception:
                    # TODO: Remember that we are now out of sync and try again
                    # later
                    logger.exception(
                        "Failed to handle device list update for %s", user_id
                    )
                    return

                stream_id = result["stream_id"]
                devices = result["devices"]

                # If the remote server has more than ~1000 devices for this user
                # we assume that something is going horribly wrong (e.g. a bot
                # that logs in and creates a new device every time it tries to
                # send a message).  Maintaining lots of devices per user in the
                # cache can cause serious performance issues as if this request
                # takes more than 60s to complete, internal replication from the
                # inbound federation worker to the synapse master may time out
                # causing the inbound federation to fail and causing the remote
                # server to retry, causing a DoS.  So in this scenario we give
                # up on storing the total list of devices and only handle the
                # delta instead.
                if len(devices) > 1000:
                    logger.warn(
                        "Ignoring device list snapshot for %s as it has >1K devs (%d)",
                        user_id, len(devices)
                    )
                    devices = []

                for device in devices:
                    logger.debug(
                        "Handling resync update %r/%r, ID: %r",
                        user_id, device["device_id"], stream_id,
                    )

                yield self.store.update_remote_device_list_cache(
                    user_id, devices, stream_id,
                )
                device_ids = [device["device_id"] for device in devices]
                yield self.device_handler.notify_device_update(user_id, device_ids)

                # We clobber the seen updates since we've re-synced from a given
                # point.
                self._seen_updates[user_id] = set([stream_id])
            else:
                # Simply update the single device, since we know that is the only
                # change (because of the single prev_id matching the current cache)
                for device_id, stream_id, prev_ids, content in pending_updates:
                    yield self.store.update_remote_device_list_cache_entry(
                        user_id, device_id, content, stream_id,
                    )

                yield self.device_handler.notify_device_update(
                    user_id, [device_id for device_id, _, _, _ in pending_updates]
                )

                self._seen_updates.setdefault(user_id, set()).update(
                    stream_id for _, stream_id, _, _ in pending_updates
                )

    @defer.inlineCallbacks
    def _need_to_do_resync(self, user_id, updates):
        """Given a list of updates for a user figure out if we need to do a full
        resync, or whether we have enough data that we can just apply the delta.
        """
        seen_updates = self._seen_updates.get(user_id, set())

        extremity = yield self.store.get_device_list_last_stream_id_for_remote(
            user_id
        )

        logger.debug(
            "Current extremity for %r: %r",
            user_id, extremity,
        )

        stream_id_in_updates = set()  # stream_ids in updates list
        for _, stream_id, prev_ids, _ in updates:
            if not prev_ids:
                # We always do a resync if there are no previous IDs
                defer.returnValue(True)

            for prev_id in prev_ids:
                if prev_id == extremity:
                    continue
                elif prev_id in seen_updates:
                    continue
                elif prev_id in stream_id_in_updates:
                    continue
                else:
                    defer.returnValue(True)

            stream_id_in_updates.add(stream_id)

        defer.returnValue(False)