Пример #1
0
class PidLog(db.Model):

    """Audit log of actions happening to persistent identifiers.

    This model is primarily used through PersistentIdentifier.log and rarely
    created manually.
    """

    __tablename__ = 'pidLOG'
    __table_args__ = (
        db.Index('idx_action', 'action'),
    )

    id = db.Column(db.Integer(15, unsigned=True), primary_key=True)
    """Id of persistent identifier entry."""

    id_pid = db.Column(
        db.Integer(15, unsigned=True), db.ForeignKey(PersistentIdentifier.id),
        nullable=True,
    )
    """PID."""

    timestamp = db.Column(db.DateTime(), nullable=False, default=datetime.now)
    """Creation datetime of entry."""

    action = db.Column(db.String(10), nullable=False)
    """Action identifier."""

    message = db.Column(db.Text(), nullable=False)
    """Log message."""

    # Relationship
    pid = db.relationship("PersistentIdentifier", backref="logs")
Пример #2
0
class UserQuery(db.Model):
    """Represent a UserQuery record."""

    __tablename__ = 'user_query'
    id_user = db.Column(db.Integer(15, unsigned=True),
                        db.ForeignKey(User.id),
                        primary_key=True,
                        server_default='0')
    id_query = db.Column(db.Integer(15, unsigned=True),
                         db.ForeignKey(WebQuery.id),
                         primary_key=True,
                         index=True,
                         server_default='0')
    hostname = db.Column(db.String(50),
                         nullable=True,
                         server_default='unknown host')
    date = db.Column(db.DateTime, nullable=True, default=datetime.datetime.now)

    webquery = db.relationship(WebQuery, backref='executions')

    @classmethod
    def log(cls, urlargs=None, id_user=None):
        """Log."""
        id_user = id_user if not None else current_user.get_id()
        urlargs = urlargs or request.query_string
        if id_user < 0:
            return
        webquery = WebQuery.query.filter_by(urlargs=urlargs).first()
        if webquery is None:
            webquery = WebQuery(urlargs=urlargs)
        db.session.add(
            cls(id_user=id_user, hostname=request.host, webquery=webquery))
        db.session.commit()
Пример #3
0
def do_upgrade():
    """Implement your upgrades here."""
    with op.batch_alter_table("wtgTAG") as batch_op:
        batch_op.alter_column(column_name='id_usergroup',
                              type_=db.Integer(15, unsigned=True),
                              nullable=True)
        batch_op.alter_column(column_name='id_user',
                              type_=db.Integer(15, unsigned=True),
                              nullable=True)
class CmtCOLLAPSED(db.Model):
    """Represents a CmtCOLLAPSED record."""

    __tablename__ = 'cmtCOLLAPSED'

    id_bibrec = db.Column(db.MediumInteger(8, unsigned=True),
                          db.ForeignKey(Bibrec.id),
                          primary_key=True)
    id_cmtRECORDCOMMENT = db.Column(db.Integer(15, unsigned=True),
                                    db.ForeignKey(CmtRECORDCOMMENT.id),
                                    primary_key=True)
    id_user = db.Column(db.Integer(15, unsigned=True),
                        db.ForeignKey(User.id),
                        primary_key=True)
Пример #5
0
class UserAccROLE(db.Model):

    """Represent an user role relationship."""

    __tablename__ = 'user_accROLE'
    id_user = db.Column(db.Integer(15, unsigned=True), db.ForeignKey(User.id),
                        nullable=False, primary_key=True)
    id_accROLE = db.Column(db.Integer(15, unsigned=True),
                           db.ForeignKey(AccROLE.id), nullable=False,
                           primary_key=True)
    expiration = db.Column(db.DateTime, nullable=False,
                           server_default='9999-12-31 23:59:59')

    user = db.relationship(User, backref='roles')
    role = db.relationship(AccROLE, backref='users')
Пример #6
0
class Session(db.Model):
    """Represent Session record."""

    __tablename__ = 'session'
    session_key = db.Column(db.String(32),
                            nullable=False,
                            server_default='',
                            primary_key=True)
    session_expiry = db.Column(db.DateTime, nullable=True, index=True)
    session_object = db.Column(db.LargeBinary, nullable=True)
    uid = db.Column(db.Integer(15, unsigned=True), nullable=False, index=True)

    def get_session(self, name, expired=False):
        """Return an instance of :class:`Session`."""
        where = Session.session_key == name
        if expired:
            where = db.and_(
                where, Session.session_expiry >= db.func.current_timestamp())
        return self.query.filter(where).one()

    def set_session(self, name, value, timeout=None):
        """Store value in database."""
        uid = current_user.get_id()
        session_expiry = datetime.utcnow() + timeout
        return Session(session_key=name,
                       session_object=value,
                       session_expiry=session_expiry,
                       uid=uid)
def do_upgrade():
    """ Implement your upgrades here  """
    if not op.has_table('remoteACCOUNT'):
        op.create_table('remoteACCOUNT',
                        db.Column('id',
                                  db.Integer(display_width=15),
                                  nullable=False),
                        db.Column('user_id',
                                  db.Integer(display_width=15),
                                  nullable=False),
                        db.Column('client_id',
                                  db.String(length=255),
                                  nullable=False),
                        db.Column('extra_data', db.JSON, nullable=True),
                        db.ForeignKeyConstraint(
                            ['user_id'],
                            ['user.id'],
                        ),
                        db.PrimaryKeyConstraint('id'),
                        db.UniqueConstraint('user_id', 'client_id'),
                        mysql_charset='utf8',
                        mysql_engine='MyISAM')
    else:
        warnings.warn("*** Creation of table 'remoteACCOUNT table skipped!'")

    if not op.has_table('remoteTOKEN'):
        op.create_table('remoteTOKEN',
                        db.Column('id_remote_account',
                                  db.Integer(display_width=15),
                                  nullable=False),
                        db.Column('token_type',
                                  db.String(length=40),
                                  nullable=False),
                        db.Column('access_token', db.Text(), nullable=False),
                        db.Column('secret', db.Text(), nullable=False),
                        db.ForeignKeyConstraint(
                            ['id_remote_account'],
                            ['remoteACCOUNT.id'],
                        ),
                        db.PrimaryKeyConstraint('id_remote_account',
                                                'token_type'),
                        mysql_charset='utf8',
                        mysql_engine='MyISAM')
    else:
        warnings.warn("*** Creation of table 'remoteTOKEN' skipped!'")
Пример #8
0
def do_upgrade():
    """Carry out the upgrade."""
    op.alter_column(table_name='oauth2TOKEN',
                    column_name='client_id',
                    type_=db.String(255),
                    existing_nullable=False)
    op.alter_column(table_name='oauth2TOKEN',
                    column_name='user_id',
                    type_=db.Integer(15, unsigned=True),
                    existing_nullable=False)
Пример #9
0
class WebQuery(db.Model):
    """Represent a WebQuery record."""

    __tablename__ = 'query'
    id = db.Column(db.Integer(15, unsigned=True),
                   primary_key=True,
                   autoincrement=True)
    type = db.Column(db.Char(1), nullable=False, server_default='r')
    urlargs = db.Column(db.Text().with_variant(db.Text(100), 'mysql'),
                        nullable=False)
