示例#1
0
class ExternalServiceAccessAuthorization(db.Model):

    id = db.Column(db.Integer, primary_key=True)
    type = db.Column(db.String, nullable=False)
    user_id = db.Column('user_id',
                        db.Integer,
                        db.ForeignKey('user.id', ondelete='CASCADE'),
                        nullable=False)

    resources = db.relationship("Resource",
                                secondary=resource_authorization_table,
                                backref="_authorizations",
                                passive_deletes=True)

    user = db.relationship("User",
                           back_populates="authorizations",
                           cascade="all, delete")

    __mapper_args__ = {
        'polymorphic_identity': 'authorization',
        'polymorphic_on': type,
    }

    def __init__(self, user) -> None:
        super().__init__()
        self.user = user

    def as_http_header(self):
        return ""

    @staticmethod
    def find_by_user_and_resource(user: User, resource: Resource):
        return [a for a in user.authorizations if resource in a.resources]
示例#2
0
class Permission(db.Model):
    user_id = db.Column(db.Integer,
                        db.ForeignKey('user.id', ondelete='CASCADE'),
                        primary_key=True)
    resource_id = db.Column(db.Integer,
                            db.ForeignKey('resource.id', ondelete='CASCADE'),
                            primary_key=True)
    roles = db.Column(db.ARRAY(db.String), nullable=True)
    user = db.relationship("User", back_populates="permissions")
    resource = db.relationship("Resource", back_populates="permissions")

    def __repr__(self):
        return '<Permission of user {} for resource {}: {}>'.format(
            self.user, self.resource, self.roles)

    def __init__(self,
                 user: User = None,
                 resource: Resource = None,
                 roles=None) -> None:
        self.user = user
        self.resource = resource
        self.roles = []
        if roles:
            for r in roles:
                self.roles.append(r)
示例#3
0
class ExternalServiceAuthorizationHeader(ExternalServiceAccessAuthorization):

    id = db.Column(db.Integer,
                   db.ForeignKey('external_service_access_authorization.id'),
                   primary_key=True)

    __mapper_args__ = {'polymorphic_identity': 'authorization_header'}

    header = db.Column(db.String, nullable=False)

    def __init__(self, user, header) -> None:
        super().__init__(user)
        self.header = header

    def as_http_header(self):
        return self.header
示例#4
0
class ApiKey(db.Model, ModelMixin):
    SCOPES = ["read", "write"]

    key = db.Column(db.String, primary_key=True)
    user_id = db.Column(
        db.Integer, db.ForeignKey('user.id', ondelete='CASCADE')
    )
    user = db.relationship(
        'User',
        backref=db.backref("api_keys", cascade="all, delete-orphan"),
    )
    scope = db.Column(db.String, nullable=False)

    def __init__(self, key=None, user=None, scope=None) -> None:
        super().__init__()
        self.key = key
        self.user = user
        self.scope = scope or ""

    def __repr__(self) -> str:
        return "ApiKey {} (scope: {})".format(self.key, self.scope)

    def set_scope(self, scope):
        if scope:
            for s in scope.split(" "):
                if s not in self.SCOPES:
                    raise ValueError("Scope '{}' not valid".format(s))
                self.scope = "{} {}".format(self.scope, s)

    def check_scopes(self, scopes: list or str):
        if isinstance(scopes, str):
            scopes = scopes.split(" ")
        supported_scopes = self.scope.split(" ")
        for scope in scopes:
            if scope not in supported_scopes:
                return False
        return True

    @classmethod
    def find(cls, api_key) -> ApiKey:
        return cls.query.filter(ApiKey.key == api_key).first()

    @classmethod
    def all(cls) -> List[ApiKey]:
        return cls.query.all()
