Example #1
0
    def test_allowed_via_ratelimit_and_overriding_parameters(self):
        """Test that we can override options of the ratelimit method that would otherwise
        fail an action
        """
        # Create a Ratelimiter with a very low allowed rate_hz and burst_count
        limiter = Ratelimiter(
            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
        )

        # First attempt should be allowed
        self.get_success_or_raise(
            limiter.ratelimit(None, key=("test_id",), _time_now_s=0)
        )

        # Second attempt, 1s later, will fail
        with self.assertRaises(LimitExceededError) as context:
            self.get_success_or_raise(
                limiter.ratelimit(None, key=("test_id",), _time_now_s=1)
            )
        self.assertEqual(context.exception.retry_after_ms, 9000)

        # But, if we allow 10 actions/sec for this request, we should be allowed
        # to continue.
        self.get_success_or_raise(
            limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0)
        )

        # Similarly if we allow a burst of 10 actions
        self.get_success_or_raise(
            limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10)
        )
Example #2
0
class LookupBindStatusForOpenid(RestServlet):
    PATTERNS = client_patterns("/login/oauth2/bind/status$", v1=True)

    def __init__(self, hs):
        super().__init__()
        self.hs = hs
        self.auth = hs.get_auth()
        # self.get_ver_code_cache = ExpiringCache(
        #     cache_name="get_ver_code_cache",
        #     clock=self._clock,
        #     max_len=1000,
        #     expiry_ms=10 * 60 * 1000,
        #     reset_expiry_on_get=False,
        # )

        self._address_ratelimiter = Ratelimiter(
            clock=hs.get_clock(),
            rate_hz=self.hs.config.rc_login_address.per_second,
            burst_count=self.hs.config.rc_login_address.burst_count,
        )
        self.http_client = SimpleHttpClient(hs)

    async def on_POST(self, request: SynapseRequest):
        self._address_ratelimiter.ratelimit(request.getClientIP())

        params = parse_json_object_from_request(request)
        logger.info("----lookup bind--------param:%s" % (str(params)))

        bind_type = params["bind_type"]
        if bind_type is None:
            raise LoginError(410,
                             "bind_type field for bind openid is missing",
                             errcode=Codes.FORBIDDEN)

        requester = await self.auth.get_user_by_req(request)
        logger.info('------requester: %s' % (requester, ))
        user_id = requester.user
        logger.info('------user: %s' % (user_id, ))

        #complete unbind
        openid = await self.hs.get_datastore(
        ).get_external_id_for_user_provider(
            bind_type,
            str(user_id),
        )
        if openid is None:
            raise LoginError(400,
                             "openid not bind",
                             errcode=Codes.OPENID_NOT_BIND)

        return 200, {}
Example #3
0
    def test_db_user_override(self):
        """Test that users that have ratelimiting disabled in the DB aren't
        ratelimited.
        """
        store = self.hs.get_datastore()

        user_id = "@user:test"
        requester = create_requester(user_id)

        self.get_success(
            store.db_pool.simple_insert(
                table="ratelimit_override",
                values={
                    "user_id": user_id,
                    "messages_per_second": None,
                    "burst_count": None,
                },
                desc="test_db_user_override",
            )
        )

        limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1)

        # Shouldn't raise
        for _ in range(20):
            self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))
Example #4
0
    def test_allowed_via_ratelimit(self):
        limiter = Ratelimiter(
            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
        )

        # Shouldn't raise
        self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0))

        # Should raise
        with self.assertRaises(LimitExceededError) as context:
            self.get_success_or_raise(
                limiter.ratelimit(None, key="test_id", _time_now_s=5)
            )
        self.assertEqual(context.exception.retry_after_ms, 5000)

        # Shouldn't raise
        self.get_success_or_raise(
            limiter.ratelimit(None, key="test_id", _time_now_s=10)
        )
    def test_allowed_via_ratelimit(self):
        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)

        # Shouldn't raise
        limiter.ratelimit(key="test_id", _time_now_s=0)

        # Should raise
        with self.assertRaises(LimitExceededError) as context:
            limiter.ratelimit(key="test_id", _time_now_s=5)
        self.assertEqual(context.exception.retry_after_ms, 5000)

        # Shouldn't raise
        limiter.ratelimit(key="test_id", _time_now_s=10)
Example #6
0
class AuthHandler(BaseHandler):
    SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000

    def __init__(self, hs):
        """
        Args:
            hs (synapse.server.HomeServer):
        """
        super(AuthHandler, self).__init__(hs)

        self.checkers = {}  # type: dict[str, UserInteractiveAuthChecker]
        for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
            inst = auth_checker_class(hs)
            if inst.is_enabled():
                self.checkers[inst.AUTH_TYPE] = inst

        self.bcrypt_rounds = hs.config.bcrypt_rounds

        # This is not a cache per se, but a store of all current sessions that
        # expire after N hours
        self.sessions = ExpiringCache(
            cache_name="register_sessions",
            clock=hs.get_clock(),
            expiry_ms=self.SESSION_EXPIRE_MS,
            reset_expiry_on_get=True,
        )

        account_handler = ModuleApi(hs, self)
        self.password_providers = [
            module(config=config, account_handler=account_handler)
            for module, config in hs.config.password_providers
        ]

        logger.info("Extra password_providers: %r", self.password_providers)

        self.hs = hs  # FIXME better possibility to access registrationHandler later?
        self.macaroon_gen = hs.get_macaroon_generator()
        self._password_enabled = hs.config.password_enabled

        # we keep this as a list despite the O(N^2) implication so that we can
        # keep PASSWORD first and avoid confusing clients which pick the first
        # type in the list. (NB that the spec doesn't require us to do so and
        # clients which favour types that they don't understand over those that
        # they do are technically broken)
        login_types = []
        if self._password_enabled:
            login_types.append(LoginType.PASSWORD)
        for provider in self.password_providers:
            if hasattr(provider, "get_supported_login_types"):
                for t in provider.get_supported_login_types().keys():
                    if t not in login_types:
                        login_types.append(t)
        self._supported_login_types = login_types

        # Ratelimiter for failed auth during UIA. Uses same ratelimit config
        # as per `rc_login.failed_attempts`.
        self._failed_uia_attempts_ratelimiter = Ratelimiter()

        self._clock = self.hs.get_clock()

    @defer.inlineCallbacks
    def validate_user_via_ui_auth(self, requester, request_body, clientip):
        """
        Checks that the user is who they claim to be, via a UI auth.

        This is used for things like device deletion and password reset where
        the user already has a valid access token, but we want to double-check
        that it isn't stolen by re-authenticating them.

        Args:
            requester (Requester): The user, as given by the access token

            request_body (dict): The body of the request sent by the client

            clientip (str): The IP address of the client.

        Returns:
            defer.Deferred[dict]: the parameters for this request (which may
                have been given only in a previous call).

        Raises:
            InteractiveAuthIncompleteError if the client has not yet completed
                any of the permitted login flows

            AuthError if the client has completed a login flow, and it gives
                a different user to `requester`

            LimitExceededError if the ratelimiter's failed request count for this
                user is too high to proceed

        """

        user_id = requester.user.to_string()

        # Check if we should be ratelimited due to too many previous failed attempts
        self._failed_uia_attempts_ratelimiter.ratelimit(
            user_id,
            time_now_s=self._clock.time(),
            rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
            burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
            update=False,
        )

        # build a list of supported flows
        flows = [[login_type] for login_type in self._supported_login_types]

        try:
            result, params, _ = yield self.check_auth(flows, request_body,
                                                      clientip)
        except LoginError:
            # Update the ratelimite to say we failed (`can_do_action` doesn't raise).
            self._failed_uia_attempts_ratelimiter.can_do_action(
                user_id,
                time_now_s=self._clock.time(),
                rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
                burst_count=self.hs.config.rc_login_failed_attempts.
                burst_count,
                update=True,
            )
            raise

        # find the completed login type
        for login_type in self._supported_login_types:
            if login_type not in result:
                continue

            user_id = result[login_type]
            break
        else:
            # this can't happen
            raise Exception(
                "check_auth returned True but no successful login type")

        # check that the UI auth matched the access token
        if user_id != requester.user.to_string():
            raise AuthError(403, "Invalid auth")

        return params

    def get_enabled_auth_types(self):
        """Return the enabled user-interactive authentication types

        Returns the UI-Auth types which are supported by the homeserver's current
        config.
        """
        return self.checkers.keys()

    @defer.inlineCallbacks
    def check_auth(self, flows, clientdict, clientip):
        """
        Takes a dictionary sent by the client in the login / registration
        protocol and handles the User-Interactive Auth flow.

        As a side effect, this function fills in the 'creds' key on the user's
        session with a map, which maps each auth-type (str) to the relevant
        identity authenticated by that auth-type (mostly str, but for captcha, bool).

        If no auth flows have been completed successfully, raises an
        InteractiveAuthIncompleteError. To handle this, you can use
        synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
        decorator.

        Args:
            flows (list): A list of login flows. Each flow is an ordered list of
                          strings representing auth-types. At least one full
                          flow must be completed in order for auth to be successful.

            clientdict: The dictionary from the client root level, not the
                        'auth' key: this method prompts for auth if none is sent.

            clientip (str): The IP address of the client.

        Returns:
            defer.Deferred[dict, dict, str]: a deferred tuple of
                (creds, params, session_id).

                'creds' contains the authenticated credentials of each stage.

                'params' contains the parameters for this request (which may
                have been given only in a previous call).

                'session_id' is the ID of this session, either passed in by the
                client or assigned by this call

        Raises:
            InteractiveAuthIncompleteError if the client has not yet completed
                all the stages in any of the permitted flows.
        """

        authdict = None
        sid = None
        if clientdict and "auth" in clientdict:
            authdict = clientdict["auth"]
            del clientdict["auth"]
            if "session" in authdict:
                sid = authdict["session"]
        session = self._get_session_info(sid)

        if len(clientdict) > 0:
            # This was designed to allow the client to omit the parameters
            # and just supply the session in subsequent calls so it split
            # auth between devices by just sharing the session, (eg. so you
            # could continue registration from your phone having clicked the
            # email auth link on there). It's probably too open to abuse
            # because it lets unauthenticated clients store arbitrary objects
            # on a homeserver.
            # Revisit: Assumimg the REST APIs do sensible validation, the data
            # isn't arbintrary.
            session["clientdict"] = clientdict
            self._save_session(session)
        elif "clientdict" in session:
            clientdict = session["clientdict"]

        if not authdict:
            raise InteractiveAuthIncompleteError(
                self._auth_dict_for_flows(flows, session))

        if "creds" not in session:
            session["creds"] = {}
        creds = session["creds"]

        # check auth type currently being presented
        errordict = {}
        if "type" in authdict:
            login_type = authdict["type"]
            try:
                result = yield self._check_auth_dict(authdict, clientip)
                if result:
                    creds[login_type] = result
                    self._save_session(session)
            except LoginError as e:
                if login_type == LoginType.EMAIL_IDENTITY:
                    # riot used to have a bug where it would request a new
                    # validation token (thus sending a new email) each time it
                    # got a 401 with a 'flows' field.
                    # (https://github.com/vector-im/vector-web/issues/2447).
                    #
                    # Grandfather in the old behaviour for now to avoid
                    # breaking old riot deployments.
                    raise

                # this step failed. Merge the error dict into the response
                # so that the client can have another go.
                errordict = e.error_dict()

        for f in flows:
            if len(set(f) - set(creds)) == 0:
                # it's very useful to know what args are stored, but this can
                # include the password in the case of registering, so only log
                # the keys (confusingly, clientdict may contain a password
                # param, creds is just what the user authed as for UI auth
                # and is not sensitive).
                logger.info(
                    "Auth completed with creds: %r. Client dict has keys: %r",
                    creds,
                    list(clientdict),
                )
                return creds, clientdict, session["id"]

        ret = self._auth_dict_for_flows(flows, session)
        ret["completed"] = list(creds)
        ret.update(errordict)
        raise InteractiveAuthIncompleteError(ret)

    @defer.inlineCallbacks
    def add_oob_auth(self, stagetype, authdict, clientip):
        """
        Adds the result of out-of-band authentication into an existing auth
        session. Currently used for adding the result of fallback auth.
        """
        if stagetype not in self.checkers:
            raise LoginError(400, "", Codes.MISSING_PARAM)
        if "session" not in authdict:
            raise LoginError(400, "", Codes.MISSING_PARAM)

        sess = self._get_session_info(authdict["session"])
        if "creds" not in sess:
            sess["creds"] = {}
        creds = sess["creds"]

        result = yield self.checkers[stagetype].check_auth(authdict, clientip)
        if result:
            creds[stagetype] = result
            self._save_session(sess)
            return True
        return False

    def get_session_id(self, clientdict):
        """
        Gets the session ID for a client given the client dictionary

        Args:
            clientdict: The dictionary sent by the client in the request

        Returns:
            str|None: The string session ID the client sent. If the client did
                not send a session ID, returns None.
        """
        sid = None
        if clientdict and "auth" in clientdict:
            authdict = clientdict["auth"]
            if "session" in authdict:
                sid = authdict["session"]
        return sid

    def set_session_data(self, session_id, key, value):
        """
        Store a key-value pair into the sessions data associated with this
        request. This data is stored server-side and cannot be modified by
        the client.

        Args:
            session_id (string): The ID of this session as returned from check_auth
            key (string): The key to store the data under
            value (any): The data to store
        """
        sess = self._get_session_info(session_id)
        sess.setdefault("serverdict", {})[key] = value
        self._save_session(sess)

    def get_session_data(self, session_id, key, default=None):
        """
        Retrieve data stored with set_session_data

        Args:
            session_id (string): The ID of this session as returned from check_auth
            key (string): The key to store the data under
            default (any): Value to return if the key has not been set
        """
        sess = self._get_session_info(session_id)
        return sess.setdefault("serverdict", {}).get(key, default)

    @defer.inlineCallbacks
    def _check_auth_dict(self, authdict, clientip):
        """Attempt to validate the auth dict provided by a client

        Args:
            authdict (object): auth dict provided by the client
            clientip (str): IP address of the client

        Returns:
            Deferred: result of the stage verification.

        Raises:
            StoreError if there was a problem accessing the database
            SynapseError if there was a problem with the request
            LoginError if there was an authentication problem.
        """
        login_type = authdict["type"]
        checker = self.checkers.get(login_type)
        if checker is not None:
            res = yield checker.check_auth(authdict, clientip=clientip)
            return res

        # build a v1-login-style dict out of the authdict and fall back to the
        # v1 code
        user_id = authdict.get("user")

        if user_id is None:
            raise SynapseError(400, "", Codes.MISSING_PARAM)

        (canonical_id, callback) = yield self.validate_login(user_id, authdict)
        return canonical_id

    def _get_params_recaptcha(self):
        return {"public_key": self.hs.config.recaptcha_public_key}

    def _get_params_terms(self):
        return {
            "policies": {
                "privacy_policy": {
                    "version": self.hs.config.user_consent_version,
                    "en": {
                        "name":
                        self.hs.config.user_consent_policy_name,
                        "url":
                        "%s_matrix/consent?v=%s" % (
                            self.hs.config.public_baseurl,
                            self.hs.config.user_consent_version,
                        ),
                    },
                }
            }
        }

    def _auth_dict_for_flows(self, flows, session):
        public_flows = []
        for f in flows:
            public_flows.append(f)

        get_params = {
            LoginType.RECAPTCHA: self._get_params_recaptcha,
            LoginType.TERMS: self._get_params_terms,
        }

        params = {}

        for f in public_flows:
            for stage in f:
                if stage in get_params and stage not in params:
                    params[stage] = get_params[stage]()

        return {
            "session": session["id"],
            "flows": [{
                "stages": f
            } for f in public_flows],
            "params": params,
        }

    def _get_session_info(self, session_id):
        if session_id not in self.sessions:
            session_id = None

        if not session_id:
            # create a new session
            while session_id is None or session_id in self.sessions:
                session_id = stringutils.random_string(24)
            self.sessions[session_id] = {"id": session_id}

        return self.sessions[session_id]

    @defer.inlineCallbacks
    def get_access_token_for_user_id(self, user_id, device_id, valid_until_ms):
        """
        Creates a new access token for the user with the given user ID.

        The user is assumed to have been authenticated by some other
        machanism (e.g. CAS), and the user_id converted to the canonical case.

        The device will be recorded in the table if it is not there already.

        Args:
            user_id (str): canonical User ID
            device_id (str|None): the device ID to associate with the tokens.
               None to leave the tokens unassociated with a device (deprecated:
               we should always have a device ID)
            valid_until_ms (int|None): when the token is valid until. None for
                no expiry.
        Returns:
              The access token for the user's session.
        Raises:
            StoreError if there was a problem storing the token.
        """
        fmt_expiry = ""
        if valid_until_ms is not None:
            fmt_expiry = time.strftime(" until %Y-%m-%d %H:%M:%S",
                                       time.localtime(valid_until_ms / 1000.0))
        logger.info("Logging in user %s on device %s%s", user_id, device_id,
                    fmt_expiry)

        yield self.auth.check_auth_blocking(user_id)

        access_token = self.macaroon_gen.generate_access_token(user_id)
        yield self.store.add_access_token_to_user(user_id, access_token,
                                                  device_id, valid_until_ms)

        # the device *should* have been registered before we got here; however,
        # it's possible we raced against a DELETE operation. The thing we
        # really don't want is active access_tokens without a record of the
        # device, so we double-check it here.
        if device_id is not None:
            try:
                yield self.store.get_device(user_id, device_id)
            except StoreError:
                yield self.store.delete_access_token(access_token)
                raise StoreError(400, "Login raced against device deletion")

        return access_token

    @defer.inlineCallbacks
    def check_user_exists(self, user_id):
        """
        Checks to see if a user with the given id exists. Will check case
        insensitively, but return None if there are multiple inexact matches.

        Args:
            (unicode|bytes) user_id: complete @user:id

        Returns:
            defer.Deferred: (unicode) canonical_user_id, or None if zero or
            multiple matches

        Raises:
            UserDeactivatedError if a user is found but is deactivated.
        """
        res = yield self._find_user_id_and_pwd_hash(user_id)
        if res is not None:
            return res[0]
        return None

    @defer.inlineCallbacks
    def _find_user_id_and_pwd_hash(self, user_id):
        """Checks to see if a user with the given id exists. Will check case
        insensitively, but will return None if there are multiple inexact
        matches.

        Returns:
            tuple: A 2-tuple of `(canonical_user_id, password_hash)`
            None: if there is not exactly one match
        """
        user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)

        result = None
        if not user_infos:
            logger.warning("Attempted to login as %s but they do not exist",
                           user_id)
        elif len(user_infos) == 1:
            # a single match (possibly not exact)
            result = user_infos.popitem()
        elif user_id in user_infos:
            # multiple matches, but one is exact
            result = (user_id, user_infos[user_id])
        else:
            # multiple matches, none of them exact
            logger.warning(
                "Attempted to login as %s but it matches more than one user "
                "inexactly: %r",
                user_id,
                user_infos.keys(),
            )
        return result

    def get_supported_login_types(self):
        """Get a the login types supported for the /login API

        By default this is just 'm.login.password' (unless password_enabled is
        False in the config file), but password auth providers can provide
        other login types.

        Returns:
            Iterable[str]: login types
        """
        return self._supported_login_types

    @defer.inlineCallbacks
    def validate_login(self, username, login_submission):
        """Authenticates the user for the /login API

        Also used by the user-interactive auth flow to validate
        m.login.password auth types.

        Args:
            username (str): username supplied by the user
            login_submission (dict): the whole of the login submission
                (including 'type' and other relevant fields)
        Returns:
            Deferred[str, func]: canonical user id, and optional callback
                to be called once the access token and device id are issued
        Raises:
            StoreError if there was a problem accessing the database
            SynapseError if there was a problem with the request
            LoginError if there was an authentication problem.
        """

        if username.startswith("@"):
            qualified_user_id = username
        else:
            qualified_user_id = UserID(username, self.hs.hostname).to_string()

        login_type = login_submission.get("type")
        known_login_type = False

        # special case to check for "password" for the check_password interface
        # for the auth providers
        password = login_submission.get("password")

        if login_type == LoginType.PASSWORD:
            if not self._password_enabled:
                raise SynapseError(400, "Password login has been disabled.")
            if not password:
                raise SynapseError(400, "Missing parameter: password")

        for provider in self.password_providers:
            if hasattr(provider,
                       "check_password") and login_type == LoginType.PASSWORD:
                known_login_type = True
                is_valid = yield provider.check_password(
                    qualified_user_id, password)
                if is_valid:
                    return qualified_user_id, None

            if not hasattr(provider,
                           "get_supported_login_types") or not hasattr(
                               provider, "check_auth"):
                # this password provider doesn't understand custom login types
                continue

            supported_login_types = provider.get_supported_login_types()
            if login_type not in supported_login_types:
                # this password provider doesn't understand this login type
                continue

            known_login_type = True
            login_fields = supported_login_types[login_type]

            missing_fields = []
            login_dict = {}
            for f in login_fields:
                if f not in login_submission:
                    missing_fields.append(f)
                else:
                    login_dict[f] = login_submission[f]
            if missing_fields:
                raise SynapseError(
                    400,
                    "Missing parameters for login type %s: %s" %
                    (login_type, missing_fields),
                )

            result = yield provider.check_auth(username, login_type,
                                               login_dict)
            if result:
                if isinstance(result, str):
                    result = (result, None)
                return result

        if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
            known_login_type = True

            canonical_user_id = yield self._check_local_password(
                qualified_user_id, password)

            if canonical_user_id:
                return canonical_user_id, None

        if not known_login_type:
            raise SynapseError(400, "Unknown login type %s" % login_type)

        # We raise a 403 here, but note that if we're doing user-interactive
        # login, it turns all LoginErrors into a 401 anyway.
        raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)

    @defer.inlineCallbacks
    def check_password_provider_3pid(self, medium, address, password):
        """Check if a password provider is able to validate a thirdparty login

        Args:
            medium (str): The medium of the 3pid (ex. email).
            address (str): The address of the 3pid (ex. [email protected]).
            password (str): The password of the user.

        Returns:
            Deferred[(str|None, func|None)]: A tuple of `(user_id,
            callback)`. If authentication is successful, `user_id` is a `str`
            containing the authenticated, canonical user ID. `callback` is
            then either a function to be later run after the server has
            completed login/registration, or `None`. If authentication was
            unsuccessful, `user_id` and `callback` are both `None`.
        """
        for provider in self.password_providers:
            if hasattr(provider, "check_3pid_auth"):
                # This function is able to return a deferred that either
                # resolves None, meaning authentication failure, or upon
                # success, to a str (which is the user_id) or a tuple of
                # (user_id, callback_func), where callback_func should be run
                # after we've finished everything else
                result = yield provider.check_3pid_auth(
                    medium, address, password)
                if result:
                    # Check if the return value is a str or a tuple
                    if isinstance(result, str):
                        # If it's a str, set callback function to None
                        result = (result, None)
                    return result

        return None, None

    @defer.inlineCallbacks
    def _check_local_password(self, user_id, password):
        """Authenticate a user against the local password database.

        user_id is checked case insensitively, but will return None if there are
        multiple inexact matches.

        Args:
            user_id (unicode): complete @user:id
            password (unicode): the provided password
        Returns:
            Deferred[unicode] the canonical_user_id, or Deferred[None] if
                unknown user/bad password
        """
        lookupres = yield self._find_user_id_and_pwd_hash(user_id)
        if not lookupres:
            return None
        (user_id, password_hash) = lookupres

        # If the password hash is None, the account has likely been deactivated
        if not password_hash:
            deactivated = yield self.store.get_user_deactivated_status(user_id)
            if deactivated:
                raise UserDeactivatedError("This account has been deactivated")

        result = yield self.validate_hash(password, password_hash)
        if not result:
            logger.warning("Failed password login for user %s", user_id)
            return None
        return user_id

    @defer.inlineCallbacks
    def validate_short_term_login_token_and_get_user_id(self, login_token):
        auth_api = self.hs.get_auth()
        user_id = None
        try:
            macaroon = pymacaroons.Macaroon.deserialize(login_token)
            user_id = auth_api.get_user_id_from_macaroon(macaroon)
            auth_api.validate_macaroon(macaroon, "login", user_id)
        except Exception:
            raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)

        yield self.auth.check_auth_blocking(user_id)
        return user_id

    @defer.inlineCallbacks
    def delete_access_token(self, access_token):
        """Invalidate a single access token

        Args:
            access_token (str): access token to be deleted

        Returns:
            Deferred
        """
        user_info = yield self.auth.get_user_by_access_token(access_token)
        yield self.store.delete_access_token(access_token)

        # see if any of our auth providers want to know about this
        for provider in self.password_providers:
            if hasattr(provider, "on_logged_out"):
                yield provider.on_logged_out(
                    user_id=str(user_info["user"]),
                    device_id=user_info["device_id"],
                    access_token=access_token,
                )

        # delete pushers associated with this access token
        if user_info["token_id"] is not None:
            yield self.hs.get_pusherpool().remove_pushers_by_access_token(
                str(user_info["user"]), (user_info["token_id"], ))

    @defer.inlineCallbacks
    def delete_access_tokens_for_user(self,
                                      user_id,
                                      except_token_id=None,
                                      device_id=None):
        """Invalidate access tokens belonging to a user

        Args:
            user_id (str):  ID of user the tokens belong to
            except_token_id (str|None): access_token ID which should *not* be
                deleted
            device_id (str|None):  ID of device the tokens are associated with.
                If None, tokens associated with any device (or no device) will
                be deleted
        Returns:
            Deferred
        """
        tokens_and_devices = yield self.store.user_delete_access_tokens(
            user_id, except_token_id=except_token_id, device_id=device_id)

        # see if any of our auth providers want to know about this
        for provider in self.password_providers:
            if hasattr(provider, "on_logged_out"):
                for token, token_id, device_id in tokens_and_devices:
                    yield provider.on_logged_out(user_id=user_id,
                                                 device_id=device_id,
                                                 access_token=token)

        # delete pushers associated with the access tokens
        yield self.hs.get_pusherpool().remove_pushers_by_access_token(
            user_id, (token_id for _, token_id, _ in tokens_and_devices))

    @defer.inlineCallbacks
    def add_threepid(self, user_id, medium, address, validated_at):
        # 'Canonicalise' email addresses down to lower case.
        # We've now moving towards the homeserver being the entity that
        # is responsible for validating threepids used for resetting passwords
        # on accounts, so in future Synapse will gain knowledge of specific
        # types (mediums) of threepid. For now, we still use the existing
        # infrastructure, but this is the start of synapse gaining knowledge
        # of specific types of threepid (and fixes the fact that checking
        # for the presence of an email address during password reset was
        # case sensitive).
        if medium == "email":
            address = address.lower()

        yield self.store.user_add_threepid(user_id, medium, address,
                                           validated_at,
                                           self.hs.get_clock().time_msec())

    @defer.inlineCallbacks
    def delete_threepid(self, user_id, medium, address, id_server=None):
        """Attempts to unbind the 3pid on the identity servers and deletes it
        from the local database.

        Args:
            user_id (str)
            medium (str)
            address (str)
            id_server (str|None): Use the given identity server when unbinding
                any threepids. If None then will attempt to unbind using the
                identity server specified when binding (if known).


        Returns:
            Deferred[bool]: Returns True if successfully unbound the 3pid on
            the identity server, False if identity server doesn't support the
            unbind API.
        """

        # 'Canonicalise' email addresses as per above
        if medium == "email":
            address = address.lower()

        identity_handler = self.hs.get_handlers().identity_handler
        result = yield identity_handler.try_unbind_threepid(
            user_id, {
                "medium": medium,
                "address": address,
                "id_server": id_server
            })

        yield self.store.user_delete_threepid(user_id, medium, address)
        return result

    def _save_session(self, session):
        # TODO: Persistent storage
        logger.debug("Saving session %s", session)
        session["last_used"] = self.hs.get_clock().time_msec()
        self.sessions[session["id"]] = session

    def hash(self, password):
        """Computes a secure hash of password.

        Args:
            password (unicode): Password to hash.

        Returns:
            Deferred(unicode): Hashed password.
        """
        def _do_hash():
            # Normalise the Unicode in the password
            pw = unicodedata.normalize("NFKC", password)

            return bcrypt.hashpw(
                pw.encode("utf8") +
                self.hs.config.password_pepper.encode("utf8"),
                bcrypt.gensalt(self.bcrypt_rounds),
            ).decode("ascii")

        return defer_to_thread(self.hs.get_reactor(), _do_hash)

    def validate_hash(self, password, stored_hash):
        """Validates that self.hash(password) == stored_hash.

        Args:
            password (unicode): Password to hash.
            stored_hash (bytes): Expected hash value.

        Returns:
            Deferred(bool): Whether self.hash(password) == stored_hash.
        """
        def _do_validate_hash():
            # Normalise the Unicode in the password
            pw = unicodedata.normalize("NFKC", password)

            return bcrypt.checkpw(
                pw.encode("utf8") +
                self.hs.config.password_pepper.encode("utf8"),
                stored_hash,
            )

        if stored_hash:
            if not isinstance(stored_hash, bytes):
                stored_hash = stored_hash.encode("ascii")

            return defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
        else:
            return defer.succeed(False)