Пример #10
0
class WtgTAGRecord(db.Model, Serializable):

    """Connection between Tag and Record."""

    __tablename__ = 'wtgTAG_bibrec'
    __public__ = set(['id_tag', 'id_bibrec', 'date_added'])

    # tagTAG.id
    id_tag = db.Column(db.Integer(15, unsigned=True),
                       db.ForeignKey(WtgTAG.id),
                       nullable=False,
                       primary_key=True)

    # Record.id
    id_bibrec = db.Column(db.MediumInteger(8, unsigned=True),
                          db.ForeignKey(Record.id),
                          nullable=False,
                          primary_key=True)

    # Annotation
    annotation = db.Column(
        db.Text(convert_unicode=True),
        default='')

    # Creation date
    date_added = db.Column(db.DateTime,
                           default=datetime.now)

    # Relationships
    tag = db.relationship(WtgTAG,
                          backref=db.backref('records_association',
                                             cascade='all'))

    tag_query = db.relationship(WtgTAG,
                                backref=db.backref('records_association_query',
                                                   cascade='all',
                                                   lazy='dynamic'))

    bibrec = db.relationship(Record,
                             backref=db.backref('tags_association',
                                                cascade='all'))

    bibrec_query = db.relationship(Record,
                                   backref=db.backref('tags_association_query',
                                                      cascade='all',
                                                      lazy='dynamic'))

    def __init__(self, bibrec=None, **kwargs):
        """TODO."""
        super(WtgTAGRecord, self).__init__(**kwargs)

        if bibrec is not None:
            self.bibrec = bibrec
Пример #11
0
class AccAuthorization(db.Model):

    """Represent an authorization."""

    __tablename__ = 'accROLE_accACTION_accARGUMENT'
    id = db.Column(db.Integer(15, unsigned=True), primary_key=True,
                   autoincrement=True)
    id_accROLE = db.Column(db.Integer(15, unsigned=True),
                           db.ForeignKey(AccROLE.id), nullable=True,
                           index=True)
    id_accACTION = db.Column(db.Integer(15, unsigned=True),
                             db.ForeignKey(AccACTION.id), nullable=True,
                             index=True)
    _id_accARGUMENT = db.Column(db.Integer(15), nullable=True,
                                name="id_accARGUMENT", index=True)
    argumentlistid = db.Column(db.MediumInteger(8), nullable=True)

    role = db.relationship(AccROLE, backref='authorizations')
    action = db.relationship(AccACTION, backref='authorizations')
    argument = db.relationship(
        AccARGUMENT, backref='authorizations',
        primaryjoin=db.and_(
            AccARGUMENT.id == _id_accARGUMENT,
            _id_accARGUMENT != -1,
            _id_accARGUMENT is not None
        ),
        foreign_keys=_id_accARGUMENT,
        uselist=False,
        cascade="all, delete",
    )

    @db.hybrid_property
    def id_accARGUMENT(self):
        """get id_accARGUMENT."""
        return self._id_accARGUMENT

    @id_accARGUMENT.setter
    def id_accARGUMENT(self, value):
        """set id_accARGUMENT."""
        self._id_accARGUMENT = value or None
Пример #12
0
class AccARGUMENT(db.Model):

    """Represent an authorization argument."""

    __tablename__ = 'accARGUMENT'
    id = db.Column(db.Integer(15), primary_key=True, autoincrement=True)
    keyword = db.Column(db.String(32), nullable=True)
    value = db.Column(db.String(255), nullable=True)
    __table_args__ = (db.Index('KEYVAL', keyword, value),
                      db.Model.__table_args__)

    def __repr__(self):
        """Repr."""
        return "{0.keyword}={0.value}".format(self)
class CmtACTIONHISTORY(db.Model):
    """Represents a CmtACTIONHISTORY record."""

    __tablename__ = 'cmtACTIONHISTORY'
    id_cmtRECORDCOMMENT = db.Column(db.Integer(15, unsigned=True),
                                    db.ForeignKey(CmtRECORDCOMMENT.id),
                                    nullable=True,
                                    primary_key=True)
    id_bibrec = db.Column(db.MediumInteger(8, unsigned=True),
                          db.ForeignKey(Bibrec.id),
                          nullable=True,
                          primary_key=True)
    id_user = db.Column(db.Integer(15, unsigned=True),
                        db.ForeignKey(User.id),
                        nullable=True,
                        primary_key=True)
    client_host = db.Column(db.Integer(10, unsigned=True), nullable=True)
    action_time = db.Column(db.DateTime,
                            nullable=False,
                            server_default='1900-01-01 00:00:00')
    action_code = db.Column(db.Char(1), nullable=False, index=True)
    recordcomment = db.relationship(CmtRECORDCOMMENT, backref='actionhistory')
    bibrec = db.relationship(Bibrec)
    user = db.relationship(User)
Пример #14
0
class AccROLE(db.Model):

    """Represent an access role."""

    __tablename__ = 'accROLE'
    id = db.Column(db.Integer(15, unsigned=True), primary_key=True,
                   autoincrement=True)
    name = db.Column(db.String(32), unique=True, nullable=True)
    description = db.Column(db.String(255), nullable=True)
    firerole_def_ser = db.Column(db.iBinary, nullable=True)
    firerole_def_src = db.Column(db.Text, nullable=True)

    def __repr__(self):
        """Repr."""
        return "{0.name} - {0.description}".format(self)
Пример #15
0
class AccACTION(db.Model):

    """Represent an access action."""

    __tablename__ = 'accACTION'
    id = db.Column(db.Integer(15, unsigned=True),
                   primary_key=True, autoincrement=True)
    name = db.Column(db.String(32), unique=True, nullable=True)
    description = db.Column(db.String(255), nullable=True)
    allowedkeywords = db.Column(db.String(255), nullable=True)
    optional = db.Column(db.Enum('yes', 'no', name='yes_no'), nullable=False,
                         server_default='no')

    def __repr__(self):
        """Repr."""
        return "{0.name}".format(self)
class UserEXT(db.Model):
    """Represent a UserEXT record."""

    __tablename__ = 'userEXT'

    id = db.Column(db.String(255), primary_key=True, nullable=False)
    method = db.Column(db.String(50), primary_key=True, nullable=False)
    id_user = db.Column(db.Integer(15, unsigned=True),
                        db.ForeignKey(User.id),
                        nullable=False)

    user = db.relationship(User, backref="external_identifiers")

    __table_args__ = (db.Index('userext_id_user_method',
                               id_user,
                               method,
                               unique=True), db.Model.__table_args__)
class CmtSUBSCRIPTION(db.Model):
    """Represents a CmtSUBSCRIPTION record."""

    __tablename__ = 'cmtSUBSCRIPTION'

    id_bibrec = db.Column(db.MediumInteger(8, unsigned=True),
                          db.ForeignKey(Bibrec.id),
                          nullable=False,
                          primary_key=True)
    id_user = db.Column(db.Integer(15, unsigned=True),
                        db.ForeignKey(User.id),
                        nullable=False,
                        primary_key=True)
    creation_time = db.Column(db.DateTime,
                              nullable=False,
                              server_default='1900-01-01 00:00:00')

    bibrec = db.relationship(Bibrec)
    user = db.relationship(User, backref='comment_subscriptions')
