def __init__(self, hostname, reactor=None, **kwargs): """ Args: hostname : The hostname for the server. """ if not reactor: from twisted.internet import reactor self._reactor = reactor self.hostname = hostname self._building = {} self._listening_services = [] self.start_time = None self.clock = Clock(reactor) self.distributor = Distributor() self.ratelimiter = Ratelimiter() self.admin_redaction_ratelimiter = Ratelimiter() self.registration_ratelimiter = Ratelimiter() self.datastores = None # Other kwargs are explicit dependencies for depname in kwargs: setattr(self, depname, kwargs[depname])
def test_allowed_appservice_via_can_requester_do_action(self): appservice = ApplicationService( None, "example.com", id="foo", rate_limited=False, sender="@as:example.com", ) as_requester = create_requester("@user:example.com", app_service=appservice) limiter = Ratelimiter( store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed)
def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs # JWT configuration variables. self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm self.jwt_issuer = hs.config.jwt_issuer self.jwt_audiences = hs.config.jwt_audiences # SSO configuration. self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled self.auth = hs.get_auth() self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_address.per_second, burst_count=self.hs.config.rc_login_address.burst_count, ) self._account_ratelimiter = Ratelimiter( clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_account.per_second, burst_count=self.hs.config.rc_login_account.burst_count, )
def __init__(self, hs): """ Args: hs (synapse.server.HomeServer): """ self.store = hs.get_datastore() # type: synapse.storage.DataStore self.auth = hs.get_auth() self.notifier = hs.get_notifier() self.state_handler = hs.get_state_handler( ) # type: synapse.state.StateHandler self.distributor = hs.get_distributor() self.clock = hs.get_clock() self.hs = hs # The rate_hz and burst_count are overridden on a per-user basis self.request_ratelimiter = Ratelimiter(clock=self.clock, rate_hz=0, burst_count=0) self._rc_message = self.hs.config.rc_message # Check whether ratelimiting room admin message redaction is enabled # by the presence of rate limits in the config if self.hs.config.rc_admin_redaction: self.admin_redaction_ratelimiter = Ratelimiter( clock=self.clock, rate_hz=self.hs.config.rc_admin_redaction.per_second, burst_count=self.hs.config.rc_admin_redaction.burst_count, ) else: self.admin_redaction_ratelimiter = None self.server_name = hs.hostname self.event_builder_factory = hs.get_event_builder_factory()
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): appservice = ApplicationService( None, "example.com", id="foo", rate_limited=True, sender="@as:example.com", ) as_requester = create_requester("@user:example.com", app_service=appservice) limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) allowed, time_allowed = limiter.can_requester_do_action(as_requester, _time_now_s=0) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) allowed, time_allowed = limiter.can_requester_do_action(as_requester, _time_now_s=5) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) allowed, time_allowed = limiter.can_requester_do_action(as_requester, _time_now_s=10) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed)
def __init__(self, hs): super().__init__(hs) # An HTTP client for contacting trusted URLs. self.http_client = SimpleHttpClient(hs) # An HTTP client for contacting identity servers specified by clients. self.blacklisting_http_client = SimpleHttpClient( hs, ip_blacklist=hs.config.federation_ip_range_blacklist) self.federation_http_client = hs.get_federation_http_client() self.hs = hs self._web_client_location = hs.config.invite_client_location # Ratelimiters for `/requestToken` endpoints. self._3pid_validation_ratelimiter_ip = Ratelimiter( store=self.store, clock=hs.get_clock(), rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, ) self._3pid_validation_ratelimiter_address = Ratelimiter( store=self.store, clock=hs.get_clock(), rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, )
def test_db_user_override(self): """Test that users that have ratelimiting disabled in the DB aren't ratelimited. """ store = self.hs.get_datastore() user_id = "@user:test" requester = create_requester(user_id) self.get_success( store.db_pool.simple_insert( table="ratelimit_override", values={ "user_id": user_id, "messages_per_second": None, "burst_count": None, }, desc="test_db_user_override", ) ) limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1) # Shouldn't raise for _ in range(20): self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))
def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs): """ Args: hostname : The hostname for the server. config: The full config for the homeserver. """ if not reactor: from twisted.internet import reactor self._reactor = reactor self.hostname = hostname self.config = config self._building = {} self._listening_services = [] self.start_time = None self._instance_id = random_string(5) self._instance_name = config.worker_name or "master" self.clock = Clock(reactor) self.distributor = Distributor() self.ratelimiter = Ratelimiter() self.admin_redaction_ratelimiter = Ratelimiter() self.registration_ratelimiter = Ratelimiter() self.datastores = None # Other kwargs are explicit dependencies for depname in kwargs: setattr(self, depname, kwargs[depname])
def __init__(self, hs: "HomeServer"): self.config = hs.config self.clock = hs.get_clock() self._instance_name = hs.get_instance_name() # These are safe to load in monolith mode, but will explode if we try # and use them. However we have guards before we use them to ensure that # we don't route to ourselves, and in monolith mode that will always be # the case. self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs) self._send_edu = ReplicationFederationSendEduRestServlet.make_client( hs) self.edu_handlers = ( {}) # type: Dict[str, Callable[[str, dict], Awaitable[None]]] self.query_handlers = { } # type: Dict[str, Callable[[dict], Awaitable[None]]] # Map from type to instance names that we should route EDU handling to. # We randomly choose one instance from the list to route to for each new # EDU received. self._edu_type_to_instance = {} # type: Dict[str, List[str]] # A rate limiter for incoming room key requests per origin. self._room_key_request_rate_limiter = Ratelimiter( clock=self.clock, rate_hz=self.config.rc_key_requests.per_second, burst_count=self.config.rc_key_requests.burst_count, )
def test_allowed(self): limiter = Ratelimiter() allowed, time_allowed = limiter.send_message( user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1, ) self.assertTrue(allowed) self.assertEquals(10., time_allowed) allowed, time_allowed = limiter.send_message( user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1, ) self.assertFalse(allowed) self.assertEquals(10., time_allowed) allowed, time_allowed = limiter.send_message(user_id="test_id", time_now_s=10, msg_rate_hz=0.1, burst_count=1) self.assertTrue(allowed) self.assertEquals(20., time_allowed)
def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() # An HTTP client for contacting trusted URLs. self.http_client = SimpleHttpClient(hs) # An HTTP client for contacting identity servers specified by clients. self.blacklisting_http_client = SimpleHttpClient( hs, ip_blacklist=hs.config.server.federation_ip_range_blacklist, ip_whitelist=hs.config.server.federation_ip_range_whitelist, ) self.federation_http_client = hs.get_federation_http_client() self.hs = hs self.rewrite_identity_server_urls = ( hs.config.registration.rewrite_identity_server_urls ) self._enable_lookup = hs.config.registration.enable_3pid_lookup self._web_client_location = hs.config.email.invite_client_location # Ratelimiters for `/requestToken` endpoints. self._3pid_validation_ratelimiter_ip = Ratelimiter( store=self.store, clock=hs.get_clock(), rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, ) self._3pid_validation_ratelimiter_address = Ratelimiter( store=self.store, clock=hs.get_clock(), rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, )
def __init__(self, hs): super(CustomRestServlet, self).__init__() self.hs = hs self.store = hs.get_datastore() self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_address.per_second, burst_count=self.hs.config.rc_login_address.burst_count, ) self._account_ratelimiter = Ratelimiter( clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_account.per_second, burst_count=self.hs.config.rc_login_account.burst_count, ) self._failed_attempts_ratelimiter = Ratelimiter( clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, )
def __init__(self, hs): """ Args: hs (synapse.server.HomeServer): """ super(AuthHandler, self).__init__(hs) self.checkers = {} # type: dict[str, UserInteractiveAuthChecker] for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: inst = auth_checker_class(hs) if inst.is_enabled(): self.checkers[inst.AUTH_TYPE] = inst self.bcrypt_rounds = hs.config.bcrypt_rounds # This is not a cache per se, but a store of all current sessions that # expire after N hours self.sessions = ExpiringCache( cache_name="register_sessions", clock=hs.get_clock(), expiry_ms=self.SESSION_EXPIRE_MS, reset_expiry_on_get=True, ) account_handler = ModuleApi(hs, self) self.password_providers = [ module(config=config, account_handler=account_handler) for module, config in hs.config.password_providers ] logger.info("Extra password_providers: %r", self.password_providers) self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.password_enabled # we keep this as a list despite the O(N^2) implication so that we can # keep PASSWORD first and avoid confusing clients which pick the first # type in the list. (NB that the spec doesn't require us to do so and # clients which favour types that they don't understand over those that # they do are technically broken) login_types = [] if self._password_enabled: login_types.append(LoginType.PASSWORD) for provider in self.password_providers: if hasattr(provider, "get_supported_login_types"): for t in provider.get_supported_login_types().keys(): if t not in login_types: login_types.append(t) self._supported_login_types = login_types # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. self._failed_uia_attempts_ratelimiter = Ratelimiter() self._clock = self.hs.get_clock()
def __init__(self, hs): """ Args: hs (synapse.server.HomeServer): """ super(AuthHandler, self).__init__(hs) self.checkers = { LoginType.RECAPTCHA: self._check_recaptcha, LoginType.EMAIL_IDENTITY: self._check_email_identity, LoginType.MSISDN: self._check_msisdn, LoginType.DUMMY: self._check_dummy_auth, LoginType.TERMS: self._check_terms_auth, } self.bcrypt_rounds = hs.config.bcrypt_rounds # This is not a cache per se, but a store of all current sessions that # expire after N hours self.sessions = ExpiringCache( cache_name="register_sessions", clock=hs.get_clock(), expiry_ms=self.SESSION_EXPIRE_MS, reset_expiry_on_get=True, ) account_handler = ModuleApi(hs, self) self.password_providers = [ module(config=config, account_handler=account_handler) for module, config in hs.config.password_providers ] logger.info("Extra password_providers: %r", self.password_providers) self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.password_enabled # we keep this as a list despite the O(N^2) implication so that we can # keep PASSWORD first and avoid confusing clients which pick the first # type in the list. (NB that the spec doesn't require us to do so and # clients which favour types that they don't understand over those that # they do are technically broken) login_types = [] if self._password_enabled: login_types.append(LoginType.PASSWORD) for provider in self.password_providers: if hasattr(provider, "get_supported_login_types"): for t in provider.get_supported_login_types().keys(): if t not in login_types: login_types.append(t) self._supported_login_types = login_types self._account_ratelimiter = Ratelimiter() self._failed_attempts_ratelimiter = Ratelimiter() self._clock = self.hs.get_clock()
def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.config = hs.config self.federation_handler = hs.get_federation_handler() self.directory_handler = hs.get_directory_handler() self.identity_handler = hs.get_identity_handler() self.registration_handler = hs.get_registration_handler() self.profile_handler = hs.get_profile_handler() self.event_creation_handler = hs.get_event_creation_handler() self.account_data_handler = hs.get_account_data_handler() self.event_auth_handler = hs.get_event_auth_handler() self.member_linearizer = Linearizer(name="member") self.clock = hs.get_clock() self.spam_checker = hs.get_spam_checker() self.third_party_event_rules = hs.get_third_party_event_rules() self._server_notices_mxid = self.config.server_notices_mxid self._enable_lookup = hs.config.enable_3pid_lookup self.allow_per_room_profiles = self.config.allow_per_room_profiles self._join_rate_limiter_local = Ratelimiter( store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, ) self._join_rate_limiter_remote = Ratelimiter( store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, ) self._invites_per_room_limiter = Ratelimiter( store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, ) self._invites_per_user_limiter = Ratelimiter( store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, ) # This is only used to get at ratelimit function, and # maybe_kick_guest_users. It's fine there are multiple of these as # it doesn't store state. self.base_handler = BaseHandler(hs)
def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm self.cas_enabled = hs.config.cas_enabled self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter()
def test_allowed_via_can_do_action(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed)
def test_pruning(self): limiter = Ratelimiter() allowed, time_allowed = limiter.can_do_action( key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1 ) self.assertIn("test_id_1", limiter.message_counts) allowed, time_allowed = limiter.can_do_action( key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1 ) self.assertNotIn("test_id_1", limiter.message_counts)
def test_pruning(self): limiter = Ratelimiter() allowed, time_allowed = limiter.send_message( user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1, ) self.assertIn("test_id_1", limiter.message_counts) allowed, time_allowed = limiter.send_message( user_id="test_id_2", time_now_s=10, msg_rate_hz=0.1, burst_count=1 ) self.assertNotIn("test_id_1", limiter.message_counts)
class LookupBindStatusForOpenid(RestServlet): PATTERNS = client_patterns("/login/oauth2/bind/status$", v1=True) def __init__(self, hs): super().__init__() self.hs = hs self.auth = hs.get_auth() # self.get_ver_code_cache = ExpiringCache( # cache_name="get_ver_code_cache", # clock=self._clock, # max_len=1000, # expiry_ms=10 * 60 * 1000, # reset_expiry_on_get=False, # ) self._address_ratelimiter = Ratelimiter( clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_address.per_second, burst_count=self.hs.config.rc_login_address.burst_count, ) self.http_client = SimpleHttpClient(hs) async def on_POST(self, request: SynapseRequest): self._address_ratelimiter.ratelimit(request.getClientIP()) params = parse_json_object_from_request(request) logger.info("----lookup bind--------param:%s" % (str(params))) bind_type = params["bind_type"] if bind_type is None: raise LoginError(410, "bind_type field for bind openid is missing", errcode=Codes.FORBIDDEN) requester = await self.auth.get_user_by_req(request) logger.info('------requester: %s' % (requester, )) user_id = requester.user logger.info('------user: %s' % (user_id, )) #complete unbind openid = await self.hs.get_datastore( ).get_external_id_for_user_provider( bind_type, str(user_id), ) if openid is None: raise LoginError(400, "openid not bind", errcode=Codes.OPENID_NOT_BIND) return 200, {}
def test_pruning(self): limiter = Ratelimiter(store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1) self.get_success_or_raise( limiter.can_do_action(None, key="test_id_1", _time_now_s=0)) self.assertIn("test_id_1", limiter.actions) self.get_success_or_raise( limiter.can_do_action(None, key="test_id_2", _time_now_s=10)) self.assertNotIn("test_id_1", limiter.actions)
def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs # JWT configuration variables. self.jwt_enabled = hs.config.jwt.jwt_enabled self.jwt_secret = hs.config.jwt.jwt_secret self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim self.jwt_algorithm = hs.config.jwt.jwt_algorithm self.jwt_issuer = hs.config.jwt.jwt_issuer self.jwt_audiences = hs.config.jwt.jwt_audiences # SSO configuration. self.saml2_enabled = hs.config.saml2.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc.oidc_enabled self._refresh_tokens_enabled = ( hs.config.registration.refreshable_access_token_lifetime is not None) self.auth = hs.get_auth() self.clock = hs.get_clock() self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self._sso_handler = hs.get_sso_handler() self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( store=hs.get_datastores().main, clock=hs.get_clock(), rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second, burst_count=self.hs.config.ratelimiting.rc_login_address. burst_count, ) self._account_ratelimiter = Ratelimiter( store=hs.get_datastores().main, clock=hs.get_clock(), rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second, burst_count=self.hs.config.ratelimiting.rc_login_account. burst_count, ) # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance. # The reason for this is to ensure that the auth_provider_ids are registered # with SsoHandler, which in turn ensures that the login/registration prometheus # counters are initialised for the auth_provider_ids. _load_sso_handlers(hs)
def test_pruning(self): limiter = Ratelimiter() allowed, time_allowed = limiter.can_do_action(key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1) self.assertIn("test_id_1", limiter.message_counts) allowed, time_allowed = limiter.can_do_action(key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1) self.assertNotIn("test_id_1", limiter.message_counts)
def test_pruning(self): limiter = Ratelimiter() allowed, time_allowed = limiter.send_message(user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1) self.assertIn("test_id_1", limiter.message_counts) allowed, time_allowed = limiter.send_message(user_id="test_id_2", time_now_s=10, msg_rate_hz=0.1, burst_count=1) self.assertNotIn("test_id_1", limiter.message_counts)
def test_rate_limit_burst_only_given_once(self) -> None: """ Regression test against a bug that meant that you could build up extra tokens by timing requests. """ limiter = Ratelimiter(store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3) def consume_at(time: float) -> bool: success, _ = self.get_success_or_raise( limiter.can_do_action(requester=None, key="a", _time_now_s=time)) return success # Use all our 3 burst tokens self.assertTrue(consume_at(0.0)) self.assertTrue(consume_at(0.1)) self.assertTrue(consume_at(0.2)) # Wait to recover 1 token (10 seconds at 0.1 Hz). self.assertTrue(consume_at(10.1)) # Check that we get rate limited after using that token. self.assertFalse(consume_at(11.1))
def get_registration_ratelimiter(self) -> Ratelimiter: return Ratelimiter( store=self.get_datastore(), clock=self.get_clock(), rate_hz=self.config.ratelimiting.rc_registration.per_second, burst_count=self.config.ratelimiting.rc_registration.burst_count, )
def __init__(self, hs): super().__init__() self.hs = hs # self.get_ver_code_cache = ExpiringCache( # cache_name="get_ver_code_cache", # clock=self._clock, # max_len=1000, # expiry_ms=10 * 60 * 1000, # reset_expiry_on_get=False, # ) self._address_ratelimiter = Ratelimiter( clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_address.per_second, burst_count=self.hs.config.rc_login_address.burst_count, ) self.http_client = SimpleHttpClient(hs)
def test_allowed_user_via_can_requester_do_action(self): user_requester = create_requester("@user:example.com") limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) allowed, time_allowed = limiter.can_requester_do_action(user_requester, _time_now_s=0) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) allowed, time_allowed = limiter.can_requester_do_action(user_requester, _time_now_s=5) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) allowed, time_allowed = limiter.can_requester_do_action(user_requester, _time_now_s=10) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed)
def test_multiple_actions(self): limiter = Ratelimiter( store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=3 ) # Test that 4 actions aren't allowed with a maximum burst of 3. allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", n_actions=4, _time_now_s=0) ) self.assertFalse(allowed) # Test that 3 actions are allowed with a maximum burst of 3. allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", n_actions=3, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) # Test that, after doing these 3 actions, we can't do any more action without # waiting. allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", n_actions=1, _time_now_s=0) ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) # Test that after waiting we can do only 1 action. allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action( None, key="test_id", update=False, n_actions=1, _time_now_s=10, ) ) self.assertTrue(allowed) # The time allowed is the current time because we could still repeat the action # once. self.assertEquals(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=10) ) self.assertFalse(allowed) # The time allowed doesn't change despite allowed being False because, while we # don't allow 2 actions, we could still do 1. self.assertEquals(10.0, time_allowed) # Test that after waiting a bit more we can do 2 actions. allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=20) ) self.assertTrue(allowed) # The time allowed is the current time because we could still repeat the action # once. self.assertEquals(20.0, time_allowed)
def test_allowed_via_can_do_action(self): limiter = Ratelimiter(store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=0)) self.assertTrue(allowed) self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=5)) self.assertFalse(allowed) self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=10)) self.assertTrue(allowed) self.assertEqual(20.0, time_allowed)
def test_allowed(self): limiter = Ratelimiter() allowed, time_allowed = limiter.send_message( user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1, ) self.assertTrue(allowed) self.assertEquals(10., time_allowed) allowed, time_allowed = limiter.send_message( user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1, ) self.assertFalse(allowed) self.assertEquals(10., time_allowed) allowed, time_allowed = limiter.send_message( user_id="test_id", time_now_s=10, msg_rate_hz=0.1, burst_count=1 ) self.assertTrue(allowed) self.assertEquals(20., time_allowed)
def test_allowed_via_ratelimit(self): limiter = Ratelimiter( store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 ) # Shouldn't raise self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0)) # Should raise with self.assertRaises(LimitExceededError) as context: self.get_success_or_raise( limiter.ratelimit(None, key="test_id", _time_now_s=5) ) self.assertEqual(context.exception.retry_after_ms, 5000) # Shouldn't raise self.get_success_or_raise( limiter.ratelimit(None, key="test_id", _time_now_s=10) )
def test_allowed(self): limiter = Ratelimiter() allowed, time_allowed = limiter.can_do_action( key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1 ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) allowed, time_allowed = limiter.can_do_action( key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1 ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) allowed, time_allowed = limiter.can_do_action( key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1 ) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed)
def __init__(self, hs: "HomeServer"): """ Args: hs: server """ self.store = hs.get_datastore() self.notifier = hs.get_notifier() self.is_mine = hs.is_mine # We only need to poke the federation sender explicitly if its on the # same instance. Other federation sender instances will get notified by # `synapse.app.generic_worker.FederationSenderHandler` when it sees it # in the to-device replication stream. self.federation_sender = None if hs.should_send_federation(): self.federation_sender = hs.get_federation_sender() # If we can handle the to device EDUs we do so, otherwise we route them # to the appropriate worker. if hs.get_instance_name() in hs.config.worker.writers.to_device: hs.get_federation_registry().register_edu_handler( "m.direct_to_device", self.on_direct_to_device_edu) else: hs.get_federation_registry().register_instances_for_edu( "m.direct_to_device", hs.config.worker.writers.to_device, ) # The handler to call when we think a user's device list might be out of # sync. We do all device list resyncing on the master instance, so if # we're on a worker we hit the device resync replication API. if hs.config.worker.worker_app is None: self._user_device_resync = ( hs.get_device_handler().device_list_updater.user_device_resync) else: self._user_device_resync = ( ReplicationUserDevicesResyncRestServlet.make_client(hs)) self._ratelimiter = Ratelimiter( clock=hs.get_clock(), rate_hz=hs.config.rc_key_requests.per_second, burst_count=hs.config.rc_key_requests.burst_count, )
class LoginRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login$") CAS_TYPE = "m.login.cas" SSO_TYPE = "m.login.sso" TOKEN_TYPE = "m.login.token" JWT_TYPE = "m.login.jwt" def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm self.cas_enabled = hs.config.cas_enabled self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter() def on_GET(self, request): flows = [] if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) if self.cas_enabled: flows.append({"type": LoginRestServlet.SSO_TYPE}) # we advertise CAS for backwards compat, though MSC1721 renamed it # to SSO. flows.append({"type": LoginRestServlet.CAS_TYPE}) # While its valid for us to advertise this login type generally, # synapse currently only gives out these tokens as part of the # CAS login flow. # Generally we don't want to advertise login flows that clients # don't know how to implement, since they (currently) will always # fall back to the fallback API if they don't understand one of the # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) flows.extend(( {"type": t} for t in self.auth_handler.get_supported_login_types() )) return (200, {"flows": flows}) def on_OPTIONS(self, request): return (200, {}) @defer.inlineCallbacks def on_POST(self, request): self._address_ratelimiter.ratelimit( request.getClientIP(), time_now_s=self.hs.clock.time(), rate_hz=self.hs.config.rc_login_address.per_second, burst_count=self.hs.config.rc_login_address.burst_count, update=True, ) login_submission = parse_json_object_from_request(request) try: if self.jwt_enabled and (login_submission["type"] == LoginRestServlet.JWT_TYPE): result = yield self.do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: result = yield self.do_token_login(login_submission) else: result = yield self._do_other_login(login_submission) except KeyError: raise SynapseError(400, "Missing JSON keys.") well_known_data = self._well_known_builder.get_well_known() if well_known_data: result["well_known"] = well_known_data defer.returnValue((200, result)) @defer.inlineCallbacks def _do_other_login(self, login_submission): """Handle non-token/saml/jwt logins Args: login_submission: Returns: dict: HTTP response """ # Log the request we got, but only certain fields to minimise the chance of # logging someone's password (even if they accidentally put it in the wrong # field) logger.info( "Got login request with identifier: %r, medium: %r, address: %r, user: %r", login_submission.get('identifier'), login_submission.get('medium'), login_submission.get('address'), login_submission.get('user'), ) login_submission_legacy_convert(login_submission) if "identifier" not in login_submission: raise SynapseError(400, "Missing param: identifier") identifier = login_submission["identifier"] if "type" not in identifier: raise SynapseError(400, "Login identifier has no type") # convert phone type identifiers to generic threepids if identifier["type"] == "m.id.phone": identifier = login_id_thirdparty_from_phone(identifier) # convert threepid identifiers to user IDs if identifier["type"] == "m.id.thirdparty": address = identifier.get('address') medium = identifier.get('medium') if medium is None or address is None: raise SynapseError(400, "Invalid thirdparty identifier") if medium == 'email': # For emails, transform the address to lowercase. # We store all email addreses as lowercase in the DB. # (See add_threepid in synapse/handlers/auth.py) address = address.lower() # Check for login providers that support 3pid login types canonical_user_id, callback_3pid = ( yield self.auth_handler.check_password_provider_3pid( medium, address, login_submission["password"], ) ) if canonical_user_id: # Authentication through password provider and 3pid succeeded result = yield self._register_device_with_callback( canonical_user_id, login_submission, callback_3pid, ) defer.returnValue(result) # No password providers were able to handle this 3pid # Check local store user_id = yield self.hs.get_datastore().get_user_id_by_threepid( medium, address, ) if not user_id: logger.warn( "unknown 3pid identifier medium %s, address %r", medium, address, ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) identifier = { "type": "m.id.user", "user": user_id, } # by this point, the identifier should be an m.id.user: if it's anything # else, we haven't understood it. if identifier["type"] != "m.id.user": raise SynapseError(400, "Unknown login identifier type") if "user" not in identifier: raise SynapseError(400, "User identifier is missing 'user' key") canonical_user_id, callback = yield self.auth_handler.validate_login( identifier["user"], login_submission, ) result = yield self._register_device_with_callback( canonical_user_id, login_submission, callback, ) defer.returnValue(result) @defer.inlineCallbacks def _register_device_with_callback( self, user_id, login_submission, callback=None, ): """ Registers a device with a given user_id. Optionally run a callback function after registration has completed. Args: user_id (str): ID of the user to register. login_submission (dict): Dictionary of login information. callback (func|None): Callback function to run after registration. Returns: result (Dict[str,str]): Dictionary of account information after successful registration. """ device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( user_id, device_id, initial_display_name, ) result = { "user_id": user_id, "access_token": access_token, "home_server": self.hs.hostname, "device_id": device_id, } if callback is not None: yield callback(result) defer.returnValue(result) @defer.inlineCallbacks def do_token_login(self, login_submission): token = login_submission['token'] auth_handler = self.auth_handler user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( user_id, device_id, initial_display_name, ) result = { "user_id": user_id, # may have changed "access_token": access_token, "home_server": self.hs.hostname, "device_id": device_id, } defer.returnValue(result) @defer.inlineCallbacks def do_jwt_login(self, login_submission): token = login_submission.get("token", None) if token is None: raise LoginError( 401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED ) import jwt from jwt.exceptions import InvalidTokenError try: payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm]) except jwt.ExpiredSignatureError: raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED) except InvalidTokenError: raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user = payload.get("sub", None) if user is None: raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user_id = UserID(user, self.hs.hostname).to_string() auth_handler = self.auth_handler registered_user_id = yield auth_handler.check_user_exists(user_id) if registered_user_id: device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( registered_user_id, device_id, initial_display_name, ) result = { "user_id": registered_user_id, "access_token": access_token, "home_server": self.hs.hostname, } else: user_id, access_token = ( yield self.handlers.registration_handler.register(localpart=user) ) device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( registered_user_id, device_id, initial_display_name, ) result = { "user_id": user_id, # may have changed "access_token": access_token, "home_server": self.hs.hostname, } defer.returnValue(result)