示例#5
0
class Client(db.Model, OAuth2ClientMixin):
    id = db.Column(db.Integer, primary_key=True)
    client_id = db.Column(db.String(48), index=True, unique=True)
    user_id = db.Column(db.Integer, db.ForeignKey('user.id',
                                                  ondelete='CASCADE'))
    user = db.relationship(
        'User',
        backref=db.backref(
            "clients",
            cascade="all, delete-orphan",
        ),
    )

    __tablename__ = "oauth2_client"

    @property
    def redirect_uris(self):
        return self.client_metadata.get('redirect_uris', [])

    @redirect_uris.setter
    def redirect_uris(self, value):
        if isinstance(value, str):
            value = value.split(',')
        metadata = self.client_metadata
        metadata['redirect_uris'] = value
        self.set_client_metadata(metadata)

    @property
    def auth_method(self):
        return self.client_metadata.get('token_endpoint_auth_method')

    @auth_method.setter
    def auth_method(self, value):
        metadata = self.client_metadata
        metadata['token_endpoint_auth_method'] = value
        self.set_client_metadata(metadata)

    @classmethod
    def find_by_id(cls, client_id) -> Client:
        return cls.query.get(client_id)

    @classmethod
    def all(cls) -> List[Client]:
        return cls.query.all()
示例#6
0
class Resource(db.Model, ModelMixin):

    id = db.Column('id', db.Integer, primary_key=True)
    uuid = db.Column(UUID, default=_uuid.uuid4)
    type = db.Column(db.String, nullable=False)
    name = db.Column(db.String, nullable=True)
    uri = db.Column(db.String, nullable=False)
    version = db.Column(db.String, nullable=True)
    created = db.Column(db.DateTime, default=datetime.datetime.utcnow)
    modified = db.Column(db.DateTime,
                         default=datetime.datetime.utcnow,
                         onupdate=datetime.datetime.utcnow)

    permissions = db.relationship("Permission",
                                  back_populates="resource",
                                  cascade="all, delete-orphan")

    __mapper_args__ = {
        'polymorphic_identity': 'resource',
        'polymorphic_on': type,
    }

    def __init__(self, uri, uuid=None, name=None, version=None) -> None:
        self.uri = uri
        self.name = name
        self.version = version
        self.uuid = uuid

    def __repr__(self):
        return '<Resource {}: {} -> {} (type={}))>'.format(
            self.id, self.uuid, self.uri, self.type)

    @hybrid_property
    def authorizations(self):
        return self._authorizations

    def get_authorization(self, user: User):
        auths = ExternalServiceAccessAuthorization.find_by_user_and_resource(
            user, self)
        # check for sub-resource authorizations
        for subresource in ["api"]:
            if hasattr(self, subresource):
                auths.extend(
                    ExternalServiceAccessAuthorization.
                    find_by_user_and_resource(self, getattr(self,
                                                            subresource)))
        return auths

    @classmethod
    def find_by_uuid(cls, uuid):
        return cls.query.filter(cls.uuid == lm_utils.uuid_param(uuid)).first()
示例#7
0
class Token(db.Model, ModelMixin, OAuth2TokenMixin):
    id = db.Column(db.Integer, primary_key=True)
    user_id = db.Column(
        db.Integer, db.ForeignKey('user.id', ondelete='CASCADE')
    )
    user = db.relationship('User')
    client_id = db.Column(db.String,
                          db.ForeignKey('oauth2_client.client_id', ondelete='CASCADE'))
    client = db.relationship('Client')

    __tablename__ = "oauth2_client_token"

    def is_expired(self) -> bool:
        return self.check_token_expiration(self.expires_at)

    def is_refresh_token_valid(self) -> bool:
        return self if not self.revoked else None

    @property
    def expires_at(self):
        return self.issued_at + self.expires_in

    @classmethod
    def find(cls, access_token):
        return cls.query.filter(Token.access_token == access_token).first()

    @classmethod
    def find_by_user(cls, user: User) -> List[Token]:
        return cls.query.filter(Token.user == user).all()

    @classmethod
    def find_by_client_user(cls, client: Client, user: User) -> List[Token]:
        return cls.query.filter(Token.client == client, Token.user == user).all()

    @classmethod
    def all(cls) -> List[Token]:
        return cls.query.all()

    @staticmethod
    def check_token_expiration(expires_at) -> bool:
        return datetime.utcnow().timestamp() - expires_at > 0
