예제 #1
0
class ApiKey(db.Model, ModelMixin):
    SCOPES = ["read", "write"]

    key = db.Column(db.String, primary_key=True)
    user_id = db.Column(
        db.Integer, db.ForeignKey('user.id', ondelete='CASCADE')
    )
    user = db.relationship(
        'User',
        backref=db.backref("api_keys", cascade="all, delete-orphan"),
    )
    scope = db.Column(db.String, nullable=False)

    def __init__(self, key=None, user=None, scope=None) -> None:
        super().__init__()
        self.key = key
        self.user = user
        self.scope = scope or ""

    def __repr__(self) -> str:
        return "ApiKey {} (scope: {})".format(self.key, self.scope)

    def set_scope(self, scope):
        if scope:
            for s in scope.split(" "):
                if s not in self.SCOPES:
                    raise ValueError("Scope '{}' not valid".format(s))
                self.scope = "{} {}".format(self.scope, s)

    def check_scopes(self, scopes: list or str):
        if isinstance(scopes, str):
            scopes = scopes.split(" ")
        supported_scopes = self.scope.split(" ")
        for scope in scopes:
            if scope not in supported_scopes:
                return False
        return True

    @classmethod
    def find(cls, api_key) -> ApiKey:
        return cls.query.filter(ApiKey.key == api_key).first()

    @classmethod
    def all(cls) -> List[ApiKey]:
        return cls.query.all()
예제 #2
0
class Client(db.Model, OAuth2ClientMixin):
    id = db.Column(db.Integer, primary_key=True)
    client_id = db.Column(db.String(48), index=True, unique=True)
    user_id = db.Column(db.Integer, db.ForeignKey('user.id',
                                                  ondelete='CASCADE'))
    user = db.relationship(
        'User',
        backref=db.backref(
            "clients",
            cascade="all, delete-orphan",
        ),
    )

    __tablename__ = "oauth2_client"

    @property
    def redirect_uris(self):
        return self.client_metadata.get('redirect_uris', [])

    @redirect_uris.setter
    def redirect_uris(self, value):
        if isinstance(value, str):
            value = value.split(',')
        metadata = self.client_metadata
        metadata['redirect_uris'] = value
        self.set_client_metadata(metadata)

    @property
    def auth_method(self):
        return self.client_metadata.get('token_endpoint_auth_method')

    @auth_method.setter
    def auth_method(self, value):
        metadata = self.client_metadata
        metadata['token_endpoint_auth_method'] = value
        self.set_client_metadata(metadata)

    @classmethod
    def find_by_id(cls, client_id) -> Client:
        return cls.query.get(client_id)

    @classmethod
    def all(cls) -> List[Client]:
        return cls.query.all()