class User(db.Model):
    """Represents a User record."""
    def __str__(self):
        """Return string representation."""
        return "%s <%s>" % (self.nickname, self.email)

    __tablename__ = 'user'
    __mapper_args__ = {'confirm_deleted_rows': False}

    id = db.Column(db.Integer(15, unsigned=True),
                   primary_key=True,
                   autoincrement=True)
    email = db.Column(db.String(255),
                      nullable=False,
                      server_default='',
                      index=True)
    _password = db.Column(db.String(255), name="password", nullable=True)
    password_salt = db.Column(db.String(255))
    password_scheme = db.Column(db.String(50), nullable=False, index=True)

    _note = db.Column(db.String(255), name="note", nullable=True)
    given_names = db.Column(db.String(255), nullable=False, server_default='')
    family_name = db.Column(db.String(255), nullable=False, server_default='')
    settings = db.Column(db.MutableDict.as_mutable(
        db.MarshalBinary(default_value=get_default_user_preferences,
                         force_type=dict)),
                         nullable=True)
    nickname = db.Column(db.String(255),
                         nullable=False,
                         server_default='',
                         index=True)
    last_login = db.Column(db.DateTime,
                           nullable=False,
                           server_default='1900-01-01 00:00:00')

    PROFILE_FIELDS = ['nickname', 'email', 'family_name', 'given_names']
    """List of fields that can be updated with update_profile."""

    @staticmethod
    def check_nickname(nickname):
        """Check if it's a valid nickname."""
        re_invalid_nickname = re.compile(""".*[,'@]+.*""")
        return bool(nickname) and not nickname.startswith(' ') and \
            not nickname.endswith(' ') and \
            nickname.lower() != 'guest' and \
            not re_invalid_nickname.match(nickname)

    @staticmethod
    def check_email(email):
        """Check if it's a valid email."""
        r = re.compile(r'(.)+\@(.)+\.(.)+')
        return bool(email) and r.match(email) and not email.find(" ") > 0

    @hybrid_property
    def note(self):
        """Return the note."""
        return self._note

    @note.setter
    def note(self, note):
        """Set the note."""
        self._note = str(note)

    @hybrid_property
    def password(self):
        """Return the password."""
        return self._password

    @password.setter
    def password(self, password):
        """Set the password."""
        if password is None:
            # Unusable password.
            self._password = None
            self.password_scheme = ''
        else:
            self._password = password_context.encrypt(password)
            self.password_scheme = password_context.default_scheme()

        # Invenio legacy salt is stored in password_salt, and every new
        # password set will be migrated to new hash not relying on
        # password_salt, thus is force to empty value.
        self.password_salt = ""

    def verify_password(self, password, migrate=False):
        """Verify if password matches the stored password hash."""
        if self.password is None or password is None:
            return False

        # Invenio 1.x legacy needs externally store password salt to compute
        # hash.
        scheme_ctx = {} if \
            self.password_scheme != invenio_aes_encrypted_email.name else \
            {'user': self.password_salt}

        # Verify password
        if not password_context.verify(
                password, self.password, scheme=self.password_scheme, **
                scheme_ctx):
            return False

        # Migrate hash if needed.
        if migrate and password_context.needs_update(self.password):
            self.password = password
            try:
                db.session.commit()
            except Exception:
                db.session.rollback()
                raise

        return True

    def verify_email(self, force=False):
        """Verify email address."""
        if force or self.note == "2":
            if self.note != "2":
                self.note = 2
                try:
                    db.session.commit()
                except Exception:
                    db.session.rollback()
                    raise
            send_account_activation_email(self)
            return True
        return False

    def update_profile(self, data):
        """Update user profile.

        Sends signal to allow other modules to subscribe to changes.
        """
        changed_attrs = {}
        for field in self.PROFILE_FIELDS:
            if field in data and getattr(self, field) != data[field]:
                changed_attrs[field] = getattr(self, field)
                setattr(self, field, data[field])

        if 'email' in changed_attrs:
            self.verify_email(force=True)

        try:
            db.session.commit()
        except Exception:
            db.session.rollback()
            raise
        current_user.reload()
        profile_updated.send(sender=self.id,
                             user=self,
                             changed_attrs=changed_attrs)

        return changed_attrs

    @property
    def guest(self):
        """Return True if the user is a guest."""
        return False if self.email else True

    #
    # Basic functions for user authentification.
    #
    def get_id(self):
        """Return the id."""
        return self.id

    def is_confirmed(self):
        """Return true if accounts has been confirmed."""
        return self.note == "1"

    def is_guest(self):
        """Return if the user is a guest."""
        return self.guest

    def is_authenticated(self):
        """Return True if user is a authenticated user."""
        return True if self.email else False

    def is_active(self):
        """Return True if use is active."""
        return self.note != "0"
Пример #19
0
class WtgTAG(db.Model, Serializable):

    """A Tag."""

    __tablename__ = 'wtgTAG'
    __public__ = set(['id', 'name', 'id_owner'])

    #
    # Access Rights
    #
    ACCESS_NAMES = {
        0: 'Nothing',
        10: 'View',
        20: 'Add',
        30: 'Add and remove',
        40: 'Manage',
    }

    ACCESS_LEVELS = \
        dict((v, k) for (k, v) in iteritems(ACCESS_NAMES))

    ACCESS_RIGHTS = {
        0: [],
        10: ['view'],
        20: ['view', 'add'],
        30: ['view', 'add', 'remove'],
        40: ['view', 'add', 'remove', 'edit'],
    }

    ACCESS_OWNER_DEFAULT = ACCESS_LEVELS['Manage']
    ACCESS_GROUP_DEFAULT = ACCESS_LEVELS['View']

    # Primary key
    id = db.Column(db.Integer(15, unsigned=True),
                   primary_key=True,
                   nullable=False,
                   autoincrement=True)

    # Name
    name = db.Column(db.String(255),
                     nullable=False,
                     server_default='',
                     index=True)

    # Owner
    id_user = db.Column(db.Integer(15, unsigned=True),
                        db.ForeignKey(User.id),
                        nullable=True)

    # Access rights of owner
    user_access_rights = db.Column(db.Integer(2, unsigned=True),
                                   nullable=False,
                                   default=ACCESS_OWNER_DEFAULT)

    # Group
    # equal to NULL for private tags
    id_usergroup = db.Column(
        db.Integer(15, unsigned=True),
        db.ForeignKey(Group.id),
        nullable=True)

    # Group access rights
    group_access_rights = db.Column(
        db.Integer(2, unsigned=True),
        nullable=False,
        default=ACCESS_GROUP_DEFAULT)

    # Access rights of everyone
    public_access_rights = db.Column(db.Integer(2, unsigned=True),
                                     nullable=False,
                                     default=ACCESS_LEVELS['Nothing'])

    # Visibility in document description
    show_in_description = db.Column(db.Boolean,
                                    nullable=False,
                                    default=True)

    # Relationships
    user = db.relationship(User,
                           backref=db.backref('tags', cascade='all'))

    user_query = db.relationship(User,
                                 backref=db.backref('tags_query',
                                                    cascade='all',
                                                    lazy='dynamic'))

    usergroup = db.relationship(
        Group,
        backref=db.backref('tags', cascade='all'))

    # association proxy of "user_keywords" collection
    # to "keyword" attribute
    records = association_proxy('records_association', 'bibrec')

    # Calculated fields
    @db.hybrid_property
    def record_count(self):
        """TODO."""
        return self.records_association_query.count()

    @record_count.expression
    def record_count(cls):
        """TODO."""
        return db.select([db.func.count(WtgTAGRecord.id_bibrec)]) \
                 .where(WtgTAGRecord.id_tag == cls.id) \
                 .label('record_count')

    @db.validates('user_access_rights')
    @db.validates('group_access_rights')
    @db.validates('public_access_rights')
    def validate_user_access_rights(self, key, value):
        """Check if the value is among defined levels."""
        assert value in WtgTAG.ACCESS_NAMES
        return value