示例#8
0
class ExternalServiceAccessToken(ExternalServiceAccessAuthorization,
                                 OAuth2TokenMixin):

    id = db.Column(db.Integer,
                   db.ForeignKey('external_service_access_authorization.id'),
                   primary_key=True)

    __mapper_args__ = {'polymorphic_identity': 'access_token'}

    def is_expired(self) -> bool:
        return self.check_token_expiration(self.expires_at)

    def is_refresh_token_valid(self) -> bool:
        return self if not self.revoked else None

    def save(self):
        db.session.add(self)
        db.session.commit()

    def delete(self):
        db.session.delete(self)
        db.session.commit()

    def as_http_header(self):
        return f"{self.token_type} {self.access_token}"

    @classmethod
    def find(cls, access_token):
        return cls.query.filter(cls.access_token == access_token).first()

    @classmethod
    def find_by_user(cls, user: User) -> List[ExternalServiceAccessToken]:
        return cls.query.filter(cls.user == user).all()

    @classmethod
    def all(cls):
        return cls.query.all()

    @staticmethod
    def check_token_expiration(expires_at) -> bool:
        return datetime.utcnow().timestamp() - expires_at > 0
示例#9
0
class User(db.Model, UserMixin):
    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.String(256), unique=True, nullable=False)
    password_hash = db.Column(db.LargeBinary, nullable=True)
    picture = db.Column(db.String(), nullable=True)

    permissions = db.relationship("Permission",
                                  back_populates="user",
                                  cascade="all, delete-orphan")
    authorizations = db.relationship("ExternalServiceAccessAuthorization",
                                     cascade="all, delete-orphan")

    def __init__(self, username=None) -> None:
        super().__init__()
        self.username = username

    def get_user_id(self):
        return self.id

    def get_authorization(self, resource: Resource):
        auths = ExternalServiceAccessAuthorization.find_by_user_and_resource(
            self, resource)
        # check for sub-resource authorizations
        for subresource in ["api"]:
            if hasattr(resource, subresource):
                auths.extend(
                    ExternalServiceAccessAuthorization.
                    find_by_user_and_resource(self,
                                              getattr(resource, subresource)))
        return auths

    @property
    def current_identity(self):
        from .services import current_registry, current_user
        if not current_user.is_anonymous:
            return self.oauth_identity
        if current_registry:
            for p, i in self.oauth_identity.items():
                if i.provider == current_registry.server_credentials:
                    return {p: i}
        return None

    @property
    def password(self):
        raise AttributeError("password is not a readable attribute")

    @password.setter
    def password(self, password):
        self.password_hash = generate_password_hash(password)

    @password.deleter
    def password(self):
        self.password_hash = None

    @property
    def has_password(self):
        return bool(self.password_hash)

    def has_permission(self, resource: Resource) -> bool:
        return self.get_permission(resource) is not None

    def get_permission(self, resource: Resource) -> Permission:
        return next((p for p in self.permissions if p.resource == resource),
                    None)

    def verify_password(self, password):
        return check_password_hash(self.password_hash, password)

    def save(self):
        db.session.add(self)
        db.session.commit()

    def to_dict(self):
        return {
            "id": self.id,
            "username": self.username,
            "identities":
            {n: i.user_info
             for n, i in self.oauth_identity.items()}
        }

    @classmethod
    def find_by_username(cls, username):
        return cls.query.filter(cls.username == username).first()

    @classmethod
    def all(cls):
        return cls.query.all()
示例#10
0
        for subresource in ["api"]:
            if hasattr(self, subresource):
                auths.extend(
                    ExternalServiceAccessAuthorization.
                    find_by_user_and_resource(self, getattr(self,
                                                            subresource)))
        return auths

    @classmethod
    def find_by_uuid(cls, uuid):
        return cls.query.filter(cls.uuid == lm_utils.uuid_param(uuid)).first()


