예제 #1
0
class DatabaseUserService:
    def __init__(self, session, *, ratelimiters=None, metrics):
        if ratelimiters is None:
            ratelimiters = {}
        ratelimiters = collections.defaultdict(DummyRateLimiter, ratelimiters)

        self.db = session
        self.ratelimiters = ratelimiters
        self.hasher = CryptContext(
            schemes=[
                "argon2",
                "bcrypt_sha256",
                "bcrypt",
                "django_bcrypt",
                "unix_disabled",
            ],
            deprecated=["auto"],
            truncate_error=True,
            # Argon 2 Configuration
            argon2__memory_cost=1024,
            argon2__parallelism=6,
            argon2__time_cost=6,
        )
        self._metrics = metrics

    @functools.lru_cache()
    def get_user(self, userid):
        # TODO: We probably don't actually want to just return the database
        #       object here.
        # TODO: We need some sort of Anonymous User.
        return self.db.query(User).get(userid)

    @functools.lru_cache()
    def get_user_by_username(self, username):
        user_id = self.find_userid(username)
        return None if user_id is None else self.get_user(user_id)

    @functools.lru_cache()
    def get_user_by_email(self, email):
        user_id = self.find_userid_by_email(email)
        return None if user_id is None else self.get_user(user_id)

    @functools.lru_cache()
    def find_userid(self, username):
        try:
            user = self.db.query(
                User.id).filter(User.username == username).one()
        except NoResultFound:
            return

        return user.id

    @functools.lru_cache()
    def find_userid_by_email(self, email):
        try:
            # flake8: noqa
            user_id = (self.db.query(
                Email.user_id).filter(Email.email == email).one())[0]
        except NoResultFound:
            return

        return user_id

    def check_password(self, userid, password, *, tags=None):
        tags = tags if tags is not None else []

        self._metrics.increment("warehouse.authentication.start", tags=tags)

        # The very first thing we want to do is check to see if we've hit our
        # global rate limit or not, assuming that we've been configured with a
        # global rate limiter anyways.
        if not self.ratelimiters["global"].test():
            logger.warning("Global failed login threshold reached.")
            self._metrics.increment(
                "warehouse.authentication.ratelimited",
                tags=tags + ["ratelimiter:global"],
            )
            raise TooManyFailedLogins(
                resets_in=self.ratelimiters["global"].resets_in())

        user = self.get_user(userid)
        if user is not None:
            # Now, check to make sure that we haven't hitten a rate limit on a
            # per user basis.
            if not self.ratelimiters["user"].test(user.id):
                self._metrics.increment(
                    "warehouse.authentication.ratelimited",
                    tags=tags + ["ratelimiter:user"],
                )
                raise TooManyFailedLogins(
                    resets_in=self.ratelimiters["user"].resets_in(user.id))

            # Actually check our hash, optionally getting a new hash for it if
            # we should upgrade our saved hashed.
            ok, new_hash = self.hasher.verify_and_update(
                password, user.password)

            # First, check to see if the password that we were given was OK.
            if ok:
                # Then, if the password was OK check to see if we've been given
                # a new password hash from the hasher, if so we'll want to save
                # that hash.
                if new_hash:
                    user.password = new_hash

                self._metrics.increment("warehouse.authentication.ok",
                                        tags=tags)

                return True
            else:
                self._metrics.increment(
                    "warehouse.authentication.failure",
                    tags=tags + ["failure_reason:password"],
                )
        else:
            self._metrics.increment("warehouse.authentication.failure",
                                    tags=tags + ["failure_reason:user"])

        # If we've gotten here, then we'll want to record a failed login in our
        # rate limiting before returning False to indicate a failed password
        # verification.
        if user is not None:
            self.ratelimiters["user"].hit(user.id)
        self.ratelimiters["global"].hit()

        return False

    def create_user(self, username, name, password):
        user = User(username=username,
                    name=name,
                    password=self.hasher.hash(password))
        self.db.add(user)
        self.db.flush()  # flush the db now so user.id is available

        return user

    def add_email(self, user_id, email_address, primary=None, verified=False):
        user = self.get_user(user_id)

        # If primary is None, then we're going to auto detect whether this should be the
        # primary address or not. The basic rule is that if the user doesn't already
        # have a primary address, then the address we're adding now is going to be
        # set to their primary.
        if primary is None:
            primary = True if user.primary_email is None else False

        email = Email(email=email_address,
                      user=user,
                      primary=primary,
                      verified=verified)
        self.db.add(email)
        self.db.flush()  # flush the db now so email.id is available

        return email

    def update_user(self, user_id, **changes):
        user = self.get_user(user_id)
        for attr, value in changes.items():
            if attr == PASSWORD_FIELD:
                value = self.hasher.hash(value)
            setattr(user, attr, value)

        # If we've given the user a new password, then we also want to unset the
        # reason for disable... because a new password means no more disabled
        # user.
        if PASSWORD_FIELD in changes:
            user.disabled_for = None

        return user

    def disable_password(self, user_id, reason=None):
        user = self.get_user(user_id)
        user.password = self.hasher.disable()
        user.disabled_for = reason

    def is_disabled(self, user_id):
        user = self.get_user(user_id)

        # User is not disabled.
        if self.hasher.is_enabled(user.password):
            return (False, None)
        # User is disabled.
        else:
            return (True, user.disabled_for)

    def has_two_factor(self, user_id):
        """
        Returns True if the user has any form of two factor
        authentication and is allowed to use it.
        """
        user = self.get_user(user_id)

        return user.has_two_factor

    def has_totp(self, user_id):
        """
        Returns True if the user has a TOTP device provisioned.
        """
        user = self.get_user(user_id)

        return user.totp_secret is not None

    def has_webauthn(self, user_id):
        """
        Returns True if the user has a security key provisioned.
        """
        user = self.get_user(user_id)

        return len(user.webauthn) > 0

    def get_totp_secret(self, user_id):
        """
        Returns the user's TOTP secret as bytes.

        If the user doesn't have a TOTP, returns None.
        """
        user = self.get_user(user_id)

        return user.totp_secret

    def check_totp_value(self, user_id, totp_value, *, tags=None):
        """
        Returns True if the given TOTP is valid against the user's secret.

        If the user doesn't have a TOTP secret or isn't allowed
        to use second factor methods, returns False.
        """
        tags = tags if tags is not None else []
        self._metrics.increment("warehouse.authentication.two_factor.start",
                                tags=tags)

        # The very first thing we want to do is check to see if we've hit our
        # global rate limit or not, assuming that we've been configured with a
        # global rate limiter anyways.
        if not self.ratelimiters["global"].test():
            logger.warning("Global failed login threshold reached.")
            self._metrics.increment(
                "warehouse.authentication.two_factor.ratelimited",
                tags=tags + ["ratelimiter:global"],
            )
            raise TooManyFailedLogins(
                resets_in=self.ratelimiters["global"].resets_in())

        # Now, check to make sure that we haven't hitten a rate limit on a
        # per user basis.
        if not self.ratelimiters["user"].test(user_id):
            self._metrics.increment(
                "warehouse.authentication.two_factor.ratelimited",
                tags=tags + ["ratelimiter:user"],
            )
            raise TooManyFailedLogins(
                resets_in=self.ratelimiters["user"].resets_in(user_id))

        totp_secret = self.get_totp_secret(user_id)

        if totp_secret is None:
            self._metrics.increment(
                "warehouse.authentication.two_factor.failure",
                tags=tags + ["failure_reason:no_totp"],
            )
            # If we've gotten here, then we'll want to record a failed attempt in our
            # rate limiting before returning False to indicate a failed totp
            # verification.
            self.ratelimiters["user"].hit(user_id)
            self.ratelimiters["global"].hit()
            return False

        valid = otp.verify_totp(totp_secret, totp_value)

        if valid:
            self._metrics.increment("warehouse.authentication.two_factor.ok",
                                    tags=tags)
        else:
            self._metrics.increment(
                "warehouse.authentication.two_factor.failure",
                tags=tags + ["failure_reason:invalid_totp"],
            )
            # If we've gotten here, then we'll want to record a failed attempt in our
            # rate limiting before returning False to indicate a failed totp
            # verification.
            self.ratelimiters["user"].hit(user_id)
            self.ratelimiters["global"].hit()

        return valid

    def get_webauthn_credential_options(self, user_id, *, challenge, rp_name,
                                        rp_id):
        """
        Returns a dictionary of credential options suitable for beginning the WebAuthn
        provisioning process for the given user.
        """
        user = self.get_user(user_id)

        return webauthn.get_credential_options(user,
                                               challenge=challenge,
                                               rp_name=rp_name,
                                               rp_id=rp_id)

    def get_webauthn_assertion_options(self, user_id, *, challenge, rp_id):
        """
        Returns a dictionary of assertion options suitable for beginning the WebAuthn
        authentication process for the given user.
        """
        user = self.get_user(user_id)

        return webauthn.get_assertion_options(user,
                                              challenge=challenge,
                                              rp_id=rp_id)

    def verify_webauthn_credential(self, credential, *, challenge, rp_id,
                                   origin):
        """
        Checks whether the given credential is valid, i.e. suitable for generating
        assertions during authentication.

        Returns the validated credential on success, raises
        webauthn.RegistrationRejectedException on failure.
        """
        validated_credential = webauthn.verify_registration_response(
            credential, challenge=challenge, rp_id=rp_id, origin=origin)

        webauthn_cred = (self.db.query(WebAuthn).filter_by(
            credential_id=validated_credential.credential_id.decode()).first())

        if webauthn_cred is not None:
            raise webauthn.RegistrationRejectedException(
                "Credential ID already in use")

        return validated_credential

    def verify_webauthn_assertion(self, user_id, assertion, *, challenge,
                                  origin, rp_id):
        """
        Checks whether the given assertion was produced by the given user's WebAuthn
        device.

        Returns the updated signage count on success, raises
        webauthn.AuthenticationRejectedException on failure.
        """
        user = self.get_user(user_id)

        return webauthn.verify_assertion_response(assertion,
                                                  challenge=challenge,
                                                  user=user,
                                                  origin=origin,
                                                  rp_id=rp_id)

    def add_webauthn(self, user_id, **kwargs):
        """
        Adds a WebAuthn credential to the given user.

        Returns None if the user already has this credential.
        """
        user = self.get_user(user_id)

        webauthn = WebAuthn(user=user, **kwargs)
        self.db.add(webauthn)
        self.db.flush()  # flush the db now so webauthn.id is available

        return webauthn

    def get_webauthn_by_label(self, user_id, label):
        """
        Returns a WebAuthn credential for the given user by its label,
        or None if no credential for the user has this label.
        """
        user = self.get_user(user_id)

        return next(
            (credential
             for credential in user.webauthn if credential.label == label),
            None,
        )

    def get_webauthn_by_credential_id(self, user_id, credential_id):
        """
        Returns a WebAuthn credential for the given user by its credential ID,
        or None of the user doesn't have a credential with this ID.
        """
        user = self.get_user(user_id)

        return next(
            (credential for credential in user.webauthn
             if credential.credential_id == credential_id),
            None,
        )