예제 #3
0
class OAuthIdentity(models.ExternalServiceAccessAuthorization, ModelMixin):
    id = db.Column(db.Integer,
                   db.ForeignKey('external_service_access_authorization.id'),
                   primary_key=True)
    provider_user_id = db.Column(db.String(256), nullable=False)
    provider_id = db.Column(db.Integer,
                            db.ForeignKey("oauth2_identity_provider.id"),
                            nullable=False)
    created_at = db.Column(DateTime, default=datetime.utcnow, nullable=False)
    token = db.Column(JSON, nullable=True)
    _user_info = None
    provider = db.relationship("OAuth2IdentityProvider",
                               uselist=False,
                               back_populates="identities")
    user = db.relationship(
        models.User,
        # This `backref` thing sets up an `oauth` property on the User model,
        # which is a dictionary of OAuth models associated with that user,
        # where the dictionary key is the OAuth provider name.
        backref=db.backref(
            "oauth_identity",
            collection_class=attribute_mapped_collection("provider.name"),
            cascade="all, delete-orphan",
        ),
    )

    __table_args__ = (db.UniqueConstraint("provider_id", "provider_user_id"), )
    __tablename__ = "oauth2_identity"
    __mapper_args__ = {'polymorphic_identity': 'oauth2_identity'}

    def __init__(self, provider, user_info, provider_user_id, token):
        super().__init__(self.user)
        self.provider = provider
        self.provider_user_id = provider_user_id
        self._user_info = user_info
        self.token = token
        self.resources.append(provider.api_resource)

    def as_http_header(self):
        return f"{self.provider.token_type} {self.token['access_token']}"

    @property
    def username(self):
        return f"{self.provider.name}_{self.provider_user_id}"

    @property
    def user_info(self):
        if not self._user_info:
            self._user_info = self.provider.get_user_info(
                self.provider_user_id, self.token)
        return self._user_info

    @user_info.setter
    def user_info(self, value):
        self._user_info = value

    def set_token(self, token):
        self.token = token

    def __repr__(self):
        parts = []
        parts.append(self.__class__.__name__)
        if self.id:
            parts.append("id={}".format(self.id))
        if self.provider:
            parts.append('provider="{}"'.format(self.provider))
        return "<{}>".format(" ".join(parts))

    @staticmethod
    def find_by_user_id(user_id, provider_name) -> OAuthIdentity:
        try:
            return OAuthIdentity.query\
                .filter(OAuthIdentity.provider.has(name=provider_name))\
                .filter_by(user_id=user_id).one()
        except NoResultFound:
            raise OAuthIdentityNotFoundException(f"{user_id}_{provider_name}")

    @staticmethod
    def find_by_provider_user_id(provider_user_id,
                                 provider_name) -> OAuthIdentity:
        try:
            return OAuthIdentity.query\
                .filter(OAuthIdentity.provider.has(name=provider_name))\
                .filter_by(provider_user_id=provider_user_id).one()
        except NoResultFound:
            raise OAuthIdentityNotFoundException(
                f"{provider_name}_{provider_user_id}")

    @classmethod
    def all(cls) -> List[OAuthIdentity]:
        return cls.query.all()
예제 #4
0
class OAuthIdentity(models.ExternalServiceAccessAuthorization, ModelMixin):
    id = db.Column(db.Integer,
                   db.ForeignKey('external_service_access_authorization.id'),
                   primary_key=True)
    provider_user_id = db.Column(db.String(256), nullable=False)
    provider_id = db.Column(db.Integer,
                            db.ForeignKey("oauth2_identity_provider.id"),
                            nullable=False)
    created_at = db.Column(DateTime, default=datetime.utcnow, nullable=False)
    _token = db.Column("token", JSON, nullable=True)
    _user_info = None
    provider = db.relationship("OAuth2IdentityProvider",
                               uselist=False,
                               back_populates="identities")
    user = db.relationship(
        models.User,
        # This `backref` thing sets up an `oauth` property on the User model,
        # which is a dictionary of OAuth models associated with that user,
        # where the dictionary key is the OAuth provider name.
        backref=db.backref(
            "oauth_identity",
            collection_class=attribute_mapped_collection("provider.name"),
            cascade="all, delete-orphan",
        ),
    )

    __table_args__ = (db.UniqueConstraint("provider_id", "provider_user_id"), )
    __tablename__ = "oauth2_identity"
    __mapper_args__ = {'polymorphic_identity': 'oauth2_identity'}

    def __init__(self, provider, user_info, provider_user_id, token):
        super().__init__(self.user)
        self.provider = provider
        self.provider_user_id = provider_user_id
        self._user_info = user_info
        self.token = token
        self.resources.append(provider.api_resource)

    def as_http_header(self):
        return f"{self.provider.token_type} {self.fetch_token()['access_token']}"

    @property
    def username(self):
        return f"{self.provider.name}_{self.provider_user_id}"

    @property
    def token(self) -> OAuth2Token:
        return OAuth2Token(self._token)

    @token.setter
    def token(self, token: dict):
        self._token = token

    def fetch_token(self):
        # enable dynamic refresh only if the identity
        # has been already stored in the database
        if inspect(self).persistent:
            # fetch up to date identity data
            self.refresh()
            # reference to the token associated with the identity instance
            token = self.token
            # the token should be refreshed
            # if it is expired or close to expire (i.e., n secs before expiration)
            if token.to_be_refreshed():
                if 'refresh_token' not in token:
                    logger.debug(
                        "The token should be refreshed but no refresh token is associated with the token"
                    )
                else:
                    logger.debug("Trying to refresh the token...")
                    oauth2session = OAuth2Session(self.provider.client_id,
                                                  self.provider.client_secret,
                                                  token=self.token)
                    new_token = oauth2session.refresh_token(
                        self.provider.access_token_url,
                        refresh_token=token['refresh_token'])
                    self.token = new_token
                    self.save()
                    logger.debug("User token updated")
                    logger.debug("Using token %r", self.token)
        return self.token

    @property
    def user_info(self):
        if not self._user_info:
            logger.debug(
                "[Identity %r], Trying to read profile of user %r from provider %r...",
                self.id, self.user_id, self.provider.name)
            self._user_info = self.provider.get_user_info(
                self.provider_user_id, self.fetch_token())
        return self._user_info

    @user_info.setter
    def user_info(self, value):
        self._user_info = value

    def __repr__(self):
        parts = []
        parts.append(self.__class__.__name__)
        if self.id:
            parts.append("id={}".format(self.id))
        if self.provider:
            parts.append('provider="{}"'.format(self.provider))
        return "<{}>".format(" ".join(parts))

    @staticmethod
    def find_by_user_id(user_id, provider_name) -> OAuthIdentity:
        try:
            return OAuthIdentity.query\
                .filter(OAuthIdentity.provider.has(name=provider_name))\
                .filter_by(user_id=user_id).one()
        except NoResultFound:
            raise OAuthIdentityNotFoundException(f"{user_id}_{provider_name}")

    @staticmethod
    def find_by_provider_user_id(provider_user_id,
                                 provider_name) -> OAuthIdentity:
        try:
            return OAuthIdentity.query\
                .filter(OAuthIdentity.provider.has(name=provider_name))\
                .filter_by(provider_user_id=provider_user_id).one()
        except NoResultFound:
            raise OAuthIdentityNotFoundException(
                f"{provider_name}_{provider_user_id}")

    @classmethod
    def all(cls) -> List[OAuthIdentity]:
        return cls.query.all()
