Example #1
0
    def __init__(self, hostname, reactor=None, **kwargs):
        """
        Args:
            hostname : The hostname for the server.
        """
        if not reactor:
            from twisted.internet import reactor

        self._reactor = reactor
        self.hostname = hostname
        self._building = {}
        self._listening_services = []
        self.start_time = None

        self.clock = Clock(reactor)
        self.distributor = Distributor()
        self.ratelimiter = Ratelimiter()
        self.admin_redaction_ratelimiter = Ratelimiter()
        self.registration_ratelimiter = Ratelimiter()

        self.datastores = None

        # Other kwargs are explicit dependencies
        for depname in kwargs:
            setattr(self, depname, kwargs[depname])
Example #2
0
    def test_allowed_appservice_via_can_requester_do_action(self):
        appservice = ApplicationService(
            None,
            "example.com",
            id="foo",
            rate_limited=False,
            sender="@as:example.com",
        )
        as_requester = create_requester("@user:example.com", app_service=appservice)

        limiter = Ratelimiter(
            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
        )
        allowed, time_allowed = self.get_success_or_raise(
            limiter.can_do_action(as_requester, _time_now_s=0)
        )
        self.assertTrue(allowed)
        self.assertEquals(-1, time_allowed)

        allowed, time_allowed = self.get_success_or_raise(
            limiter.can_do_action(as_requester, _time_now_s=5)
        )
        self.assertTrue(allowed)
        self.assertEquals(-1, time_allowed)

        allowed, time_allowed = self.get_success_or_raise(
            limiter.can_do_action(as_requester, _time_now_s=10)
        )
        self.assertTrue(allowed)
        self.assertEquals(-1, time_allowed)
Example #3
0
    def __init__(self, hs: "HomeServer"):
        super().__init__()
        self.hs = hs

        # JWT configuration variables.
        self.jwt_enabled = hs.config.jwt_enabled
        self.jwt_secret = hs.config.jwt_secret
        self.jwt_algorithm = hs.config.jwt_algorithm
        self.jwt_issuer = hs.config.jwt_issuer
        self.jwt_audiences = hs.config.jwt_audiences

        # SSO configuration.
        self.saml2_enabled = hs.config.saml2_enabled
        self.cas_enabled = hs.config.cas_enabled
        self.oidc_enabled = hs.config.oidc_enabled

        self.auth = hs.get_auth()

        self.auth_handler = self.hs.get_auth_handler()
        self.registration_handler = hs.get_registration_handler()
        self._well_known_builder = WellKnownBuilder(hs)
        self._address_ratelimiter = Ratelimiter(
            clock=hs.get_clock(),
            rate_hz=self.hs.config.rc_login_address.per_second,
            burst_count=self.hs.config.rc_login_address.burst_count,
        )
        self._account_ratelimiter = Ratelimiter(
            clock=hs.get_clock(),
            rate_hz=self.hs.config.rc_login_account.per_second,
            burst_count=self.hs.config.rc_login_account.burst_count,
        )
Example #4
0
    def __init__(self, hs):
        """
        Args:
            hs (synapse.server.HomeServer):
        """
        self.store = hs.get_datastore()  # type: synapse.storage.DataStore
        self.auth = hs.get_auth()
        self.notifier = hs.get_notifier()
        self.state_handler = hs.get_state_handler(
        )  # type: synapse.state.StateHandler
        self.distributor = hs.get_distributor()
        self.clock = hs.get_clock()
        self.hs = hs

        # The rate_hz and burst_count are overridden on a per-user basis
        self.request_ratelimiter = Ratelimiter(clock=self.clock,
                                               rate_hz=0,
                                               burst_count=0)
        self._rc_message = self.hs.config.rc_message

        # Check whether ratelimiting room admin message redaction is enabled
        # by the presence of rate limits in the config
        if self.hs.config.rc_admin_redaction:
            self.admin_redaction_ratelimiter = Ratelimiter(
                clock=self.clock,
                rate_hz=self.hs.config.rc_admin_redaction.per_second,
                burst_count=self.hs.config.rc_admin_redaction.burst_count,
            )
        else:
            self.admin_redaction_ratelimiter = None

        self.server_name = hs.hostname

        self.event_builder_factory = hs.get_event_builder_factory()
Example #5
0
    def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
        appservice = ApplicationService(
            None,
            "example.com",
            id="foo",
            rate_limited=True,
            sender="@as:example.com",
        )
        as_requester = create_requester("@user:example.com",
                                        app_service=appservice)

        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
        allowed, time_allowed = limiter.can_requester_do_action(as_requester,
                                                                _time_now_s=0)
        self.assertTrue(allowed)
        self.assertEquals(10.0, time_allowed)

        allowed, time_allowed = limiter.can_requester_do_action(as_requester,
                                                                _time_now_s=5)
        self.assertFalse(allowed)
        self.assertEquals(10.0, time_allowed)

        allowed, time_allowed = limiter.can_requester_do_action(as_requester,
                                                                _time_now_s=10)
        self.assertTrue(allowed)
        self.assertEquals(20.0, time_allowed)