예제 #2
0
class DatabaseUserService:
    def __init__(self, session, *, ratelimiters=None, remote_addr, metrics):
        if ratelimiters is None:
            ratelimiters = {}
        ratelimiters = collections.defaultdict(DummyRateLimiter, ratelimiters)

        self.db = session
        self.ratelimiters = ratelimiters
        self.hasher = CryptContext(
            schemes=[
                "argon2",
                "bcrypt_sha256",
                "bcrypt",
                "django_bcrypt",
                "unix_disabled",
            ],
            deprecated=["auto"],
            truncate_error=True,
            # Argon 2 Configuration
            argon2__memory_cost=1024,
            argon2__parallelism=6,
            argon2__time_cost=6,
        )
        self.remote_addr = remote_addr
        self._metrics = metrics
        self.cached_get_user = functools.lru_cache()(self._get_user)

    def _get_user(self, userid):
        # TODO: We probably don't actually want to just return the database
        #       object here.
        # TODO: We need some sort of Anonymous User.
        return self.db.query(User).options(joinedload(User.webauthn)).get(userid)

    def get_user(self, userid):
        return self.cached_get_user(userid)

    @functools.lru_cache()
    def get_user_by_username(self, username):
        user_id = self.find_userid(username)
        return None if user_id is None else self.get_user(user_id)

    @functools.lru_cache()
    def get_user_by_email(self, email):
        user_id = self.find_userid_by_email(email)
        return None if user_id is None else self.get_user(user_id)

    @functools.lru_cache()
    def get_admins(self):
        return self.db.query(User).filter(User.is_superuser.is_(True)).all()

    def username_is_prohibited(self, username):
        return self.db.query(
            exists().where(ProhibitedUserName.name == username.lower())
        ).scalar()

    @functools.lru_cache()
    def find_userid(self, username):
        try:
            user = self.db.query(User.id).filter(User.username == username).one()
        except NoResultFound:
            return

        return user.id

    @functools.lru_cache()
    def find_userid_by_email(self, email):
        try:
            user_id = (self.db.query(Email.user_id).filter(Email.email == email).one())[
                0
            ]
        except NoResultFound:
            return

        return user_id

    def _check_ratelimits(self, userid=None, tags=None):
        tags = tags if tags is not None else []

        # First we want to check if a single IP is exceeding our rate limiter.
        if self.remote_addr is not None:
            if not self.ratelimiters["ip.login"].test(self.remote_addr):
                logger.warning("IP failed login threshold reached.")
                self._metrics.increment(
                    "warehouse.authentication.ratelimited",
                    tags=tags + ["ratelimiter:ip"],
                )
                raise TooManyFailedLogins(
                    resets_in=self.ratelimiters["ip.login"].resets_in(self.remote_addr)
                )

        # Next check to see if we've hit our global rate limit or not,
        # assuming that we've been configured with a global rate limiter anyways.
        if not self.ratelimiters["global.login"].test():
            logger.warning("Global failed login threshold reached.")
            self._metrics.increment(
                "warehouse.authentication.ratelimited",
                tags=tags + ["ratelimiter:global"],
            )
            raise TooManyFailedLogins(
                resets_in=self.ratelimiters["global.login"].resets_in()
            )

        # Now, check to make sure that we haven't hitten a rate limit on a
        # per user basis.
        if userid is not None:
            if not self.ratelimiters["user.login"].test(userid):
                self._metrics.increment(
                    "warehouse.authentication.ratelimited",
                    tags=tags + ["ratelimiter:user"],
                )
                raise TooManyFailedLogins(
                    resets_in=self.ratelimiters["user.login"].resets_in(userid)
                )

    def _hit_ratelimits(self, userid=None):
        if userid is not None:
            self.ratelimiters["user.login"].hit(userid)
        self.ratelimiters["global.login"].hit()
        self.ratelimiters["ip.login"].hit(self.remote_addr)

    def check_password(self, userid, password, *, tags=None):
        tags = tags if tags is not None else []
        tags.append("mechanism:check_password")

        self._metrics.increment("warehouse.authentication.start", tags=tags)

        self._check_ratelimits(userid=None, tags=tags)

        user = self.get_user(userid)
        if user is not None:
            self._check_ratelimits(userid=user.id, tags=tags)

            # Actually check our hash, optionally getting a new hash for it if
            # we should upgrade our saved hashed.
            ok, new_hash = self.hasher.verify_and_update(password, user.password)

            # First, check to see if the password that we were given was OK.
            if ok:
                # Then, if the password was OK check to see if we've been given
                # a new password hash from the hasher, if so we'll want to save
                # that hash.
                if new_hash:
                    user.password = new_hash

                self._metrics.increment("warehouse.authentication.ok", tags=tags)

                return True
            else:
                self._metrics.increment(
                    "warehouse.authentication.failure",
                    tags=tags + ["failure_reason:password"],
                )
        else:
            self._metrics.increment(
                "warehouse.authentication.failure", tags=tags + ["failure_reason:user"]
            )

        # If we've gotten here, then we'll want to record a failed login in our
        # rate limiting before returning False to indicate a failed password
        # verification.
        self._hit_ratelimits(userid=(user.id if user is not None else None))
        return False

    def create_user(self, username, name, password):
        user = User(username=username, name=name, password=self.hasher.hash(password))
        self.db.add(user)
        self.db.flush()  # flush the db now so user.id is available

        return user

    def add_email(
        self,
        user_id,
        email_address,
        primary=None,
        verified=False,
        public=False,
    ):
        # Check to make sure that we haven't hitten the rate limit for this IP
        if not self.ratelimiters["email.add"].test(self.remote_addr):
            self._metrics.increment(
                "warehouse.email.add.ratelimited", tags=["ratelimiter:email.add"]
            )
            raise TooManyEmailsAdded(
                resets_in=self.ratelimiters["email.add"].resets_in(self.remote_addr)
            )

        user = self.get_user(user_id)

        # If primary is None, then we're going to auto detect whether this should be the
        # primary address or not. The basic rule is that if the user doesn't already
        # have a primary address, then the address we're adding now is going to be
        # set to their primary.
        if primary is None:
            primary = True if user.primary_email is None else False

        email = Email(
            email=email_address,
            user=user,
            primary=primary,
            verified=verified,
            public=public,
        )
        self.db.add(email)
        self.db.flush()  # flush the db now so email.id is available

        self.ratelimiters["email.add"].hit(self.remote_addr)
        self._metrics.increment("warehouse.email.add.ok")

        return email

    def update_user(self, user_id, **changes):
        user = self.get_user(user_id)
        for attr, value in changes.items():
            if attr == PASSWORD_FIELD:
                value = self.hasher.hash(value)
            setattr(user, attr, value)

        # If we've given the user a new password, then we also want to unset the
        # reason for disable... because a new password means no more disabled
        # user.
        if PASSWORD_FIELD in changes:
            user.disabled_for = None

        return user

    def disable_password(self, user_id, reason=None):
        user = self.get_user(user_id)
        user.password = self.hasher.disable()
        user.disabled_for = reason

    def is_disabled(self, user_id):
        user = self.get_user(user_id)

        if user.is_frozen:
            return (True, DisableReason.AccountFrozen)

        # User is disabled due to password being disabled
        if not self.hasher.is_enabled(user.password):
            return (True, user.disabled_for)

        # User is not disabled.
        return (False, None)

    def has_two_factor(self, user_id):
        """
        Returns True if the user has any form of two factor
        authentication and is allowed to use it.
        """
        user = self.get_user(user_id)

        return user.has_two_factor

    def has_totp(self, user_id):
        """
        Returns True if the user has a TOTP device provisioned.
        """
        user = self.get_user(user_id)

        return user.totp_secret is not None

    def has_webauthn(self, user_id):
        """
        Returns True if the user has a security key provisioned.
        """
        user = self.get_user(user_id)

        return len(user.webauthn) > 0

    def has_recovery_codes(self, user_id):
        """
        Returns True if the user has generated recovery codes.
        """
        user = self.get_user(user_id)

        return user.has_recovery_codes

    def get_recovery_codes(self, user_id):
        """
        Returns all recovery codes for the user
        """
        user = self.get_user(user_id)

        stored_recovery_codes = self.db.query(RecoveryCode).filter_by(user=user).all()

        if stored_recovery_codes:
            return stored_recovery_codes

        self._metrics.increment(
            "warehouse.authentication.recovery_code.failure",
            tags=["failure_reason:no_recovery_codes"],
        )
        # If we've gotten here, then we'll want to record a failed attempt in our
        # rate limiting before raising an exception to indicate a failed
        # recovery code verification.
        self._hit_ratelimits(userid=user_id)
        raise NoRecoveryCodes

    def get_recovery_code(self, user_id, code):
        """
        Returns a specific recovery code if it exists
        """
        user = self.get_user(user_id)

        for stored_recovery_code in self.get_recovery_codes(user.id):
            if self.hasher.verify(code, stored_recovery_code.code):
                return stored_recovery_code

        self._metrics.increment(
            "warehouse.authentication.recovery_code.failure",
            tags=["failure_reason:invalid_recovery_code"],
        )
        # If we've gotten here, then we'll want to record a failed attempt in our
        # rate limiting before returning False to indicate a failed recovery code
        # verification.
        self._hit_ratelimits(userid=user_id)
        raise InvalidRecoveryCode

    def get_totp_secret(self, user_id):
        """
        Returns the user's TOTP secret as bytes.

        If the user doesn't have a TOTP, returns None.
        """
        user = self.get_user(user_id)

        return user.totp_secret

    def get_last_totp_value(self, user_id):
        """
        Returns the user's last (accepted) TOTP value.

        If the user doesn't have a TOTP or hasn't used their TOTP
        method, returns None.
        """
        user = self.get_user(user_id)

        return user.last_totp_value

    def check_totp_value(self, user_id, totp_value, *, tags=None):
        """
        Returns True if the given TOTP is valid against the user's secret.

        If the user doesn't have a TOTP secret or isn't allowed
        to use second factor methods, returns False.
        """
        tags = tags if tags is not None else []
        tags.append("mechanism:check_totp_value")
        self._metrics.increment("warehouse.authentication.two_factor.start", tags=tags)

        self._check_ratelimits(userid=user_id, tags=tags)

        totp_secret = self.get_totp_secret(user_id)

        if totp_secret is None:
            self._metrics.increment(
                "warehouse.authentication.two_factor.failure",
                tags=tags + ["failure_reason:no_totp"],
            )
            # If we've gotten here, then we'll want to record a failed attempt in our
            # rate limiting before returning False to indicate a failed totp
            # verification.
            self._hit_ratelimits(userid=user_id)
            return False

        last_totp_value = self.get_last_totp_value(user_id)

        if last_totp_value is not None and totp_value == last_totp_value.encode():
            return False

        valid = otp.verify_totp(totp_secret, totp_value)

        if valid:
            self._metrics.increment("warehouse.authentication.two_factor.ok", tags=tags)
        else:
            self._metrics.increment(
                "warehouse.authentication.two_factor.failure",
                tags=tags + ["failure_reason:invalid_totp"],
            )
            # If we've gotten here, then we'll want to record a failed attempt in our
            # rate limiting before returning False to indicate a failed totp
            # verification.
            self._hit_ratelimits(userid=user_id)

        return valid

    def get_webauthn_credential_options(self, user_id, *, challenge, rp_name, rp_id):
        """
        Returns a dictionary of credential options suitable for beginning the WebAuthn
        provisioning process for the given user.
        """
        user = self.get_user(user_id)

        return webauthn.get_credential_options(
            user, challenge=challenge, rp_name=rp_name, rp_id=rp_id
        )

    def get_webauthn_assertion_options(self, user_id, *, challenge, rp_id):
        """
        Returns a dictionary of assertion options suitable for beginning the WebAuthn
        authentication process for the given user.
        """
        user = self.get_user(user_id)

        return webauthn.get_assertion_options(user, challenge=challenge, rp_id=rp_id)

    def verify_webauthn_credential(self, credential, *, challenge, rp_id, origin):
        """
        Checks whether the given credential is valid, i.e. suitable for generating
        assertions during authentication.

        Returns the validated credential on success, raises
        webauthn.RegistrationRejectedError on failure.
        """
        validated_credential = webauthn.verify_registration_response(
            credential, challenge=challenge, rp_id=rp_id, origin=origin
        )

        webauthn_cred = (
            self.db.query(WebAuthn)
            .filter_by(
                credential_id=bytes_to_base64url(validated_credential.credential_id)
            )
            .first()
        )

        if webauthn_cred is not None:
            raise webauthn.RegistrationRejectedError("Credential ID already in use")

        return validated_credential

    def verify_webauthn_assertion(
        self, user_id, assertion, *, challenge, origin, rp_id
    ):
        """
        Checks whether the given assertion was produced by the given user's WebAuthn
        device.

        Returns the updated signage count on success, raises
        webauthn.AuthenticationRejectedError on failure.
        """
        user = self.get_user(user_id)

        return webauthn.verify_assertion_response(
            assertion, challenge=challenge, user=user, origin=origin, rp_id=rp_id
        )

    def add_webauthn(self, user_id, **kwargs):
        """
        Adds a WebAuthn credential to the given user.

        Returns None if the user already has this credential.
        """
        user = self.get_user(user_id)

        webauthn = WebAuthn(user=user, **kwargs)
        self.db.add(webauthn)
        self.db.flush()  # flush the db now so webauthn.id is available

        return webauthn

    def get_webauthn_by_label(self, user_id, label):
        """
        Returns a WebAuthn credential for the given user by its label,
        or None if no credential for the user has this label.
        """
        user = self.get_user(user_id)

        return next(
            (credential for credential in user.webauthn if credential.label == label),
            None,
        )

    def get_webauthn_by_credential_id(self, user_id, credential_id):
        """
        Returns a WebAuthn credential for the given user by its credential ID,
        or None of the user doesn't have a credential with this ID.
        """
        user = self.get_user(user_id)

        return next(
            (
                credential
                for credential in user.webauthn
                if credential.credential_id == credential_id
            ),
            None,
        )

    def record_event(self, user_id, *, tag, additional=None):
        """
        Creates a new UserEvent for the given user with the given
        tag, IP address, and additional metadata.

        Returns the event.
        """
        user = self.get_user(user_id)
        return user.record_event(
            tag=tag, ip_address=self.remote_addr, additional=additional
        )

    def generate_recovery_codes(self, user_id):
        user = self.get_user(user_id)

        if user.has_recovery_codes:
            self.db.query(RecoveryCode).filter_by(user=user).delete()

        recovery_codes = [secrets.token_hex(8) for _ in range(RECOVERY_CODE_COUNT)]
        for recovery_code in recovery_codes:
            self.db.add(RecoveryCode(user=user, code=self.hasher.hash(recovery_code)))

        self.db.flush()

        return recovery_codes

    def check_recovery_code(self, user_id, code):
        self._metrics.increment("warehouse.authentication.recovery_code.start")

        self._check_ratelimits(
            userid=user_id,
            tags=["mechanism:check_recovery_code"],
        )

        user = self.get_user(user_id)
        stored_recovery_code = self.get_recovery_code(user.id, code)

        if stored_recovery_code.burned:
            self._metrics.increment(
                "warehouse.authentication.recovery_code.failure",
                tags=["failure_reason:burned_recovery_code"],
            )
            raise BurnedRecoveryCode

        # The code is valid and not burned. Mark it as burned
        stored_recovery_code.burned = datetime.datetime.now()
        self.db.flush()
        self._metrics.increment("warehouse.authentication.recovery_code.ok")
        return True

    def get_password_timestamp(self, user_id):
        user = self.get_user(user_id)
        return user.password_date.timestamp() if user.password_date is not None else 0