resource_authorization_table = db.Table(
    'resource_authorization', db.Model.metadata,
    db.Column('resource_id', db.Integer,
              db.ForeignKey("resource.id", ondelete="CASCADE")),
    db.Column(
        'authorization_id', db.Integer,
        db.ForeignKey("external_service_access_authorization.id",
                      ondelete="CASCADE")))


class RoleType:
    owner = "owner"
    viewer = "viewer"


class Permission(db.Model):
    user_id = db.Column(db.Integer,
                        db.ForeignKey('user.id', ondelete='CASCADE'),
                        primary_key=True)
示例#11
0
class OAuthIdentity(models.ExternalServiceAccessAuthorization, ModelMixin):
    id = db.Column(db.Integer,
                   db.ForeignKey('external_service_access_authorization.id'),
                   primary_key=True)
    provider_user_id = db.Column(db.String(256), nullable=False)
    provider_id = db.Column(db.Integer,
                            db.ForeignKey("oauth2_identity_provider.id"),
                            nullable=False)
    created_at = db.Column(DateTime, default=datetime.utcnow, nullable=False)
    token = db.Column(JSON, nullable=True)
    _user_info = None
    provider = db.relationship("OAuth2IdentityProvider",
                               uselist=False,
                               back_populates="identities")
    user = db.relationship(
        models.User,
        # This `backref` thing sets up an `oauth` property on the User model,
        # which is a dictionary of OAuth models associated with that user,
        # where the dictionary key is the OAuth provider name.
        backref=db.backref(
            "oauth_identity",
            collection_class=attribute_mapped_collection("provider.name"),
            cascade="all, delete-orphan",
        ),
    )

    __table_args__ = (db.UniqueConstraint("provider_id", "provider_user_id"), )
    __tablename__ = "oauth2_identity"
    __mapper_args__ = {'polymorphic_identity': 'oauth2_identity'}

    def __init__(self, provider, user_info, provider_user_id, token):
        super().__init__(self.user)
        self.provider = provider
        self.provider_user_id = provider_user_id
        self._user_info = user_info
        self.token = token
        self.resources.append(provider.api_resource)

    def as_http_header(self):
        return f"{self.provider.token_type} {self.token['access_token']}"

    @property
    def username(self):
        return f"{self.provider.name}_{self.provider_user_id}"

    @property
    def user_info(self):
        if not self._user_info:
            self._user_info = self.provider.get_user_info(
                self.provider_user_id, self.token)
        return self._user_info

    @user_info.setter
    def user_info(self, value):
        self._user_info = value

    def set_token(self, token):
        self.token = token

    def __repr__(self):
        parts = []
        parts.append(self.__class__.__name__)
        if self.id:
            parts.append("id={}".format(self.id))
        if self.provider:
            parts.append('provider="{}"'.format(self.provider))
        return "<{}>".format(" ".join(parts))

    @staticmethod
    def find_by_user_id(user_id, provider_name) -> OAuthIdentity:
        try:
            return OAuthIdentity.query\
                .filter(OAuthIdentity.provider.has(name=provider_name))\
                .filter_by(user_id=user_id).one()
        except NoResultFound:
            raise OAuthIdentityNotFoundException(f"{user_id}_{provider_name}")

    @staticmethod
    def find_by_provider_user_id(provider_user_id,
                                 provider_name) -> OAuthIdentity:
        try:
            return OAuthIdentity.query\
                .filter(OAuthIdentity.provider.has(name=provider_name))\
                .filter_by(provider_user_id=provider_user_id).one()
        except NoResultFound:
            raise OAuthIdentityNotFoundException(
                f"{provider_name}_{provider_user_id}")

    @classmethod
    def all(cls) -> List[OAuthIdentity]:
        return cls.query.all()