Example #7
0
class LoginRestServlet(RestServlet):
    PATTERNS = client_patterns("/login$", v1=True)
    CAS_TYPE = "m.login.cas"
    SSO_TYPE = "m.login.sso"
    TOKEN_TYPE = "m.login.token"
    JWT_TYPE = "m.login.jwt"

    def __init__(self, hs):
        super(LoginRestServlet, self).__init__()
        self.hs = hs
        self.jwt_enabled = hs.config.jwt_enabled
        self.jwt_secret = hs.config.jwt_secret
        self.jwt_algorithm = hs.config.jwt_algorithm
        self.saml2_enabled = hs.config.saml2_enabled
        self.cas_enabled = hs.config.cas_enabled
        self.auth_handler = self.hs.get_auth_handler()
        self.registration_handler = hs.get_registration_handler()
        self.handlers = hs.get_handlers()
        self._clock = hs.get_clock()
        self._well_known_builder = WellKnownBuilder(hs)
        self._address_ratelimiter = Ratelimiter()
        self._account_ratelimiter = Ratelimiter()
        self._failed_attempts_ratelimiter = Ratelimiter()

    def on_GET(self, request):
        flows = []
        if self.jwt_enabled:
            flows.append({"type": LoginRestServlet.JWT_TYPE})
        if self.saml2_enabled:
            flows.append({"type": LoginRestServlet.SSO_TYPE})
            flows.append({"type": LoginRestServlet.TOKEN_TYPE})
        if self.cas_enabled:
            flows.append({"type": LoginRestServlet.SSO_TYPE})

            # we advertise CAS for backwards compat, though MSC1721 renamed it
            # to SSO.
            flows.append({"type": LoginRestServlet.CAS_TYPE})

            # While its valid for us to advertise this login type generally,
            # synapse currently only gives out these tokens as part of the
            # CAS login flow.
            # Generally we don't want to advertise login flows that clients
            # don't know how to implement, since they (currently) will always
            # fall back to the fallback API if they don't understand one of the
            # login flow types returned.
            flows.append({"type": LoginRestServlet.TOKEN_TYPE})

        flows.extend(({
            "type": t
        } for t in self.auth_handler.get_supported_login_types()))

        return 200, {"flows": flows}

    def on_OPTIONS(self, request):
        return 200, {}

    async def on_POST(self, request):
        self._address_ratelimiter.ratelimit(
            request.getClientIP(),
            time_now_s=self.hs.clock.time(),
            rate_hz=self.hs.config.rc_login_address.per_second,
            burst_count=self.hs.config.rc_login_address.burst_count,
            update=True,
        )

        login_submission = parse_json_object_from_request(request)
        try:
            if self.jwt_enabled and (login_submission["type"]
                                     == LoginRestServlet.JWT_TYPE):
                result = await self.do_jwt_login(login_submission)
            elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
                result = await self.do_token_login(login_submission)
            else:
                result = await self._do_other_login(login_submission)
        except KeyError:
            raise SynapseError(400, "Missing JSON keys.")

        well_known_data = self._well_known_builder.get_well_known()
        if well_known_data:
            result["well_known"] = well_known_data
        return 200, result

    async def _do_other_login(self, login_submission):
        """Handle non-token/saml/jwt logins

        Args:
            login_submission:

        Returns:
            dict: HTTP response
        """
        # Log the request we got, but only certain fields to minimise the chance of
        # logging someone's password (even if they accidentally put it in the wrong
        # field)
        logger.info(
            "Got login request with identifier: %r, medium: %r, address: %r, user: %r",
            login_submission.get("identifier"),
            login_submission.get("medium"),
            login_submission.get("address"),
            login_submission.get("user"),
        )
        login_submission_legacy_convert(login_submission)

        if "identifier" not in login_submission:
            raise SynapseError(400, "Missing param: identifier")

        identifier = login_submission["identifier"]
        if "type" not in identifier:
            raise SynapseError(400, "Login identifier has no type")

        # convert phone type identifiers to generic threepids
        if identifier["type"] == "m.id.phone":
            identifier = login_id_thirdparty_from_phone(identifier)

        # convert threepid identifiers to user IDs
        if identifier["type"] == "m.id.thirdparty":
            address = identifier.get("address")
            medium = identifier.get("medium")

            if medium is None or address is None:
                raise SynapseError(400, "Invalid thirdparty identifier")

            if medium == "email":
                # For emails, transform the address to lowercase.
                # We store all email addreses as lowercase in the DB.
                # (See add_threepid in synapse/handlers/auth.py)
                address = address.lower()

            # We also apply account rate limiting using the 3PID as a key, as
            # otherwise using 3PID bypasses the ratelimiting based on user ID.
            self._failed_attempts_ratelimiter.ratelimit(
                (medium, address),
                time_now_s=self._clock.time(),
                rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
                burst_count=self.hs.config.rc_login_failed_attempts.
                burst_count,
                update=False,
            )

            # Check for login providers that support 3pid login types
            (
                canonical_user_id,
                callback_3pid,
            ) = await self.auth_handler.check_password_provider_3pid(
                medium, address, login_submission["password"])
            if canonical_user_id:
                # Authentication through password provider and 3pid succeeded

                result = await self._complete_login(canonical_user_id,
                                                    login_submission,
                                                    callback_3pid)
                return result

            # No password providers were able to handle this 3pid
            # Check local store
            user_id = await self.hs.get_datastore().get_user_id_by_threepid(
                medium, address)
            if not user_id:
                logger.warning("unknown 3pid identifier medium %s, address %r",
                               medium, address)
                # We mark that we've failed to log in here, as
                # `check_password_provider_3pid` might have returned `None` due
                # to an incorrect password, rather than the account not
                # existing.
                #
                # If it returned None but the 3PID was bound then we won't hit
                # this code path, which is fine as then the per-user ratelimit
                # will kick in below.
                self._failed_attempts_ratelimiter.can_do_action(
                    (medium, address),
                    time_now_s=self._clock.time(),
                    rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
                    burst_count=self.hs.config.rc_login_failed_attempts.
                    burst_count,
                    update=True,
                )
                raise LoginError(403, "", errcode=Codes.FORBIDDEN)

            identifier = {"type": "m.id.user", "user": user_id}

        # by this point, the identifier should be an m.id.user: if it's anything
        # else, we haven't understood it.
        if identifier["type"] != "m.id.user":
            raise SynapseError(400, "Unknown login identifier type")
        if "user" not in identifier:
            raise SynapseError(400, "User identifier is missing 'user' key")

        if identifier["user"].startswith("@"):
            qualified_user_id = identifier["user"]
        else:
            qualified_user_id = UserID(identifier["user"],
                                       self.hs.hostname).to_string()

        # Check if we've hit the failed ratelimit (but don't update it)
        self._failed_attempts_ratelimiter.ratelimit(
            qualified_user_id.lower(),
            time_now_s=self._clock.time(),
            rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
            burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
            update=False,
        )

        try:
            canonical_user_id, callback = await self.auth_handler.validate_login(
                identifier["user"], login_submission)
        except LoginError:
            # The user has failed to log in, so we need to update the rate
            # limiter. Using `can_do_action` avoids us raising a ratelimit
            # exception and masking the LoginError. The actual ratelimiting
            # should have happened above.
            self._failed_attempts_ratelimiter.can_do_action(
                qualified_user_id.lower(),
                time_now_s=self._clock.time(),
                rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
                burst_count=self.hs.config.rc_login_failed_attempts.
                burst_count,
                update=True,
            )
            raise

        result = await self._complete_login(canonical_user_id,
                                            login_submission, callback)
        return result

    async def _complete_login(self,
                              user_id,
                              login_submission,
                              callback=None,
                              create_non_existant_users=False):
        """Called when we've successfully authed the user and now need to
        actually login them in (e.g. create devices). This gets called on
        all succesful logins.

        Applies the ratelimiting for succesful login attempts against an
        account.

        Args:
            user_id (str): ID of the user to register.
            login_submission (dict): Dictionary of login information.
            callback (func|None): Callback function to run after registration.
            create_non_existant_users (bool): Whether to create the user if
                they don't exist. Defaults to False.

        Returns:
            result (Dict[str,str]): Dictionary of account information after
                successful registration.
        """

        # Before we actually log them in we check if they've already logged in
        # too often. This happens here rather than before as we don't
        # necessarily know the user before now.
        self._account_ratelimiter.ratelimit(
            user_id.lower(),
            time_now_s=self._clock.time(),
            rate_hz=self.hs.config.rc_login_account.per_second,
            burst_count=self.hs.config.rc_login_account.burst_count,
            update=True,
        )

        if create_non_existant_users:
            user_id = await self.auth_handler.check_user_exists(user_id)
            if not user_id:
                user_id = await self.registration_handler.register_user(
                    localpart=UserID.from_string(user_id).localpart)

        device_id = login_submission.get("device_id")
        initial_display_name = login_submission.get(
            "initial_device_display_name")
        device_id, access_token = await self.registration_handler.register_device(
            user_id, device_id, initial_display_name)

        result = {
            "user_id": user_id,
            "access_token": access_token,
            "home_server": self.hs.hostname,
            "device_id": device_id,
        }

        if callback is not None:
            await callback(result)

        return result

    async def do_token_login(self, login_submission):
        token = login_submission["token"]
        auth_handler = self.auth_handler
        user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
            token)

        result = await self._complete_login(user_id, login_submission)
        return result

    async def do_jwt_login(self, login_submission):
        token = login_submission.get("token", None)
        if token is None:
            raise LoginError(401,
                             "Token field for JWT is missing",
                             errcode=Codes.UNAUTHORIZED)

        import jwt
        from jwt.exceptions import InvalidTokenError

        try:
            payload = jwt.decode(token,
                                 self.jwt_secret,
                                 algorithms=[self.jwt_algorithm])
        except jwt.ExpiredSignatureError:
            raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
        except InvalidTokenError:
            raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)

        user = payload.get("sub", None)
        if user is None:
            raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)

        user_id = UserID(user, self.hs.hostname).to_string()
        result = await self._complete_login(user_id,
                                            login_submission,
                                            create_non_existant_users=True)
        return result