예제 #3
0
class DatabaseUserService:
    def __init__(self, session, *, ratelimiters=None, metrics):
        if ratelimiters is None:
            ratelimiters = {}
        ratelimiters = collections.defaultdict(DummyRateLimiter, ratelimiters)

        self.db = session
        self.ratelimiters = ratelimiters
        self.hasher = CryptContext(
            schemes=[
                "argon2",
                "bcrypt_sha256",
                "bcrypt",
                "django_bcrypt",
                "unix_disabled",
            ],
            deprecated=["auto"],
            truncate_error=True,
            # Argon 2 Configuration
            argon2__memory_cost=1024,
            argon2__parallelism=6,
            argon2__time_cost=6,
        )
        self._metrics = metrics

    @functools.lru_cache()
    def get_user(self, userid):
        # TODO: We probably don't actually want to just return the database
        #       object here.
        # TODO: We need some sort of Anonymous User.
        return self.db.query(User).get(userid)

    @functools.lru_cache()
    def get_user_by_username(self, username):
        user_id = self.find_userid(username)
        return None if user_id is None else self.get_user(user_id)

    @functools.lru_cache()
    def get_user_by_email(self, email):
        user_id = self.find_userid_by_email(email)
        return None if user_id is None else self.get_user(user_id)

    @functools.lru_cache()
    def find_userid(self, username):
        try:
            user = self.db.query(
                User.id).filter(User.username == username).one()
        except NoResultFound:
            return

        return user.id

    @functools.lru_cache()
    def find_userid_by_email(self, email):
        try:
            # flake8: noqa
            user_id = (self.db.query(
                Email.user_id).filter(Email.email == email).one())[0]
        except NoResultFound:
            return

        return user_id

    def check_password(self, userid, password, *, tags=None):
        tags = tags if tags is not None else []

        self._metrics.increment("warehouse.authentication.start", tags=tags)

        # The very first thing we want to do is check to see if we've hit our
        # global rate limit or not, assuming that we've been configured with a
        # global rate limiter anyways.
        if not self.ratelimiters["global"].test():
            logger.warning("Global failed login threshold reached.")
            self._metrics.increment(
                "warehouse.authentication.ratelimited",
                tags=tags + ["ratelimiter:global"],
            )
            raise TooManyFailedLogins(
                resets_in=self.ratelimiters["global"].resets_in())

        user = self.get_user(userid)
        if user is not None:
            # Now, check to make sure that we haven't hitten a rate limit on a
            # per user basis.
            if not self.ratelimiters["user"].test(user.id):
                self._metrics.increment(
                    "warehouse.authentication.ratelimited",
                    tags=tags + ["ratelimiter:user"],
                )
                raise TooManyFailedLogins(
                    resets_in=self.ratelimiters["user"].resets_in(user.id))

            # Actually check our hash, optionally getting a new hash for it if
            # we should upgrade our saved hashed.
            ok, new_hash = self.hasher.verify_and_update(
                password, user.password)

            # First, check to see if the password that we were given was OK.
            if ok:
                # Then, if the password was OK check to see if we've been given
                # a new password hash from the hasher, if so we'll want to save
                # that hash.
                if new_hash:
                    user.password = new_hash

                self._metrics.increment("warehouse.authentication.ok",
                                        tags=tags)

                return True
            else:
                self._metrics.increment(
                    "warehouse.authentication.failure",
                    tags=tags + ["failure_reason:password"],
                )
        else:
            self._metrics.increment("warehouse.authentication.failure",
                                    tags=tags + ["failure_reason:user"])

        # If we've gotten here, then we'll want to record a failed login in our
        # rate limiting before returning False to indicate a failed password
        # verification.
        if user is not None:
            self.ratelimiters["user"].hit(user.id)
        self.ratelimiters["global"].hit()

        return False

    def create_user(self, username, name, password):
        user = User(username=username,
                    name=name,
                    password=self.hasher.hash(password))
        self.db.add(user)
        self.db.flush()  # flush the db now so user.id is available

        return user

    def add_email(self, user_id, email_address, primary=None, verified=False):
        user = self.get_user(user_id)

        # If primary is None, then we're going to auto detect whether this should be the
        # primary address or not. The basic rule is that if the user doesn't already
        # have a primary address, then the address we're adding now is going to be
        # set to their primary.
        if primary is None:
            primary = True if user.primary_email is None else False

        email = Email(email=email_address,
                      user=user,
                      primary=primary,
                      verified=verified)
        self.db.add(email)
        self.db.flush()  # flush the db now so email.id is available

        return email

    def update_user(self, user_id, **changes):
        user = self.get_user(user_id)
        for attr, value in changes.items():
            if attr == PASSWORD_FIELD:
                value = self.hasher.hash(value)
            setattr(user, attr, value)

        # If we've given the user a new password, then we also want to unset the
        # reason for disable... because a new password means no more disabled
        # user.
        if PASSWORD_FIELD in changes:
            user.disabled_for = None

        return user

    def disable_password(self, user_id, reason=None):
        user = self.get_user(user_id)
        user.password = self.hasher.disable()
        user.disabled_for = reason

    def is_disabled(self, user_id):
        user = self.get_user(user_id)

        # User is not disabled.
        if self.hasher.is_enabled(user.password):
            return (False, None)
        # User is disabled.
        else:
            return (True, user.disabled_for)