Example #6
0
    def __init__(self, hs):
        super().__init__(hs)

        # An HTTP client for contacting trusted URLs.
        self.http_client = SimpleHttpClient(hs)
        # An HTTP client for contacting identity servers specified by clients.
        self.blacklisting_http_client = SimpleHttpClient(
            hs, ip_blacklist=hs.config.federation_ip_range_blacklist)
        self.federation_http_client = hs.get_federation_http_client()
        self.hs = hs

        self._web_client_location = hs.config.invite_client_location

        # Ratelimiters for `/requestToken` endpoints.
        self._3pid_validation_ratelimiter_ip = Ratelimiter(
            store=self.store,
            clock=hs.get_clock(),
            rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
            burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
        )
        self._3pid_validation_ratelimiter_address = Ratelimiter(
            store=self.store,
            clock=hs.get_clock(),
            rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
            burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
        )
Example #7
0
    def test_db_user_override(self):
        """Test that users that have ratelimiting disabled in the DB aren't
        ratelimited.
        """
        store = self.hs.get_datastore()

        user_id = "@user:test"
        requester = create_requester(user_id)

        self.get_success(
            store.db_pool.simple_insert(
                table="ratelimit_override",
                values={
                    "user_id": user_id,
                    "messages_per_second": None,
                    "burst_count": None,
                },
                desc="test_db_user_override",
            )
        )

        limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1)

        # Shouldn't raise
        for _ in range(20):
            self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))
Example #8
0
    def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs):
        """
        Args:
            hostname : The hostname for the server.
            config: The full config for the homeserver.
        """
        if not reactor:
            from twisted.internet import reactor

        self._reactor = reactor
        self.hostname = hostname
        self.config = config
        self._building = {}
        self._listening_services = []
        self.start_time = None

        self._instance_id = random_string(5)
        self._instance_name = config.worker_name or "master"

        self.clock = Clock(reactor)
        self.distributor = Distributor()
        self.ratelimiter = Ratelimiter()
        self.admin_redaction_ratelimiter = Ratelimiter()
        self.registration_ratelimiter = Ratelimiter()

        self.datastores = None

        # Other kwargs are explicit dependencies
        for depname in kwargs:
            setattr(self, depname, kwargs[depname])
Example #9
0
    def __init__(self, hs: "HomeServer"):
        self.config = hs.config
        self.clock = hs.get_clock()
        self._instance_name = hs.get_instance_name()

        # These are safe to load in monolith mode, but will explode if we try
        # and use them. However we have guards before we use them to ensure that
        # we don't route to ourselves, and in monolith mode that will always be
        # the case.
        self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
        self._send_edu = ReplicationFederationSendEduRestServlet.make_client(
            hs)

        self.edu_handlers = (
            {})  # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
        self.query_handlers = {
        }  # type: Dict[str, Callable[[dict], Awaitable[None]]]

        # Map from type to instance names that we should route EDU handling to.
        # We randomly choose one instance from the list to route to for each new
        # EDU received.
        self._edu_type_to_instance = {}  # type: Dict[str, List[str]]

        # A rate limiter for incoming room key requests per origin.
        self._room_key_request_rate_limiter = Ratelimiter(
            clock=self.clock,
            rate_hz=self.config.rc_key_requests.per_second,
            burst_count=self.config.rc_key_requests.burst_count,
        )
Example #10
0
    def test_allowed(self):
        limiter = Ratelimiter()
        allowed, time_allowed = limiter.send_message(
            user_id="test_id",
            time_now_s=0,
            msg_rate_hz=0.1,
            burst_count=1,
        )
        self.assertTrue(allowed)
        self.assertEquals(10., time_allowed)

        allowed, time_allowed = limiter.send_message(
            user_id="test_id",
            time_now_s=5,
            msg_rate_hz=0.1,
            burst_count=1,
        )
        self.assertFalse(allowed)
        self.assertEquals(10., time_allowed)

        allowed, time_allowed = limiter.send_message(user_id="test_id",
                                                     time_now_s=10,
                                                     msg_rate_hz=0.1,
                                                     burst_count=1)
        self.assertTrue(allowed)
        self.assertEquals(20., time_allowed)
Example #11
0
    def __init__(self, hs: "HomeServer"):
        self.store = hs.get_datastore()
        # An HTTP client for contacting trusted URLs.
        self.http_client = SimpleHttpClient(hs)
        # An HTTP client for contacting identity servers specified by clients.
        self.blacklisting_http_client = SimpleHttpClient(
            hs,
            ip_blacklist=hs.config.server.federation_ip_range_blacklist,
            ip_whitelist=hs.config.server.federation_ip_range_whitelist,
        )
        self.federation_http_client = hs.get_federation_http_client()
        self.hs = hs

        self.rewrite_identity_server_urls = (
            hs.config.registration.rewrite_identity_server_urls
        )
        self._enable_lookup = hs.config.registration.enable_3pid_lookup

        self._web_client_location = hs.config.email.invite_client_location

        # Ratelimiters for `/requestToken` endpoints.
        self._3pid_validation_ratelimiter_ip = Ratelimiter(
            store=self.store,
            clock=hs.get_clock(),
            rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
            burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
        )
        self._3pid_validation_ratelimiter_address = Ratelimiter(
            store=self.store,
            clock=hs.get_clock(),
            rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
            burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
        )