Example #8
0
class AuthHandler(BaseHandler):
    SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000

    def __init__(self, hs):
        """
        Args:
            hs (synapse.server.HomeServer):
        """
        super(AuthHandler, self).__init__(hs)

        self.checkers = {}  # type: Dict[str, UserInteractiveAuthChecker]
        for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
            inst = auth_checker_class(hs)
            if inst.is_enabled():
                self.checkers[inst.AUTH_TYPE] = inst  # type: ignore

        self.bcrypt_rounds = hs.config.bcrypt_rounds

        account_handler = ModuleApi(hs, self)
        self.password_providers = [
            module(config=config, account_handler=account_handler)
            for module, config in hs.config.password_providers
        ]

        logger.info("Extra password_providers: %r", self.password_providers)

        self.hs = hs  # FIXME better possibility to access registrationHandler later?
        self.macaroon_gen = hs.get_macaroon_generator()
        self._password_enabled = hs.config.password_enabled
        self._sso_enabled = (hs.config.cas_enabled or hs.config.saml2_enabled
                             or hs.config.oidc_enabled)

        # we keep this as a list despite the O(N^2) implication so that we can
        # keep PASSWORD first and avoid confusing clients which pick the first
        # type in the list. (NB that the spec doesn't require us to do so and
        # clients which favour types that they don't understand over those that
        # they do are technically broken)
        login_types = []
        if self._password_enabled:
            login_types.append(LoginType.PASSWORD)
        for provider in self.password_providers:
            if hasattr(provider, "get_supported_login_types"):
                for t in provider.get_supported_login_types().keys():
                    if t not in login_types:
                        login_types.append(t)
        self._supported_login_types = login_types
        # Login types and UI Auth types have a heavy overlap, but are not
        # necessarily identical. Login types have SSO (and other login types)
        # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
        ui_auth_types = login_types.copy()
        if self._sso_enabled:
            ui_auth_types.append(LoginType.SSO)
        self._supported_ui_auth_types = ui_auth_types

        # Ratelimiter for failed auth during UIA. Uses same ratelimit config
        # as per `rc_login.failed_attempts`.
        self._failed_uia_attempts_ratelimiter = Ratelimiter()

        self._clock = self.hs.get_clock()

        # Expire old UI auth sessions after a period of time.
        if hs.config.worker_app is None:
            self._clock.looping_call(
                run_as_background_process,
                5 * 60 * 1000,
                "expire_old_sessions",
                self._expire_old_sessions,
            )

        # Load the SSO HTML templates.

        # The following template is shown to the user during a client login via SSO,
        # after the SSO completes and before redirecting them back to their client.
        # It notifies the user they are about to give access to their matrix account
        # to the client.
        self._sso_redirect_confirm_template = load_jinja2_templates(
            hs.config.sso_template_dir,
            ["sso_redirect_confirm.html"],
        )[0]
        # The following template is shown during user interactive authentication
        # in the fallback auth scenario. It notifies the user that they are
        # authenticating for an operation to occur on their account.
        self._sso_auth_confirm_template = load_jinja2_templates(
            hs.config.sso_template_dir,
            ["sso_auth_confirm.html"],
        )[0]
        # The following template is shown after a successful user interactive
        # authentication session. It tells the user they can close the window.
        self._sso_auth_success_template = hs.config.sso_auth_success_template
        # The following template is shown during the SSO authentication process if
        # the account is deactivated.
        self._sso_account_deactivated_template = (
            hs.config.sso_account_deactivated_template)

        self._server_name = hs.config.server_name

        # cast to tuple for use with str.startswith
        self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)

    async def validate_user_via_ui_auth(
        self,
        requester: Requester,
        request: SynapseRequest,
        request_body: Dict[str, Any],
        clientip: str,
        description: str,
    ) -> dict:
        """
        Checks that the user is who they claim to be, via a UI auth.

        This is used for things like device deletion and password reset where
        the user already has a valid access token, but we want to double-check
        that it isn't stolen by re-authenticating them.

        Args:
            requester: The user, as given by the access token

            request: The request sent by the client.

            request_body: The body of the request sent by the client

            clientip: The IP address of the client.

            description: A human readable string to be displayed to the user that
                         describes the operation happening on their account.

        Returns:
            The parameters for this request (which may
                have been given only in a previous call).

        Raises:
            InteractiveAuthIncompleteError if the client has not yet completed
                any of the permitted login flows

            AuthError if the client has completed a login flow, and it gives
                a different user to `requester`

            LimitExceededError if the ratelimiter's failed request count for this
                user is too high to proceed

        """

        user_id = requester.user.to_string()

        # Check if we should be ratelimited due to too many previous failed attempts
        self._failed_uia_attempts_ratelimiter.ratelimit(
            user_id,
            time_now_s=self._clock.time(),
            rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
            burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
            update=False,
        )

        # build a list of supported flows
        flows = [[login_type] for login_type in self._supported_ui_auth_types]

        try:
            result, params, _ = await self.check_auth(flows, request,
                                                      request_body, clientip,
                                                      description)
        except LoginError:
            # Update the ratelimite to say we failed (`can_do_action` doesn't raise).
            self._failed_uia_attempts_ratelimiter.can_do_action(
                user_id,
                time_now_s=self._clock.time(),
                rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
                burst_count=self.hs.config.rc_login_failed_attempts.
                burst_count,
                update=True,
            )
            raise

        # find the completed login type
        for login_type in self._supported_ui_auth_types:
            if login_type not in result:
                continue

            user_id = result[login_type]
            break
        else:
            # this can't happen
            raise Exception(
                "check_auth returned True but no successful login type")

        # check that the UI auth matched the access token
        if user_id != requester.user.to_string():
            raise AuthError(403, "Invalid auth")

        return params

    def get_enabled_auth_types(self):
        """Return the enabled user-interactive authentication types

        Returns the UI-Auth types which are supported by the homeserver's current
        config.
        """
        return self.checkers.keys()

    async def check_auth(
        self,
        flows: List[List[str]],
        request: SynapseRequest,
        clientdict: Dict[str, Any],
        clientip: str,
        description: str,
    ) -> Tuple[dict, dict, str]:
        """
        Takes a dictionary sent by the client in the login / registration
        protocol and handles the User-Interactive Auth flow.

        If no auth flows have been completed successfully, raises an
        InteractiveAuthIncompleteError. To handle this, you can use
        synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
        decorator.

        Args:
            flows: A list of login flows. Each flow is an ordered list of
                   strings representing auth-types. At least one full
                   flow must be completed in order for auth to be successful.

            request: The request sent by the client.

            clientdict: The dictionary from the client root level, not the
                        'auth' key: this method prompts for auth if none is sent.

            clientip: The IP address of the client.

            description: A human readable string to be displayed to the user that
                         describes the operation happening on their account.

        Returns:
            A tuple of (creds, params, session_id).

                'creds' contains the authenticated credentials of each stage.

                'params' contains the parameters for this request (which may
                have been given only in a previous call).

                'session_id' is the ID of this session, either passed in by the
                client or assigned by this call

        Raises:
            InteractiveAuthIncompleteError if the client has not yet completed
                all the stages in any of the permitted flows.
        """

        authdict = None
        sid = None  # type: Optional[str]
        if clientdict and "auth" in clientdict:
            authdict = clientdict["auth"]
            del clientdict["auth"]
            if "session" in authdict:
                sid = authdict["session"]

        # Convert the URI and method to strings.
        uri = request.uri.decode("utf-8")
        method = request.uri.decode("utf-8")

        # If there's no session ID, create a new session.
        if not sid:
            session = await self.store.create_ui_auth_session(
                clientdict, uri, method, description)

        else:
            try:
                session = await self.store.get_ui_auth_session(sid)
            except StoreError:
                raise SynapseError(400, "Unknown session ID: %s" % (sid, ))

            # If the client provides parameters, update what is persisted,
            # otherwise use whatever was last provided.
            #
            # This was designed to allow the client to omit the parameters
            # and just supply the session in subsequent calls so it split
            # auth between devices by just sharing the session, (eg. so you
            # could continue registration from your phone having clicked the
            # email auth link on there). It's probably too open to abuse
            # because it lets unauthenticated clients store arbitrary objects
            # on a homeserver.
            #
            # Revisit: Assuming the REST APIs do sensible validation, the data
            # isn't arbitrary.
            #
            # Note that the registration endpoint explicitly removes the
            # "initial_device_display_name" parameter if it is provided
            # without a "password" parameter. See the changes to
            # synapse.rest.client.v2_alpha.register.RegisterRestServlet.on_POST
            # in commit 544722bad23fc31056b9240189c3cbbbf0ffd3f9.
            if not clientdict:
                clientdict = session.clientdict

            # Ensure that the queried operation does not vary between stages of
            # the UI authentication session. This is done by generating a stable
            # comparator and storing it during the initial query. Subsequent
            # queries ensure that this comparator has not changed.
            #
            # The comparator is based on the requested URI and HTTP method. The
            # client dict (minus the auth dict) should also be checked, but some
            # clients are not spec compliant, just warn for now if the client
            # dict changes.
            if (session.uri, session.method) != (uri, method):
                raise SynapseError(
                    403,
                    "Requested operation has changed during the UI authentication session.",
                )

            if session.clientdict != clientdict:
                logger.warning(
                    "Requested operation has changed during the UI "
                    "authentication session. A future version of Synapse "
                    "will remove this capability.")

            # For backwards compatibility, changes to the client dict are
            # persisted as clients modify them throughout their user interactive
            # authentication flow.
            await self.store.set_ui_auth_clientdict(sid, clientdict)

        if not authdict:
            raise InteractiveAuthIncompleteError(
                self._auth_dict_for_flows(flows, session.session_id))

        # check auth type currently being presented
        errordict = {}  # type: Dict[str, Any]
        if "type" in authdict:
            login_type = authdict["type"]  # type: str
            try:
                result = await self._check_auth_dict(authdict, clientip)
                if result:
                    await self.store.mark_ui_auth_stage_complete(
                        session.session_id, login_type, result)
            except LoginError as e:
                if login_type == LoginType.EMAIL_IDENTITY:
                    # riot used to have a bug where it would request a new
                    # validation token (thus sending a new email) each time it
                    # got a 401 with a 'flows' field.
                    # (https://github.com/vector-im/vector-web/issues/2447).
                    #
                    # Grandfather in the old behaviour for now to avoid
                    # breaking old riot deployments.
                    raise

                # this step failed. Merge the error dict into the response
                # so that the client can have another go.
                errordict = e.error_dict()

        creds = await self.store.get_completed_ui_auth_stages(
            session.session_id)
        for f in flows:
            if len(set(f) - set(creds)) == 0:
                # it's very useful to know what args are stored, but this can
                # include the password in the case of registering, so only log
                # the keys (confusingly, clientdict may contain a password
                # param, creds is just what the user authed as for UI auth
                # and is not sensitive).
                logger.info(
                    "Auth completed with creds: %r. Client dict has keys: %r",
                    creds,
                    list(clientdict),
                )

                return creds, clientdict, session.session_id

        ret = self._auth_dict_for_flows(flows, session.session_id)
        ret["completed"] = list(creds)
        ret.update(errordict)
        raise InteractiveAuthIncompleteError(ret)

    async def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any],
                           clientip: str) -> bool:
        """
        Adds the result of out-of-band authentication into an existing auth
        session. Currently used for adding the result of fallback auth.
        """
        if stagetype not in self.checkers:
            raise LoginError(400, "", Codes.MISSING_PARAM)
        if "session" not in authdict:
            raise LoginError(400, "", Codes.MISSING_PARAM)

        result = await self.checkers[stagetype].check_auth(authdict, clientip)
        if result:
            await self.store.mark_ui_auth_stage_complete(
                authdict["session"], stagetype, result)
            return True
        return False

    def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
        """
        Gets the session ID for a client given the client dictionary

        Args:
            clientdict: The dictionary sent by the client in the request

        Returns:
            The string session ID the client sent. If the client did
                not send a session ID, returns None.
        """
        sid = None
        if clientdict and "auth" in clientdict:
            authdict = clientdict["auth"]
            if "session" in authdict:
                sid = authdict["session"]
        return sid

    async def set_session_data(self, session_id: str, key: str,
                               value: Any) -> None:
        """
        Store a key-value pair into the sessions data associated with this
        request. This data is stored server-side and cannot be modified by
        the client.

        Args:
            session_id: The ID of this session as returned from check_auth
            key: The key to store the data under
            value: The data to store
        """
        try:
            await self.store.set_ui_auth_session_data(session_id, key, value)
        except StoreError:
            raise SynapseError(400, "Unknown session ID: %s" % (session_id, ))

    async def get_session_data(self,
                               session_id: str,
                               key: str,
                               default: Optional[Any] = None) -> Any:
        """
        Retrieve data stored with set_session_data

        Args:
            session_id: The ID of this session as returned from check_auth
            key: The key to store the data under
            default: Value to return if the key has not been set
        """
        try:
            return await self.store.get_ui_auth_session_data(
                session_id, key, default)
        except StoreError:
            raise SynapseError(400, "Unknown session ID: %s" % (session_id, ))

    async def _expire_old_sessions(self):
        """
        Invalidate any user interactive authentication sessions that have expired.
        """
        now = self._clock.time_msec()
        expiration_time = now - self.SESSION_EXPIRE_MS
        await self.store.delete_old_ui_auth_sessions(expiration_time)

    async def _check_auth_dict(self, authdict: Dict[str, Any],
                               clientip: str) -> Union[Dict[str, Any], str]:
        """Attempt to validate the auth dict provided by a client

        Args:
            authdict: auth dict provided by the client
            clientip: IP address of the client

        Returns:
            Result of the stage verification.

        Raises:
            StoreError if there was a problem accessing the database
            SynapseError if there was a problem with the request
            LoginError if there was an authentication problem.
        """
        login_type = authdict["type"]
        checker = self.checkers.get(login_type)
        if checker is not None:
            res = await checker.check_auth(authdict, clientip=clientip)
            return res

        # build a v1-login-style dict out of the authdict and fall back to the
        # v1 code
        user_id = authdict.get("user")

        if user_id is None:
            raise SynapseError(400, "", Codes.MISSING_PARAM)

        (canonical_id, callback) = await self.validate_login(user_id, authdict)
        return canonical_id

    def _get_params_recaptcha(self) -> dict:
        return {"public_key": self.hs.config.recaptcha_public_key}

    def _get_params_terms(self) -> dict:
        return {
            "policies": {
                "privacy_policy": {
                    "version": self.hs.config.user_consent_version,
                    "en": {
                        "name":
                        self.hs.config.user_consent_policy_name,
                        "url":
                        "%s_matrix/consent?v=%s" % (
                            self.hs.config.public_baseurl,
                            self.hs.config.user_consent_version,
                        ),
                    },
                }
            }
        }

    def _auth_dict_for_flows(
        self,
        flows: List[List[str]],
        session_id: str,
    ) -> Dict[str, Any]:
        public_flows = []
        for f in flows:
            public_flows.append(f)

        get_params = {
            LoginType.RECAPTCHA: self._get_params_recaptcha,
            LoginType.TERMS: self._get_params_terms,
        }

        params = {}  # type: Dict[str, Any]

        for f in public_flows:
            for stage in f:
                if stage in get_params and stage not in params:
                    params[stage] = get_params[stage]()

        return {
            "session": session_id,
            "flows": [{
                "stages": f
            } for f in public_flows],
            "params": params,
        }

    async def get_access_token_for_user_id(self, user_id: str,
                                           device_id: Optional[str],
                                           valid_until_ms: Optional[int]):
        """
        Creates a new access token for the user with the given user ID.

        The user is assumed to have been authenticated by some other
        machanism (e.g. CAS), and the user_id converted to the canonical case.

        The device will be recorded in the table if it is not there already.

        Args:
            user_id: canonical User ID
            device_id: the device ID to associate with the tokens.
               None to leave the tokens unassociated with a device (deprecated:
               we should always have a device ID)
            valid_until_ms: when the token is valid until. None for
                no expiry.
        Returns:
              The access token for the user's session.
        Raises:
            StoreError if there was a problem storing the token.
        """
        fmt_expiry = ""
        if valid_until_ms is not None:
            fmt_expiry = time.strftime(" until %Y-%m-%d %H:%M:%S",
                                       time.localtime(valid_until_ms / 1000.0))
        logger.info("Logging in user %s on device %s%s", user_id, device_id,
                    fmt_expiry)

        await self.auth.check_auth_blocking(user_id)

        access_token = self.macaroon_gen.generate_access_token(user_id)
        await self.store.add_access_token_to_user(user_id, access_token,
                                                  device_id, valid_until_ms)

        # the device *should* have been registered before we got here; however,
        # it's possible we raced against a DELETE operation. The thing we
        # really don't want is active access_tokens without a record of the
        # device, so we double-check it here.
        if device_id is not None:
            try:
                await self.store.get_device(user_id, device_id)
            except StoreError:
                await self.store.delete_access_token(access_token)
                raise StoreError(400, "Login raced against device deletion")

        return access_token

    async def check_user_exists(self, user_id: str) -> Optional[str]:
        """
        Checks to see if a user with the given id exists. Will check case
        insensitively, but return None if there are multiple inexact matches.

        Args:
            user_id: complete @user:id

        Returns:
            The canonical_user_id, or None if zero or multiple matches
        """
        res = await self._find_user_id_and_pwd_hash(user_id)
        if res is not None:
            return res[0]
        return None

    async def _find_user_id_and_pwd_hash(
            self, user_id: str) -> Optional[Tuple[str, str]]:
        """Checks to see if a user with the given id exists. Will check case
        insensitively, but will return None if there are multiple inexact
        matches.

        Returns:
            A 2-tuple of `(canonical_user_id, password_hash)` or `None`
            if there is not exactly one match
        """
        user_infos = await self.store.get_users_by_id_case_insensitive(user_id)

        result = None
        if not user_infos:
            logger.warning("Attempted to login as %s but they do not exist",
                           user_id)
        elif len(user_infos) == 1:
            # a single match (possibly not exact)
            result = user_infos.popitem()
        elif user_id in user_infos:
            # multiple matches, but one is exact
            result = (user_id, user_infos[user_id])
        else:
            # multiple matches, none of them exact
            logger.warning(
                "Attempted to login as %s but it matches more than one user "
                "inexactly: %r",
                user_id,
                user_infos.keys(),
            )
        return result

    def get_supported_login_types(self) -> Iterable[str]:
        """Get a the login types supported for the /login API

        By default this is just 'm.login.password' (unless password_enabled is
        False in the config file), but password auth providers can provide
        other login types.

        Returns:
            login types
        """
        return self._supported_login_types

    async def validate_login(
        self, username: str, login_submission: Dict[str, Any]
    ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
        """Authenticates the user for the /login API

        Also used by the user-interactive auth flow to validate
        m.login.password auth types.

        Args:
            username: username supplied by the user
            login_submission: the whole of the login submission
                (including 'type' and other relevant fields)
        Returns:
            A tuple of the canonical user id, and optional callback
                to be called once the access token and device id are issued
        Raises:
            StoreError if there was a problem accessing the database
            SynapseError if there was a problem with the request
            LoginError if there was an authentication problem.
        """

        if username.startswith("@"):
            qualified_user_id = username
        else:
            qualified_user_id = UserID(username, self.hs.hostname).to_string()

        login_type = login_submission.get("type")
        known_login_type = False

        # special case to check for "password" for the check_password interface
        # for the auth providers
        password = login_submission.get("password")

        if login_type == LoginType.PASSWORD:
            if not self._password_enabled:
                raise SynapseError(400, "Password login has been disabled.")
            if not password:
                raise SynapseError(400, "Missing parameter: password")

        for provider in self.password_providers:
            if hasattr(provider,
                       "check_password") and login_type == LoginType.PASSWORD:
                known_login_type = True
                is_valid = await provider.check_password(
                    qualified_user_id, password)
                if is_valid:
                    return qualified_user_id, None

            if not hasattr(provider,
                           "get_supported_login_types") or not hasattr(
                               provider, "check_auth"):
                # this password provider doesn't understand custom login types
                continue

            supported_login_types = provider.get_supported_login_types()
            if login_type not in supported_login_types:
                # this password provider doesn't understand this login type
                continue

            known_login_type = True
            login_fields = supported_login_types[login_type]

            missing_fields = []
            login_dict = {}
            for f in login_fields:
                if f not in login_submission:
                    missing_fields.append(f)
                else:
                    login_dict[f] = login_submission[f]
            if missing_fields:
                raise SynapseError(
                    400,
                    "Missing parameters for login type %s: %s" %
                    (login_type, missing_fields),
                )

            result = await provider.check_auth(username, login_type,
                                               login_dict)
            if result:
                if isinstance(result, str):
                    result = (result, None)
                return result

        if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
            known_login_type = True

            canonical_user_id = await self._check_local_password(
                qualified_user_id,
                password  # type: ignore
            )

            if canonical_user_id:
                return canonical_user_id, None

        if not known_login_type:
            raise SynapseError(400, "Unknown login type %s" % login_type)

        # We raise a 403 here, but note that if we're doing user-interactive
        # login, it turns all LoginErrors into a 401 anyway.
        raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)

    async def check_password_provider_3pid(
        self, medium: str, address: str, password: str
    ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]:
        """Check if a password provider is able to validate a thirdparty login

        Args:
            medium: The medium of the 3pid (ex. email).
            address: The address of the 3pid (ex. [email protected]).
            password: The password of the user.

        Returns:
            A tuple of `(user_id, callback)`. If authentication is successful,
            `user_id`is the authenticated, canonical user ID. `callback` is
            then either a function to be later run after the server has
            completed login/registration, or `None`. If authentication was
            unsuccessful, `user_id` and `callback` are both `None`.
        """
        for provider in self.password_providers:
            if hasattr(provider, "check_3pid_auth"):
                # This function is able to return a deferred that either
                # resolves None, meaning authentication failure, or upon
                # success, to a str (which is the user_id) or a tuple of
                # (user_id, callback_func), where callback_func should be run
                # after we've finished everything else
                result = await provider.check_3pid_auth(
                    medium, address, password)
                if result:
                    # Check if the return value is a str or a tuple
                    if isinstance(result, str):
                        # If it's a str, set callback function to None
                        result = (result, None)
                    return result

        return None, None

    async def _check_local_password(self, user_id: str,
                                    password: str) -> Optional[str]:
        """Authenticate a user against the local password database.

        user_id is checked case insensitively, but will return None if there are
        multiple inexact matches.

        Args:
            user_id: complete @user:id
            password: the provided password
        Returns:
            The canonical_user_id, or None if unknown user/bad password
        """
        lookupres = await self._find_user_id_and_pwd_hash(user_id)
        if not lookupres:
            return None
        (user_id, password_hash) = lookupres

        # If the password hash is None, the account has likely been deactivated
        if not password_hash:
            deactivated = await self.store.get_user_deactivated_status(user_id)
            if deactivated:
                raise UserDeactivatedError("This account has been deactivated")

        result = await self.validate_hash(password, password_hash)
        if not result:
            logger.warning("Failed password login for user %s", user_id)
            return None
        return user_id

    async def validate_short_term_login_token_and_get_user_id(
            self, login_token: str):
        auth_api = self.hs.get_auth()
        user_id = None
        try:
            macaroon = pymacaroons.Macaroon.deserialize(login_token)
            user_id = auth_api.get_user_id_from_macaroon(macaroon)
            auth_api.validate_macaroon(macaroon, "login", user_id)
        except Exception:
            raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)

        await self.auth.check_auth_blocking(user_id)
        return user_id

    async def delete_access_token(self, access_token: str):
        """Invalidate a single access token

        Args:
            access_token: access token to be deleted

        """
        user_info = await self.auth.get_user_by_access_token(access_token)
        await self.store.delete_access_token(access_token)

        # see if any of our auth providers want to know about this
        for provider in self.password_providers:
            if hasattr(provider, "on_logged_out"):
                await provider.on_logged_out(
                    user_id=str(user_info["user"]),
                    device_id=user_info["device_id"],
                    access_token=access_token,
                )

        # delete pushers associated with this access token
        if user_info["token_id"] is not None:
            await self.hs.get_pusherpool().remove_pushers_by_access_token(
                str(user_info["user"]), (user_info["token_id"], ))

    async def delete_access_tokens_for_user(
        self,
        user_id: str,
        except_token_id: Optional[str] = None,
        device_id: Optional[str] = None,
    ):
        """Invalidate access tokens belonging to a user

        Args:
            user_id:  ID of user the tokens belong to
            except_token_id: access_token ID which should *not* be deleted
            device_id:  ID of device the tokens are associated with.
                If None, tokens associated with any device (or no device) will
                be deleted
        """
        tokens_and_devices = await self.store.user_delete_access_tokens(
            user_id, except_token_id=except_token_id, device_id=device_id)

        # see if any of our auth providers want to know about this
        for provider in self.password_providers:
            if hasattr(provider, "on_logged_out"):
                for token, token_id, device_id in tokens_and_devices:
                    await provider.on_logged_out(user_id=user_id,
                                                 device_id=device_id,
                                                 access_token=token)

        # delete pushers associated with the access tokens
        await self.hs.get_pusherpool().remove_pushers_by_access_token(
            user_id, (token_id for _, token_id, _ in tokens_and_devices))

    async def add_threepid(self, user_id: str, medium: str, address: str,
                           validated_at: int):
        # check if medium has a valid value
        if medium not in ["email", "msisdn"]:
            raise SynapseError(
                code=400,
                msg=("'%s' is not a valid value for 'medium'" % (medium, )),
                errcode=Codes.INVALID_PARAM,
            )

        # 'Canonicalise' email addresses down to lower case.
        # We've now moving towards the homeserver being the entity that
        # is responsible for validating threepids used for resetting passwords
        # on accounts, so in future Synapse will gain knowledge of specific
        # types (mediums) of threepid. For now, we still use the existing
        # infrastructure, but this is the start of synapse gaining knowledge
        # of specific types of threepid (and fixes the fact that checking
        # for the presence of an email address during password reset was
        # case sensitive).
        if medium == "email":
            address = address.lower()

        await self.store.user_add_threepid(user_id, medium, address,
                                           validated_at,
                                           self.hs.get_clock().time_msec())

    async def delete_threepid(self,
                              user_id: str,
                              medium: str,
                              address: str,
                              id_server: Optional[str] = None) -> bool:
        """Attempts to unbind the 3pid on the identity servers and deletes it
        from the local database.

        Args:
            user_id: ID of user to remove the 3pid from.
            medium: The medium of the 3pid being removed: "email" or "msisdn".
            address: The 3pid address to remove.
            id_server: Use the given identity server when unbinding
                any threepids. If None then will attempt to unbind using the
                identity server specified when binding (if known).

        Returns:
            Returns True if successfully unbound the 3pid on
            the identity server, False if identity server doesn't support the
            unbind API.
        """

        # 'Canonicalise' email addresses as per above
        if medium == "email":
            address = address.lower()

        identity_handler = self.hs.get_handlers().identity_handler
        result = await identity_handler.try_unbind_threepid(
            user_id, {
                "medium": medium,
                "address": address,
                "id_server": id_server
            })

        await self.store.user_delete_threepid(user_id, medium, address)
        return result

    async def hash(self, password: str) -> str:
        """Computes a secure hash of password.

        Args:
            password: Password to hash.

        Returns:
            Hashed password.
        """
        def _do_hash():
            # Normalise the Unicode in the password
            pw = unicodedata.normalize("NFKC", password)

            return bcrypt.hashpw(
                pw.encode("utf8") +
                self.hs.config.password_pepper.encode("utf8"),
                bcrypt.gensalt(self.bcrypt_rounds),
            ).decode("ascii")

        return await defer_to_thread(self.hs.get_reactor(), _do_hash)

    async def validate_hash(self, password: str,
                            stored_hash: Union[bytes, str]) -> bool:
        """Validates that self.hash(password) == stored_hash.

        Args:
            password: Password to hash.
            stored_hash: Expected hash value.

        Returns:
            Whether self.hash(password) == stored_hash.
        """
        def _do_validate_hash():
            # Normalise the Unicode in the password
            pw = unicodedata.normalize("NFKC", password)

            return bcrypt.checkpw(
                pw.encode("utf8") +
                self.hs.config.password_pepper.encode("utf8"),
                stored_hash,
            )

        if stored_hash:
            if not isinstance(stored_hash, bytes):
                stored_hash = stored_hash.encode("ascii")

            return await defer_to_thread(self.hs.get_reactor(),
                                         _do_validate_hash)
        else:
            return False

    async def start_sso_ui_auth(self, redirect_url: str,
                                session_id: str) -> str:
        """
        Get the HTML for the SSO redirect confirmation page.

        Args:
            redirect_url: The URL to redirect to the SSO provider.
            session_id: The user interactive authentication session ID.

        Returns:
            The HTML to render.
        """
        try:
            session = await self.store.get_ui_auth_session(session_id)
        except StoreError:
            raise SynapseError(400, "Unknown session ID: %s" % (session_id, ))
        return self._sso_auth_confirm_template.render(
            description=session.description,
            redirect_url=redirect_url,
        )

    async def complete_sso_ui_auth(
        self,
        registered_user_id: str,
        session_id: str,
        request: SynapseRequest,
    ):
        """Having figured out a mxid for this user, complete the HTTP request

        Args:
            registered_user_id: The registered user ID to complete SSO login for.
            request: The request to complete.
            client_redirect_url: The URL to which to redirect the user at the end of the
                process.
        """
        # Mark the stage of the authentication as successful.
        # Save the user who authenticated with SSO, this will be used to ensure
        # that the account be modified is also the person who logged in.
        await self.store.mark_ui_auth_stage_complete(session_id, LoginType.SSO,
                                                     registered_user_id)

        # Render the HTML and return.
        html_bytes = self._sso_auth_success_template.encode("utf-8")
        request.setResponseCode(200)
        request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
        request.setHeader(b"Content-Length", b"%d" % (len(html_bytes), ))

        request.write(html_bytes)
        finish_request(request)

    async def complete_sso_login(
        self,
        registered_user_id: str,
        request: SynapseRequest,
        client_redirect_url: str,
    ):
        """Having figured out a mxid for this user, complete the HTTP request

        Args:
            registered_user_id: The registered user ID to complete SSO login for.
            request: The request to complete.
            client_redirect_url: The URL to which to redirect the user at the end of the
                process.
        """
        # If the account has been deactivated, do not proceed with the login
        # flow.
        deactivated = await self.store.get_user_deactivated_status(
            registered_user_id)
        if deactivated:
            html_bytes = self._sso_account_deactivated_template.encode("utf-8")

            request.setResponseCode(403)
            request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
            request.setHeader(b"Content-Length", b"%d" % (len(html_bytes), ))
            request.write(html_bytes)
            finish_request(request)
            return

        self._complete_sso_login(registered_user_id, request,
                                 client_redirect_url)

    def _complete_sso_login(
        self,
        registered_user_id: str,
        request: SynapseRequest,
        client_redirect_url: str,
    ):
        """
        The synchronous portion of complete_sso_login.

        This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
        """
        # Create a login token
        login_token = self.macaroon_gen.generate_short_term_login_token(
            registered_user_id)

        # Append the login token to the original redirect URL (i.e. with its query
        # parameters kept intact) to build the URL to which the template needs to
        # redirect the users once they have clicked on the confirmation link.
        redirect_url = self.add_query_param_to_url(client_redirect_url,
                                                   "loginToken", login_token)

        # if the client is whitelisted, we can redirect straight to it
        if client_redirect_url.startswith(self._whitelisted_sso_clients):
            request.redirect(redirect_url)
            finish_request(request)
            return

        # Otherwise, serve the redirect confirmation page.

        # Remove the query parameters from the redirect URL to get a shorter version of
        # it. This is only to display a human-readable URL in the template, but not the
        # URL we redirect users to.
        redirect_url_no_params = client_redirect_url.split("?")[0]

        html_bytes = self._sso_redirect_confirm_template.render(
            display_url=redirect_url_no_params,
            redirect_url=redirect_url,
            server_name=self._server_name,
        ).encode("utf-8")

        request.setResponseCode(200)
        request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
        request.setHeader(b"Content-Length", b"%d" % (len(html_bytes), ))
        request.write(html_bytes)
        finish_request(request)

    @staticmethod
    def add_query_param_to_url(url: str, param_name: str, param: Any):
        url_parts = list(urllib.parse.urlparse(url))
        query = dict(urllib.parse.parse_qsl(url_parts[4]))
        query.update({param_name: param})
        url_parts[4] = urllib.parse.urlencode(query)
        return urllib.parse.urlunparse(url_parts)