class CmtRECORDCOMMENT(db.Model):
    """Represents a CmtRECORDCOMMENT record."""

    __tablename__ = 'cmtRECORDCOMMENT'

    id = db.Column(db.Integer(15, unsigned=True),
                   nullable=False,
                   primary_key=True,
                   autoincrement=True)
    id_bibrec = db.Column(db.MediumInteger(8, unsigned=True),
                          db.ForeignKey(Bibrec.id),
                          nullable=False,
                          server_default='0')
    id_user = db.Column(db.Integer(15, unsigned=True),
                        db.ForeignKey(User.id),
                        nullable=False,
                        server_default='0')
    title = db.Column(db.String(255), nullable=False, server_default='')
    body = db.Column(db.Text, nullable=False)
    date_creation = db.Column(db.DateTime,
                              nullable=False,
                              server_default='1900-01-01 00:00:00')
    star_score = db.Column(db.TinyInteger(5, unsigned=True),
                           nullable=False,
                           server_default='0')
    nb_votes_yes = db.Column(db.Integer(10),
                             nullable=False,
                             server_default='0')
    nb_votes_total = db.Column(db.Integer(10, unsigned=True),
                               nullable=False,
                               server_default='0')
    nb_abuse_reports = db.Column(db.Integer(10),
                                 nullable=False,
                                 server_default='0')
    status = db.Column(db.Char(2),
                       nullable=False,
                       index=True,
                       server_default='ok')
    round_name = db.Column(db.String(255), nullable=False, server_default='')
    restriction = db.Column(db.String(50), nullable=False, server_default='')
    in_reply_to_id_cmtRECORDCOMMENT = db.Column(db.Integer(15, unsigned=True),
                                                db.ForeignKey(id),
                                                nullable=False,
                                                server_default='0')
    reply_order_cached_data = db.Column(db.Binary, nullable=True)
    bibrec = db.relationship(Bibrec, backref='recordcomments')
    user = db.relationship(User, backref='recordcomments')
    replies = db.relationship('CmtRECORDCOMMENT',
                              backref=db.backref('parent',
                                                 remote_side=[id],
                                                 order_by=date_creation))

    @property
    def is_deleted(self):
        """Check if is deleted."""
        return self.status != 'ok'

    def is_collapsed(self, id_user):
        """Return true if the comment is collapsed by user."""
        return CmtCOLLAPSED.query.filter(
            db.and_(CmtCOLLAPSED.id_bibrec == self.id_bibrec,
                    CmtCOLLAPSED.id_cmtRECORDCOMMENT == self.id,
                    CmtCOLLAPSED.id_user == id_user)).count() > 0

    @session_manager
    def collapse(self, id_user):
        """Collapse comment beloging to user."""
        c = CmtCOLLAPSED(id_bibrec=self.id_bibrec,
                         id_cmtRECORDCOMMENT=self.id,
                         id_user=id_user)
        db.session.add(c)
        db.session.commit()

    def expand(self, id_user):
        """Expand comment beloging to user."""
        CmtCOLLAPSED.query.filter(
            db.and_(CmtCOLLAPSED.id_bibrec == self.id_bibrec,
                    CmtCOLLAPSED.id_cmtRECORDCOMMENT == self.id,
                    CmtCOLLAPSED.id_user == id_user)).delete(
                        synchronize_session=False)

    __table_args__ = (db.Index('cmtRECORDCOMMENT_reply_order_cached_data',
                               reply_order_cached_data,
                               mysql_length=40), db.Model.__table_args__)

    @classmethod
    def count(cls, *criteria, **filters):
        """Count how many comments."""
        return cls.query.filter(*criteria).filter_by(**filters).count()
Пример #21
0
class AccMAILCOOKIE(db.Model):

    """Represent an email cookie."""

    __tablename__ = 'accMAILCOOKIE'

    AUTHORIZATIONS_KIND = (
        'pw_reset', 'mail_activation', 'role', 'authorize_action',
        'comment_msg', 'generic'
    )

    id = db.Column(db.Integer(15, unsigned=True), primary_key=True,
                   autoincrement=True)
    _data = db.Column('data', db.iBinary, nullable=False)
    expiration = db.Column(db.DateTime, nullable=False,
                           server_default='9999-12-31 23:59:59', index=True)
    kind = db.Column(db.String(32), nullable=False)
    onetime = db.Column(db.TinyInteger(1), nullable=False, server_default='0')
    status = db.Column(db.Char(1), nullable=False, server_default='W')

    @validates('kind')
    def validate_kind(self, key, kind):
        """Validate cookie kind."""
        assert kind in self.AUTHORIZATIONS_KIND
        return kind

    @classmethod
    def get(cls, cookie, delete=False):
        """Get cookie if it is valid."""
        password = cookie[:16]+cookie[-16:]
        cookie_id = int(cookie[16:-16], 16)

        obj, data = db.session.query(
            cls,
            AccMAILCOOKIE._data
        ).filter_by(id=cookie_id).one()
        obj.data = loads(mysql_aes_decrypt(data, password))

        (kind_check, params, expiration, onetime_check) = obj.data
        assert obj.kind in cls.AUTHORIZATIONS_KIND

        if not (obj.kind == kind_check and obj.onetime == onetime_check):
            raise InvenioWebAccessMailCookieError("Cookie is corrupted")
        if obj.status == 'D':
            raise InvenioWebAccessMailCookieDeletedError(
                "Cookie has been deleted")
        if obj.onetime or delete:
            obj.status = 'D'
            db.session.merge(obj)
            db.session.commit()
        return obj

    @classmethod
    def create(cls, kind, params, cookie_timeout=timedelta(days=1),
               onetime=False):
        """Create cookie with given params."""
        expiration = datetime.today() + cookie_timeout
        data = (kind, params, expiration, onetime)
        password = md5(str(random())).hexdigest()
        cookie = cls(
            expiration=expiration,
            kind=kind,
            onetime=int(onetime),
        )
        cookie._data = mysql_aes_encrypt(dumps(data), password)
        db.session.add(cookie)
        db.session.commit()
        db.session.refresh(cookie)
        return password[:16]+hex(cookie.id)[2:-1]+password[-16:]

    @classmethod
    @session_manager
    def gc(cls):
        """Remove expired items."""
        return cls.query.filter(cls.expiration < db.func.now()).delete()