Example #12
0
 def __init__(self, hs):
     super(CustomRestServlet, self).__init__()
     self.hs = hs
     self.store = hs.get_datastore()
     self.jwt_enabled = hs.config.jwt_enabled
     self.jwt_secret = hs.config.jwt_secret
     self.jwt_algorithm = hs.config.jwt_algorithm
     self.saml2_enabled = hs.config.saml2_enabled
     self.cas_enabled = hs.config.cas_enabled
     self.oidc_enabled = hs.config.oidc_enabled
     self.auth_handler = self.hs.get_auth_handler()
     self.registration_handler = hs.get_registration_handler()
     self.handlers = hs.get_handlers()
     self._well_known_builder = WellKnownBuilder(hs)
     self._address_ratelimiter = Ratelimiter(
         clock=hs.get_clock(),
         rate_hz=self.hs.config.rc_login_address.per_second,
         burst_count=self.hs.config.rc_login_address.burst_count,
     )
     self._account_ratelimiter = Ratelimiter(
         clock=hs.get_clock(),
         rate_hz=self.hs.config.rc_login_account.per_second,
         burst_count=self.hs.config.rc_login_account.burst_count,
     )
     self._failed_attempts_ratelimiter = Ratelimiter(
         clock=hs.get_clock(),
         rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
         burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
     )
Example #13
0
    def __init__(self, hs):
        """
        Args:
            hs (synapse.server.HomeServer):
        """
        super(AuthHandler, self).__init__(hs)

        self.checkers = {}  # type: dict[str, UserInteractiveAuthChecker]
        for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
            inst = auth_checker_class(hs)
            if inst.is_enabled():
                self.checkers[inst.AUTH_TYPE] = inst

        self.bcrypt_rounds = hs.config.bcrypt_rounds

        # This is not a cache per se, but a store of all current sessions that
        # expire after N hours
        self.sessions = ExpiringCache(
            cache_name="register_sessions",
            clock=hs.get_clock(),
            expiry_ms=self.SESSION_EXPIRE_MS,
            reset_expiry_on_get=True,
        )

        account_handler = ModuleApi(hs, self)
        self.password_providers = [
            module(config=config, account_handler=account_handler)
            for module, config in hs.config.password_providers
        ]

        logger.info("Extra password_providers: %r", self.password_providers)

        self.hs = hs  # FIXME better possibility to access registrationHandler later?
        self.macaroon_gen = hs.get_macaroon_generator()
        self._password_enabled = hs.config.password_enabled

        # we keep this as a list despite the O(N^2) implication so that we can
        # keep PASSWORD first and avoid confusing clients which pick the first
        # type in the list. (NB that the spec doesn't require us to do so and
        # clients which favour types that they don't understand over those that
        # they do are technically broken)
        login_types = []
        if self._password_enabled:
            login_types.append(LoginType.PASSWORD)
        for provider in self.password_providers:
            if hasattr(provider, "get_supported_login_types"):
                for t in provider.get_supported_login_types().keys():
                    if t not in login_types:
                        login_types.append(t)
        self._supported_login_types = login_types

        # Ratelimiter for failed auth during UIA. Uses same ratelimit config
        # as per `rc_login.failed_attempts`.
        self._failed_uia_attempts_ratelimiter = Ratelimiter()

        self._clock = self.hs.get_clock()
Example #14
0
    def __init__(self, hs):
        """
        Args:
            hs (synapse.server.HomeServer):
        """
        super(AuthHandler, self).__init__(hs)
        self.checkers = {
            LoginType.RECAPTCHA: self._check_recaptcha,
            LoginType.EMAIL_IDENTITY: self._check_email_identity,
            LoginType.MSISDN: self._check_msisdn,
            LoginType.DUMMY: self._check_dummy_auth,
            LoginType.TERMS: self._check_terms_auth,
        }
        self.bcrypt_rounds = hs.config.bcrypt_rounds

        # This is not a cache per se, but a store of all current sessions that
        # expire after N hours
        self.sessions = ExpiringCache(
            cache_name="register_sessions",
            clock=hs.get_clock(),
            expiry_ms=self.SESSION_EXPIRE_MS,
            reset_expiry_on_get=True,
        )

        account_handler = ModuleApi(hs, self)
        self.password_providers = [
            module(config=config, account_handler=account_handler)
            for module, config in hs.config.password_providers
        ]

        logger.info("Extra password_providers: %r", self.password_providers)

        self.hs = hs  # FIXME better possibility to access registrationHandler later?
        self.macaroon_gen = hs.get_macaroon_generator()
        self._password_enabled = hs.config.password_enabled

        # we keep this as a list despite the O(N^2) implication so that we can
        # keep PASSWORD first and avoid confusing clients which pick the first
        # type in the list. (NB that the spec doesn't require us to do so and
        # clients which favour types that they don't understand over those that
        # they do are technically broken)
        login_types = []
        if self._password_enabled:
            login_types.append(LoginType.PASSWORD)
        for provider in self.password_providers:
            if hasattr(provider, "get_supported_login_types"):
                for t in provider.get_supported_login_types().keys():
                    if t not in login_types:
                        login_types.append(t)
        self._supported_login_types = login_types

        self._account_ratelimiter = Ratelimiter()
        self._failed_attempts_ratelimiter = Ratelimiter()

        self._clock = self.hs.get_clock()