Example #9
0
class IdentityHandler(BaseHandler):
    def __init__(self, hs):
        super().__init__(hs)

        # An HTTP client for contacting trusted URLs.
        self.http_client = SimpleHttpClient(hs)
        # An HTTP client for contacting identity servers specified by clients.
        self.blacklisting_http_client = SimpleHttpClient(
            hs, ip_blacklist=hs.config.federation_ip_range_blacklist)
        self.federation_http_client = hs.get_federation_http_client()
        self.hs = hs

        self._web_client_location = hs.config.invite_client_location

        # Ratelimiters for `/requestToken` endpoints.
        self._3pid_validation_ratelimiter_ip = Ratelimiter(
            clock=hs.get_clock(),
            rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
            burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
        )
        self._3pid_validation_ratelimiter_address = Ratelimiter(
            clock=hs.get_clock(),
            rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
            burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
        )

    def ratelimit_request_token_requests(
        self,
        request: SynapseRequest,
        medium: str,
        address: str,
    ):
        """Used to ratelimit requests to `/requestToken` by IP and address.

        Args:
            request: The associated request
            medium: The type of threepid, e.g. "msisdn" or "email"
            address: The actual threepid ID, e.g. the phone number or email address
        """

        self._3pid_validation_ratelimiter_ip.ratelimit(
            (medium, request.getClientIP()))
        self._3pid_validation_ratelimiter_address.ratelimit((medium, address))

    async def threepid_from_creds(self, id_server: str,
                                  creds: Dict[str, str]) -> Optional[JsonDict]:
        """
        Retrieve and validate a threepid identifier from a "credentials" dictionary against a
        given identity server

        Args:
            id_server: The identity server to validate 3PIDs against. Must be a
                complete URL including the protocol (http(s)://)
            creds: Dictionary containing the following keys:
                * client_secret|clientSecret: A unique secret str provided by the client
                * sid: The ID of the validation session

        Returns:
            A dictionary consisting of response params to the /getValidated3pid
            endpoint of the Identity Service API, or None if the threepid was not found
        """
        client_secret = creds.get("client_secret") or creds.get("clientSecret")
        if not client_secret:
            raise SynapseError(400,
                               "Missing param client_secret in creds",
                               errcode=Codes.MISSING_PARAM)
        assert_valid_client_secret(client_secret)

        session_id = creds.get("sid")
        if not session_id:
            raise SynapseError(400,
                               "Missing param session_id in creds",
                               errcode=Codes.MISSING_PARAM)

        query_params = {"sid": session_id, "client_secret": client_secret}

        url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid"

        try:
            data = await self.http_client.get_json(url, query_params)
        except RequestTimedOutError:
            raise SynapseError(500, "Timed out contacting identity server")
        except HttpResponseException as e:
            logger.info(
                "%s returned %i for threepid validation for: %s",
                id_server,
                e.code,
                creds,
            )
            return None

        # Old versions of Sydent return a 200 http code even on a failed validation
        # check. Thus, in addition to the HttpResponseException check above (which
        # checks for non-200 errors), we need to make sure validation_session isn't
        # actually an error, identified by the absence of a "medium" key
        # See https://github.com/matrix-org/sydent/issues/215 for details
        if "medium" in data:
            return data

        logger.info("%s reported non-validated threepid: %s", id_server, creds)
        return None

    async def bind_threepid(
        self,
        client_secret: str,
        sid: str,
        mxid: str,
        id_server: str,
        id_access_token: Optional[str] = None,
        use_v2: bool = True,
    ) -> JsonDict:
        """Bind a 3PID to an identity server

        Args:
            client_secret: A unique secret provided by the client
            sid: The ID of the validation session
            mxid: The MXID to bind the 3PID to
            id_server: The domain of the identity server to query
            id_access_token: The access token to authenticate to the identity
                server with, if necessary. Required if use_v2 is true
            use_v2: Whether to use v2 Identity Service API endpoints. Defaults to True

        Returns:
            The response from the identity server
        """
        logger.debug("Proxying threepid bind request for %s to %s", mxid,
                     id_server)

        # If an id_access_token is not supplied, force usage of v1
        if id_access_token is None:
            use_v2 = False

        # Decide which API endpoint URLs to use
        headers = {}
        bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
        if use_v2:
            bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (
                id_server, )
            headers["Authorization"] = create_id_access_token_header(
                id_access_token)  # type: ignore
        else:
            bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (
                id_server, )

        try:
            # Use the blacklisting http client as this call is only to identity servers
            # provided by a client
            data = await self.blacklisting_http_client.post_json_get_json(
                bind_url, bind_data, headers=headers)

            # Remember where we bound the threepid
            await self.store.add_user_bound_threepid(
                user_id=mxid,
                medium=data["medium"],
                address=data["address"],
                id_server=id_server,
            )

            return data
        except HttpResponseException as e:
            if e.code != 404 or not use_v2:
                logger.error("3PID bind failed with Matrix error: %r", e)
                raise e.to_synapse_error()
        except RequestTimedOutError:
            raise SynapseError(500, "Timed out contacting identity server")
        except CodeMessageException as e:
            data = json_decoder.decode(e.msg)  # XXX WAT?
            return data

        logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL",
                    bind_url)
        res = await self.bind_threepid(client_secret,
                                       sid,
                                       mxid,
                                       id_server,
                                       id_access_token,
                                       use_v2=False)
        return res

    async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool:
        """Attempt to remove a 3PID from an identity server, or if one is not provided, all
        identity servers we're aware the binding is present on

        Args:
            mxid: Matrix user ID of binding to be removed
            threepid: Dict with medium & address of binding to be
                removed, and an optional id_server.

        Raises:
            SynapseError: If we failed to contact the identity server

        Returns:
            True on success, otherwise False if the identity
            server doesn't support unbinding (or no identity server found to
            contact).
        """
        if threepid.get("id_server"):
            id_servers = [threepid["id_server"]]
        else:
            id_servers = await self.store.get_id_servers_user_bound(
                user_id=mxid,
                medium=threepid["medium"],
                address=threepid["address"])

        # We don't know where to unbind, so we don't have a choice but to return
        if not id_servers:
            return False

        changed = True
        for id_server in id_servers:
            changed &= await self.try_unbind_threepid_with_id_server(
                mxid, threepid, id_server)

        return changed

    async def try_unbind_threepid_with_id_server(self, mxid: str,
                                                 threepid: dict,
                                                 id_server: str) -> bool:
        """Removes a binding from an identity server

        Args:
            mxid: Matrix user ID of binding to be removed
            threepid: Dict with medium & address of binding to be removed
            id_server: Identity server to unbind from

        Raises:
            SynapseError: If we failed to contact the identity server

        Returns:
            True on success, otherwise False if the identity
            server doesn't support unbinding
        """
        url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server, )
        url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")

        content = {
            "mxid": mxid,
            "threepid": {
                "medium": threepid["medium"],
                "address": threepid["address"]
            },
        }

        # we abuse the federation http client to sign the request, but we have to send it
        # using the normal http client since we don't want the SRV lookup and want normal
        # 'browser-like' HTTPS.
        auth_headers = self.federation_http_client.build_auth_headers(
            destination=None,
            method=b"POST",
            url_bytes=url_bytes,
            content=content,
            destination_is=id_server.encode("ascii"),
        )
        headers = {b"Authorization": auth_headers}

        try:
            # Use the blacklisting http client as this call is only to identity servers
            # provided by a client
            await self.blacklisting_http_client.post_json_get_json(
                url, content, headers)
            changed = True
        except HttpResponseException as e:
            changed = False
            if e.code in (400, 404, 501):
                # The remote server probably doesn't support unbinding (yet)
                logger.warning("Received %d response while unbinding threepid",
                               e.code)
            else:
                logger.error(
                    "Failed to unbind threepid on identity server: %s", e)
                raise SynapseError(500, "Failed to contact identity server")
        except RequestTimedOutError:
            raise SynapseError(500, "Timed out contacting identity server")

        await self.store.remove_user_bound_threepid(
            user_id=mxid,
            medium=threepid["medium"],
            address=threepid["address"],
            id_server=id_server,
        )

        return changed

    async def send_threepid_validation(
        self,
        email_address: str,
        client_secret: str,
        send_attempt: int,
        send_email_func: Callable[[str, str, str, str], Awaitable],
        next_link: Optional[str] = None,
    ) -> str:
        """Send a threepid validation email for password reset or
        registration purposes

        Args:
            email_address: The user's email address
            client_secret: The provided client secret
            send_attempt: Which send attempt this is
            send_email_func: A function that takes an email address, token,
                             client_secret and session_id, sends an email
                             and returns an Awaitable.
            next_link: The URL to redirect the user to after validation

        Returns:
            The new session_id upon success

        Raises:
            SynapseError is an error occurred when sending the email
        """
        # Check that this email/client_secret/send_attempt combo is new or
        # greater than what we've seen previously
        session = await self.store.get_threepid_validation_session(
            "email", client_secret, address=email_address, validated=False)

        # Check to see if a session already exists and that it is not yet
        # marked as validated
        if session and session.get("validated_at") is None:
            session_id = session["session_id"]
            last_send_attempt = session["last_send_attempt"]

            # Check that the send_attempt is higher than previous attempts
            if send_attempt <= last_send_attempt:
                # If not, just return a success without sending an email
                return session_id
        else:
            # An non-validated session does not exist yet.
            # Generate a session id
            session_id = random_string(16)

        if next_link:
            # Manipulate the next_link to add the sid, because the caller won't get
            # it until we send a response, by which time we've sent the mail.
            if "?" in next_link:
                next_link += "&"
            else:
                next_link += "?"
            next_link += "sid=" + urllib.parse.quote(session_id)

        # Generate a new validation token
        token = random_string(32)

        # Send the mail with the link containing the token, client_secret
        # and session_id
        try:
            await send_email_func(email_address, token, client_secret,
                                  session_id)
        except Exception:
            logger.exception("Error sending threepid validation email to %s",
                             email_address)
            raise SynapseError(
                500, "An error was encountered when sending the email")

        token_expires = (self.hs.get_clock().time_msec() +
                         self.hs.config.email_validation_token_lifetime)

        await self.store.start_or_continue_validation_session(
            "email",
            email_address,
            session_id,
            client_secret,
            send_attempt,
            next_link,
            token,
            token_expires,
        )

        return session_id

    async def requestEmailToken(
        self,
        id_server: str,
        email: str,
        client_secret: str,
        send_attempt: int,
        next_link: Optional[str] = None,
    ) -> JsonDict:
        """
        Request an external server send an email on our behalf for the purposes of threepid
        validation.

        Args:
            id_server: The identity server to proxy to
            email: The email to send the message to
            client_secret: The unique client_secret sends by the user
            send_attempt: Which attempt this is
            next_link: A link to redirect the user to once they submit the token

        Returns:
            The json response body from the server
        """
        params = {
            "email": email,
            "client_secret": client_secret,
            "send_attempt": send_attempt,
        }
        if next_link:
            params["next_link"] = next_link

        if self.hs.config.using_identity_server_from_trusted_list:
            # Warn that a deprecated config option is in use
            logger.warning(
                'The config option "trust_identity_server_for_password_resets" '
                'has been replaced by "account_threepid_delegate". '
                "Please consult the sample config at docs/sample_config.yaml for "
                "details and update your config file.")

        try:
            data = await self.http_client.post_json_get_json(
                id_server +
                "/_matrix/identity/api/v1/validate/email/requestToken",
                params,
            )
            return data
        except HttpResponseException as e:
            logger.info("Proxied requestToken failed: %r", e)
            raise e.to_synapse_error()
        except RequestTimedOutError:
            raise SynapseError(500, "Timed out contacting identity server")

    async def requestMsisdnToken(
        self,
        id_server: str,
        country: str,
        phone_number: str,
        client_secret: str,
        send_attempt: int,
        next_link: Optional[str] = None,
    ) -> JsonDict:
        """
        Request an external server send an SMS message on our behalf for the purposes of
        threepid validation.
        Args:
            id_server: The identity server to proxy to
            country: The country code of the phone number
            phone_number: The number to send the message to
            client_secret: The unique client_secret sends by the user
            send_attempt: Which attempt this is
            next_link: A link to redirect the user to once they submit the token

        Returns:
            The json response body from the server
        """
        params = {
            "country": country,
            "phone_number": phone_number,
            "client_secret": client_secret,
            "send_attempt": send_attempt,
        }
        if next_link:
            params["next_link"] = next_link

        if self.hs.config.using_identity_server_from_trusted_list:
            # Warn that a deprecated config option is in use
            logger.warning(
                'The config option "trust_identity_server_for_password_resets" '
                'has been replaced by "account_threepid_delegate". '
                "Please consult the sample config at docs/sample_config.yaml for "
                "details and update your config file.")

        try:
            data = await self.http_client.post_json_get_json(
                id_server +
                "/_matrix/identity/api/v1/validate/msisdn/requestToken",
                params,
            )
        except HttpResponseException as e:
            logger.info("Proxied requestToken failed: %r", e)
            raise e.to_synapse_error()
        except RequestTimedOutError:
            raise SynapseError(500, "Timed out contacting identity server")

        # It is already checked that public_baseurl is configured since this code
        # should only be used if account_threepid_delegate_msisdn is true.
        assert self.hs.config.public_baseurl

        # we need to tell the client to send the token back to us, since it doesn't
        # otherwise know where to send it, so add submit_url response parameter
        # (see also MSC2078)
        data["submit_url"] = (
            self.hs.config.public_baseurl +
            "_matrix/client/unstable/add_threepid/msisdn/submit_token")
        return data

    async def validate_threepid_session(self, client_secret: str,
                                        sid: str) -> Optional[JsonDict]:
        """Validates a threepid session with only the client secret and session ID
        Tries validating against any configured account_threepid_delegates as well as locally.

        Args:
            client_secret: A secret provided by the client
            sid: The ID of the session

        Returns:
            The json response if validation was successful, otherwise None
        """
        # XXX: We shouldn't need to keep wrapping and unwrapping this value
        threepid_creds = {"client_secret": client_secret, "sid": sid}

        # We don't actually know which medium this 3PID is. Thus we first assume it's email,
        # and if validation fails we try msisdn
        validation_session = None

        # Try to validate as email
        if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
            # Ask our delegated email identity server
            validation_session = await self.threepid_from_creds(
                self.hs.config.account_threepid_delegate_email, threepid_creds)
        elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
            # Get a validated session matching these details
            validation_session = await self.store.get_threepid_validation_session(
                "email", client_secret, sid=sid, validated=True)

        if validation_session:
            return validation_session

        # Try to validate as msisdn
        if self.hs.config.account_threepid_delegate_msisdn:
            # Ask our delegated msisdn identity server
            validation_session = await self.threepid_from_creds(
                self.hs.config.account_threepid_delegate_msisdn,
                threepid_creds)

        return validation_session

    async def proxy_msisdn_submit_token(self, id_server: str,
                                        client_secret: str, sid: str,
                                        token: str) -> JsonDict:
        """Proxy a POST submitToken request to an identity server for verification purposes

        Args:
            id_server: The identity server URL to contact
            client_secret: Secret provided by the client
            sid: The ID of the session
            token: The verification token

        Raises:
            SynapseError: If we failed to contact the identity server

        Returns:
            The response dict from the identity server
        """
        body = {"client_secret": client_secret, "sid": sid, "token": token}

        try:
            return await self.http_client.post_json_get_json(
                id_server +
                "/_matrix/identity/api/v1/validate/msisdn/submitToken",
                body,
            )
        except RequestTimedOutError:
            raise SynapseError(500, "Timed out contacting identity server")
        except HttpResponseException as e:
            logger.warning(
                "Error contacting msisdn account_threepid_delegate: %s", e)
            raise SynapseError(400, "Error contacting the identity server")

    async def lookup_3pid(
        self,
        id_server: str,
        medium: str,
        address: str,
        id_access_token: Optional[str] = None,
    ) -> Optional[str]:
        """Looks up a 3pid in the passed identity server.

        Args:
            id_server: The server name (including port, if required)
                of the identity server to use.
            medium: The type of the third party identifier (e.g. "email").
            address: The third party identifier (e.g. "*****@*****.**").
            id_access_token: The access token to authenticate to the identity
                server with

        Returns:
            the matrix ID of the 3pid, or None if it is not recognized.
        """
        if id_access_token is not None:
            try:
                results = await self._lookup_3pid_v2(id_server,
                                                     id_access_token, medium,
                                                     address)
                return results

            except Exception as e:
                # Catch HttpResponseExcept for a non-200 response code
                # Check if this identity server does not know about v2 lookups
                if isinstance(e, HttpResponseException) and e.code == 404:
                    # This is an old identity server that does not yet support v2 lookups
                    logger.warning(
                        "Attempted v2 lookup on v1 identity server %s. Falling "
                        "back to v1",
                        id_server,
                    )
                else:
                    logger.warning("Error when looking up hashing details: %s",
                                   e)
                    return None

        return await self._lookup_3pid_v1(id_server, medium, address)

    async def _lookup_3pid_v1(self, id_server: str, medium: str,
                              address: str) -> Optional[str]:
        """Looks up a 3pid in the passed identity server using v1 lookup.

        Args:
            id_server: The server name (including port, if required)
                of the identity server to use.
            medium: The type of the third party identifier (e.g. "email").
            address: The third party identifier (e.g. "*****@*****.**").

        Returns:
            the matrix ID of the 3pid, or None if it is not recognized.
        """
        try:
            data = await self.blacklisting_http_client.get_json(
                "%s%s/_matrix/identity/api/v1/lookup" %
                (id_server_scheme, id_server),
                {
                    "medium": medium,
                    "address": address
                },
            )

            if "mxid" in data:
                # note: we used to verify the identity server's signature here, but no longer
                # require or validate it. See the following for context:
                # https://github.com/matrix-org/synapse/issues/5253#issuecomment-666246950
                return data["mxid"]
        except RequestTimedOutError:
            raise SynapseError(500, "Timed out contacting identity server")
        except IOError as e:
            logger.warning("Error from v1 identity server lookup: %s" % (e, ))

        return None

    async def _lookup_3pid_v2(self, id_server: str, id_access_token: str,
                              medium: str, address: str) -> Optional[str]:
        """Looks up a 3pid in the passed identity server using v2 lookup.

        Args:
            id_server: The server name (including port, if required)
                of the identity server to use.
            id_access_token: The access token to authenticate to the identity server with
            medium: The type of the third party identifier (e.g. "email").
            address: The third party identifier (e.g. "*****@*****.**").

        Returns:
            the matrix ID of the 3pid, or None if it is not recognised.
        """
        # Check what hashing details are supported by this identity server
        try:
            hash_details = await self.blacklisting_http_client.get_json(
                "%s%s/_matrix/identity/v2/hash_details" %
                (id_server_scheme, id_server),
                {"access_token": id_access_token},
            )
        except RequestTimedOutError:
            raise SynapseError(500, "Timed out contacting identity server")

        if not isinstance(hash_details, dict):
            logger.warning(
                "Got non-dict object when checking hash details of %s%s: %s",
                id_server_scheme,
                id_server,
                hash_details,
            )
            raise SynapseError(
                400,
                "Non-dict object from %s%s during v2 hash_details request: %s"
                % (id_server_scheme, id_server, hash_details),
            )

        # Extract information from hash_details
        supported_lookup_algorithms = hash_details.get("algorithms")
        lookup_pepper = hash_details.get("lookup_pepper")
        if (not supported_lookup_algorithms
                or not isinstance(supported_lookup_algorithms, list)
                or not lookup_pepper or not isinstance(lookup_pepper, str)):
            raise SynapseError(
                400,
                "Invalid hash details received from identity server %s%s: %s" %
                (id_server_scheme, id_server, hash_details),
            )

        # Check if any of the supported lookup algorithms are present
        if LookupAlgorithm.SHA256 in supported_lookup_algorithms:
            # Perform a hashed lookup
            lookup_algorithm = LookupAlgorithm.SHA256

            # Hash address, medium and the pepper with sha256
            to_hash = "%s %s %s" % (address, medium, lookup_pepper)
            lookup_value = sha256_and_url_safe_base64(to_hash)

        elif LookupAlgorithm.NONE in supported_lookup_algorithms:
            # Perform a non-hashed lookup
            lookup_algorithm = LookupAlgorithm.NONE

            # Combine together plaintext address and medium
            lookup_value = "%s %s" % (address, medium)

        else:
            logger.warning(
                "None of the provided lookup algorithms of %s are supported: %s",
                id_server,
                supported_lookup_algorithms,
            )
            raise SynapseError(
                400,
                "Provided identity server does not support any v2 lookup "
                "algorithms that this homeserver supports.",
            )

        # Authenticate with identity server given the access token from the client
        headers = {
            "Authorization": create_id_access_token_header(id_access_token)
        }

        try:
            lookup_results = await self.blacklisting_http_client.post_json_get_json(
                "%s%s/_matrix/identity/v2/lookup" %
                (id_server_scheme, id_server),
                {
                    "addresses": [lookup_value],
                    "algorithm": lookup_algorithm,
                    "pepper": lookup_pepper,
                },
                headers=headers,
            )
        except RequestTimedOutError:
            raise SynapseError(500, "Timed out contacting identity server")
        except Exception as e:
            logger.warning("Error when performing a v2 3pid lookup: %s", e)
            raise SynapseError(
                500, "Unknown error occurred during identity server lookup")

        # Check for a mapping from what we looked up to an MXID
        if "mappings" not in lookup_results or not isinstance(
                lookup_results["mappings"], dict):
            logger.warning("No results from 3pid lookup")
            return None

        # Return the MXID if it's available, or None otherwise
        mxid = lookup_results["mappings"].get(lookup_value)
        return mxid

    async def ask_id_server_for_third_party_invite(
        self,
        requester: Requester,
        id_server: str,
        medium: str,
        address: str,
        room_id: str,
        inviter_user_id: str,
        room_alias: str,
        room_avatar_url: str,
        room_join_rules: str,
        room_name: str,
        inviter_display_name: str,
        inviter_avatar_url: str,
        id_access_token: Optional[str] = None,
    ) -> Tuple[str, List[Dict[str, str]], Dict[str, str], str]:
        """
        Asks an identity server for a third party invite.

        Args:
            requester
            id_server: hostname + optional port for the identity server.
            medium: The literal string "email".
            address: The third party address being invited.
            room_id: The ID of the room to which the user is invited.
            inviter_user_id: The user ID of the inviter.
            room_alias: An alias for the room, for cosmetic notifications.
            room_avatar_url: The URL of the room's avatar, for cosmetic
                notifications.
            room_join_rules: The join rules of the email (e.g. "public").
            room_name: The m.room.name of the room.
            inviter_display_name: The current display name of the
                inviter.
            inviter_avatar_url: The URL of the inviter's avatar.
            id_access_token (str|None): The access token to authenticate to the identity
                server with

        Returns:
            A tuple containing:
                token: The token which must be signed to prove authenticity.
                public_keys ([{"public_key": str, "key_validity_url": str}]):
                    public_key is a base64-encoded ed25519 public key.
                fallback_public_key: One element from public_keys.
                display_name: A user-friendly name to represent the invited user.
        """
        invite_config = {
            "medium": medium,
            "address": address,
            "room_id": room_id,
            "room_alias": room_alias,
            "room_avatar_url": room_avatar_url,
            "room_join_rules": room_join_rules,
            "room_name": room_name,
            "sender": inviter_user_id,
            "sender_display_name": inviter_display_name,
            "sender_avatar_url": inviter_avatar_url,
        }
        # If a custom web client location is available, include it in the request.
        if self._web_client_location:
            invite_config[
                "org.matrix.web_client_location"] = self._web_client_location

        # Add the identity service access token to the JSON body and use the v2
        # Identity Service endpoints if id_access_token is present
        data = None
        base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server)

        if id_access_token:
            key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
                id_server_scheme,
                id_server,
            )

            # Attempt a v2 lookup
            url = base_url + "/v2/store-invite"
            try:
                data = await self.blacklisting_http_client.post_json_get_json(
                    url,
                    invite_config,
                    {
                        "Authorization":
                        create_id_access_token_header(id_access_token)
                    },
                )
            except RequestTimedOutError:
                raise SynapseError(500, "Timed out contacting identity server")
            except HttpResponseException as e:
                if e.code != 404:
                    logger.info("Failed to POST %s with JSON: %s", url, e)
                    raise e

        if data is None:
            key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
                id_server_scheme,
                id_server,
            )
            url = base_url + "/api/v1/store-invite"

            try:
                data = await self.blacklisting_http_client.post_json_get_json(
                    url, invite_config)
            except RequestTimedOutError:
                raise SynapseError(500, "Timed out contacting identity server")
            except HttpResponseException as e:
                logger.warning(
                    "Error trying to call /store-invite on %s%s: %s",
                    id_server_scheme,
                    id_server,
                    e,
                )

            if data is None:
                # Some identity servers may only support application/x-www-form-urlencoded
                # types. This is especially true with old instances of Sydent, see
                # https://github.com/matrix-org/sydent/pull/170
                try:
                    data = await self.blacklisting_http_client.post_urlencoded_get_json(
                        url, invite_config)
                except HttpResponseException as e:
                    logger.warning(
                        "Error calling /store-invite on %s%s with fallback "
                        "encoding: %s",
                        id_server_scheme,
                        id_server,
                        e,
                    )
                    raise e

        # TODO: Check for success
        token = data["token"]
        public_keys = data.get("public_keys", [])
        if "public_key" in data:
            fallback_public_key = {
                "public_key": data["public_key"],
                "key_validity_url": key_validity_url,
            }
        else:
            fallback_public_key = public_keys[0]

        if not public_keys:
            public_keys.append(fallback_public_key)
        display_name = data["display_name"]
        return token, public_keys, fallback_public_key, display_name