Пример #22
0
class HstEXCEPTION(db.Model):
    """Represents a HstEXCEPTION record."""
    __tablename__ = 'hstEXCEPTION'
    id = db.Column(db.Integer(15, unsigned=True),
                   nullable=False,
                   primary_key=True,
                   autoincrement=True)
    name = db.Column(db.String(50), nullable=False)
    filename = db.Column(db.String(255), nullable=True)
    line = db.Column(db.Integer(9), nullable=True)
    last_seen = db.Column(db.DateTime,
                          nullable=False,
                          server_default='1900-01-01 00:00:00',
                          index=True)
    last_notified = db.Column(db.DateTime,
                              nullable=False,
                              server_default='1900-01-01 00:00:00',
                              index=True)
    counter = db.Column(db.Integer(15), nullable=False, server_default='0')
    total = db.Column(db.Integer(15),
                      nullable=False,
                      server_default='0',
                      index=True)

    __table_args__ = (db.Index('name', name, filename, line,
                               unique=True), db.Model.__table_args__)

    @classmethod
    def get_or_create(cls, name, filename, line):
        """Finds or create exception log."""
        try:
            log = cls.query.filter_by(name=name, filename=filename,
                                      line=line).one()
            delta = datetime.datetime.now() - log.last_notified
            reset_counter = (delta.seconds + delta.days * 86400) >= \
                cfg['CFG_ERRORLIB_RESET_EXCEPTION_NOTIFICATION_COUNTER_AFTER']
            counter = 1 if reset_counter else log.counter + 1
            log.update(
                {
                    'last_notified': db.func.now(),
                    'counter': counter,
                    'total': log.total + 1
                },
                synchronize_settion=False)
            db.session.add(log)
        except:
            log = HstEXCEPTION(name=name,
                               filename=filename,
                               line=line,
                               last_seen=datetime.now(),
                               last_notified=datetime.now(),
                               counter=1,
                               total=1)
            db.session.add(log)
        try:
            db.session.commit()
        except:
            db.session.rollback()
        return log

    @property
    def exception_should_be_notified(self):
        return _is_pow_of_2(self.counter)

    @property
    def pretty_notification_info(self):
        return ("This exception has already been seen %s times\n    "
                "last time it was seen: %s\n    "
                "last time it was notified: %s\n" %
                (self.total, self.last_seen.strftime("%Y-%m-%d %H:%M:%S"),
                 self.last_notified.strftime("%Y-%m-%d %H:%M:%S")))

    @classmethod
    def get_pretty_notification_info(cls, name, filename, line):
        """
        Return a sentence describing when this exception was already seen.
        """
        try:
            return cls.query.filter_by(
                name=name, filename=filename,
                line=line).one().pretty_notification_info
        except:
            return "It is the first time this exception has been seen.\n"
Пример #23
0
class CheckerRuleExecution(db.Model):

    __tablename__ = 'checker_rule_execution'

    uuid = db.Column(
        db.String(36),
        primary_key=True,
        doc="UUID of the execution. Same with that of RedisMaster and logfile.",
    )

    owner_id = db.Column(
        db.Integer(15, unsigned=True),
        db.ForeignKey('user.id'),
        nullable=False,
        default=1,
    )
    owner = db.relationship(
        'User',
        doc="User who owns this execution. May be used by reporters.",
    )

    rule_name = db.Column(
        db.String(127),
        db.ForeignKey('checker_rule.name'),
        nullable=False,
        index=True,
        doc="Name of the associated task.",
    )

    _status = db.Column(
        ChoiceType(StatusMaster, impl=db.Integer()),
        default=StatusMaster.unknown,
    )

    status_update_date = db.Column(
        db.DateTime(),
        nullable=False,
        server_default='1900-01-01 00:00:00',
        doc="Last date the status was updated.",
    )

    start_date = db.Column(
        db.DateTime(),
        nullable=False,
        server_default='1900-01-01 00:00:00',
        doc="Date at which this task was started.",
    )

    dry_run = db.Column(
        db.Boolean,
        default=False,
        doc=
        "Whether this execution is a dry run. Note the `should_*` properties.")

    @db.hybrid_property
    def should_commit(self):
        """Whether this execution should commit record modifications."""
        return not self.dry_run

    @db.hybrid_property
    def should_report_logs(self):
        """Whether this execution should report logs to the reporters."""
        return not self.dry_run

    @db.hybrid_property
    def should_report_exceptions(self):
        """Whether this execution should report exceptions to the reporters."""
        return not self.dry_run

    @db.hybrid_property
    def master(self):
        """The master object of this execution.

        :rtype: `invenio_checker.clients.master.RedisMaster`
        """
        return RedisMaster(self.uuid)

    @db.hybrid_property
    def status(self):
        """The status of the execution.

        :rtype: `StatusMaster`
        """
        return self._status

    @status.setter
    @session_manager
    def status(self, new_status):
        """Status setter.

        :type new_status: `StatusMaster`
        """
        self._status = new_status
        self.status_update_date = datetime.now()

    def read_logs(self):
        """Stream user-friendly structured logs of this execution.

        First attempt to stream using `eliot-tree` which provides a text
        tree-like structure of the execution given its logs.

        If `eliot-tree` fails (which happens when there is an eliot Task
        serialization bug in our code, a warning is yielded, followed by the
        output of `eliot-prettyprint`.

        Therefore, the output of this function is not guaranteed to be
        machine-readable.

        ..note::
            This function may be called mid-run.
        """
        from glob import glob
        import subprocess
        from .config import get_eliot_log_file

        filenames = glob(get_eliot_log_file(master_id=self.uuid).name + "*")
        eliottree_subp = subprocess.Popen(['eliot-tree', '--field-limit', '0'],
                                          stdout=subprocess.PIPE,
                                          stdin=subprocess.PIPE)
        eliottree_failed = False
        with eliottree_subp.stdin:
            try:
                for filename in filenames:
                    with open(filename, 'r') as file_:
                        eliottree_subp.stdin.write(file_.read())
            except (IOError, MemoryError):
                eliottree_failed = True

        with eliottree_subp.stdout:
            for line in eliottree_subp.stdout:
                yield line

        if eliottree_failed or (eliottree_subp.wait() != 0):
            # eliot-tree can fail on unfinished logging. We still want output
            # for debugging, so we use the less structured eliot-prettyprint
            from eliot.prettyprint import pretty_format
            from eliot._bytesjson import loads
            yield '\n`eliot-tree` failed to format output. ' \
                'Retrying with eliot-prettyprint:\n'
            for filename in filenames:
                yield "{}:\n".format(filename)
                with open(filename, 'r') as file_:
                    for line in file_:
                        yield pretty_format(loads(line))