Example #15
0
    def __init__(self, hs: "HomeServer"):
        self.hs = hs
        self.store = hs.get_datastore()
        self.auth = hs.get_auth()
        self.state_handler = hs.get_state_handler()
        self.config = hs.config

        self.federation_handler = hs.get_federation_handler()
        self.directory_handler = hs.get_directory_handler()
        self.identity_handler = hs.get_identity_handler()
        self.registration_handler = hs.get_registration_handler()
        self.profile_handler = hs.get_profile_handler()
        self.event_creation_handler = hs.get_event_creation_handler()
        self.account_data_handler = hs.get_account_data_handler()
        self.event_auth_handler = hs.get_event_auth_handler()

        self.member_linearizer = Linearizer(name="member")

        self.clock = hs.get_clock()
        self.spam_checker = hs.get_spam_checker()
        self.third_party_event_rules = hs.get_third_party_event_rules()
        self._server_notices_mxid = self.config.server_notices_mxid
        self._enable_lookup = hs.config.enable_3pid_lookup
        self.allow_per_room_profiles = self.config.allow_per_room_profiles

        self._join_rate_limiter_local = Ratelimiter(
            store=self.store,
            clock=self.clock,
            rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
            burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
        )
        self._join_rate_limiter_remote = Ratelimiter(
            store=self.store,
            clock=self.clock,
            rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
            burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
        )

        self._invites_per_room_limiter = Ratelimiter(
            store=self.store,
            clock=self.clock,
            rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
            burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
        )
        self._invites_per_user_limiter = Ratelimiter(
            store=self.store,
            clock=self.clock,
            rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
            burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
        )

        # This is only used to get at ratelimit function, and
        # maybe_kick_guest_users. It's fine there are multiple of these as
        # it doesn't store state.
        self.base_handler = BaseHandler(hs)
Example #16
0
 def __init__(self, hs):
     super(LoginRestServlet, self).__init__(hs)
     self.jwt_enabled = hs.config.jwt_enabled
     self.jwt_secret = hs.config.jwt_secret
     self.jwt_algorithm = hs.config.jwt_algorithm
     self.cas_enabled = hs.config.cas_enabled
     self.auth_handler = self.hs.get_auth_handler()
     self.registration_handler = hs.get_registration_handler()
     self.handlers = hs.get_handlers()
     self._well_known_builder = WellKnownBuilder(hs)
     self._address_ratelimiter = Ratelimiter()
Example #17
0
    def test_allowed_via_can_do_action(self):
        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
        allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0)
        self.assertTrue(allowed)
        self.assertEquals(10.0, time_allowed)

        allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5)
        self.assertFalse(allowed)
        self.assertEquals(10.0, time_allowed)

        allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10)
        self.assertTrue(allowed)
        self.assertEquals(20.0, time_allowed)
Example #18
0
    def test_pruning(self):
        limiter = Ratelimiter()
        allowed, time_allowed = limiter.can_do_action(
            key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1
        )

        self.assertIn("test_id_1", limiter.message_counts)

        allowed, time_allowed = limiter.can_do_action(
            key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1
        )

        self.assertNotIn("test_id_1", limiter.message_counts)
Example #19
0
    def test_pruning(self):
        limiter = Ratelimiter()
        allowed, time_allowed = limiter.send_message(
            user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1,
        )

        self.assertIn("test_id_1", limiter.message_counts)

        allowed, time_allowed = limiter.send_message(
            user_id="test_id_2", time_now_s=10, msg_rate_hz=0.1, burst_count=1
        )

        self.assertNotIn("test_id_1", limiter.message_counts)
Example #20
0
class LookupBindStatusForOpenid(RestServlet):
    PATTERNS = client_patterns("/login/oauth2/bind/status$", v1=True)

    def __init__(self, hs):
        super().__init__()
        self.hs = hs
        self.auth = hs.get_auth()
        # self.get_ver_code_cache = ExpiringCache(
        #     cache_name="get_ver_code_cache",
        #     clock=self._clock,
        #     max_len=1000,
        #     expiry_ms=10 * 60 * 1000,
        #     reset_expiry_on_get=False,
        # )

        self._address_ratelimiter = Ratelimiter(
            clock=hs.get_clock(),
            rate_hz=self.hs.config.rc_login_address.per_second,
            burst_count=self.hs.config.rc_login_address.burst_count,
        )
        self.http_client = SimpleHttpClient(hs)

    async def on_POST(self, request: SynapseRequest):
        self._address_ratelimiter.ratelimit(request.getClientIP())

        params = parse_json_object_from_request(request)
        logger.info("----lookup bind--------param:%s" % (str(params)))

        bind_type = params["bind_type"]
        if bind_type is None:
            raise LoginError(410,
                             "bind_type field for bind openid is missing",
                             errcode=Codes.FORBIDDEN)

        requester = await self.auth.get_user_by_req(request)
        logger.info('------requester: %s' % (requester, ))
        user_id = requester.user
        logger.info('------user: %s' % (user_id, ))

        #complete unbind
        openid = await self.hs.get_datastore(
        ).get_external_id_for_user_provider(
            bind_type,
            str(user_id),
        )
        if openid is None:
            raise LoginError(400,
                             "openid not bind",
                             errcode=Codes.OPENID_NOT_BIND)

        return 200, {}