예제 #4
0
class DatabaseUserService:
    def __init__(self, session, *, ratelimiters=None, metrics):
        if ratelimiters is None:
            ratelimiters = {}
        ratelimiters = collections.defaultdict(DummyRateLimiter, ratelimiters)

        self.db = session
        self.ratelimiters = ratelimiters
        self.hasher = CryptContext(
            schemes=[
                "argon2",
                "bcrypt_sha256",
                "bcrypt",
                "django_bcrypt",
                "unix_disabled",
            ],
            deprecated=["auto"],
            truncate_error=True,
            # Argon 2 Configuration
            argon2__memory_cost=1024,
            argon2__parallelism=6,
            argon2__time_cost=6,
        )
        self._metrics = metrics

    @functools.lru_cache()
    def get_user(self, userid):
        # TODO: We probably don't actually want to just return the database
        #       object here.
        # TODO: We need some sort of Anonymous User.
        return self.db.query(User).get(userid)

    @functools.lru_cache()
    def get_user_by_username(self, username):
        user_id = self.find_userid(username)
        return None if user_id is None else self.get_user(user_id)

    @functools.lru_cache()
    def get_user_by_email(self, email):
        user_id = self.find_userid_by_email(email)
        return None if user_id is None else self.get_user(user_id)

    @functools.lru_cache()
    def find_userid(self, username):
        try:
            user = self.db.query(User.id).filter(User.username == username).one()
        except NoResultFound:
            return

        return user.id

    @functools.lru_cache()
    def find_userid_by_email(self, email):
        try:
            # flake8: noqa
            user_id = (self.db.query(Email.user_id).filter(Email.email == email).one())[
                0
            ]
        except NoResultFound:
            return

        return user_id

    def check_password(self, userid, password, *, tags=None):
        tags = tags if tags is not None else []

        self._metrics.increment("warehouse.authentication.start", tags=tags)

        # The very first thing we want to do is check to see if we've hit our
        # global rate limit or not, assuming that we've been configured with a
        # global rate limiter anyways.
        if not self.ratelimiters["global"].test():
            logger.warning("Global failed login threshold reached.")
            self._metrics.increment(
                "warehouse.authentication.ratelimited",
                tags=tags + ["ratelimiter:global"],
            )
            raise TooManyFailedLogins(resets_in=self.ratelimiters["global"].resets_in())

        user = self.get_user(userid)
        if user is not None:
            # Now, check to make sure that we haven't hitten a rate limit on a
            # per user basis.
            if not self.ratelimiters["user"].test(user.id):
                self._metrics.increment(
                    "warehouse.authentication.ratelimited",
                    tags=tags + ["ratelimiter:user"],
                )
                raise TooManyFailedLogins(
                    resets_in=self.ratelimiters["user"].resets_in(user.id)
                )

            # Actually check our hash, optionally getting a new hash for it if
            # we should upgrade our saved hashed.
            ok, new_hash = self.hasher.verify_and_update(password, user.password)

            # First, check to see if the password that we were given was OK.
            if ok:
                # Then, if the password was OK check to see if we've been given
                # a new password hash from the hasher, if so we'll want to save
                # that hash.
                if new_hash:
                    user.password = new_hash

                self._metrics.increment("warehouse.authentication.ok", tags=tags)

                return True
            else:
                self._metrics.increment(
                    "warehouse.authentication.failure",
                    tags=tags + ["failure_reason:password"],
                )
        else:
            self._metrics.increment(
                "warehouse.authentication.failure", tags=tags + ["failure_reason:user"]
            )

        # If we've gotten here, then we'll want to record a failed login in our
        # rate limiting before returning False to indicate a failed password
        # verification.
        if user is not None:
            self.ratelimiters["user"].hit(user.id)
        self.ratelimiters["global"].hit()

        return False

    def create_user(
        self, username, name, password, is_active=False, is_superuser=False
    ):

        user = User(
            username=username,
            name=name,
            password=self.hasher.hash(password),
            is_active=is_active,
            is_superuser=is_superuser,
        )
        self.db.add(user)
        self.db.flush()  # flush the db now so user.id is available

        return user

    def add_email(self, user_id, email_address, primary=None, verified=False):
        user = self.get_user(user_id)

        # If primary is None, then we're going to auto detect whether this should be the
        # primary address or not. The basic rule is that if the user doesn't already
        # have a primary address, then the address we're adding now is going to be
        # set to their primary.
        if primary is None:
            primary = True if user.primary_email is None else False

        email = Email(
            email=email_address, user=user, primary=primary, verified=verified
        )
        self.db.add(email)
        self.db.flush()  # flush the db now so email.id is available

        return email

    def update_user(self, user_id, **changes):
        user = self.get_user(user_id)
        for attr, value in changes.items():
            if attr == PASSWORD_FIELD:
                value = self.hasher.hash(value)
            setattr(user, attr, value)

        # If we've given the user a new password, then we also want to unset the
        # reason for disable... because a new password means no more disabled
        # user.
        if PASSWORD_FIELD in changes:
            user.disabled_for = None

        return user

    def disable_password(self, user_id, reason=None):
        user = self.get_user(user_id)
        user.password = self.hasher.disable()
        user.disabled_for = reason

    def is_disabled(self, user_id):
        user = self.get_user(user_id)

        # User is not disabled.
        if self.hasher.is_enabled(user.password):
            return (False, None)
        # User is disabled.
        else:
            return (True, user.disabled_for)