Example #10
0
class LoginRestServlet(ClientV1RestServlet):
    PATTERNS = client_path_patterns("/login$")
    CAS_TYPE = "m.login.cas"
    SSO_TYPE = "m.login.sso"
    TOKEN_TYPE = "m.login.token"
    JWT_TYPE = "m.login.jwt"

    def __init__(self, hs):
        super(LoginRestServlet, self).__init__(hs)
        self.jwt_enabled = hs.config.jwt_enabled
        self.jwt_secret = hs.config.jwt_secret
        self.jwt_algorithm = hs.config.jwt_algorithm
        self.cas_enabled = hs.config.cas_enabled
        self.auth_handler = self.hs.get_auth_handler()
        self.registration_handler = hs.get_registration_handler()
        self.handlers = hs.get_handlers()
        self._well_known_builder = WellKnownBuilder(hs)
        self._address_ratelimiter = Ratelimiter()

    def on_GET(self, request):
        flows = []
        if self.jwt_enabled:
            flows.append({"type": LoginRestServlet.JWT_TYPE})
        if self.cas_enabled:
            flows.append({"type": LoginRestServlet.SSO_TYPE})

            # we advertise CAS for backwards compat, though MSC1721 renamed it
            # to SSO.
            flows.append({"type": LoginRestServlet.CAS_TYPE})

            # While its valid for us to advertise this login type generally,
            # synapse currently only gives out these tokens as part of the
            # CAS login flow.
            # Generally we don't want to advertise login flows that clients
            # don't know how to implement, since they (currently) will always
            # fall back to the fallback API if they don't understand one of the
            # login flow types returned.
            flows.append({"type": LoginRestServlet.TOKEN_TYPE})

        flows.extend((
            {"type": t} for t in self.auth_handler.get_supported_login_types()
        ))

        return (200, {"flows": flows})

    def on_OPTIONS(self, request):
        return (200, {})

    @defer.inlineCallbacks
    def on_POST(self, request):
        self._address_ratelimiter.ratelimit(
            request.getClientIP(), time_now_s=self.hs.clock.time(),
            rate_hz=self.hs.config.rc_login_address.per_second,
            burst_count=self.hs.config.rc_login_address.burst_count,
            update=True,
        )

        login_submission = parse_json_object_from_request(request)
        try:
            if self.jwt_enabled and (login_submission["type"] ==
                                     LoginRestServlet.JWT_TYPE):
                result = yield self.do_jwt_login(login_submission)
            elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
                result = yield self.do_token_login(login_submission)
            else:
                result = yield self._do_other_login(login_submission)
        except KeyError:
            raise SynapseError(400, "Missing JSON keys.")

        well_known_data = self._well_known_builder.get_well_known()
        if well_known_data:
            result["well_known"] = well_known_data
        defer.returnValue((200, result))

    @defer.inlineCallbacks
    def _do_other_login(self, login_submission):
        """Handle non-token/saml/jwt logins

        Args:
            login_submission:

        Returns:
            dict: HTTP response
        """
        # Log the request we got, but only certain fields to minimise the chance of
        # logging someone's password (even if they accidentally put it in the wrong
        # field)
        logger.info(
            "Got login request with identifier: %r, medium: %r, address: %r, user: %r",
            login_submission.get('identifier'),
            login_submission.get('medium'),
            login_submission.get('address'),
            login_submission.get('user'),
        )
        login_submission_legacy_convert(login_submission)

        if "identifier" not in login_submission:
            raise SynapseError(400, "Missing param: identifier")

        identifier = login_submission["identifier"]
        if "type" not in identifier:
            raise SynapseError(400, "Login identifier has no type")

        # convert phone type identifiers to generic threepids
        if identifier["type"] == "m.id.phone":
            identifier = login_id_thirdparty_from_phone(identifier)

        # convert threepid identifiers to user IDs
        if identifier["type"] == "m.id.thirdparty":
            address = identifier.get('address')
            medium = identifier.get('medium')

            if medium is None or address is None:
                raise SynapseError(400, "Invalid thirdparty identifier")

            if medium == 'email':
                # For emails, transform the address to lowercase.
                # We store all email addreses as lowercase in the DB.
                # (See add_threepid in synapse/handlers/auth.py)
                address = address.lower()

            # Check for login providers that support 3pid login types
            canonical_user_id, callback_3pid = (
                yield self.auth_handler.check_password_provider_3pid(
                    medium,
                    address,
                    login_submission["password"],
                )
            )
            if canonical_user_id:
                # Authentication through password provider and 3pid succeeded
                result = yield self._register_device_with_callback(
                    canonical_user_id, login_submission, callback_3pid,
                )
                defer.returnValue(result)

            # No password providers were able to handle this 3pid
            # Check local store
            user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
                medium, address,
            )
            if not user_id:
                logger.warn(
                    "unknown 3pid identifier medium %s, address %r",
                    medium, address,
                )
                raise LoginError(403, "", errcode=Codes.FORBIDDEN)

            identifier = {
                "type": "m.id.user",
                "user": user_id,
            }

        # by this point, the identifier should be an m.id.user: if it's anything
        # else, we haven't understood it.
        if identifier["type"] != "m.id.user":
            raise SynapseError(400, "Unknown login identifier type")
        if "user" not in identifier:
            raise SynapseError(400, "User identifier is missing 'user' key")

        canonical_user_id, callback = yield self.auth_handler.validate_login(
            identifier["user"],
            login_submission,
        )

        result = yield self._register_device_with_callback(
            canonical_user_id, login_submission, callback,
        )
        defer.returnValue(result)

    @defer.inlineCallbacks
    def _register_device_with_callback(
        self,
        user_id,
        login_submission,
        callback=None,
    ):
        """ Registers a device with a given user_id. Optionally run a callback
        function after registration has completed.

        Args:
            user_id (str): ID of the user to register.
            login_submission (dict): Dictionary of login information.
            callback (func|None): Callback function to run after registration.

        Returns:
            result (Dict[str,str]): Dictionary of account information after
                successful registration.
        """
        device_id = login_submission.get("device_id")
        initial_display_name = login_submission.get("initial_device_display_name")
        device_id, access_token = yield self.registration_handler.register_device(
            user_id, device_id, initial_display_name,
        )

        result = {
            "user_id": user_id,
            "access_token": access_token,
            "home_server": self.hs.hostname,
            "device_id": device_id,
        }

        if callback is not None:
            yield callback(result)

        defer.returnValue(result)

    @defer.inlineCallbacks
    def do_token_login(self, login_submission):
        token = login_submission['token']
        auth_handler = self.auth_handler
        user_id = (
            yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
        )

        device_id = login_submission.get("device_id")
        initial_display_name = login_submission.get("initial_device_display_name")
        device_id, access_token = yield self.registration_handler.register_device(
            user_id, device_id, initial_display_name,
        )

        result = {
            "user_id": user_id,  # may have changed
            "access_token": access_token,
            "home_server": self.hs.hostname,
            "device_id": device_id,
        }

        defer.returnValue(result)

    @defer.inlineCallbacks
    def do_jwt_login(self, login_submission):
        token = login_submission.get("token", None)
        if token is None:
            raise LoginError(
                401, "Token field for JWT is missing",
                errcode=Codes.UNAUTHORIZED
            )

        import jwt
        from jwt.exceptions import InvalidTokenError

        try:
            payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
        except jwt.ExpiredSignatureError:
            raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
        except InvalidTokenError:
            raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)

        user = payload.get("sub", None)
        if user is None:
            raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)

        user_id = UserID(user, self.hs.hostname).to_string()

        auth_handler = self.auth_handler
        registered_user_id = yield auth_handler.check_user_exists(user_id)
        if registered_user_id:
            device_id = login_submission.get("device_id")
            initial_display_name = login_submission.get("initial_device_display_name")
            device_id, access_token = yield self.registration_handler.register_device(
                registered_user_id, device_id, initial_display_name,
            )

            result = {
                "user_id": registered_user_id,
                "access_token": access_token,
                "home_server": self.hs.hostname,
            }
        else:
            user_id, access_token = (
                yield self.handlers.registration_handler.register(localpart=user)
            )

            device_id = login_submission.get("device_id")
            initial_display_name = login_submission.get("initial_device_display_name")
            device_id, access_token = yield self.registration_handler.register_device(
                registered_user_id, device_id, initial_display_name,
            )

            result = {
                "user_id": user_id,  # may have changed
                "access_token": access_token,
                "home_server": self.hs.hostname,
            }

        defer.returnValue(result)
Example #11
0
class BindOpenidtoMXID(RestServlet):
    logger.info("------------init--1----")
    PATTERNS = client_patterns("/login/oauth2/bind$", v1=True)

    def __init__(self, hs):
        super().__init__()
        self.hs = hs
        logger.info("------------init------")
        self.auth = hs.get_auth()
        self._auth_handler = hs.get_auth_handler()
        self._cache = hs.get_eachchat_cache_for_openid()
        # self.get_ver_code_cache = ExpiringCache(
        #     cache_name="get_ver_code_cache",
        #     clock=self._clock,
        #     max_len=1000,
        #     expiry_ms=10 * 60 * 1000,
        #     reset_expiry_on_get=False,
        # )

        self._address_ratelimiter = Ratelimiter(
            clock=hs.get_clock(),
            rate_hz=self.hs.config.rc_login_address.per_second,
            burst_count=self.hs.config.rc_login_address.burst_count,
        )
        self.http_client = SimpleHttpClient(hs)

    async def on_POST(self, request: SynapseRequest):
        self._address_ratelimiter.ratelimit(request.getClientIP())

        params = parse_json_object_from_request(request)
        logger.info("------------param:%s" % (str(params)))
        bind_type = params["bind_type"]
        auth_code = params["auth_code"]
        if bind_type is None:
            raise LoginError(410,
                             "bind_type field for bind openid is missing",
                             errcode=Codes.FORBIDDEN)
        if auth_code is None:
            raise LoginError(410,
                             "auth_code field for bind openid is missing",
                             errcode=Codes.FORBIDDEN)
        requester = await self.auth.get_user_by_req(request)
        logger.info('------requester: %s' % (requester, ))
        user_id = requester.user
        logger.info('------user: %s' % (user_id, ))

        openid = await self.hs.get_datastore(
        ).get_external_id_for_user_provider(bind_type, str(user_id))
        if openid is not None:
            raise LoginError(416,
                             "openid is binded,not support reeapt bind",
                             errcode=Codes.OPENID_BINDED)

        #first get openid from cache with auth_code
        if self._cache.__contains__(auth_code):
            remote_user_id = self._cache[auth_code]

            logger.info('--cache----remote_user_id: %s' % (remote_user_id, ))
            if remote_user_id is not None:
                # complete bind
                await self.hs.get_datastore().record_user_external_id(
                    bind_type,
                    remote_user_id,
                    str(user_id),
                )
                self._cache.pop(auth_code, None)

                return 200, {}
        else:
            logger.info('-no-cache-by-auth_code--')

        # first get openid from cache with auth_code
        app_type = ''
        if bind_type == 'alipay':
            app_type = 'zhifubao'
        elif bind_type == 'weixin':
            app_type = 'weixin'
            # Desttop handler, only weixin
            if params.__contains__("device_type"):
                device_type = params["device_type"]
                if device_type:
                    app_type = 'weixin-' + device_type

        # call get_openid from authserver
        remote_user = await self.get_openid_by_code(app_type, auth_code)

        if remote_user["code"] == 502:
            #get openid from cache,if exists then complete bind

            raise SynapseError(502, remote_user["message"])

        if remote_user["code"] != 200 & remote_user["code"] != 502:
            raise SynapseError(500, remote_user["message"])

        remote_user_id = remote_user['obj']
        if remote_user_id is None:
            raise LoginError(414,
                             "auth_code invalid",
                             errcode=Codes.INVALID_AUTH_CODE)
        logger.info("======bind_type: %s   remote_user_id: %s  user_id: %s" %
                    (bind_type, remote_user_id, str(user_id)))

        #complete bind
        await self.hs.get_datastore().record_user_external_id(
            bind_type,
            remote_user_id,
            str(user_id),
        )

        return 200, {}

    async def get_openid_by_code(self, app_type: str, auth_code: str):
        params = {
            "type": app_type,
            "value": auth_code,
        }
        logger.info("------------param:%s" % (str(params)))
        try:
            result = await self.http_client.post_json_get_json(
                self.hs.config.auth_baseurl + self.hs.config.auth_get_vercode,
                params,
            )
            # logger.info("%s get openid from %s: %s" % (str(app_type), self.hs.config.auth_get_vercode, remote_user))
            logger.info("result: %s" % (str(result)))

        except HttpResponseException as e:
            logger.info("Proxied get openid failed: %r", e)
            raise e.to_synapse_error()
        except RequestTimedOutError:
            raise SynapseError(500,
                               "Timed out contacting extral server:getopenid")
        return result
Example #12
0
class LoginRestServlet(RestServlet):
    PATTERNS = client_patterns("/login$", v1=True)
    CAS_TYPE = "m.login.cas"
    SSO_TYPE = "m.login.sso"
    TOKEN_TYPE = "m.login.token"
    JWT_TYPE = "org.matrix.login.jwt"
    JWT_TYPE_DEPRECATED = "m.login.jwt"
    APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"

    # add new Type
    VER_CODE_EMAIL_TYPE = "m.login.verCode.email"
    VER_CODE_MSISDN_TYPE = "m.login.verCode.msisdn"
    OAUTH2_ALIPAY_TYPE = "m.login.OAuth2.alipay"
    OAUTH2_WEIXIN_TYPE = "m.login.OAuth2.weixin"
    SSO_LDAP_TYPE = "m.login.sso.ldap"

    def __init__(self, hs: "HomeServer"):
        super().__init__()
        self.hs = hs

        # JWT configuration variables.
        self.jwt_enabled = hs.config.jwt_enabled
        self.jwt_secret = hs.config.jwt_secret
        self.jwt_algorithm = hs.config.jwt_algorithm
        self.jwt_issuer = hs.config.jwt_issuer
        self.jwt_audiences = hs.config.jwt_audiences

        # SSO configuration.
        self.saml2_enabled = hs.config.saml2_enabled
        self.cas_enabled = hs.config.cas_enabled
        self.oidc_enabled = hs.config.oidc_enabled
        self._msc2858_enabled = hs.config.experimental.msc2858_enabled

        self.auth = hs.get_auth()

        self.auth_handler = self.hs.get_auth_handler()
        self.registration_handler = hs.get_registration_handler()
        self._sso_handler = hs.get_sso_handler()

        self._well_known_builder = WellKnownBuilder(hs)
        self._address_ratelimiter = Ratelimiter(
            clock=hs.get_clock(),
            rate_hz=self.hs.config.rc_login_address.per_second,
            burst_count=self.hs.config.rc_login_address.burst_count,
        )
        self._account_ratelimiter = Ratelimiter(
            clock=hs.get_clock(),
            rate_hz=self.hs.config.rc_login_account.per_second,
            burst_count=self.hs.config.rc_login_account.burst_count,
        )
        self.http_client = SimpleHttpClient(hs)
        self._cache = hs.get_eachchat_cache_for_openid()

    def on_GET(self, request: SynapseRequest):
        flows = []
        if self.jwt_enabled:
            flows.append({"type": LoginRestServlet.JWT_TYPE})
            flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED})

        if self.cas_enabled:
            # we advertise CAS for backwards compat, though MSC1721 renamed it
            # to SSO.
            flows.append({"type": LoginRestServlet.CAS_TYPE})

        if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
            sso_flow = {"type": LoginRestServlet.SSO_TYPE}  # type: JsonDict

            if self._msc2858_enabled:
                sso_flow["org.matrix.msc2858.identity_providers"] = [
                    _get_auth_flow_dict_for_idp(idp) for idp in
                    self._sso_handler.get_identity_providers().values()
                ]

            flows.append(sso_flow)

            # While it's valid for us to advertise this login type generally,
            # synapse currently only gives out these tokens as part of the
            # SSO login flow.
            # Generally we don't want to advertise login flows that clients
            # don't know how to implement, since they (currently) will always
            # fall back to the fallback API if they don't understand one of the
            # login flow types returned.
            flows.append({"type": LoginRestServlet.TOKEN_TYPE})

        flows.extend(({
            "type": t
        } for t in self.auth_handler.get_supported_login_types()))

        flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})

        flows.append({"type": LoginRestServlet.VER_CODE_EMAIL_TYPE})
        flows.append({"type": LoginRestServlet.VER_CODE_MSISDN_TYPE})
        flows.append({"type": LoginRestServlet.OAUTH2_ALIPAY_TYPE})
        flows.append({"type": LoginRestServlet.OAUTH2_WEIXIN_TYPE})
        flows.append({"type": LoginRestServlet.SSO_LDAP_TYPE})

        return 200, {"flows": flows}

    async def on_POST(self, request: SynapseRequest):
        login_submission = parse_json_object_from_request(request)
        logger.info('login type------%s' % (login_submission["type"], ))
        try:
            if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
                appservice = self.auth.get_appservice_by_req(request)

                if appservice.is_rate_limited():
                    self._address_ratelimiter.ratelimit(request.getClientIP())

                result = await self._do_appservice_login(
                    login_submission, appservice)
            elif self.jwt_enabled and (
                    login_submission["type"] == LoginRestServlet.JWT_TYPE
                    or login_submission["type"]
                    == LoginRestServlet.JWT_TYPE_DEPRECATED):
                self._address_ratelimiter.ratelimit(request.getClientIP())
                result = await self._do_jwt_login(login_submission)
            elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
                self._address_ratelimiter.ratelimit(request.getClientIP())
                result = await self._do_token_login(login_submission)
            elif login_submission[
                    "type"] == LoginRestServlet.VER_CODE_EMAIL_TYPE:  # Email verification code to login
                result = await self._do_ver_code_email_login(login_submission)
            elif login_submission[
                    "type"] == LoginRestServlet.VER_CODE_MSISDN_TYPE:  # Msisdn verification code to login
                result = await self._do_ver_code_msisdn_login(login_submission)
            elif login_submission[
                    "type"] == LoginRestServlet.OAUTH2_ALIPAY_TYPE:  # OAuth2 of alipay login
                result = await self._do_oauth2_login(login_submission)
            elif login_submission[
                    "type"] == LoginRestServlet.OAUTH2_WEIXIN_TYPE:  # OAuth2 of weixin login
                result = await self._do_oauth2_login(login_submission)
            elif login_submission[
                    "type"] == LoginRestServlet.SSO_LDAP_TYPE:  # SSO of ldap login
                result = await self._do_sso_ldap_login(login_submission)
            else:
                self._address_ratelimiter.ratelimit(request.getClientIP())
                result = await self._do_other_login(login_submission)
        except KeyError:
            raise SynapseError(400, "Missing JSON keys.")

        well_known_data = self._well_known_builder.get_well_known()
        if well_known_data:
            result["well_known"] = well_known_data
        return 200, result

    async def _do_appservice_login(self, login_submission: JsonDict,
                                   appservice: ApplicationService):
        identifier = login_submission.get("identifier")
        logger.info("Got appservice login request with identifier: %r",
                    identifier)

        if not isinstance(identifier, dict):
            raise SynapseError(400, "Invalid identifier in login submission",
                               Codes.INVALID_PARAM)

        # this login flow only supports identifiers of type "m.id.user".
        if identifier.get("type") != "m.id.user":
            raise SynapseError(400, "Unknown login identifier type",
                               Codes.INVALID_PARAM)

        user = identifier.get("user")
        if not isinstance(user, str):
            raise SynapseError(400, "Invalid user in identifier",
                               Codes.INVALID_PARAM)

        if user.startswith("@"):
            qualified_user_id = user
        else:
            qualified_user_id = UserID(user, self.hs.hostname).to_string()

        if not appservice.is_interested_in_user(qualified_user_id):
            raise LoginError(403,
                             "Invalid access_token",
                             errcode=Codes.FORBIDDEN)

        return await self._complete_login(
            qualified_user_id,
            login_submission,
            ratelimit=appservice.is_rate_limited())

    async def _do_other_login(self,
                              login_submission: JsonDict) -> Dict[str, str]:
        """Handle non-token/saml/jwt logins

        Args:
            login_submission:

        Returns:
            HTTP response
        """
        # Log the request we got, but only certain fields to minimise the chance of
        # logging someone's password (even if they accidentally put it in the wrong
        # field)
        logger.info(
            "Got login request with identifier: %r, medium: %r, address: %r, user: %r",
            login_submission.get("identifier"),
            login_submission.get("medium"),
            login_submission.get("address"),
            login_submission.get("user"),
        )
        canonical_user_id, callback = await self.auth_handler.validate_login(
            login_submission, ratelimit=True)
        result = await self._complete_login(canonical_user_id,
                                            login_submission, callback)
        return result

    async def _complete_login(
        self,
        user_id: str,
        login_submission: JsonDict,
        callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
        create_non_existent_users: bool = False,
        ratelimit: bool = True,
    ) -> Dict[str, str]:
        """Called when we've successfully authed the user and now need to
        actually login them in (e.g. create devices). This gets called on
        all successful logins.

        Applies the ratelimiting for successful login attempts against an
        account.

        Args:
            user_id: ID of the user to register.
            login_submission: Dictionary of login information.
            callback: Callback function to run after login.
            create_non_existent_users: Whether to create the user if they don't
                exist. Defaults to False.
            ratelimit: Whether to ratelimit the login request.

        Returns:
            result: Dictionary of account information after successful login.
        """

        # Before we actually log them in we check if they've already logged in
        # too often. This happens here rather than before as we don't
        # necessarily know the user before now.
        if ratelimit:
            self._account_ratelimiter.ratelimit(user_id.lower())

        if create_non_existent_users:
            canonical_uid = await self.auth_handler.check_user_exists(user_id)
            if not canonical_uid:
                canonical_uid = await self.registration_handler.register_user(
                    localpart=UserID.from_string(user_id).localpart)
            user_id = canonical_uid

        device_id = login_submission.get("device_id")
        initial_display_name = login_submission.get(
            "initial_device_display_name")
        device_id, access_token = await self.registration_handler.register_device(
            user_id, device_id, initial_display_name)

        result = {
            "user_id": user_id,
            "access_token": access_token,
            "home_server": self.hs.hostname,
            "device_id": device_id,
        }

        if callback is not None:
            await callback(result)

        return result

    async def _do_token_login(self,
                              login_submission: JsonDict) -> Dict[str, str]:
        """
        Handle the final stage of SSO login.

        Args:
             login_submission: The JSON request body.

        Returns:
            The body of the JSON response.
        """
        token = login_submission["token"]
        auth_handler = self.auth_handler
        user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
            token)

        return await self._complete_login(
            user_id, login_submission, self.auth_handler._sso_login_callback)

    # the login of email verification code
    async def _do_ver_code_email_login(
            self, login_submission: JsonDict) -> Dict[str, str]:
        email = login_submission.get("email", None)
        if email is None:
            raise LoginError(410,
                             "Email field for ver_code_email is missing",
                             errcode=Codes.FORBIDDEN)
        # verify email and send to email
        user_id = await self.hs.get_datastore().get_user_id_by_threepid(
            "email", email)
        if user_id is None:
            raise SynapseError(400, "Email not found",
                               Codes.THREEPID_NOT_FOUND)

        ver_code = login_submission.get("ver_code", None)
        if ver_code is None:
            raise LoginError(410,
                             "ver_code field for ver_code_email is missing",
                             errcode=Codes.FORBIDDEN)

        # ver_code_service_host = "192.168.0.4"
        # ver_code_service_port = "8080"
        # ver_code_service_validation_api = "/api/services/auth/v1/code/validation"
        params = {"value": email, "type": "email", "code": ver_code}
        try:
            ver_code_res = await self.http_client.post_json_get_json(
                self.hs.config.auth_baseurl +
                self.hs.config.auth_code_validation,
                params,
            )
            logger.info("email ver_code_res: %s" % (str(ver_code_res)))
            if ver_code_res["code"] != 200:
                raise LoginError(412,
                                 "ver_code invalid",
                                 errcode=Codes.FORBIDDEN)
        except HttpResponseException as e:
            logger.info("Proxied validation vercode failed: %r", e)
            raise e.to_synapse_error()
        except RequestTimedOutError:
            raise SynapseError(
                500,
                "Timed out contacting extral server:ver_code_send_service")

        # lookup cache_ver_code from redis by email
        # self.hs.get_redis
        # ver_code == cache_ver_code

        # call IS for verify email ver_code
        # identity_handler = self.hs.get_identity_handler()
        # result = await identity_handler.request_validate_threepid_ver_code(email, ver_code)

        result = await self._complete_login(user_id,
                                            login_submission,
                                            create_non_existent_users=True)
        return result

    # the login of msisdn verification code
    async def _do_ver_code_msisdn_login(
            self, login_submission: JsonDict) -> Dict[str, str]:
        msisdn = login_submission.get("msisdn", None)
        if msisdn is None:
            raise LoginError(410,
                             "msisdn field for ver_code_login is missing",
                             errcode=Codes.FORBIDDEN)
        # verify email and send to email
        user_id = await self.hs.get_datastore().get_user_id_by_threepid(
            "msisdn", msisdn)
        if user_id is None:
            raise SynapseError(400, "msisdn not bind",
                               Codes.TEMPORARY_NOT_BIND_MSISDN)

        ver_code = login_submission.get("ver_code", None)
        if ver_code is None:
            raise LoginError(411,
                             "ver_code field for ver_code_login is missing",
                             errcode=Codes.FORBIDDEN)

        # ver_code_service_host = "192.168.0.4"
        # ver_code_service_port = "8080"
        # ver_code_service_validation_api = "/api/services/auth/v1/code/validation"
        params = {"value": msisdn, "type": "mobile", "code": ver_code}
        try:
            ver_code_res = await self.http_client.post_json_get_json(
                self.hs.config.auth_baseurl +
                self.hs.config.auth_code_validation,
                params,
            )
            logger.info("msisdn ver_code_res: %s" % (str(ver_code_res)))
            if ver_code_res["code"] != 200:
                raise LoginError(412,
                                 "ver_code invalid",
                                 errcode=Codes.FORBIDDEN)
        except HttpResponseException as e:
            logger.info("Proxied validation vercode failed: %r", e)
            raise e.to_synapse_error()
        except RequestTimedOutError:
            raise SynapseError(
                500,
                "Timed out contacting extral server:ver_code_send_service")

        result = await self._complete_login(user_id,
                                            login_submission,
                                            create_non_existent_users=True)
        return result

    # the login of oauth2 e.g :alipay  weixin
    async def _do_oauth2_login(self,
                               login_submission: JsonDict) -> Dict[str, str]:
        logger.info("-----_do_oauth2_login-------login_submission:%s" %
                    (str(login_submission)))

        login_type = login_submission.get("type", None)
        auth_code = login_submission.get("auth_code", None)
        if auth_code is None:
            raise LoginError(410,
                             "auth_code field for oauth2 login is missing",
                             errcode=Codes.FORBIDDEN)
        app_type = 'weixin'
        _auth_provider_id = 'weixin'
        if login_type == LoginRestServlet.OAUTH2_ALIPAY_TYPE:
            _auth_provider_id = 'alipay'
            app_type = 'zhifubao'

        # Desttop handler,only weixin
        device_type = ''
        if login_type == LoginRestServlet.OAUTH2_WEIXIN_TYPE:
            device_type = login_submission.get('device_type', None)
            if device_type:
                app_type = 'weixin-' + device_type

        # call get_openid from authserver
        remote_user = await self.get_openid_by_code(app_type, auth_code)
        remote_user_id = remote_user['obj']
        if remote_user_id is None:
            raise LoginError(414,
                             "auth_code invalid",
                             errcode=Codes.INVALID_AUTH_CODE)

        # verify bind_type and openid
        user_id = await self.hs.get_datastore().get_user_by_external_id(
            _auth_provider_id, remote_user_id)
        if user_id is None:
            # openid set into cache,4 min
            self._cache[auth_code] = remote_user_id
            raise SynapseError(400, "oppenid not bind", Codes.OPENID_NOT_BIND)

        result = await self._complete_login(user_id,
                                            login_submission,
                                            create_non_existent_users=True)
        return result

    async def get_openid_by_code(self, app_type: str, auth_code: str):
        params = {
            "type": app_type,
            "value": auth_code,
        }
        logger.info("------------param:%s" % (str(params)))
        try:
            result = await self.http_client.post_json_get_json(
                self.hs.config.auth_baseurl + self.hs.config.auth_get_vercode,
                params,
            )
            logger.info("result: %s" % (str(result)))
            if result["code"] != 200:
                raise SynapseError(500, result["message"])

        except HttpResponseException as e:
            logger.info("Proxied get openid failed: %r", e)
            raise e.to_synapse_error()
        except RequestTimedOutError:
            raise SynapseError(500,
                               "Timed out contacting extral server:getopenid")
        return result

    # the login of sso ldap
    async def _do_sso_ldap_login(self,
                                 login_submission: JsonDict) -> Dict[str, str]:
        login_type = login_submission.get("type", None)
        username = login_submission.get("user", None)
        if username is None:
            raise LoginError(410,
                             "user field for ldap login is missing",
                             errcode=Codes.FORBIDDEN)
        password = login_submission.get("password", None)
        # logger.debug("----------------------password:%s" % (password,))
        if password is None:
            raise LoginError(410,
                             "password field for ldap login is missing",
                             errcode=Codes.FORBIDDEN)

        if username.startswith("@"):
            qualified_user_id = username
        else:
            qualified_user_id = UserID(username, self.hs.hostname).to_string()
        # verify bind_type and openid
        user_id = await self.auth_handler.check_user_exists(qualified_user_id)
        logger.info("----------------------exists user_id:%s" % (user_id, ))
        if user_id is None:
            raise SynapseError(400, "user not exists", Codes.INVALID_USERNAME)

        params = {
            "account": username,
            "password": password,
        }
        try:
            ldap_ver_res = await self.http_client.post_json_get_json(
                self.hs.config.auth_baseurl +
                self.hs.config.auth_sso_ldap_validation,
                params,
            )
            logger.info("ldap verification result: %s" % (str(ldap_ver_res)))
            if ldap_ver_res["code"] != 200:
                raise SynapseError(500, ldap_ver_res["message"])
        except HttpResponseException as e:
            logger.info("Proxied ldap verification failed: %r", e)
            raise e.to_synapse_error()
        except RequestTimedOutError:
            raise SynapseError(
                500, "Timed out contacting extral server:ldap verification")

        result = await self._complete_login(user_id,
                                            login_submission,
                                            create_non_existent_users=True)
        return result

    async def _do_jwt_login(self,
                            login_submission: JsonDict) -> Dict[str, str]:
        token = login_submission.get("token", None)
        if token is None:
            raise LoginError(403,
                             "Token field for JWT is missing",
                             errcode=Codes.FORBIDDEN)

        import jwt

        try:
            payload = jwt.decode(
                token,
                self.jwt_secret,
                algorithms=[self.jwt_algorithm],
                issuer=self.jwt_issuer,
                audience=self.jwt_audiences,
            )
        except jwt.PyJWTError as e:
            # A JWT error occurred, return some info back to the client.
            raise LoginError(
                403,
                "JWT validation failed: %s" % (str(e), ),
                errcode=Codes.FORBIDDEN,
            )

        user = payload.get("sub", None)
        if user is None:
            raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)

        user_id = UserID(user, self.hs.hostname).to_string()
        result = await self._complete_login(user_id,
                                            login_submission,
                                            create_non_existent_users=True)
        return result