Example #21
0
    def test_pruning(self):
        limiter = Ratelimiter(store=self.hs.get_datastores().main,
                              clock=None,
                              rate_hz=0.1,
                              burst_count=1)
        self.get_success_or_raise(
            limiter.can_do_action(None, key="test_id_1", _time_now_s=0))

        self.assertIn("test_id_1", limiter.actions)

        self.get_success_or_raise(
            limiter.can_do_action(None, key="test_id_2", _time_now_s=10))

        self.assertNotIn("test_id_1", limiter.actions)
Example #22
0
    def __init__(self, hs: "HomeServer"):
        super().__init__()
        self.hs = hs

        # JWT configuration variables.
        self.jwt_enabled = hs.config.jwt.jwt_enabled
        self.jwt_secret = hs.config.jwt.jwt_secret
        self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
        self.jwt_algorithm = hs.config.jwt.jwt_algorithm
        self.jwt_issuer = hs.config.jwt.jwt_issuer
        self.jwt_audiences = hs.config.jwt.jwt_audiences

        # SSO configuration.
        self.saml2_enabled = hs.config.saml2.saml2_enabled
        self.cas_enabled = hs.config.cas.cas_enabled
        self.oidc_enabled = hs.config.oidc.oidc_enabled
        self._refresh_tokens_enabled = (
            hs.config.registration.refreshable_access_token_lifetime
            is not None)

        self.auth = hs.get_auth()

        self.clock = hs.get_clock()

        self.auth_handler = self.hs.get_auth_handler()
        self.registration_handler = hs.get_registration_handler()
        self._sso_handler = hs.get_sso_handler()

        self._well_known_builder = WellKnownBuilder(hs)
        self._address_ratelimiter = Ratelimiter(
            store=hs.get_datastores().main,
            clock=hs.get_clock(),
            rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second,
            burst_count=self.hs.config.ratelimiting.rc_login_address.
            burst_count,
        )
        self._account_ratelimiter = Ratelimiter(
            store=hs.get_datastores().main,
            clock=hs.get_clock(),
            rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second,
            burst_count=self.hs.config.ratelimiting.rc_login_account.
            burst_count,
        )

        # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance.
        # The reason for this is to ensure that the auth_provider_ids are registered
        # with SsoHandler, which in turn ensures that the login/registration prometheus
        # counters are initialised for the auth_provider_ids.
        _load_sso_handlers(hs)
    def test_pruning(self):
        limiter = Ratelimiter()
        allowed, time_allowed = limiter.can_do_action(key="test_id_1",
                                                      time_now_s=0,
                                                      rate_hz=0.1,
                                                      burst_count=1)

        self.assertIn("test_id_1", limiter.message_counts)

        allowed, time_allowed = limiter.can_do_action(key="test_id_2",
                                                      time_now_s=10,
                                                      rate_hz=0.1,
                                                      burst_count=1)

        self.assertNotIn("test_id_1", limiter.message_counts)
    def test_pruning(self):
        limiter = Ratelimiter()
        allowed, time_allowed = limiter.send_message(user_id="test_id_1",
                                                     time_now_s=0,
                                                     msg_rate_hz=0.1,
                                                     burst_count=1)

        self.assertIn("test_id_1", limiter.message_counts)

        allowed, time_allowed = limiter.send_message(user_id="test_id_2",
                                                     time_now_s=10,
                                                     msg_rate_hz=0.1,
                                                     burst_count=1)

        self.assertNotIn("test_id_1", limiter.message_counts)
Example #25
0
    def test_rate_limit_burst_only_given_once(self) -> None:
        """
        Regression test against a bug that meant that you could build up
        extra tokens by timing requests.
        """
        limiter = Ratelimiter(store=self.hs.get_datastores().main,
                              clock=None,
                              rate_hz=0.1,
                              burst_count=3)

        def consume_at(time: float) -> bool:
            success, _ = self.get_success_or_raise(
                limiter.can_do_action(requester=None,
                                      key="a",
                                      _time_now_s=time))
            return success

        # Use all our 3 burst tokens
        self.assertTrue(consume_at(0.0))
        self.assertTrue(consume_at(0.1))
        self.assertTrue(consume_at(0.2))

        # Wait to recover 1 token (10 seconds at 0.1 Hz).
        self.assertTrue(consume_at(10.1))

        # Check that we get rate limited after using that token.
        self.assertFalse(consume_at(11.1))
Example #26
0
 def get_registration_ratelimiter(self) -> Ratelimiter:
     return Ratelimiter(
         store=self.get_datastore(),
         clock=self.get_clock(),
         rate_hz=self.config.ratelimiting.rc_registration.per_second,
         burst_count=self.config.ratelimiting.rc_registration.burst_count,
     )
Example #27
0
    def __init__(self, hs):
        super().__init__()
        self.hs = hs
        # self.get_ver_code_cache = ExpiringCache(
        #     cache_name="get_ver_code_cache",
        #     clock=self._clock,
        #     max_len=1000,
        #     expiry_ms=10 * 60 * 1000,
        #     reset_expiry_on_get=False,
        # )

        self._address_ratelimiter = Ratelimiter(
            clock=hs.get_clock(),
            rate_hz=self.hs.config.rc_login_address.per_second,
            burst_count=self.hs.config.rc_login_address.burst_count,
        )
        self.http_client = SimpleHttpClient(hs)