Пример #24
0
class CheckerRule(db.Model):
    """Represent runnable rules (also known as tasks)."""

    __tablename__ = 'checker_rule'

    name = db.Column(
        db.String(127),
        primary_key=True,
        doc="Name of the rule. Must be unique and user-friendly.",
    )

    plugin = db.Column(
        db.String(127),
        nullable=False,
        doc="Check to use. Must be importable string. Does not need to exist"
        " at task insertion time.",
    )

    arguments = db.Column(
        JsonEncodedDict(1023),
        default={},
        doc="Arguments to pass to the check.",
    )

    # XXX: Currently unsupported by search. Disabled elsewhere in the code.
    consider_deleted_records = db.Column(
        db.Boolean,
        nullable=True,
        default=False,
        doc="Whether to consider deleted records while filtering.",
    )

    filter_pattern = db.Column(
        db.String(255),
        nullable=True,
        doc="String pattern to search with to resolve records to check.",
    )

    filter_records = db.Column(
        IntBitSetType(1023),
        nullable=True,
        doc="Record IDs to run this task on.",
    )

    records = db.relationship(
        'CheckerRecord',
        backref='rule',
        cascade='all, delete-orphan',
        doc="Records which this rule has worked with in the past.",
    )

    reporters = db.relationship(
        'CheckerReporter',
        backref='rule',
        cascade='all, delete-orphan',
        doc="Reporters to be called while this task executes.",
    )

    executions = db.relationship(
        'CheckerRuleExecution',
        backref='rule',
        cascade='all, delete-orphan',
        doc="Past executions of this task. User should be free to clear them.",
    )

    last_scheduled_run = db.Column(
        db.DateTime(),
        nullable=True,
        doc="Last time this task was ran by the scheduler.",
    )

    schedule = db.Column(
        db.String(255),
        nullable=True,
        doc="Cron-style string that defines the schedule for this task.",
    )

    schedule_enabled = db.Column(
        db.Boolean,
        default=True,
        nullable=False,
        doc="Whether `schedule` is enabled.",
    )

    # TODO: You may use this column as a filter for tasks you don't want to see
    # by default in interfaces
    temporary = db.Column(
        db.Boolean,
        default=False,
        doc="Flag for tasks which will not be reused.",
    )

    force_run_on_unmodified_records = db.Column(
        db.Boolean,
        default=False,
        doc="Force a record-centric task to run on records it has checked"
        " before, even if they have already been checked in their current"
        " version.",
    )

    confirm_hash_on_commit = db.Column(
        db.Boolean,
        default=False,
        doc="Only commit recids whose hash has not changed between first"
        " requested modification and commit time.",
    )

    allow_chunking = db.Column(  # XXX unclear name (maybe "run_in_parallel")
        db.Boolean,
        default=True,
        doc="If the check is record-centric, allow checks to run in parallel.",
    )

    last_modification_date = db.Column(
        db.DateTime(),
        nullable=False,
        server_default='1900-01-01 00:00:00',
        doc="Last date on which this task was modified.",
    )

    owner_id = db.Column(
        db.Integer(15, unsigned=True),
        db.ForeignKey('user.id'),
        nullable=False,
        default=1,
    )
    owner = db.relationship(
        'User',
        doc="User that created this task. Used for scheduled tasks.",
    )

    @db.hybrid_property
    def filepath(self):
        """Resolve a the filepath of this rule's plugin/check file."""
        try:
            path = inspect.getfile(plugin_files[self.plugin])
        except KeyError:
            return None
        if path.endswith('.pyc'):
            path = path[:-1]
        return path

    @db.hybrid_property
    def modified_requested_recids(self):
        """Record IDs of records that match the filters of this task.

        This property takes (0) `requested_ids`, (1) `filter_pattern` and if
        `force_run_on_unmodified_records` is enabled (2)
        `CheckerRecord.last_run_version_id` into consideration to figure out
        which recids a record-centric task should run on.

        :rtype: intbitset
        """
        # Get all records that are already associated to this rule
        # If this is returning an empty set, you forgot to run bibindex
        try:
            associated_records = intbitset(
                zip(*db.session.query(CheckerRecord.rec_id).filter(
                    CheckerRecord.rule_name == self.name).all())[0])
        except IndexError:
            associated_records = intbitset()

        # Store requested records that were until now unknown to this rule
        requested_ids = self.requested_recids
        for requested_id in requested_ids - associated_records:
            new_record = CheckerRecord(rec_id=requested_id,
                                       rule_name=self.name)
            db.session.add(new_record)
        db.session.commit()

        # Figure out which records have been edited since the last time we ran
        # this rule
        try:
            recids = zip(*db.session.query(CheckerRecord.rec_id).outerjoin(
                RecordMetadata).filter(
                    CheckerRecord.rec_id.in_(requested_ids),
                    CheckerRecord.rule_name == self.name,
                    db.or_(
                        self.force_run_on_unmodified_records,
                        db.or_(
                            CheckerRecord.last_run_version_id == 1,
                            CheckerRecord.last_run_version_id <
                            RecordMetadata.version_id,
                        ),
                    )))[0]
        except IndexError:
            recids = set()
        return intbitset(recids)

    @session_manager
    def mark_recids_as_checked(self, recids):
        """Mark the given recids as checked by this task at their current `version_id`."""
        db.session.query(CheckerRecord).\
            filter(
                CheckerRecord.rec_id == RecordMetadata.id,
                CheckerRecord.rule_name == self.name,
                CheckerRecord.rec_id.in_(recids),
            ).\
            update({"last_run_version_id": RecordMetadata.version_id},
                   synchronize_session=False)

    @db.hybrid_property
    def requested_recids(self):
        """Search given `self.filter_pattern` and `self.filter_records`.

        :rtype: intbitset"""
        # TODO: Use self.option_consider_deleted_records when it's available
        pattern = self.filter_pattern or ''
        recids = Query(pattern).search().recids

        if self.filter_records is not None:
            recids &= self.filter_records

        return recids

    def __str__(self):
        name_len = len(self.name)
        trails = 61 - name_len
        return '\n'.join((
            '=== Checker Task: {} {}'.format(self.name, trails * '='),
            '* Name: {}'.format(self.name),
            '* Plugin: {}'.format(self.plugin),
            '* Arguments: {}'.format(self.arguments),
            '* Consider deleted records: {}'.format(
                self.consider_deleted_records),
            '* Filter Pattern: {}'.format(self.filter_pattern),
            '* Filter Records: {}'.format(ranges_str(self.filter_records)),
            '* Last scheduled run: {}'.format(self.last_scheduled_run),
            '* Schedule: {} [{}]'.format(
                self.schedule,
                'enabled' if self.schedule_enabled else 'disabled'),
            '* Temporary: {}'.format(self.temporary),
            '* Force-run on unmodified records: {}'.format(
                self.force_run_on_unmodified_records),
            '{}'.format(80 * '='),
        ))

    @staticmethod
    def update_time(mapper, connection, instance):
        """Update the `last_modification_date` to the current time."""
        instance.last_modification_date = datetime.now()
Пример #25
0
class RemoteToken(db.Model):
    """Storage for the access tokens for linked accounts."""

    __tablename__ = 'remoteTOKEN'

    #
    # Fields
    #
    id_remote_account = db.Column(db.Integer(15, unsigned=True),
                                  db.ForeignKey(RemoteAccount.id),
                                  nullable=False,
                                  primary_key=True)
    """Foreign key to account."""

    token_type = db.Column(db.String(40),
                           default='',
                           nullable=False,
                           primary_key=True)
    """Type of token."""

    access_token = db.Column(TextEncryptedType(type_in=db.Text,
                                               key=secret_key),
                             nullable=False)
    """Access token to remote application."""

    secret = db.Column(db.Text(), default='', nullable=False)
    """Used only by OAuth 1."""
    def token(self):
        """Get token as expected by Flask-OAuthlib."""
        return (self.access_token, self.secret)

    def update_token(self, token, secret):
        """Update token with new values."""
        if self.access_token != token or self.secret != secret:
            self.access_token = token
            self.secret = secret
            db.session.commit()

    @classmethod
    def get(cls, user_id, client_id, token_type='', access_token=None):
        """Get RemoteToken for user."""
        args = [
            RemoteAccount.id == RemoteToken.id_remote_account,
            RemoteAccount.user_id == user_id,
            RemoteAccount.client_id == client_id,
            RemoteToken.token_type == token_type,
        ]

        if access_token:
            args.append(RemoteToken.access_token == access_token)

        return cls.query.options(
            db.joinedload('remote_account')).filter(*args).first()

    @classmethod
    def get_by_token(cls, client_id, access_token, token_type=''):
        """Get RemoteAccount object for token."""
        return cls.query.options(db.joinedload('remote_account')).filter(
            RemoteAccount.id == RemoteToken.id_remote_account,
            RemoteAccount.client_id == client_id,
            RemoteToken.token_type == token_type,
            RemoteToken.access_token == access_token,
        ).first()

    @classmethod
    def create(cls,
               user_id,
               client_id,
               token,
               secret,
               token_type='',
               extra_data=None):
        """Create a new access token.

        Creates RemoteAccount as well if it does not exists.
        """
        account = RemoteAccount.get(user_id, client_id)

        if account is None:
            account = RemoteAccount(
                user_id=user_id,
                client_id=client_id,
                extra_data=extra_data or dict(),
            )
            db.session.add(account)

        token = cls(
            token_type=token_type,
            remote_account=account,
            access_token=token,
            secret=secret,
        )
        db.session.add(token)
        db.session.commit()
        return token