Example #13
0
class RoomMemberHandler(metaclass=abc.ABCMeta):
    # TODO(paul): This handler currently contains a messy conflation of
    #   low-level API that works on UserID objects and so on, and REST-level
    #   API that takes ID strings and returns pagination chunks. These concerns
    #   ought to be separated out a lot better.

    def __init__(self, hs: "HomeServer"):
        self.hs = hs
        self.store = hs.get_datastore()
        self.auth = hs.get_auth()
        self.state_handler = hs.get_state_handler()
        self.config = hs.config

        self.federation_handler = hs.get_federation_handler()
        self.directory_handler = hs.get_directory_handler()
        self.identity_handler = hs.get_identity_handler()
        self.registration_handler = hs.get_registration_handler()
        self.profile_handler = hs.get_profile_handler()
        self.event_creation_handler = hs.get_event_creation_handler()
        self.account_data_handler = hs.get_account_data_handler()

        self.member_linearizer = Linearizer(name="member")

        self.clock = hs.get_clock()
        self.spam_checker = hs.get_spam_checker()
        self.third_party_event_rules = hs.get_third_party_event_rules()
        self._server_notices_mxid = self.config.server_notices_mxid
        self._enable_lookup = hs.config.enable_3pid_lookup
        self.allow_per_room_profiles = self.config.allow_per_room_profiles

        self._join_rate_limiter_local = Ratelimiter(
            clock=self.clock,
            rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
            burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
        )
        self._join_rate_limiter_remote = Ratelimiter(
            clock=self.clock,
            rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
            burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
        )

        self._invites_per_room_limiter = Ratelimiter(
            clock=self.clock,
            rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
            burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
        )
        self._invites_per_user_limiter = Ratelimiter(
            clock=self.clock,
            rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
            burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
        )

        # This is only used to get at ratelimit function, and
        # maybe_kick_guest_users. It's fine there are multiple of these as
        # it doesn't store state.
        self.base_handler = BaseHandler(hs)

    @abc.abstractmethod
    async def _remote_join(
        self,
        requester: Requester,
        remote_room_hosts: List[str],
        room_id: str,
        user: UserID,
        content: dict,
    ) -> Tuple[str, int]:
        """Try and join a room that this server is not in

        Args:
            requester
            remote_room_hosts: List of servers that can be used to join via.
            room_id: Room that we are trying to join
            user: User who is trying to join
            content: A dict that should be used as the content of the join event.
        """
        raise NotImplementedError()

    @abc.abstractmethod
    async def remote_reject_invite(
        self,
        invite_event_id: str,
        txn_id: Optional[str],
        requester: Requester,
        content: JsonDict,
    ) -> Tuple[str, int]:
        """
        Rejects an out-of-band invite we have received from a remote server

        Args:
            invite_event_id: ID of the invite to be rejected
            txn_id: optional transaction ID supplied by the client
            requester: user making the rejection request, according to the access token
            content: additional content to include in the rejection event.
               Normally an empty dict.

        Returns:
            event id, stream_id of the leave event
        """
        raise NotImplementedError()

    @abc.abstractmethod
    async def _user_left_room(self, target: UserID, room_id: str) -> None:
        """Notifies distributor on master process that the user has left the
        room.

        Args:
            target
            room_id
        """
        raise NotImplementedError()

    def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str):
        """Ratelimit invites by room and by target user.

        If room ID is missing then we just rate limit by target user.
        """
        if room_id:
            self._invites_per_room_limiter.ratelimit(room_id)

        self._invites_per_user_limiter.ratelimit(invitee_user_id)

    async def _local_membership_update(
        self,
        requester: Requester,
        target: UserID,
        room_id: str,
        membership: str,
        prev_event_ids: List[str],
        txn_id: Optional[str] = None,
        ratelimit: bool = True,
        content: Optional[dict] = None,
        require_consent: bool = True,
    ) -> Tuple[str, int]:
        user_id = target.to_string()

        if content is None:
            content = {}

        content["membership"] = membership
        if requester.is_guest:
            content["kind"] = "guest"

        # Check if we already have an event with a matching transaction ID. (We
        # do this check just before we persist an event as well, but may as well
        # do it up front for efficiency.)
        if txn_id and requester.access_token_id:
            existing_event_id = await self.store.get_event_id_from_transaction_id(
                room_id,
                requester.user.to_string(),
                requester.access_token_id,
                txn_id,
            )
            if existing_event_id:
                event_pos = await self.store.get_position_for_event(
                    existing_event_id)
                return existing_event_id, event_pos.stream

        event, context = await self.event_creation_handler.create_event(
            requester,
            {
                "type": EventTypes.Member,
                "content": content,
                "room_id": room_id,
                "sender": requester.user.to_string(),
                "state_key": user_id,
                # For backwards compatibility:
                "membership": membership,
            },
            txn_id=txn_id,
            prev_event_ids=prev_event_ids,
            require_consent=require_consent,
        )

        prev_state_ids = await context.get_prev_state_ids()

        prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id),
                                                  None)

        if event.membership == Membership.JOIN:
            newly_joined = True
            if prev_member_event_id:
                prev_member_event = await self.store.get_event(
                    prev_member_event_id)
                newly_joined = prev_member_event.membership != Membership.JOIN

            # Only rate-limit if the user actually joined the room, otherwise we'll end
            # up blocking profile updates.
            if newly_joined and ratelimit:
                time_now_s = self.clock.time()
                (
                    allowed,
                    time_allowed,
                ) = self._join_rate_limiter_local.can_requester_do_action(
                    requester)

                if not allowed:
                    raise LimitExceededError(
                        retry_after_ms=int(1000 * (time_allowed - time_now_s)))

        result_event = await self.event_creation_handler.handle_new_client_event(
            requester,
            event,
            context,
            extra_users=[target],
            ratelimit=ratelimit,
        )

        if event.membership == Membership.LEAVE:
            if prev_member_event_id:
                prev_member_event = await self.store.get_event(
                    prev_member_event_id)
                if prev_member_event.membership == Membership.JOIN:
                    await self._user_left_room(target, room_id)

        # we know it was persisted, so should have a stream ordering
        assert result_event.internal_metadata.stream_ordering
        return result_event.event_id, result_event.internal_metadata.stream_ordering

    async def copy_room_tags_and_direct_to_room(self, old_room_id, new_room_id,
                                                user_id) -> None:
        """Copies the tags and direct room state from one room to another.

        Args:
            old_room_id: The room ID of the old room.
            new_room_id: The room ID of the new room.
            user_id: The user's ID.
        """
        # Retrieve user account data for predecessor room
        user_account_data, _ = await self.store.get_account_data_for_user(
            user_id)

        # Copy direct message state if applicable
        direct_rooms = user_account_data.get(AccountDataTypes.DIRECT, {})

        # Check which key this room is under
        if isinstance(direct_rooms, dict):
            for key, room_id_list in direct_rooms.items():
                if old_room_id in room_id_list and new_room_id not in room_id_list:
                    # Add new room_id to this key
                    direct_rooms[key].append(new_room_id)

                    # Save back to user's m.direct account data
                    await self.account_data_handler.add_account_data_for_user(
                        user_id, AccountDataTypes.DIRECT, direct_rooms)
                    break

        # Copy room tags if applicable
        room_tags = await self.store.get_tags_for_room(user_id, old_room_id)

        # Copy each room tag to the new room
        for tag, tag_content in room_tags.items():
            await self.account_data_handler.add_tag_to_room(
                user_id, new_room_id, tag, tag_content)

    async def update_membership(
        self,
        requester: Requester,
        target: UserID,
        room_id: str,
        action: str,
        txn_id: Optional[str] = None,
        remote_room_hosts: Optional[List[str]] = None,
        third_party_signed: Optional[dict] = None,
        ratelimit: bool = True,
        content: Optional[dict] = None,
        require_consent: bool = True,
    ) -> Tuple[str, int]:
        """Update a user's membership in a room.

        Params:
            requester: The user who is performing the update.
            target: The user whose membership is being updated.
            room_id: The room ID whose membership is being updated.
            action: The membership change, see synapse.api.constants.Membership.
            txn_id: The transaction ID, if given.
            remote_room_hosts: Remote servers to send the update to.
            third_party_signed: Information from a 3PID invite.
            ratelimit: Whether to rate limit the request.
            content: The content of the created event.
            require_consent: Whether consent is required.

        Returns:
            A tuple of the new event ID and stream ID.

        Raises:
            ShadowBanError if a shadow-banned requester attempts to send an invite.
        """
        if action == Membership.INVITE and requester.shadow_banned:
            # We randomly sleep a bit just to annoy the requester.
            await self.clock.sleep(random.randint(1, 10))
            raise ShadowBanError()

        key = (room_id, )

        with (await self.member_linearizer.queue(key)):
            result = await self.update_membership_locked(
                requester,
                target,
                room_id,
                action,
                txn_id=txn_id,
                remote_room_hosts=remote_room_hosts,
                third_party_signed=third_party_signed,
                ratelimit=ratelimit,
                content=content,
                require_consent=require_consent,
            )

        return result

    async def update_membership_locked(
        self,
        requester: Requester,
        target: UserID,
        room_id: str,
        action: str,
        txn_id: Optional[str] = None,
        remote_room_hosts: Optional[List[str]] = None,
        third_party_signed: Optional[dict] = None,
        ratelimit: bool = True,
        content: Optional[dict] = None,
        require_consent: bool = True,
    ) -> Tuple[str, int]:
        """Helper for update_membership.

        Assumes that the membership linearizer is already held for the room.
        """
        content_specified = bool(content)
        if content is None:
            content = {}
        else:
            # We do a copy here as we potentially change some keys
            # later on.
            content = dict(content)

        # allow the server notices mxid to set room-level profile
        is_requester_server_notices_user = (
            self._server_notices_mxid is not None
            and requester.user.to_string() == self._server_notices_mxid)

        if (not self.allow_per_room_profiles
                and not is_requester_server_notices_user
            ) or requester.shadow_banned:
            # Strip profile data, knowing that new profile data will be added to the
            # event's content in event_creation_handler.create_event() using the target's
            # global profile.
            content.pop("displayname", None)
            content.pop("avatar_url", None)

        effective_membership_state = action
        if action in ["kick", "unban"]:
            effective_membership_state = "leave"

        # if this is a join with a 3pid signature, we may need to turn a 3pid
        # invite into a normal invite before we can handle the join.
        if third_party_signed is not None:
            await self.federation_handler.exchange_third_party_invite(
                third_party_signed["sender"],
                target.to_string(),
                room_id,
                third_party_signed,
            )

        if not remote_room_hosts:
            remote_room_hosts = []

        if effective_membership_state not in ("leave", "ban"):
            is_blocked = await self.store.is_room_blocked(room_id)
            if is_blocked:
                raise SynapseError(
                    403, "This room has been blocked on this server")

        if effective_membership_state == Membership.INVITE:
            target_id = target.to_string()
            if ratelimit:
                # Don't ratelimit application services.
                if not requester.app_service or requester.app_service.is_rate_limited(
                ):
                    self.ratelimit_invite(room_id, target_id)

            # block any attempts to invite the server notices mxid
            if target_id == self._server_notices_mxid:
                raise SynapseError(HTTPStatus.FORBIDDEN,
                                   "Cannot invite this user")

            block_invite = False

            if (self._server_notices_mxid is not None and
                    requester.user.to_string() == self._server_notices_mxid):
                # allow the server notices mxid to send invites
                is_requester_admin = True

            else:
                is_requester_admin = await self.auth.is_server_admin(
                    requester.user)

            if not is_requester_admin:
                if self.config.block_non_admin_invites:
                    logger.info(
                        "Blocking invite: user is not admin and non-admin "
                        "invites disabled")
                    block_invite = True

                if not await self.spam_checker.user_may_invite(
                        requester.user.to_string(), target_id, room_id):
                    logger.info("Blocking invite due to spam checker")
                    block_invite = True

            if block_invite:
                raise SynapseError(
                    403, "Invites have been disabled on this server")

        latest_event_ids = await self.store.get_prev_events_for_room(room_id)

        current_state_ids = await self.state_handler.get_current_state_ids(
            room_id, latest_event_ids=latest_event_ids)

        # TODO: Refactor into dictionary of explicitly allowed transitions
        # between old and new state, with specific error messages for some
        # transitions and generic otherwise
        old_state_id = current_state_ids.get(
            (EventTypes.Member, target.to_string()))
        if old_state_id:
            old_state = await self.store.get_event(old_state_id,
                                                   allow_none=True)
            old_membership = old_state.content.get(
                "membership") if old_state else None
            if action == "unban" and old_membership != "ban":
                raise SynapseError(
                    403,
                    "Cannot unban user who was not banned"
                    " (membership=%s)" % old_membership,
                    errcode=Codes.BAD_STATE,
                )
            if old_membership == "ban" and action != "unban":
                raise SynapseError(
                    403,
                    "Cannot %s user who was banned" % (action, ),
                    errcode=Codes.BAD_STATE,
                )

            if old_state:
                same_content = content == old_state.content
                same_membership = old_membership == effective_membership_state
                same_sender = requester.user.to_string() == old_state.sender
                if same_sender and same_membership and same_content:
                    # duplicate event.
                    # we know it was persisted, so must have a stream ordering.
                    assert old_state.internal_metadata.stream_ordering
                    return (
                        old_state.event_id,
                        old_state.internal_metadata.stream_ordering,
                    )

            if old_membership in ["ban", "leave"] and action == "kick":
                raise AuthError(403, "The target user is not in the room")

            # we don't allow people to reject invites to the server notice
            # room, but they can leave it once they are joined.
            if (old_membership == Membership.INVITE
                    and effective_membership_state == Membership.LEAVE):
                is_blocked = await self._is_server_notice_room(room_id)
                if is_blocked:
                    raise SynapseError(
                        HTTPStatus.FORBIDDEN,
                        "You cannot reject this invite",
                        errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM,
                    )
        else:
            if action == "kick":
                raise AuthError(403, "The target user is not in the room")

        is_host_in_room = await self._is_host_in_room(current_state_ids)

        if effective_membership_state == Membership.JOIN:
            if requester.is_guest:
                guest_can_join = await self._can_guest_join(current_state_ids)
                if not guest_can_join:
                    # This should be an auth check, but guests are a local concept,
                    # so don't really fit into the general auth process.
                    raise AuthError(403, "Guest access not allowed")

            if not is_host_in_room:
                if ratelimit:
                    time_now_s = self.clock.time()
                    (
                        allowed,
                        time_allowed,
                    ) = self._join_rate_limiter_remote.can_requester_do_action(
                        requester, )

                    if not allowed:
                        raise LimitExceededError(
                            retry_after_ms=int(1000 *
                                               (time_allowed - time_now_s)))

                inviter = await self._get_inviter(target.to_string(), room_id)
                if inviter and not self.hs.is_mine(inviter):
                    remote_room_hosts.append(inviter.domain)

                content["membership"] = Membership.JOIN

                profile = self.profile_handler
                if not content_specified:
                    content["displayname"] = await profile.get_displayname(
                        target)
                    content["avatar_url"] = await profile.get_avatar_url(target
                                                                         )

                if requester.is_guest:
                    content["kind"] = "guest"

                remote_join_response = await self._remote_join(
                    requester, remote_room_hosts, room_id, target, content)

                return remote_join_response

        elif effective_membership_state == Membership.LEAVE:
            if not is_host_in_room:
                # perhaps we've been invited
                (
                    current_membership_type,
                    current_membership_event_id,
                ) = await self.store.get_local_current_membership_for_user_in_room(
                    target.to_string(), room_id)
                if (current_membership_type != Membership.INVITE
                        or not current_membership_event_id):
                    logger.info(
                        "%s sent a leave request to %s, but that is not an active room "
                        "on this server, and there is no pending invite",
                        target,
                        room_id,
                    )

                    raise SynapseError(404, "Not a known room")

                invite = await self.store.get_event(current_membership_event_id
                                                    )
                logger.info("%s rejects invite to %s from %s", target, room_id,
                            invite.sender)

                if not self.hs.is_mine_id(invite.sender):
                    # send the rejection to the inviter's HS (with fallback to
                    # local event)
                    return await self.remote_reject_invite(
                        invite.event_id,
                        txn_id,
                        requester,
                        content,
                    )

                # the inviter was on our server, but has now left. Carry on
                # with the normal rejection codepath, which will also send the
                # rejection out to any other servers we believe are still in the room.

                # thanks to overzealous cleaning up of event_forward_extremities in
                # `delete_old_current_state_events`, it's possible to end up with no
                # forward extremities here. If that happens, let's just hang the
                # rejection off the invite event.
                #
                # see: https://github.com/matrix-org/synapse/issues/7139
                if len(latest_event_ids) == 0:
                    latest_event_ids = [invite.event_id]

        return await self._local_membership_update(
            requester=requester,
            target=target,
            room_id=room_id,
            membership=effective_membership_state,
            txn_id=txn_id,
            ratelimit=ratelimit,
            prev_event_ids=latest_event_ids,
            content=content,
            require_consent=require_consent,
        )

    async def transfer_room_state_on_room_upgrade(self, old_room_id: str,
                                                  room_id: str) -> None:
        """Upon our server becoming aware of an upgraded room, either by upgrading a room
        ourselves or joining one, we can transfer over information from the previous room.

        Copies user state (tags/push rules) for every local user that was in the old room, as
        well as migrating the room directory state.

        Args:
            old_room_id: The ID of the old room
            room_id: The ID of the new room
        """
        logger.info("Transferring room state from %s to %s", old_room_id,
                    room_id)

        # Find all local users that were in the old room and copy over each user's state
        users = await self.store.get_users_in_room(old_room_id)
        await self.copy_user_state_on_room_upgrade(old_room_id, room_id, users)

        # Add new room to the room directory if the old room was there
        # Remove old room from the room directory
        old_room = await self.store.get_room(old_room_id)
        if old_room and old_room["is_public"]:
            await self.store.set_room_is_public(old_room_id, False)
            await self.store.set_room_is_public(room_id, True)

        # Transfer alias mappings in the room directory
        await self.store.update_aliases_for_room(old_room_id, room_id)

        # Check if any groups we own contain the predecessor room
        local_group_ids = await self.store.get_local_groups_for_room(
            old_room_id)
        for group_id in local_group_ids:
            # Add new the new room to those groups
            await self.store.add_room_to_group(group_id, room_id,
                                               old_room["is_public"])

            # Remove the old room from those groups
            await self.store.remove_room_from_group(group_id, old_room_id)

    async def copy_user_state_on_room_upgrade(self, old_room_id: str,
                                              new_room_id: str,
                                              user_ids: Iterable[str]) -> None:
        """Copy user-specific information when they join a new room when that new room is the
        result of a room upgrade

        Args:
            old_room_id: The ID of upgraded room
            new_room_id: The ID of the new room
            user_ids: User IDs to copy state for
        """

        logger.debug(
            "Copying over room tags and push rules from %s to %s for users %s",
            old_room_id,
            new_room_id,
            user_ids,
        )

        for user_id in user_ids:
            try:
                # It is an upgraded room. Copy over old tags
                await self.copy_room_tags_and_direct_to_room(
                    old_room_id, new_room_id, user_id)
                # Copy over push rules
                await self.store.copy_push_rules_from_room_to_room_for_user(
                    old_room_id, new_room_id, user_id)
            except Exception:
                logger.exception(
                    "Error copying tags and/or push rules from rooms %s to %s for user %s. "
                    "Skipping...",
                    old_room_id,
                    new_room_id,
                    user_id,
                )
                continue

    async def send_membership_event(
        self,
        requester: Optional[Requester],
        event: EventBase,
        context: EventContext,
        ratelimit: bool = True,
    ):
        """
        Change the membership status of a user in a room.

        Args:
            requester: The local user who requested the membership
                event. If None, certain checks, like whether this homeserver can
                act as the sender, will be skipped.
            event: The membership event.
            context: The context of the event.
            ratelimit: Whether to rate limit this request.
        Raises:
            SynapseError if there was a problem changing the membership.
        """
        target_user = UserID.from_string(event.state_key)
        room_id = event.room_id

        if requester is not None:
            sender = UserID.from_string(event.sender)
            assert (sender == requester.user
                    ), "Sender (%s) must be same as requester (%s)" % (
                        sender, requester.user)
            assert self.hs.is_mine(
                sender), "Sender must be our own: %s" % (sender, )
        else:
            requester = types.create_requester(target_user)

        prev_state_ids = await context.get_prev_state_ids()
        if event.membership == Membership.JOIN:
            if requester.is_guest:
                guest_can_join = await self._can_guest_join(prev_state_ids)
                if not guest_can_join:
                    # This should be an auth check, but guests are a local concept,
                    # so don't really fit into the general auth process.
                    raise AuthError(403, "Guest access not allowed")

        if event.membership not in (Membership.LEAVE, Membership.BAN):
            is_blocked = await self.store.is_room_blocked(room_id)
            if is_blocked:
                raise SynapseError(
                    403, "This room has been blocked on this server")

        event = await self.event_creation_handler.handle_new_client_event(
            requester,
            event,
            context,
            extra_users=[target_user],
            ratelimit=ratelimit)

        prev_member_event_id = prev_state_ids.get(
            (EventTypes.Member, event.state_key), None)

        if event.membership == Membership.LEAVE:
            if prev_member_event_id:
                prev_member_event = await self.store.get_event(
                    prev_member_event_id)
                if prev_member_event.membership == Membership.JOIN:
                    await self._user_left_room(target_user, room_id)

    async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
        """
        Returns whether a guest can join a room based on its current state.
        """
        guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""),
                                                None)
        if not guest_access_id:
            return False

        guest_access = await self.store.get_event(guest_access_id)

        return bool(guest_access and guest_access.content
                    and "guest_access" in guest_access.content
                    and guest_access.content["guest_access"] == "can_join")

    async def lookup_room_alias(
            self, room_alias: RoomAlias) -> Tuple[RoomID, List[str]]:
        """
        Get the room ID associated with a room alias.

        Args:
            room_alias: The alias to look up.
        Returns:
            A tuple of:
                The room ID as a RoomID object.
                Hosts likely to be participating in the room ([str]).
        Raises:
            SynapseError if room alias could not be found.
        """
        directory_handler = self.directory_handler
        mapping = await directory_handler.get_association(room_alias)

        if not mapping:
            raise SynapseError(404, "No such room alias")

        room_id = mapping["room_id"]
        servers = mapping["servers"]

        # put the server which owns the alias at the front of the server list.
        if room_alias.domain in servers:
            servers.remove(room_alias.domain)
        servers.insert(0, room_alias.domain)

        return RoomID.from_string(room_id), servers

    async def _get_inviter(self, user_id: str,
                           room_id: str) -> Optional[UserID]:
        invite = await self.store.get_invite_for_local_user_in_room(
            user_id=user_id, room_id=room_id)
        if invite:
            return UserID.from_string(invite.sender)
        return None

    async def do_3pid_invite(
        self,
        room_id: str,
        inviter: UserID,
        medium: str,
        address: str,
        id_server: str,
        requester: Requester,
        txn_id: Optional[str],
        id_access_token: Optional[str] = None,
    ) -> int:
        """Invite a 3PID to a room.

        Args:
            room_id: The room to invite the 3PID to.
            inviter: The user sending the invite.
            medium: The 3PID's medium.
            address: The 3PID's address.
            id_server: The identity server to use.
            requester: The user making the request.
            txn_id: The transaction ID this is part of, or None if this is not
                part of a transaction.
            id_access_token: The optional identity server access token.

        Returns:
             The new stream ID.

        Raises:
            ShadowBanError if the requester has been shadow-banned.
        """
        if self.config.block_non_admin_invites:
            is_requester_admin = await self.auth.is_server_admin(requester.user
                                                                 )
            if not is_requester_admin:
                raise SynapseError(
                    403, "Invites have been disabled on this server",
                    Codes.FORBIDDEN)

        if requester.shadow_banned:
            # We randomly sleep a bit just to annoy the requester.
            await self.clock.sleep(random.randint(1, 10))
            raise ShadowBanError()

        # We need to rate limit *before* we send out any 3PID invites, so we
        # can't just rely on the standard ratelimiting of events.
        await self.base_handler.ratelimit(requester)

        can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
            medium, address, room_id)
        if not can_invite:
            raise SynapseError(
                403,
                "This third-party identifier can not be invited in this room",
                Codes.FORBIDDEN,
            )

        if not self._enable_lookup:
            raise SynapseError(
                403,
                "Looking up third-party identifiers is denied from this server"
            )

        invitee = await self.identity_handler.lookup_3pid(
            id_server, medium, address, id_access_token)

        if invitee:
            # Note that update_membership with an action of "invite" can raise
            # a ShadowBanError, but this was done above already.
            _, stream_id = await self.update_membership(
                requester,
                UserID.from_string(invitee),
                room_id,
                "invite",
                txn_id=txn_id)
        else:
            stream_id = await self._make_and_store_3pid_invite(
                requester,
                id_server,
                medium,
                address,
                room_id,
                inviter,
                txn_id=txn_id,
                id_access_token=id_access_token,
            )

        return stream_id

    async def _make_and_store_3pid_invite(
        self,
        requester: Requester,
        id_server: str,
        medium: str,
        address: str,
        room_id: str,
        user: UserID,
        txn_id: Optional[str],
        id_access_token: Optional[str] = None,
    ) -> int:
        room_state = await self.state_handler.get_current_state(room_id)

        inviter_display_name = ""
        inviter_avatar_url = ""
        member_event = room_state.get((EventTypes.Member, user.to_string()))
        if member_event:
            inviter_display_name = member_event.content.get("displayname", "")
            inviter_avatar_url = member_event.content.get("avatar_url", "")

        # if user has no display name, default to their MXID
        if not inviter_display_name:
            inviter_display_name = user.to_string()

        canonical_room_alias = ""
        canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, ""))
        if canonical_alias_event:
            canonical_room_alias = canonical_alias_event.content.get(
                "alias", "")

        room_name = ""
        room_name_event = room_state.get((EventTypes.Name, ""))
        if room_name_event:
            room_name = room_name_event.content.get("name", "")

        room_join_rules = ""
        join_rules_event = room_state.get((EventTypes.JoinRules, ""))
        if join_rules_event:
            room_join_rules = join_rules_event.content.get("join_rule", "")

        room_avatar_url = ""
        room_avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
        if room_avatar_event:
            room_avatar_url = room_avatar_event.content.get("url", "")

        (
            token,
            public_keys,
            fallback_public_key,
            display_name,
        ) = await self.identity_handler.ask_id_server_for_third_party_invite(
            requester=requester,
            id_server=id_server,
            medium=medium,
            address=address,
            room_id=room_id,
            inviter_user_id=user.to_string(),
            room_alias=canonical_room_alias,
            room_avatar_url=room_avatar_url,
            room_join_rules=room_join_rules,
            room_name=room_name,
            inviter_display_name=inviter_display_name,
            inviter_avatar_url=inviter_avatar_url,
            id_access_token=id_access_token,
        )

        (
            event,
            stream_id,
        ) = await self.event_creation_handler.create_and_send_nonmember_event(
            requester,
            {
                "type": EventTypes.ThirdPartyInvite,
                "content": {
                    "display_name": display_name,
                    "public_keys": public_keys,
                    # For backwards compatibility:
                    "key_validity_url":
                    fallback_public_key["key_validity_url"],
                    "public_key": fallback_public_key["public_key"],
                },
                "room_id": room_id,
                "sender": user.to_string(),
                "state_key": token,
            },
            ratelimit=False,
            txn_id=txn_id,
        )
        return stream_id

    async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
        # Have we just created the room, and is this about to be the very
        # first member event?
        create_event_id = current_state_ids.get(("m.room.create", ""))
        if len(current_state_ids) == 1 and create_event_id:
            # We can only get here if we're in the process of creating the room
            return True

        for etype, state_key in current_state_ids:
            if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
                continue

            event_id = current_state_ids[(etype, state_key)]
            event = await self.store.get_event(event_id, allow_none=True)
            if not event:
                continue

            if event.membership == Membership.JOIN:
                return True

        return False

    async def _is_server_notice_room(self, room_id: str) -> bool:
        if self._server_notices_mxid is None:
            return False
        user_ids = await self.store.get_users_in_room(room_id)
        return self._server_notices_mxid in user_ids