Example #28
0
    def test_allowed_user_via_can_requester_do_action(self):
        user_requester = create_requester("@user:example.com")
        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
        allowed, time_allowed = limiter.can_requester_do_action(user_requester,
                                                                _time_now_s=0)
        self.assertTrue(allowed)
        self.assertEquals(10.0, time_allowed)

        allowed, time_allowed = limiter.can_requester_do_action(user_requester,
                                                                _time_now_s=5)
        self.assertFalse(allowed)
        self.assertEquals(10.0, time_allowed)

        allowed, time_allowed = limiter.can_requester_do_action(user_requester,
                                                                _time_now_s=10)
        self.assertTrue(allowed)
        self.assertEquals(20.0, time_allowed)
    def test_multiple_actions(self):
        limiter = Ratelimiter(
            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=3
        )
        # Test that 4 actions aren't allowed with a maximum burst of 3.
        allowed, time_allowed = self.get_success_or_raise(
            limiter.can_do_action(None, key="test_id", n_actions=4, _time_now_s=0)
        )
        self.assertFalse(allowed)

        # Test that 3 actions are allowed with a maximum burst of 3.
        allowed, time_allowed = self.get_success_or_raise(
            limiter.can_do_action(None, key="test_id", n_actions=3, _time_now_s=0)
        )
        self.assertTrue(allowed)
        self.assertEquals(10.0, time_allowed)

        # Test that, after doing these 3 actions, we can't do any more action without
        # waiting.
        allowed, time_allowed = self.get_success_or_raise(
            limiter.can_do_action(None, key="test_id", n_actions=1, _time_now_s=0)
        )
        self.assertFalse(allowed)
        self.assertEquals(10.0, time_allowed)

        # Test that after waiting we can do only 1 action.
        allowed, time_allowed = self.get_success_or_raise(
            limiter.can_do_action(
                None,
                key="test_id",
                update=False,
                n_actions=1,
                _time_now_s=10,
            )
        )
        self.assertTrue(allowed)
        # The time allowed is the current time because we could still repeat the action
        # once.
        self.assertEquals(10.0, time_allowed)

        allowed, time_allowed = self.get_success_or_raise(
            limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=10)
        )
        self.assertFalse(allowed)
        # The time allowed doesn't change despite allowed being False because, while we
        # don't allow 2 actions, we could still do 1.
        self.assertEquals(10.0, time_allowed)

        # Test that after waiting a bit more we can do 2 actions.
        allowed, time_allowed = self.get_success_or_raise(
            limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=20)
        )
        self.assertTrue(allowed)
        # The time allowed is the current time because we could still repeat the action
        # once.
        self.assertEquals(20.0, time_allowed)
Example #30
0
    def test_allowed_via_can_do_action(self):
        limiter = Ratelimiter(store=self.hs.get_datastores().main,
                              clock=None,
                              rate_hz=0.1,
                              burst_count=1)
        allowed, time_allowed = self.get_success_or_raise(
            limiter.can_do_action(None, key="test_id", _time_now_s=0))
        self.assertTrue(allowed)
        self.assertEqual(10.0, time_allowed)

        allowed, time_allowed = self.get_success_or_raise(
            limiter.can_do_action(None, key="test_id", _time_now_s=5))
        self.assertFalse(allowed)
        self.assertEqual(10.0, time_allowed)

        allowed, time_allowed = self.get_success_or_raise(
            limiter.can_do_action(None, key="test_id", _time_now_s=10))
        self.assertTrue(allowed)
        self.assertEqual(20.0, time_allowed)
Example #31
0
    def test_allowed(self):
        limiter = Ratelimiter()
        allowed, time_allowed = limiter.send_message(
            user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1,
        )
        self.assertTrue(allowed)
        self.assertEquals(10., time_allowed)

        allowed, time_allowed = limiter.send_message(
            user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1,
        )
        self.assertFalse(allowed)
        self.assertEquals(10., time_allowed)

        allowed, time_allowed = limiter.send_message(
            user_id="test_id", time_now_s=10, msg_rate_hz=0.1, burst_count=1
        )
        self.assertTrue(allowed)
        self.assertEquals(20., time_allowed)
Example #32
0
    def test_allowed_via_ratelimit(self):
        limiter = Ratelimiter(
            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
        )

        # Shouldn't raise
        self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0))

        # Should raise
        with self.assertRaises(LimitExceededError) as context:
            self.get_success_or_raise(
                limiter.ratelimit(None, key="test_id", _time_now_s=5)
            )
        self.assertEqual(context.exception.retry_after_ms, 5000)

        # Shouldn't raise
        self.get_success_or_raise(
            limiter.ratelimit(None, key="test_id", _time_now_s=10)
        )
Example #33
0
    def test_allowed(self):
        limiter = Ratelimiter()
        allowed, time_allowed = limiter.can_do_action(
            key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1
        )
        self.assertTrue(allowed)
        self.assertEquals(10.0, time_allowed)

        allowed, time_allowed = limiter.can_do_action(
            key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1
        )
        self.assertFalse(allowed)
        self.assertEquals(10.0, time_allowed)

        allowed, time_allowed = limiter.can_do_action(
            key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1
        )
        self.assertTrue(allowed)
        self.assertEquals(20.0, time_allowed)