示例#12
0
class OAuth2IdentityProvider(db.Model, ModelMixin):

    id = db.Column(db.Integer, primary_key=True)
    _type = db.Column("type", db.String, nullable=False)
    name = db.Column(db.String, nullable=False, unique=True)
    client_id = db.Column(db.String, nullable=False)
    client_secret = db.Column(db.String, nullable=False)
    client_kwargs = db.Column(JSON, nullable=True)
    _authorize_url = db.Column("authorize_url", db.String, nullable=False)
    authorize_params = db.Column(JSON, nullable=True)
    _access_token_url = db.Column("access_token_url",
                                  db.String,
                                  nullable=False)
    access_token_params = db.Column(JSON, nullable=True)
    userinfo_endpoint = db.Column(db.String, nullable=False)
    api_resource_id = db.Column(db.Integer,
                                db.ForeignKey("resource.id"),
                                nullable=False)
    api_resource = db.relationship("Resource", cascade="all, delete")
    identities = db.relationship("OAuthIdentity",
                                 back_populates="provider",
                                 cascade="all, delete")

    __tablename__ = "oauth2_identity_provider"
    __mapper_args__ = {
        'polymorphic_on': _type,
        'polymorphic_identity': 'oauth2_identity_provider'
    }

    def __init__(self,
                 name,
                 client_id,
                 client_secret,
                 api_base_url,
                 authorize_url,
                 access_token_url,
                 userinfo_endpoint,
                 client_kwargs=None,
                 authorize_params=None,
                 access_token_params=None,
                 **kwargs):
        self.name = name
        self.client_id = client_id
        self.client_secret = client_secret
        self.api_resource = models.Resource(api_base_url, name=self.name)
        self.client_kwargs = client_kwargs
        self.authorize_url = authorize_url
        self.access_token_url = access_token_url
        self.access_token_params = access_token_params
        self.userinfo_endpoint = urljoin(api_base_url, userinfo_endpoint)

    @property
    def type(self):
        return self._type

    @property
    def token_type(self):
        return "Bearer"

    def get_user_info(self, provider_user_id, token, normalized=True):
        access_token = token['access_token'] if isinstance(token,
                                                           dict) else token
        response = requests.get(
            urljoin(self.api_base_url, self.userinfo_endpoint),
            headers={'Authorization': f'Bearer {access_token}'})
        if response.status_code in (401, 403):
            raise NotAuthorizedException(
                detail=f"Unable to get user info from provider {self.name}")
        if response.status_code != 200:
            raise LifeMonitorException(details=response.content)
        try:
            data = response.json()
        except Exception as e:
            raise LifeMonitorException(title="Unable to decode user data",
                                       details=str(e))
        return data if not normalized \
            else self.normalize_userinfo(OAuth2Registry.get_instance().get_client(self.name), data)

    @property
    def api_base_url(self):
        return self.api_resource.uri

    @api_base_url.setter
    def api_base_url(self, api_base_url):
        assert api_base_url and len(api_base_url) > 0, "URL cannot be empty"
        self.uri = api_base_url
        self.api_resource.uri = api_base_url

    @hybrid_property
    def authorize_url(self):
        return self._authorize_url

    @authorize_url.setter
    def authorize_url(self, authorize_url):
        assert authorize_url and len(authorize_url) > 0, "URL cannot be empty"
        self._authorize_url = urljoin(self.api_base_url, authorize_url)

    @hybrid_property
    def access_token_url(self):
        return self._access_token_url

    @access_token_url.setter
    def access_token_url(self, token_url):
        assert token_url and len(token_url) > 0, "URL cannot be empty"
        self._access_token_url = urljoin(self.api_base_url, token_url)

    @property
    def oauth_config(self):
        return {
            'client_id': self.client_id,
            'client_secret': self.client_secret,
            'client_kwargs': self.client_kwargs,
            'api_base_url': self.api_base_url,
            'authorize_url': self.authorize_url,
            'authorize_params': self.authorize_params,
            'access_token_url': self.access_token_url,
            'access_token_params': self.access_token_params,
            'userinfo_endpoint': self.userinfo_endpoint,
            'userinfo_compliance_fix': self.normalize_userinfo,
        }

    def normalize_userinfo(self, client, data):
        errors = []
        for client_type in (self.name, self.type):
            logger.debug(f"Searching with {client_type}")
            try:
                m = f"lifemonitor.auth.oauth2.client.providers.{client_type}"
                mod = import_module(m)
                return getattr(mod, "normalize_userinfo")(client, data)
            except ModuleNotFoundError:
                errors.append(
                    f"ModuleNotFoundError: Unable to load module {m}")
            except AttributeError:
                errors.append(
                    f"Unable to create an instance of WorkflowRegistryClient from module {m}"
                )

        raise LifeMonitorException(
            f"Unable to load utility to normalize user info from provider {self.name}"
        )

    def find_identity_by_provider_user_id(self, provider_user_id):
        try:
            return OAuthIdentity.query.with_parent(self)\
                .filter_by(provider_user_id=provider_user_id).one()
        except NoResultFound:
            raise OAuthIdentityNotFoundException(f"{provider_user_id}@{self}")

    @classmethod
    def find(cls, name) -> OAuth2IdentityProvider:
        try:
            return cls.query.filter(cls.name == name).one()
        except NoResultFound:
            raise EntityNotFoundException(cls, entity_id=name)

    @classmethod
    def all(cls) -> List[OAuth2IdentityProvider]:
        return cls.query.all()