Пример #26
0
class RemoteAccount(db.Model):
    """Storage for remote linked accounts."""

    __tablename__ = 'remoteACCOUNT'

    __table_args__ = (db.UniqueConstraint('user_id', 'client_id'),
                      db.Model.__table_args__)

    #
    # Fields
    #
    id = db.Column(db.Integer(15, unsigned=True),
                   primary_key=True,
                   autoincrement=True)
    """Primary key."""

    user_id = db.Column(db.Integer(15, unsigned=True),
                        db.ForeignKey(User.id),
                        nullable=False)
    """Local user linked with a remote app via the access token."""

    client_id = db.Column(db.String(255), nullable=False)
    """Client ID of remote application (defined in OAUTHCLIENT_REMOTE_APPS)."""

    extra_data = db.Column(MutableDict.as_mutable(db.JSON), nullable=False)
    """Extra data associated with this linked account."""

    #
    # Relationships propoerties
    #
    user = db.relationship('User')
    """SQLAlchemy relationship to user."""

    tokens = db.relationship(
        "RemoteToken",
        backref="remote_account",
    )
    """SQLAlchemy relationship to RemoteToken objects."""
    @classmethod
    def get(cls, user_id, client_id):
        """Get RemoteAccount object for user.

        :param user_id: User id
        :param client_id: Client id.
        """
        return cls.query.filter_by(
            user_id=user_id,
            client_id=client_id,
        ).first()

    @classmethod
    def create(cls, user_id, client_id, extra_data):
        """Create new remote account for user.

        :param user_id: User id.
        :param client_id: Client id.
        :param extra_data: JSON-serializable dictionary of any extra data that
                           needs to be save together with this link.
        """
        account = cls(user_id=user_id,
                      client_id=client_id,
                      extra_data=extra_data or dict())
        db.session.add(account)
        db.session.commit()
        return account

    def delete(self):
        """Delete remote account together with all stored tokens."""
        RemoteToken.query.filter_by(id_remote_account=self.id).delete()
        db.session.delete(self)
        db.session.commit()
Пример #27
0
class PersistentIdentifier(db.Model):

    """Store and register persistent identifiers.

    Assumptions:
      * Persistent identifiers can be represented as a string of max 255 chars.
      * An object has many persistent identifiers.
      * A persistent identifier has one and only one object.
    """

    __tablename__ = 'pidSTORE'
    __table_args__ = (
        db.Index('uidx_type_pid', 'pid_type', 'pid_value', unique=True),
        db.Index('idx_status', 'status'),
        db.Index('idx_object', 'object_type', 'object_value'),
    )

    id = db.Column(db.Integer(15, unsigned=True), primary_key=True)
    """Id of persistent identifier entry."""

    pid_type = db.Column(db.String(6), nullable=False)
    """Persistent Identifier Schema."""

    pid_value = db.Column(db.String(length=255), nullable=False)
    """Persistent Identifier."""

    pid_provider = db.Column(db.String(length=255), nullable=False)
    """Persistent Identifier Provider"""

    status = db.Column(db.CHAR(length=1), nullable=False)
    """Status of persistent identifier, e.g. registered, reserved, deleted."""

    object_type = db.Column(db.String(3), nullable=True)
    """Object Type - e.g. rec for record."""

    object_value = db.Column(db.String(length=255), nullable=True)
    """Object ID - e.g. a record id."""

    created = db.Column(db.DateTime(), nullable=False, default=datetime.now)
    """Creation datetime of entry."""

    last_modified = db.Column(
        db.DateTime(), nullable=False, default=datetime.now,
        onupdate=datetime.now
    )
    """Last modification datetime of entry."""

    #
    # Class methods
    #
    @classmethod
    def create(cls, pid_type, pid_value, pid_provider='', provider=None):
        """Internally reserve a new persistent identifier.

        A provider for the given persistent identifier type must exists. By
        default the system will choose a provider according to the pid
        type. If desired, the default system provider can be overridden via
        the provider keyword argument.

        Return PID object if successful otherwise None.
        """
        # Ensure provider exists
        if provider is None:
            provider = PidProvider.create(pid_type, pid_value, pid_provider)
            if not provider:
                raise Exception(
                    "No provider found for %s:%s (%s)" % (
                        pid_type, pid_value, pid_provider)
                )

        db.session.begin_nested()
        try:
            obj = cls(pid_type=provider.pid_type,
                      pid_value=provider.create_new_pid(pid_value),
                      pid_provider=pid_provider,
                      status=cfg['PIDSTORE_STATUS_NEW'])
            obj._provider = provider
            db.session.add(obj)
            obj.log("CREATE", "Created")
            return obj
        except SQLAlchemyError:
            db.session.rollback()
            obj.log("CREATE", "Failed to created. Already exists.")
            return None

    @classmethod
    def get(cls, pid_type, pid_value, pid_provider='', provider=None):
        """Get persistent identifier.

        Return None if not found.
        """
        pid_value = to_unicode(pid_value)
        obj = cls.query.filter_by(
            pid_type=pid_type, pid_value=pid_value, pid_provider=pid_provider
        ).first()
        if obj:
            obj._provider = provider
            return obj
        else:
            return None

    #
    # Instance methods
    #
    def has_object(self, object_type, object_value):
        """Determine if this PID is assigned to a specific object."""
        if object_type not in cfg['PIDSTORE_OBJECT_TYPES']:
            raise Exception("Invalid object type %s." % object_type)

        object_value = to_unicode(object_value)

        return self.object_type == object_type and \
            self.object_value == object_value

    def get_provider(self):
        """Get the provider for this type of persistent identifier."""
        if self._provider is None:
            self._provider = PidProvider.create(
                self.pid_type, self.pid_value, self.pid_provider
            )
        return self._provider

    def assign(self, object_type, object_value, overwrite=False):
        """Assign this persistent identifier to a given object.

        Note, the persistent identifier must first have been reserved. Also,
        if an exsiting object is already assigned to the pid, it will raise an
        exception unless overwrite=True.
        """
        if object_type not in cfg['PIDSTORE_OBJECT_TYPES']:
            raise Exception("Invalid object type %s." % object_type)
        object_value = to_unicode(object_value)

        if not self.id:
            raise Exception(
                "You must first create the persistent identifier before you "
                "can assign objects to it."
            )

        if self.is_deleted():
            raise Exception(
                "You cannot assign objects to a deleted persistent identifier."
            )

        with db.session.begin_nested():
            # Check for an existing object assigned to this pid
            existing_obj_id = self.get_assigned_object(object_type)

            if existing_obj_id and existing_obj_id != object_value:
                if not overwrite:
                    raise Exception(
                        "Persistent identifier is already assigned to another "
                        "object"
                    )
                else:
                    self.log(
                        "ASSIGN",
                        "Unassigned object %s:%s (overwrite requested)" % (
                            self.object_type, self.object_value)
                    )
                    self.object_type = None
                    self.object_value = None
            elif existing_obj_id and existing_obj_id == object_value:
                # The object is already assigned to this pid.
                return True

            self.object_type = object_type
            self.object_value = object_value
            db.session.commit()
            self.log("ASSIGN", "Assigned object {0}:{1}".format(
                self.object_type, self.object_value
            ))
            return True

    def update(self, with_deleted=False, *args, **kwargs):
        """Update the persistent identifier with the provider."""
        if self.is_new() or self.is_reserved():
            raise Exception(
                "Persistent identifier has not yet been registered."
            )

        if not with_deleted and self.is_deleted():
            raise Exception("Persistent identifier has been deleted.")

        with db.session.begin_nested():
            provider = self.get_provider()
            if provider is None:
                self.log("UPDATE", "No provider found.")
                raise Exception("No provider found.")

            if provider.update(self, *args, **kwargs):
                if with_deleted and self.is_deleted():
                    self.status = cfg['PIDSTORE_ST):TUS_REGISTERED']
                return True
        return False

    def reserve(self, *args, **kwargs):
        """Reserve the persistent identifier with the provider.

        Note, the reserve method may be called multiple times, even if it was
        already reserved.
        """
        if not (self.is_new() or self.is_reserved()):
            raise Exception(
                "Persistent identifier has already been registered."
            )

        with db.session.begin_nested():
            provider = self.get_provider()
            if provider is None:
                self.log("RESERVE", "No provider found.")
                raise Exception("No provider found.")

            if provider.reserve(self, *args, **kwargs):
                self.status = cfg['PIDSTORE_STATUS_RESERVED']
                return True
        return False

    def register(self, *args, **kwargs):
        """Register the persistent identifier with the provider."""
        if self.is_registered() or self.is_deleted():
            raise Exception(
                "Persistent identifier has already been registered."
            )

        with db.session.begin_nested():
            provider = self.get_provider()
            if provider is None:
                self.log("REGISTER", "No provider found.")
                raise Exception("No provider found.")

            if provider.register(self, *args, **kwargs):
                self.status = cfg['PIDSTORE_STATUS_REGISTERED']
                return True
        return False

    def delete(self, *args, **kwargs):
        """Delete the persistent identifier."""
        with db.session.begin_nested():
            if self.is_new():
                # New persistent identifier which haven't been registered yet.
                # Just delete it completely but keep log)
                # Remove links to log entries (leave the otherwise)
                PidLog.query.filter_by(id_pid=self.id).update({'id_pid': None})
                db.session.delete(self)
                self.log("DELETE", "Unregistered PID successfully deleted")
            else:
                provider = self.get_provider()
                if not provider.delete(self, *args, **kwargs):
                    return False
                self.status = cfg['PIDSTORE_STATUS_DELETED']
            return True

    def sync_status(self, *args, **kwargs):
        """Synchronize persistent identifier status.

        Used when the provider uses an external service, which might have been
        modified outside of our system.
        """
        with db.session.begin_nested():
            provider = self.get_provider()
            result = provider.sync_status(self, *args, **kwargs)
            return result

    def get_assigned_object(self, object_type=None):
        """Return an assigned object."""
        if object_type is not None and self.object_type == object_type:
            return self.object_value
        return None

    def is_registered(self):
        """Return true if the persistent identifier has been registered."""
        return self.status == cfg['PIDSTORE_STATUS_REGISTERED']

    def is_deleted(self):
        """Return true if the persistent identifier has been deleted."""
        return self.status == cfg['PIDSTORE_STATUS_DELETED']

    def is_new(self):
        """Return true if the PIDhas not yet been registered or reserved."""
        return self.status == cfg['PIDSTORE_STATUS_NEW']

    def is_reserved(self):
        """Return true if the PID has not yet been reserved."""
        return self.status == cfg['PIDSTORE_STATUS_RESERVED']

    def log(self, action, message):
        """Store action and message in log."""
        if self.pid_type and self.pid_value:
            message = "[%s:%s] %s" % (self.pid_type, self.pid_value, message)
        with db.session.begin_nested():
            p = PidLog(id_pid=self.id, action=action, message=message)
            db.session.add(p)