Example #34
0
    def __init__(self, hs: "HomeServer"):
        """
        Args:
            hs: server
        """
        self.store = hs.get_datastore()
        self.notifier = hs.get_notifier()
        self.is_mine = hs.is_mine

        # We only need to poke the federation sender explicitly if its on the
        # same instance. Other federation sender instances will get notified by
        # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
        # in the to-device replication stream.
        self.federation_sender = None
        if hs.should_send_federation():
            self.federation_sender = hs.get_federation_sender()

        # If we can handle the to device EDUs we do so, otherwise we route them
        # to the appropriate worker.
        if hs.get_instance_name() in hs.config.worker.writers.to_device:
            hs.get_federation_registry().register_edu_handler(
                "m.direct_to_device", self.on_direct_to_device_edu)
        else:
            hs.get_federation_registry().register_instances_for_edu(
                "m.direct_to_device",
                hs.config.worker.writers.to_device,
            )

        # The handler to call when we think a user's device list might be out of
        # sync. We do all device list resyncing on the master instance, so if
        # we're on a worker we hit the device resync replication API.
        if hs.config.worker.worker_app is None:
            self._user_device_resync = (
                hs.get_device_handler().device_list_updater.user_device_resync)
        else:
            self._user_device_resync = (
                ReplicationUserDevicesResyncRestServlet.make_client(hs))

        self._ratelimiter = Ratelimiter(
            clock=hs.get_clock(),
            rate_hz=hs.config.rc_key_requests.per_second,
            burst_count=hs.config.rc_key_requests.burst_count,
        )
Example #35
0
 def __init__(self, hs):
     super(LoginRestServlet, self).__init__(hs)
     self.jwt_enabled = hs.config.jwt_enabled
     self.jwt_secret = hs.config.jwt_secret
     self.jwt_algorithm = hs.config.jwt_algorithm
     self.cas_enabled = hs.config.cas_enabled
     self.auth_handler = self.hs.get_auth_handler()
     self.registration_handler = hs.get_registration_handler()
     self.handlers = hs.get_handlers()
     self._well_known_builder = WellKnownBuilder(hs)
     self._address_ratelimiter = Ratelimiter()