示例#13
0
class AuthorizationCode(db.Model, OAuth2AuthorizationCodeMixin):
    id = db.Column(db.Integer, primary_key=True)
    user_id = db.Column(db.Integer, db.ForeignKey('user.id',
                                                  ondelete='CASCADE'))
    user = db.relationship('User')
示例#14
0
class OAuthIdentity(models.ExternalServiceAccessAuthorization, ModelMixin):
    id = db.Column(db.Integer,
                   db.ForeignKey('external_service_access_authorization.id'),
                   primary_key=True)
    provider_user_id = db.Column(db.String(256), nullable=False)
    provider_id = db.Column(db.Integer,
                            db.ForeignKey("oauth2_identity_provider.id"),
                            nullable=False)
    created_at = db.Column(DateTime, default=datetime.utcnow, nullable=False)
    _token = db.Column("token", JSON, nullable=True)
    _user_info = None
    provider = db.relationship("OAuth2IdentityProvider",
                               uselist=False,
                               back_populates="identities")
    user = db.relationship(
        models.User,
        # This `backref` thing sets up an `oauth` property on the User model,
        # which is a dictionary of OAuth models associated with that user,
        # where the dictionary key is the OAuth provider name.
        backref=db.backref(
            "oauth_identity",
            collection_class=attribute_mapped_collection("provider.name"),
            cascade="all, delete-orphan",
        ),
    )

    __table_args__ = (db.UniqueConstraint("provider_id", "provider_user_id"), )
    __tablename__ = "oauth2_identity"
    __mapper_args__ = {'polymorphic_identity': 'oauth2_identity'}

    def __init__(self, provider, user_info, provider_user_id, token):
        super().__init__(self.user)
        self.provider = provider
        self.provider_user_id = provider_user_id
        self._user_info = user_info
        self.token = token
        self.resources.append(provider.api_resource)

    def as_http_header(self):
        return f"{self.provider.token_type} {self.fetch_token()['access_token']}"

    @property
    def username(self):
        return f"{self.provider.name}_{self.provider_user_id}"

    @property
    def token(self) -> OAuth2Token:
        return OAuth2Token(self._token)

    @token.setter
    def token(self, token: dict):
        self._token = token

    def fetch_token(self):
        # enable dynamic refresh only if the identity
        # has been already stored in the database
        if inspect(self).persistent:
            # fetch up to date identity data
            self.refresh()
            # reference to the token associated with the identity instance
            token = self.token
            # the token should be refreshed
            # if it is expired or close to expire (i.e., n secs before expiration)
            if token.to_be_refreshed():
                if 'refresh_token' not in token:
                    logger.debug(
                        "The token should be refreshed but no refresh token is associated with the token"
                    )
                else:
                    logger.debug("Trying to refresh the token...")
                    oauth2session = OAuth2Session(self.provider.client_id,
                                                  self.provider.client_secret,
                                                  token=self.token)
                    new_token = oauth2session.refresh_token(
                        self.provider.access_token_url,
                        refresh_token=token['refresh_token'])
                    self.token = new_token
                    self.save()
                    logger.debug("User token updated")
                    logger.debug("Using token %r", self.token)
        return self.token

    @property
    def user_info(self):
        if not self._user_info:
            logger.debug(
                "[Identity %r], Trying to read profile of user %r from provider %r...",
                self.id, self.user_id, self.provider.name)
            self._user_info = self.provider.get_user_info(
                self.provider_user_id, self.fetch_token())
        return self._user_info

    @user_info.setter
    def user_info(self, value):
        self._user_info = value

    def __repr__(self):
        parts = []
        parts.append(self.__class__.__name__)
        if self.id:
            parts.append("id={}".format(self.id))
        if self.provider:
            parts.append('provider="{}"'.format(self.provider))
        return "<{}>".format(" ".join(parts))

    @staticmethod
    def find_by_user_id(user_id, provider_name) -> OAuthIdentity:
        try:
            return OAuthIdentity.query\
                .filter(OAuthIdentity.provider.has(name=provider_name))\
                .filter_by(user_id=user_id).one()
        except NoResultFound:
            raise OAuthIdentityNotFoundException(f"{user_id}_{provider_name}")

    @staticmethod
    def find_by_provider_user_id(provider_user_id,
                                 provider_name) -> OAuthIdentity:
        try:
            return OAuthIdentity.query\
                .filter(OAuthIdentity.provider.has(name=provider_name))\
                .filter_by(provider_user_id=provider_user_id).one()
        except NoResultFound:
            raise OAuthIdentityNotFoundException(
                f"{provider_name}_{provider_user_id}")

    @classmethod
    def all(cls) -> List[OAuthIdentity]:
        return cls.query.all()