Пример #28
0
class Token(db.Model):

    """A bearer token is the final token that can be used by the client."""

    __tablename__ = 'oauth2TOKEN'

    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    """Object ID."""

    client_id = db.Column(
        db.String(255), db.ForeignKey('oauth2CLIENT.client_id'),
        nullable=False,
    )
    """Foreign key to client application."""

    client = db.relationship(
        'Client',
        backref=db.backref(
            'oauth2tokens',
            cascade="all, delete-orphan"
        ))
    """SQLAlchemy relationship to client application."""

    user_id = db.Column(
        db.Integer(15, unsigned=True), db.ForeignKey('user.id'), nullable=True
    )
    """Foreign key to user."""

    user = db.relationship(
        User,
        backref=db.backref(
            "oauth2tokens",
            cascade="all, delete-orphan",
        )
    )
    """SQLAlchemy relationship to user."""

    token_type = db.Column(db.String(255), default='bearer')
    """Token type - only bearer is supported at the moment."""

    access_token = db.Column(String255EncryptedType(
        type_in=db.String(255),
        key=secret_key),
        unique=True
    )

    refresh_token = db.Column(String255EncryptedType(
        type_in=db.String(255),
        key=secret_key,
        engine=NoneAesEngine),
        unique=True, nullable=True
    )

    expires = db.Column(db.DateTime, nullable=True)

    _scopes = db.Column(db.Text)

    is_personal = db.Column(db.Boolean, default=False)
    """Personal accesss token."""

    is_internal = db.Column(db.Boolean, default=False)
    """Determines if token is an internally generated token."""

    @property
    def scopes(self):
        """Return all scopes."""
        if self._scopes:
            return self._scopes.split()
        return []

    @scopes.setter
    def scopes(self, scopes):
        """Set scopes."""
        validate_scopes(scopes)
        self._scopes = " ".join(set(scopes)) if scopes else ""

    def get_visible_scopes(self):
        """Get list of non-internal scopes for token."""
        from .registry import scopes as scopes_registry
        return [k for k, s in scopes_registry.choices() if k in self.scopes]

    @classmethod
    def create_personal(cls, name, user_id, scopes=None, is_internal=False):
        """Create a personal access token.

        A token that is bound to a specific user and which doesn't expire, i.e.
        similar to the concept of an API key.
        """
        scopes = " ".join(scopes) if scopes else ""

        c = Client(
            name=name,
            user_id=user_id,
            is_internal=True,
            is_confidential=False,
            _default_scopes=scopes
        )
        c.gen_salt()

        t = Token(
            client_id=c.client_id,
            user_id=user_id,
            access_token=gen_salt(
                current_app.config.get('OAUTH2_TOKEN_PERSONAL_SALT_LEN')
            ),
            expires=None,
            _scopes=scopes,
            is_personal=True,
            is_internal=is_internal,
        )

        db.session.add(c)
        db.session.add(t)
        db.session.commit()

        return t