Example #36
0
class LoginRestServlet(ClientV1RestServlet):
    PATTERNS = client_path_patterns("/login$")
    CAS_TYPE = "m.login.cas"
    SSO_TYPE = "m.login.sso"
    TOKEN_TYPE = "m.login.token"
    JWT_TYPE = "m.login.jwt"

    def __init__(self, hs):
        super(LoginRestServlet, self).__init__(hs)
        self.jwt_enabled = hs.config.jwt_enabled
        self.jwt_secret = hs.config.jwt_secret
        self.jwt_algorithm = hs.config.jwt_algorithm
        self.cas_enabled = hs.config.cas_enabled
        self.auth_handler = self.hs.get_auth_handler()
        self.registration_handler = hs.get_registration_handler()
        self.handlers = hs.get_handlers()
        self._well_known_builder = WellKnownBuilder(hs)
        self._address_ratelimiter = Ratelimiter()

    def on_GET(self, request):
        flows = []
        if self.jwt_enabled:
            flows.append({"type": LoginRestServlet.JWT_TYPE})
        if self.cas_enabled:
            flows.append({"type": LoginRestServlet.SSO_TYPE})

            # we advertise CAS for backwards compat, though MSC1721 renamed it
            # to SSO.
            flows.append({"type": LoginRestServlet.CAS_TYPE})

            # While its valid for us to advertise this login type generally,
            # synapse currently only gives out these tokens as part of the
            # CAS login flow.
            # Generally we don't want to advertise login flows that clients
            # don't know how to implement, since they (currently) will always
            # fall back to the fallback API if they don't understand one of the
            # login flow types returned.
            flows.append({"type": LoginRestServlet.TOKEN_TYPE})

        flows.extend((
            {"type": t} for t in self.auth_handler.get_supported_login_types()
        ))

        return (200, {"flows": flows})

    def on_OPTIONS(self, request):
        return (200, {})

    @defer.inlineCallbacks
    def on_POST(self, request):
        self._address_ratelimiter.ratelimit(
            request.getClientIP(), time_now_s=self.hs.clock.time(),
            rate_hz=self.hs.config.rc_login_address.per_second,
            burst_count=self.hs.config.rc_login_address.burst_count,
            update=True,
        )

        login_submission = parse_json_object_from_request(request)
        try:
            if self.jwt_enabled and (login_submission["type"] ==
                                     LoginRestServlet.JWT_TYPE):
                result = yield self.do_jwt_login(login_submission)
            elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
                result = yield self.do_token_login(login_submission)
            else:
                result = yield self._do_other_login(login_submission)
        except KeyError:
            raise SynapseError(400, "Missing JSON keys.")

        well_known_data = self._well_known_builder.get_well_known()
        if well_known_data:
            result["well_known"] = well_known_data
        defer.returnValue((200, result))

    @defer.inlineCallbacks
    def _do_other_login(self, login_submission):
        """Handle non-token/saml/jwt logins

        Args:
            login_submission:

        Returns:
            dict: HTTP response
        """
        # Log the request we got, but only certain fields to minimise the chance of
        # logging someone's password (even if they accidentally put it in the wrong
        # field)
        logger.info(
            "Got login request with identifier: %r, medium: %r, address: %r, user: %r",
            login_submission.get('identifier'),
            login_submission.get('medium'),
            login_submission.get('address'),
            login_submission.get('user'),
        )
        login_submission_legacy_convert(login_submission)

        if "identifier" not in login_submission:
            raise SynapseError(400, "Missing param: identifier")

        identifier = login_submission["identifier"]
        if "type" not in identifier:
            raise SynapseError(400, "Login identifier has no type")

        # convert phone type identifiers to generic threepids
        if identifier["type"] == "m.id.phone":
            identifier = login_id_thirdparty_from_phone(identifier)

        # convert threepid identifiers to user IDs
        if identifier["type"] == "m.id.thirdparty":
            address = identifier.get('address')
            medium = identifier.get('medium')

            if medium is None or address is None:
                raise SynapseError(400, "Invalid thirdparty identifier")

            if medium == 'email':
                # For emails, transform the address to lowercase.
                # We store all email addreses as lowercase in the DB.
                # (See add_threepid in synapse/handlers/auth.py)
                address = address.lower()

            # Check for login providers that support 3pid login types
            canonical_user_id, callback_3pid = (
                yield self.auth_handler.check_password_provider_3pid(
                    medium,
                    address,
                    login_submission["password"],
                )
            )
            if canonical_user_id:
                # Authentication through password provider and 3pid succeeded
                result = yield self._register_device_with_callback(
                    canonical_user_id, login_submission, callback_3pid,
                )
                defer.returnValue(result)

            # No password providers were able to handle this 3pid
            # Check local store
            user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
                medium, address,
            )
            if not user_id:
                logger.warn(
                    "unknown 3pid identifier medium %s, address %r",
                    medium, address,
                )
                raise LoginError(403, "", errcode=Codes.FORBIDDEN)

            identifier = {
                "type": "m.id.user",
                "user": user_id,
            }

        # by this point, the identifier should be an m.id.user: if it's anything
        # else, we haven't understood it.
        if identifier["type"] != "m.id.user":
            raise SynapseError(400, "Unknown login identifier type")
        if "user" not in identifier:
            raise SynapseError(400, "User identifier is missing 'user' key")

        canonical_user_id, callback = yield self.auth_handler.validate_login(
            identifier["user"],
            login_submission,
        )

        result = yield self._register_device_with_callback(
            canonical_user_id, login_submission, callback,
        )
        defer.returnValue(result)

    @defer.inlineCallbacks
    def _register_device_with_callback(
        self,
        user_id,
        login_submission,
        callback=None,
    ):
        """ Registers a device with a given user_id. Optionally run a callback
        function after registration has completed.

        Args:
            user_id (str): ID of the user to register.
            login_submission (dict): Dictionary of login information.
            callback (func|None): Callback function to run after registration.

        Returns:
            result (Dict[str,str]): Dictionary of account information after
                successful registration.
        """
        device_id = login_submission.get("device_id")
        initial_display_name = login_submission.get("initial_device_display_name")
        device_id, access_token = yield self.registration_handler.register_device(
            user_id, device_id, initial_display_name,
        )

        result = {
            "user_id": user_id,
            "access_token": access_token,
            "home_server": self.hs.hostname,
            "device_id": device_id,
        }

        if callback is not None:
            yield callback(result)

        defer.returnValue(result)

    @defer.inlineCallbacks
    def do_token_login(self, login_submission):
        token = login_submission['token']
        auth_handler = self.auth_handler
        user_id = (
            yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
        )

        device_id = login_submission.get("device_id")
        initial_display_name = login_submission.get("initial_device_display_name")
        device_id, access_token = yield self.registration_handler.register_device(
            user_id, device_id, initial_display_name,
        )

        result = {
            "user_id": user_id,  # may have changed
            "access_token": access_token,
            "home_server": self.hs.hostname,
            "device_id": device_id,
        }

        defer.returnValue(result)

    @defer.inlineCallbacks
    def do_jwt_login(self, login_submission):
        token = login_submission.get("token", None)
        if token is None:
            raise LoginError(
                401, "Token field for JWT is missing",
                errcode=Codes.UNAUTHORIZED
            )

        import jwt
        from jwt.exceptions import InvalidTokenError

        try:
            payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
        except jwt.ExpiredSignatureError:
            raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
        except InvalidTokenError:
            raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)

        user = payload.get("sub", None)
        if user is None:
            raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)

        user_id = UserID(user, self.hs.hostname).to_string()

        auth_handler = self.auth_handler
        registered_user_id = yield auth_handler.check_user_exists(user_id)
        if registered_user_id:
            device_id = login_submission.get("device_id")
            initial_display_name = login_submission.get("initial_device_display_name")
            device_id, access_token = yield self.registration_handler.register_device(
                registered_user_id, device_id, initial_display_name,
            )

            result = {
                "user_id": registered_user_id,
                "access_token": access_token,
                "home_server": self.hs.hostname,
            }
        else:
            user_id, access_token = (
                yield self.handlers.registration_handler.register(localpart=user)
            )

            device_id = login_submission.get("device_id")
            initial_display_name = login_submission.get("initial_device_display_name")
            device_id, access_token = yield self.registration_handler.register_device(
                registered_user_id, device_id, initial_display_name,
            )

            result = {
                "user_id": user_id,  # may have changed
                "access_token": access_token,
                "home_server": self.hs.hostname,
            }

        defer.returnValue(result)