示例#15
0
class Client(db.Model, OAuth2ClientMixin):
    id = db.Column(db.Integer, primary_key=True)
    client_id = db.Column(db.String(48), index=True, unique=True)
    user_id = db.Column(
        db.Integer, db.ForeignKey('user.id', ondelete='CASCADE')
    )
    user = db.relationship(
        'User',
        backref=db.backref(
            "clients",
            cascade="all, delete-orphan",
        ),
    )

    __tablename__ = "oauth2_client"

    def is_confidential(self):
        return self.has_client_secret()

    def set_client_metadata(self, value):
        if not isinstance(value, dict):
            return
        data = copy.deepcopy(value)
        data['scope'] = values_as_string(value['scope'], out_separator=" ")
        for p in ('redirect_uris', 'grant_types', 'response_types', 'contacts'):
            data[p] = values_as_list(value.get(p, []))
        return super().set_client_metadata(data)

    @property
    def redirect_uris(self):
        return super().redirect_uris

    @redirect_uris.setter
    def redirect_uris(self, value):
        metadata = self.client_metadata
        metadata['redirect_uris'] = value
        self.set_client_metadata(metadata)

    @property
    def scopes(self):
        return self.scope.split(" ") if self.scope else []

    @scopes.setter
    def scopes(self, scopes):
        metadata = self.client_metadata
        metadata['scope'] = scopes
        self.set_client_metadata(metadata)

    @property
    def auth_method(self):
        return self.client_metadata.get('token_endpoint_auth_method')

    @auth_method.setter
    def auth_method(self, value):
        metadata = self.client_metadata
        metadata['token_endpoint_auth_method'] = value
        self.set_client_metadata(metadata)

    @classmethod
    def find_by_id(cls, client_id) -> Client:
        return cls.query.get(client_id)

    @classmethod
    def all(cls) -> List[Client]:
        return cls.query.all()