예제 #5
0
class Client(db.Model, OAuth2ClientMixin):
    id = db.Column(db.Integer, primary_key=True)
    client_id = db.Column(db.String(48), index=True, unique=True)
    user_id = db.Column(
        db.Integer, db.ForeignKey('user.id', ondelete='CASCADE')
    )
    user = db.relationship(
        'User',
        backref=db.backref(
            "clients",
            cascade="all, delete-orphan",
        ),
    )

    __tablename__ = "oauth2_client"

    def is_confidential(self):
        return self.has_client_secret()

    def set_client_metadata(self, value):
        if not isinstance(value, dict):
            return
        data = copy.deepcopy(value)
        data['scope'] = values_as_string(value['scope'], out_separator=" ")
        for p in ('redirect_uris', 'grant_types', 'response_types', 'contacts'):
            data[p] = values_as_list(value.get(p, []))
        return super().set_client_metadata(data)

    @property
    def redirect_uris(self):
        return super().redirect_uris

    @redirect_uris.setter
    def redirect_uris(self, value):
        metadata = self.client_metadata
        metadata['redirect_uris'] = value
        self.set_client_metadata(metadata)

    @property
    def scopes(self):
        return self.scope.split(" ") if self.scope else []

    @scopes.setter
    def scopes(self, scopes):
        metadata = self.client_metadata
        metadata['scope'] = scopes
        self.set_client_metadata(metadata)

    @property
    def auth_method(self):
        return self.client_metadata.get('token_endpoint_auth_method')

    @auth_method.setter
    def auth_method(self, value):
        metadata = self.client_metadata
        metadata['token_endpoint_auth_method'] = value
        self.set_client_metadata(metadata)

    @classmethod
    def find_by_id(cls, client_id) -> Client:
        return cls.query.get(client_id)

    @classmethod
    def all(cls) -> List[Client]:
        return cls.query.all()