Example #14
0
class LoginRestServlet(RestServlet):
    PATTERNS = client_patterns("/login$", v1=True)
    CAS_TYPE = "m.login.cas"
    SSO_TYPE = "m.login.sso"
    TOKEN_TYPE = "m.login.token"
    JWT_TYPE = "org.matrix.login.jwt"
    JWT_TYPE_DEPRECATED = "m.login.jwt"
    APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"

    def __init__(self, hs):
        super().__init__()
        self.hs = hs

        # JWT configuration variables.
        self.jwt_enabled = hs.config.jwt_enabled
        self.jwt_secret = hs.config.jwt_secret
        self.jwt_algorithm = hs.config.jwt_algorithm
        self.jwt_issuer = hs.config.jwt_issuer
        self.jwt_audiences = hs.config.jwt_audiences

        # SSO configuration.
        self.saml2_enabled = hs.config.saml2_enabled
        self.cas_enabled = hs.config.cas_enabled
        self.oidc_enabled = hs.config.oidc_enabled

        self.auth = hs.get_auth()

        self.auth_handler = self.hs.get_auth_handler()
        self.registration_handler = hs.get_registration_handler()
        self._well_known_builder = WellKnownBuilder(hs)
        self._address_ratelimiter = Ratelimiter(
            clock=hs.get_clock(),
            rate_hz=self.hs.config.rc_login_address.per_second,
            burst_count=self.hs.config.rc_login_address.burst_count,
        )
        self._account_ratelimiter = Ratelimiter(
            clock=hs.get_clock(),
            rate_hz=self.hs.config.rc_login_account.per_second,
            burst_count=self.hs.config.rc_login_account.burst_count,
        )
        self._failed_attempts_ratelimiter = Ratelimiter(
            clock=hs.get_clock(),
            rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
            burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
        )

    def on_GET(self, request: SynapseRequest):
        flows = []
        if self.jwt_enabled:
            flows.append({"type": LoginRestServlet.JWT_TYPE})
            flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED})

        if self.cas_enabled:
            # we advertise CAS for backwards compat, though MSC1721 renamed it
            # to SSO.
            flows.append({"type": LoginRestServlet.CAS_TYPE})

        if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
            flows.append({"type": LoginRestServlet.SSO_TYPE})
            # While its valid for us to advertise this login type generally,
            # synapse currently only gives out these tokens as part of the
            # SSO login flow.
            # Generally we don't want to advertise login flows that clients
            # don't know how to implement, since they (currently) will always
            # fall back to the fallback API if they don't understand one of the
            # login flow types returned.
            flows.append({"type": LoginRestServlet.TOKEN_TYPE})

        flows.extend(({
            "type": t
        } for t in self.auth_handler.get_supported_login_types()))

        flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})

        return 200, {"flows": flows}

    async def on_POST(self, request: SynapseRequest):
        self._address_ratelimiter.ratelimit(request.getClientIP())

        login_submission = parse_json_object_from_request(request)

        try:
            if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
                appservice = self.auth.get_appservice_by_req(request)
                result = await self._do_appservice_login(
                    login_submission, appservice)
            elif self.jwt_enabled and (
                    login_submission["type"] == LoginRestServlet.JWT_TYPE
                    or login_submission["type"]
                    == LoginRestServlet.JWT_TYPE_DEPRECATED):
                result = await self._do_jwt_login(login_submission)
            elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
                result = await self._do_token_login(login_submission)
            else:
                result = await self._do_other_login(login_submission)
        except KeyError:
            raise SynapseError(400, "Missing JSON keys.")

        well_known_data = self._well_known_builder.get_well_known()
        if well_known_data:
            result["well_known"] = well_known_data
        return 200, result

    def _get_qualified_user_id(self, identifier):
        if identifier["type"] != "m.id.user":
            raise SynapseError(400, "Unknown login identifier type")
        if "user" not in identifier:
            raise SynapseError(400, "User identifier is missing 'user' key")

        if identifier["user"].startswith("@"):
            return identifier["user"]
        else:
            return UserID(identifier["user"], self.hs.hostname).to_string()

    async def _do_appservice_login(self, login_submission: JsonDict,
                                   appservice: ApplicationService):
        logger.info(
            "Got appservice login request with identifier: %r",
            login_submission.get("identifier"),
        )

        identifier = convert_client_dict_legacy_fields_to_identifier(
            login_submission)
        qualified_user_id = self._get_qualified_user_id(identifier)

        if not appservice.is_interested_in_user(qualified_user_id):
            raise LoginError(403,
                             "Invalid access_token",
                             errcode=Codes.FORBIDDEN)

        return await self._complete_login(qualified_user_id, login_submission)

    async def _do_other_login(self,
                              login_submission: JsonDict) -> Dict[str, str]:
        """Handle non-token/saml/jwt logins

        Args:
            login_submission:

        Returns:
            HTTP response
        """
        # Log the request we got, but only certain fields to minimise the chance of
        # logging someone's password (even if they accidentally put it in the wrong
        # field)
        logger.info(
            "Got login request with identifier: %r, medium: %r, address: %r, user: %r",
            login_submission.get("identifier"),
            login_submission.get("medium"),
            login_submission.get("address"),
            login_submission.get("user"),
        )
        identifier = convert_client_dict_legacy_fields_to_identifier(
            login_submission)

        # convert phone type identifiers to generic threepids
        if identifier["type"] == "m.id.phone":
            identifier = login_id_phone_to_thirdparty(identifier)

        # convert threepid identifiers to user IDs
        if identifier["type"] == "m.id.thirdparty":
            address = identifier.get("address")
            medium = identifier.get("medium")

            if medium is None or address is None:
                raise SynapseError(400, "Invalid thirdparty identifier")

            # For emails, canonicalise the address.
            # We store all email addresses canonicalised in the DB.
            # (See add_threepid in synapse/handlers/auth.py)
            if medium == "email":
                try:
                    address = canonicalise_email(address)
                except ValueError as e:
                    raise SynapseError(400, str(e))

            # We also apply account rate limiting using the 3PID as a key, as
            # otherwise using 3PID bypasses the ratelimiting based on user ID.
            self._failed_attempts_ratelimiter.ratelimit((medium, address),
                                                        update=False)

            # Check for login providers that support 3pid login types
            (
                canonical_user_id,
                callback_3pid,
            ) = await self.auth_handler.check_password_provider_3pid(
                medium, address, login_submission["password"])
            if canonical_user_id:
                # Authentication through password provider and 3pid succeeded

                result = await self._complete_login(canonical_user_id,
                                                    login_submission,
                                                    callback_3pid)
                return result

            # No password providers were able to handle this 3pid
            # Check local store
            user_id = await self.hs.get_datastore().get_user_id_by_threepid(
                medium, address)
            if not user_id:
                logger.warning("unknown 3pid identifier medium %s, address %r",
                               medium, address)
                # We mark that we've failed to log in here, as
                # `check_password_provider_3pid` might have returned `None` due
                # to an incorrect password, rather than the account not
                # existing.
                #
                # If it returned None but the 3PID was bound then we won't hit
                # this code path, which is fine as then the per-user ratelimit
                # will kick in below.
                self._failed_attempts_ratelimiter.can_do_action(
                    (medium, address))
                raise LoginError(403, "", errcode=Codes.FORBIDDEN)

            identifier = {"type": "m.id.user", "user": user_id}

        # by this point, the identifier should be an m.id.user: if it's anything
        # else, we haven't understood it.
        qualified_user_id = self._get_qualified_user_id(identifier)

        # Check if we've hit the failed ratelimit (but don't update it)
        self._failed_attempts_ratelimiter.ratelimit(qualified_user_id.lower(),
                                                    update=False)

        try:
            canonical_user_id, callback = await self.auth_handler.validate_login(
                identifier["user"], login_submission)
        except LoginError:
            # The user has failed to log in, so we need to update the rate
            # limiter. Using `can_do_action` avoids us raising a ratelimit
            # exception and masking the LoginError. The actual ratelimiting
            # should have happened above.
            self._failed_attempts_ratelimiter.can_do_action(
                qualified_user_id.lower())
            raise

        result = await self._complete_login(canonical_user_id,
                                            login_submission, callback)
        return result

    async def _complete_login(
        self,
        user_id: str,
        login_submission: JsonDict,
        callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
        create_non_existent_users: bool = False,
    ) -> Dict[str, str]:
        """Called when we've successfully authed the user and now need to
        actually login them in (e.g. create devices). This gets called on
        all successful logins.

        Applies the ratelimiting for successful login attempts against an
        account.

        Args:
            user_id: ID of the user to register.
            login_submission: Dictionary of login information.
            callback: Callback function to run after login.
            create_non_existent_users: Whether to create the user if they don't
                exist. Defaults to False.

        Returns:
            result: Dictionary of account information after successful login.
        """

        # Before we actually log them in we check if they've already logged in
        # too often. This happens here rather than before as we don't
        # necessarily know the user before now.
        self._account_ratelimiter.ratelimit(user_id.lower())

        if create_non_existent_users:
            canonical_uid = await self.auth_handler.check_user_exists(user_id)
            if not canonical_uid:
                canonical_uid = await self.registration_handler.register_user(
                    localpart=UserID.from_string(user_id).localpart)
            user_id = canonical_uid

        device_id = login_submission.get("device_id")
        initial_display_name = login_submission.get(
            "initial_device_display_name")
        device_id, access_token = await self.registration_handler.register_device(
            user_id, device_id, initial_display_name)

        result = {
            "user_id": user_id,
            "access_token": access_token,
            "home_server": self.hs.hostname,
            "device_id": device_id,
        }

        if callback is not None:
            await callback(result)

        return result

    async def _do_token_login(self,
                              login_submission: JsonDict) -> Dict[str, str]:
        """
        Handle the final stage of SSO login.

        Args:
             login_submission: The JSON request body.

        Returns:
            The body of the JSON response.
        """
        token = login_submission["token"]
        auth_handler = self.auth_handler
        user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
            token)

        return await self._complete_login(
            user_id, login_submission, self.auth_handler._sso_login_callback)

    async def _do_jwt_login(self,
                            login_submission: JsonDict) -> Dict[str, str]:
        token = login_submission.get("token", None)
        if token is None:
            raise LoginError(403,
                             "Token field for JWT is missing",
                             errcode=Codes.FORBIDDEN)

        import jwt

        try:
            payload = jwt.decode(
                token,
                self.jwt_secret,
                algorithms=[self.jwt_algorithm],
                issuer=self.jwt_issuer,
                audience=self.jwt_audiences,
            )
        except jwt.PyJWTError as e:
            # A JWT error occurred, return some info back to the client.
            raise LoginError(
                403,
                "JWT validation failed: %s" % (str(e), ),
                errcode=Codes.FORBIDDEN,
            )

        user = payload.get("sub", None)
        if user is None:
            raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)

        user_id = UserID(user, self.hs.hostname).to_string()
        result = await self._complete_login(user_id,
                                            login_submission,
                                            create_non_existent_users=True)
        return result
Example #15
0
class GetVerificationCodeServlet(RestServlet):
    PATTERNS = client_patterns("/login/getvercode", v1=True)

    def __init__(self, hs):
        super().__init__()
        self.hs = hs
        # self.get_ver_code_cache = ExpiringCache(
        #     cache_name="get_ver_code_cache",
        #     clock=self._clock,
        #     max_len=1000,
        #     expiry_ms=10 * 60 * 1000,
        #     reset_expiry_on_get=False,
        # )

        self._address_ratelimiter = Ratelimiter(
            clock=hs.get_clock(),
            rate_hz=self.hs.config.rc_login_address.per_second,
            burst_count=self.hs.config.rc_login_address.burst_count,
        )
        self.http_client = SimpleHttpClient(hs)

    async def on_POST(self, request: SynapseRequest):
        self._address_ratelimiter.ratelimit(request.getClientIP())

        params = parse_json_object_from_request(request)
        medium = params["medium"]
        address = params["address"]
        if medium is None:
            raise LoginError(410,
                             "medium field for get_ver_code is missing",
                             errcode=Codes.FORBIDDEN)
        if address is None:
            raise LoginError(410,
                             "address field for get_ver_code is missing",
                             errcode=Codes.FORBIDDEN)
        if medium not in ("email", "msisdn"):
            raise LoginError(411, "no support medium", errcode=Codes.FORBIDDEN)
        # verify medium and send to email
        existing_user_id = await self.hs.get_datastore(
        ).get_user_id_by_threepid(medium, address)
        if existing_user_id is None:
            if medium == "msisdn":
                raise SynapseError(400, "msisdn not bind",
                                   Codes.TEMPORARY_NOT_BIND_MSISDN)
            else:
                raise SynapseError(400, "email not bind", Codes.EMAIL_NOT_BIND)

        # call IS for verify email ver_code
        # identity_handler = self.hs.get_identity_handler()
        # result = await identity_handler.request_get_threepid_ver_code(self.hs.config.account_threepid_delegate_email, "email", email)
        # logger.info("result:%s" % (str(result)))
        # self.get_ver_code_cache.setdefault(email, result["verCode"])
        # ver_code_service_host = "192.168.15.4"
        # ver_code_service_port = "8080"
        # ver_code_service_send_api = "/api/services/auth/v1/code"
        sendSmsType = medium
        if sendSmsType == "msisdn":
            sendSmsType = "mobile"
        params = {"value": address, "type": sendSmsType}
        try:
            result = await self.http_client.post_json_get_json(
                self.hs.config.auth_baseurl + self.hs.config.auth_get_vercode,
                params,
            )
            logger.info("result: %s" % (str(result)))
            if result["code"] != 200:
                raise SynapseError(500, result["message"])
        except HttpResponseException as e:
            logger.info("Proxied getvercode failed: %r", e)
            raise e.to_synapse_error()
        except RequestTimedOutError:
            raise SynapseError(
                500,
                "Timed out contacting extral server:ver_code_send_service")
        return 200, {}
Example #16
0
class LoginRestServlet(RestServlet):
    PATTERNS = client_patterns("/login$", v1=True)
    CAS_TYPE = "m.login.cas"
    SSO_TYPE = "m.login.sso"
    TOKEN_TYPE = "m.login.token"
    JWT_TYPE = "m.login.jwt"

    def __init__(self, hs):
        super(LoginRestServlet, self).__init__()
        self.hs = hs
        self.jwt_enabled = hs.config.jwt_enabled
        self.jwt_secret = hs.config.jwt_secret
        self.jwt_algorithm = hs.config.jwt_algorithm
        self.saml2_enabled = hs.config.saml2_enabled
        self.cas_enabled = hs.config.cas_enabled
        self.auth_handler = self.hs.get_auth_handler()
        self.registration_handler = hs.get_registration_handler()
        self.handlers = hs.get_handlers()
        self._well_known_builder = WellKnownBuilder(hs)
        self._address_ratelimiter = Ratelimiter()

    def on_GET(self, request):
        flows = []
        if self.jwt_enabled:
            flows.append({"type": LoginRestServlet.JWT_TYPE})
        if self.saml2_enabled:
            flows.append({"type": LoginRestServlet.SSO_TYPE})
            flows.append({"type": LoginRestServlet.TOKEN_TYPE})
        if self.cas_enabled:
            flows.append({"type": LoginRestServlet.SSO_TYPE})

            # we advertise CAS for backwards compat, though MSC1721 renamed it
            # to SSO.
            flows.append({"type": LoginRestServlet.CAS_TYPE})

            # While its valid for us to advertise this login type generally,
            # synapse currently only gives out these tokens as part of the
            # CAS login flow.
            # Generally we don't want to advertise login flows that clients
            # don't know how to implement, since they (currently) will always
            # fall back to the fallback API if they don't understand one of the
            # login flow types returned.
            flows.append({"type": LoginRestServlet.TOKEN_TYPE})

        flows.extend(({
            "type": t
        } for t in self.auth_handler.get_supported_login_types()))

        return (200, {"flows": flows})

    def on_OPTIONS(self, request):
        return (200, {})

    @defer.inlineCallbacks
    def on_POST(self, request):
        self._address_ratelimiter.ratelimit(
            request.getClientIP(),
            time_now_s=self.hs.clock.time(),
            rate_hz=self.hs.config.rc_login_address.per_second,
            burst_count=self.hs.config.rc_login_address.burst_count,
            update=True,
        )

        login_submission = parse_json_object_from_request(request)
        try:
            if self.jwt_enabled and (login_submission["type"]
                                     == LoginRestServlet.JWT_TYPE):
                result = yield self.do_jwt_login(login_submission)
            elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
                result = yield self.do_token_login(login_submission)
            else:
                result = yield self._do_other_login(login_submission)
        except KeyError:
            raise SynapseError(400, "Missing JSON keys.")

        well_known_data = self._well_known_builder.get_well_known()
        if well_known_data:
            result["well_known"] = well_known_data
        return (200, result)

    @defer.inlineCallbacks
    def _do_other_login(self, login_submission):
        """Handle non-token/saml/jwt logins

        Args:
            login_submission:

        Returns:
            dict: HTTP response
        """
        # Log the request we got, but only certain fields to minimise the chance of
        # logging someone's password (even if they accidentally put it in the wrong
        # field)
        logger.info(
            "Got login request with identifier: %r, medium: %r, address: %r, user: %r",
            login_submission.get("identifier"),
            login_submission.get("medium"),
            login_submission.get("address"),
            login_submission.get("user"),
        )
        login_submission_legacy_convert(login_submission)

        if "identifier" not in login_submission:
            raise SynapseError(400, "Missing param: identifier")

        identifier = login_submission["identifier"]
        if "type" not in identifier:
            raise SynapseError(400, "Login identifier has no type")

        # convert phone type identifiers to generic threepids
        if identifier["type"] == "m.id.phone":
            identifier = login_id_thirdparty_from_phone(identifier)

        # convert threepid identifiers to user IDs
        if identifier["type"] == "m.id.thirdparty":
            address = identifier.get("address")
            medium = identifier.get("medium")

            if medium is None or address is None:
                raise SynapseError(400, "Invalid thirdparty identifier")

            if medium == "email":
                # For emails, transform the address to lowercase.
                # We store all email addreses as lowercase in the DB.
                # (See add_threepid in synapse/handlers/auth.py)
                address = address.lower()

            # Check for login providers that support 3pid login types
            canonical_user_id, callback_3pid = (
                yield self.auth_handler.check_password_provider_3pid(
                    medium, address, login_submission["password"]))
            if canonical_user_id:
                # Authentication through password provider and 3pid succeeded
                result = yield self._register_device_with_callback(
                    canonical_user_id, login_submission, callback_3pid)
                return result

            # No password providers were able to handle this 3pid
            # Check local store
            user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
                medium, address)
            if not user_id:
                logger.warn("unknown 3pid identifier medium %s, address %r",
                            medium, address)
                raise LoginError(403, "", errcode=Codes.FORBIDDEN)

            identifier = {"type": "m.id.user", "user": user_id}

        # by this point, the identifier should be an m.id.user: if it's anything
        # else, we haven't understood it.
        if identifier["type"] != "m.id.user":
            raise SynapseError(400, "Unknown login identifier type")
        if "user" not in identifier:
            raise SynapseError(400, "User identifier is missing 'user' key")

        canonical_user_id, callback = yield self.auth_handler.validate_login(
            identifier["user"], login_submission)

        result = yield self._register_device_with_callback(
            canonical_user_id, login_submission, callback)
        return result

    @defer.inlineCallbacks
    def _register_device_with_callback(self,
                                       user_id,
                                       login_submission,
                                       callback=None):
        """ Registers a device with a given user_id. Optionally run a callback
        function after registration has completed.

        Args:
            user_id (str): ID of the user to register.
            login_submission (dict): Dictionary of login information.
            callback (func|None): Callback function to run after registration.

        Returns:
            result (Dict[str,str]): Dictionary of account information after
                successful registration.
        """
        device_id = login_submission.get("device_id")
        initial_display_name = login_submission.get(
            "initial_device_display_name")
        device_id, access_token = yield self.registration_handler.register_device(
            user_id, device_id, initial_display_name)

        result = {
            "user_id": user_id,
            "access_token": access_token,
            "home_server": self.hs.hostname,
            "device_id": device_id,
        }

        if callback is not None:
            yield callback(result)

        return result

    @defer.inlineCallbacks
    def do_token_login(self, login_submission):
        token = login_submission["token"]
        auth_handler = self.auth_handler
        user_id = (yield auth_handler.
                   validate_short_term_login_token_and_get_user_id(token))

        result = yield self._register_device_with_callback(
            user_id, login_submission)
        return result

    @defer.inlineCallbacks
    def do_jwt_login(self, login_submission):
        token = login_submission.get("token", None)
        if token is None:
            raise LoginError(401,
                             "Token field for JWT is missing",
                             errcode=Codes.UNAUTHORIZED)

        import jwt
        from jwt.exceptions import InvalidTokenError

        try:
            payload = jwt.decode(token,
                                 self.jwt_secret,
                                 algorithms=[self.jwt_algorithm])
        except jwt.ExpiredSignatureError:
            raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
        except InvalidTokenError:
            raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)

        user = payload.get("sub", None)
        if user is None:
            raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)

        user_id = UserID(user, self.hs.hostname).to_string()

        registered_user_id = yield self.auth_handler.check_user_exists(user_id)
        if not registered_user_id:
            registered_user_id = yield self.registration_handler.register_user(
                localpart=user)

        result = yield self._register_device_with_callback(
            registered_user_id, login_submission)
        return result
Example #17
0
class LoginRestServlet(RestServlet):
    PATTERNS = client_patterns("/login$", v1=True)
    CAS_TYPE = "m.login.cas"
    SSO_TYPE = "m.login.sso"
    TOKEN_TYPE = "m.login.token"
    JWT_TYPE = "org.matrix.login.jwt"
    JWT_TYPE_DEPRECATED = "m.login.jwt"
    APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"

    def __init__(self, hs: "HomeServer"):
        super().__init__()
        self.hs = hs

        # JWT configuration variables.
        self.jwt_enabled = hs.config.jwt_enabled
        self.jwt_secret = hs.config.jwt_secret
        self.jwt_algorithm = hs.config.jwt_algorithm
        self.jwt_issuer = hs.config.jwt_issuer
        self.jwt_audiences = hs.config.jwt_audiences

        # SSO configuration.
        self.saml2_enabled = hs.config.saml2_enabled
        self.cas_enabled = hs.config.cas_enabled
        self.oidc_enabled = hs.config.oidc_enabled

        self.auth = hs.get_auth()

        self.auth_handler = self.hs.get_auth_handler()
        self.registration_handler = hs.get_registration_handler()
        self._well_known_builder = WellKnownBuilder(hs)
        self._address_ratelimiter = Ratelimiter(
            clock=hs.get_clock(),
            rate_hz=self.hs.config.rc_login_address.per_second,
            burst_count=self.hs.config.rc_login_address.burst_count,
        )
        self._account_ratelimiter = Ratelimiter(
            clock=hs.get_clock(),
            rate_hz=self.hs.config.rc_login_account.per_second,
            burst_count=self.hs.config.rc_login_account.burst_count,
        )

    def on_GET(self, request: SynapseRequest):
        flows = []
        if self.jwt_enabled:
            flows.append({"type": LoginRestServlet.JWT_TYPE})
            flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED})

        if self.cas_enabled:
            # we advertise CAS for backwards compat, though MSC1721 renamed it
            # to SSO.
            flows.append({"type": LoginRestServlet.CAS_TYPE})

        if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
            flows.append({"type": LoginRestServlet.SSO_TYPE})
            # While its valid for us to advertise this login type generally,
            # synapse currently only gives out these tokens as part of the
            # SSO login flow.
            # Generally we don't want to advertise login flows that clients
            # don't know how to implement, since they (currently) will always
            # fall back to the fallback API if they don't understand one of the
            # login flow types returned.
            flows.append({"type": LoginRestServlet.TOKEN_TYPE})

        flows.extend(({
            "type": t
        } for t in self.auth_handler.get_supported_login_types()))

        flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})

        return 200, {"flows": flows}

    async def on_POST(self, request: SynapseRequest):
        login_submission = parse_json_object_from_request(request)

        try:
            if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
                appservice = self.auth.get_appservice_by_req(request)

                if appservice.is_rate_limited():
                    self._address_ratelimiter.ratelimit(request.getClientIP())

                result = await self._do_appservice_login(
                    login_submission, appservice)
            elif self.jwt_enabled and (
                    login_submission["type"] == LoginRestServlet.JWT_TYPE
                    or login_submission["type"]
                    == LoginRestServlet.JWT_TYPE_DEPRECATED):
                self._address_ratelimiter.ratelimit(request.getClientIP())
                result = await self._do_jwt_login(login_submission)
            elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
                self._address_ratelimiter.ratelimit(request.getClientIP())
                result = await self._do_token_login(login_submission)
            else:
                self._address_ratelimiter.ratelimit(request.getClientIP())
                result = await self._do_other_login(login_submission)
        except KeyError:
            raise SynapseError(400, "Missing JSON keys.")

        well_known_data = self._well_known_builder.get_well_known()
        if well_known_data:
            result["well_known"] = well_known_data
        return 200, result

    async def _do_appservice_login(self, login_submission: JsonDict,
                                   appservice: ApplicationService):
        identifier = login_submission.get("identifier")
        logger.info("Got appservice login request with identifier: %r",
                    identifier)

        if not isinstance(identifier, dict):
            raise SynapseError(400, "Invalid identifier in login submission",
                               Codes.INVALID_PARAM)

        # this login flow only supports identifiers of type "m.id.user".
        if identifier.get("type") != "m.id.user":
            raise SynapseError(400, "Unknown login identifier type",
                               Codes.INVALID_PARAM)

        user = identifier.get("user")
        if not isinstance(user, str):
            raise SynapseError(400, "Invalid user in identifier",
                               Codes.INVALID_PARAM)

        if user.startswith("@"):
            qualified_user_id = user
        else:
            qualified_user_id = UserID(user, self.hs.hostname).to_string()

        if not appservice.is_interested_in_user(qualified_user_id):
            raise LoginError(403,
                             "Invalid access_token",
                             errcode=Codes.FORBIDDEN)

        return await self._complete_login(
            qualified_user_id,
            login_submission,
            ratelimit=appservice.is_rate_limited())

    async def _do_other_login(self,
                              login_submission: JsonDict) -> Dict[str, str]:
        """Handle non-token/saml/jwt logins

        Args:
            login_submission:

        Returns:
            HTTP response
        """
        # Log the request we got, but only certain fields to minimise the chance of
        # logging someone's password (even if they accidentally put it in the wrong
        # field)
        logger.info(
            "Got login request with identifier: %r, medium: %r, address: %r, user: %r",
            login_submission.get("identifier"),
            login_submission.get("medium"),
            login_submission.get("address"),
            login_submission.get("user"),
        )
        canonical_user_id, callback = await self.auth_handler.validate_login(
            login_submission, ratelimit=True)
        result = await self._complete_login(canonical_user_id,
                                            login_submission, callback)
        return result

    async def _complete_login(
        self,
        user_id: str,
        login_submission: JsonDict,
        callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
        create_non_existent_users: bool = False,
        ratelimit: bool = True,
    ) -> Dict[str, str]:
        """Called when we've successfully authed the user and now need to
        actually login them in (e.g. create devices). This gets called on
        all successful logins.

        Applies the ratelimiting for successful login attempts against an
        account.

        Args:
            user_id: ID of the user to register.
            login_submission: Dictionary of login information.
            callback: Callback function to run after login.
            create_non_existent_users: Whether to create the user if they don't
                exist. Defaults to False.
            ratelimit: Whether to ratelimit the login request.

        Returns:
            result: Dictionary of account information after successful login.
        """

        # Before we actually log them in we check if they've already logged in
        # too often. This happens here rather than before as we don't
        # necessarily know the user before now.
        if ratelimit:
            self._account_ratelimiter.ratelimit(user_id.lower())

        if create_non_existent_users:
            canonical_uid = await self.auth_handler.check_user_exists(user_id)
            if not canonical_uid:
                canonical_uid = await self.registration_handler.register_user(
                    localpart=UserID.from_string(user_id).localpart)
            user_id = canonical_uid

        device_id = login_submission.get("device_id")
        initial_display_name = login_submission.get(
            "initial_device_display_name")
        device_id, access_token = await self.registration_handler.register_device(
            user_id, device_id, initial_display_name)

        result = {
            "user_id": user_id,
            "access_token": access_token,
            "home_server": self.hs.hostname,
            "device_id": device_id,
        }

        if callback is not None:
            await callback(result)

        return result

    async def _do_token_login(self,
                              login_submission: JsonDict) -> Dict[str, str]:
        """
        Handle the final stage of SSO login.

        Args:
             login_submission: The JSON request body.

        Returns:
            The body of the JSON response.
        """
        token = login_submission["token"]
        auth_handler = self.auth_handler
        user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
            token)

        return await self._complete_login(
            user_id, login_submission, self.auth_handler._sso_login_callback)

    async def _do_jwt_login(self,
                            login_submission: JsonDict) -> Dict[str, str]:
        token = login_submission.get("token", None)
        if token is None:
            raise LoginError(403,
                             "Token field for JWT is missing",
                             errcode=Codes.FORBIDDEN)

        import jwt

        try:
            payload = jwt.decode(
                token,
                self.jwt_secret,
                algorithms=[self.jwt_algorithm],
                issuer=self.jwt_issuer,
                audience=self.jwt_audiences,
            )
        except jwt.PyJWTError as e:
            # A JWT error occurred, return some info back to the client.
            raise LoginError(
                403,
                "JWT validation failed: %s" % (str(e), ),
                errcode=Codes.FORBIDDEN,
            )

        user = payload.get("sub", None)
        if user is None:
            raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)

        user_id = UserID(user, self.hs.hostname).to_string()
        result = await self._complete_login(user_id,
                                            login_submission,
                                            create_non_existent_users=True)
        return result
Example #18
0
class BaseHandler:
    """
    Common base class for the event handlers.
    """
    def __init__(self, hs):
        """
        Args:
            hs (synapse.server.HomeServer):
        """
        self.store = hs.get_datastore()  # type: synapse.storage.DataStore
        self.auth = hs.get_auth()
        self.notifier = hs.get_notifier()
        self.state_handler = hs.get_state_handler(
        )  # type: synapse.state.StateHandler
        self.distributor = hs.get_distributor()
        self.clock = hs.get_clock()
        self.hs = hs

        # The rate_hz and burst_count are overridden on a per-user basis
        self.request_ratelimiter = Ratelimiter(clock=self.clock,
                                               rate_hz=0,
                                               burst_count=0)
        self._rc_message = self.hs.config.rc_message

        # Check whether ratelimiting room admin message redaction is enabled
        # by the presence of rate limits in the config
        if self.hs.config.rc_admin_redaction:
            self.admin_redaction_ratelimiter = Ratelimiter(
                clock=self.clock,
                rate_hz=self.hs.config.rc_admin_redaction.per_second,
                burst_count=self.hs.config.rc_admin_redaction.burst_count,
            )
        else:
            self.admin_redaction_ratelimiter = None

        self.server_name = hs.hostname

        self.event_builder_factory = hs.get_event_builder_factory()

    async def ratelimit(self,
                        requester,
                        update=True,
                        is_admin_redaction=False):
        """Ratelimits requests.

        Args:
            requester (Requester)
            update (bool): Whether to record that a request is being processed.
                Set to False when doing multiple checks for one request (e.g.
                to check up front if we would reject the request), and set to
                True for the last call for a given request.
            is_admin_redaction (bool): Whether this is a room admin/moderator
                redacting an event. If so then we may apply different
                ratelimits depending on config.

        Raises:
            LimitExceededError if the request should be ratelimited
        """
        user_id = requester.user.to_string()

        # The AS user itself is never rate limited.
        app_service = self.store.get_app_service_by_user_id(user_id)
        if app_service is not None:
            return  # do not ratelimit app service senders

        # Disable rate limiting of users belonging to any AS that is configured
        # not to be rate limited in its registration file (rate_limited: true|false).
        if requester.app_service and not requester.app_service.is_rate_limited(
        ):
            return

        messages_per_second = self._rc_message.per_second
        burst_count = self._rc_message.burst_count

        # Check if there is a per user override in the DB.
        override = await self.store.get_ratelimit_for_user(user_id)
        if override:
            # If overridden with a null Hz then ratelimiting has been entirely
            # disabled for the user
            if not override.messages_per_second:
                return

            messages_per_second = override.messages_per_second
            burst_count = override.burst_count

        if is_admin_redaction and self.admin_redaction_ratelimiter:
            # If we have separate config for admin redactions, use a separate
            # ratelimiter as to not have user_ids clash
            self.admin_redaction_ratelimiter.ratelimit(user_id, update=update)
        else:
            # Override rate and burst count per-user
            self.request_ratelimiter.ratelimit(
                user_id,
                rate_hz=messages_per_second,
                burst_count=burst_count,
                update=update,
            )

    async def maybe_kick_guest_users(self, event, context=None):
        # Technically this function invalidates current_state by changing it.
        # Hopefully this isn't that important to the caller.
        if event.type == EventTypes.GuestAccess:
            guest_access = event.content.get("guest_access", "forbidden")
            if guest_access != "can_join":
                if context:
                    current_state_ids = await context.get_current_state_ids()
                    current_state = await self.store.get_events(
                        list(current_state_ids.values()))
                else:
                    current_state = await self.state_handler.get_current_state(
                        event.room_id)

                current_state = list(current_state.values())

                logger.info("maybe_kick_guest_users %r", current_state)
                await self.kick_guest_users(current_state)

    async def kick_guest_users(self, current_state):
        for member_event in current_state:
            try:
                if member_event.type != EventTypes.Member:
                    continue

                target_user = UserID.from_string(member_event.state_key)
                if not self.hs.is_mine(target_user):
                    continue

                if member_event.content["membership"] not in {
                        Membership.JOIN,
                        Membership.INVITE,
                }:
                    continue

                if ("kind" not in member_event.content
                        or member_event.content["kind"] != "guest"):
                    continue

                # We make the user choose to leave, rather than have the
                # event-sender kick them. This is partially because we don't
                # need to worry about power levels, and partially because guest
                # users are a concept which doesn't hugely work over federation,
                # and having homeservers have their own users leave keeps more
                # of that decision-making and control local to the guest-having
                # homeserver.
                requester = synapse.types.create_requester(target_user,
                                                           is_guest=True)
                handler = self.hs.get_room_member_handler()
                await handler.update_membership(
                    requester,
                    target_user,
                    member_event.room_id,
                    "leave",
                    ratelimit=False,
                    require_consent=False,
                )
            except Exception as e:
                logger.exception("Error kicking guest user: %s" % (e, ))