class Bibfmt(db.Model):
    """Represent a Bibfmt record."""

    __tablename__ = 'bibfmt'

    id_bibrec = db.Column(
        db.MediumInteger(8, unsigned=True),
        nullable=False,
        server_default='0',
        primary_key=True,
        autoincrement=False)

    format = db.Column(
        db.String(10),
        nullable=False,
        server_default='',
        primary_key=True,
        index=True)

    kind = db.Column(
        db.String(10),
        nullable=False,
        server_default='',
        index=True
    )

    last_updated = db.Column(
        db.DateTime,
        nullable=False,
        server_default='1900-01-01 00:00:00',
        index=True)

    value = db.Column(db.iLargeBinary)

    needs_2nd_pass = db.Column(db.TinyInteger(1), server_default='0')
Esempio n. 2
0
def do_upgrade():
    """Upgrade recipe.

    Adds two new columns (password_salt and password_scheme) and migrates
    emails to password salt.
    """
    op.add_column(
        'user', db.Column('password_salt',
                          db.String(length=255),
                          nullable=True))
    op.add_column(
        'user',
        db.Column('password_scheme', db.String(length=50), nullable=False))

    # Temporary column needed for data migration
    op.add_column('user', db.Column('new_password', db.String(length=255)))

    # Migrate emails to password_salt
    m = db.MetaData(bind=db.engine)
    m.reflect()
    u = m.tables['user']

    conn = db.engine.connect()
    conn.execute(
        u.update().values(password_salt=u.c.email,
                          password_scheme='invenio_aes_encrypted_email'))

    # Migrate password blob to password varchar.
    for row in conn.execute(select([u])):
        # NOTE: Empty string passwords were stored as empty strings
        # instead of a hashed version, hence they must be treated differently.
        legacy_pw = row[u.c.password] or mysql_aes_encrypt(row[u.c.email], "")

        stmt = u.update().where(u.c.id == row[u.c.id]).values(
            new_password=hashlib.sha256(legacy_pw).hexdigest())
        conn.execute(stmt)

    # Create index
    op.create_index(op.f('ix_user_password_scheme'),
                    'user', ['password_scheme'],
                    unique=False)

    # Drop old database column and rename new.
    op.drop_column('user', 'password')
    op.alter_column(
        'user',
        'new_password',
        new_column_name='password',
        existing_type=mysql.VARCHAR(255),
        existing_nullable=True,
    )
Esempio n. 3
0
class AccARGUMENT(db.Model):

    """Represent an authorization argument."""

    __tablename__ = 'accARGUMENT'
    id = db.Column(db.Integer(15), primary_key=True, autoincrement=True)
    keyword = db.Column(db.String(32), nullable=True)
    value = db.Column(db.String(255), nullable=True)
    __table_args__ = (db.Index('KEYVAL', keyword, value),
                      db.Model.__table_args__)

    def __repr__(self):
        """Repr."""
        return "{0.keyword}={0.value}".format(self)
Esempio n. 4
0
class AccROLE(db.Model):

    """Represent an access role."""

    __tablename__ = 'accROLE'
    id = db.Column(db.Integer(15, unsigned=True), primary_key=True,
                   autoincrement=True)
    name = db.Column(db.String(32), unique=True, nullable=True)
    description = db.Column(db.String(255), nullable=True)
    firerole_def_ser = db.Column(db.iBinary, nullable=True)
    firerole_def_src = db.Column(db.Text, nullable=True)

    def __repr__(self):
        """Repr."""
        return "{0.name} - {0.description}".format(self)
Esempio n. 5
0
class Session(db.Model):
    """Represent Session record."""

    __tablename__ = 'session'
    session_key = db.Column(db.String(32),
                            nullable=False,
                            server_default='',
                            primary_key=True)
    session_expiry = db.Column(db.DateTime, nullable=True, index=True)
    session_object = db.Column(db.LargeBinary, nullable=True)
    uid = db.Column(db.Integer(15, unsigned=True), nullable=False, index=True)

    def get_session(self, name, expired=False):
        """Return an instance of :class:`Session`."""
        where = Session.session_key == name
        if expired:
            where = db.and_(
                where, Session.session_expiry >= db.func.current_timestamp())
        return self.query.filter(where).one()

    def set_session(self, name, value, timeout=None):
        """Store value in database."""
        uid = current_user.get_id()
        session_expiry = datetime.utcnow() + timeout
        return Session(session_key=name,
                       session_object=value,
                       session_expiry=session_expiry,
                       uid=uid)
Esempio n. 6
0
class UserQuery(db.Model):
    """Represent a UserQuery record."""

    __tablename__ = 'user_query'
    id_user = db.Column(db.Integer(15, unsigned=True),
                        db.ForeignKey(User.id),
                        primary_key=True,
                        server_default='0')
    id_query = db.Column(db.Integer(15, unsigned=True),
                         db.ForeignKey(WebQuery.id),
                         primary_key=True,
                         index=True,
                         server_default='0')
    hostname = db.Column(db.String(50),
                         nullable=True,
                         server_default='unknown host')
    date = db.Column(db.DateTime, nullable=True, default=datetime.datetime.now)

    webquery = db.relationship(WebQuery, backref='executions')

    @classmethod
    def log(cls, urlargs=None, id_user=None):
        """Log."""
        id_user = id_user if not None else current_user.get_id()
        urlargs = urlargs or request.query_string
        if id_user < 0:
            return
        webquery = WebQuery.query.filter_by(urlargs=urlargs).first()
        if webquery is None:
            webquery = WebQuery(urlargs=urlargs)
        db.session.add(
            cls(id_user=id_user, hostname=request.host, webquery=webquery))
        db.session.commit()
Esempio n. 7
0
class CheckerRecord(db.Model):
    """Connect checks with their executions on records."""

    __tablename__ = 'checker_record'

    rec_id = db.Column(
        db.MediumInteger(8, unsigned=True),
        db.ForeignKey(RecordMetadata.id),
        primary_key=True,
        nullable=False,
        autoincrement=True,
    )
    record = db.relationship(
        RecordMetadata,
        backref=backref("checker_record", cascade="all, delete-orphan"),
        doc="The record associated with a task.",
    )

    rule_name = db.Column(db.String(127),
                          db.ForeignKey('checker_rule.name'),
                          nullable=False,
                          index=True,
                          primary_key=True,
                          doc="Name of the task in this associaton.")

    last_run_version_id = db.Column(
        db.Integer,
        nullable=False,
        doc="Last checked version ID of associated record.",
    )
Esempio n. 8
0
class PidLog(db.Model):

    """Audit log of actions happening to persistent identifiers.

    This model is primarily used through PersistentIdentifier.log and rarely
    created manually.
    """

    __tablename__ = 'pidLOG'
    __table_args__ = (
        db.Index('idx_action', 'action'),
    )

    id = db.Column(db.Integer(15, unsigned=True), primary_key=True)
    """Id of persistent identifier entry."""

    id_pid = db.Column(
        db.Integer(15, unsigned=True), db.ForeignKey(PersistentIdentifier.id),
        nullable=True,
    )
    """PID."""

    timestamp = db.Column(db.DateTime(), nullable=False, default=datetime.now)
    """Creation datetime of entry."""

    action = db.Column(db.String(10), nullable=False)
    """Action identifier."""

    message = db.Column(db.Text(), nullable=False)
    """Log message."""

    # Relationship
    pid = db.relationship("PersistentIdentifier", backref="logs")
Esempio n. 9
0
class AccACTION(db.Model):

    """Represent an access action."""

    __tablename__ = 'accACTION'
    id = db.Column(db.Integer(15, unsigned=True),
                   primary_key=True, autoincrement=True)
    name = db.Column(db.String(32), unique=True, nullable=True)
    description = db.Column(db.String(255), nullable=True)
    allowedkeywords = db.Column(db.String(255), nullable=True)
    optional = db.Column(db.Enum('yes', 'no', name='yes_no'), nullable=False,
                         server_default='no')

    def __repr__(self):
        """Repr."""
        return "{0.name}".format(self)
class UserEXT(db.Model):
    """Represent a UserEXT record."""

    __tablename__ = 'userEXT'

    id = db.Column(db.String(255), primary_key=True, nullable=False)
    method = db.Column(db.String(50), primary_key=True, nullable=False)
    id_user = db.Column(db.Integer(15, unsigned=True),
                        db.ForeignKey(User.id),
                        nullable=False)

    user = db.relationship(User, backref="external_identifiers")

    __table_args__ = (db.Index('userext_id_user_method',
                               id_user,
                               method,
                               unique=True), db.Model.__table_args__)
def do_upgrade():
    """ Implement your upgrades here  """
    if not op.has_table('remoteACCOUNT'):
        op.create_table('remoteACCOUNT',
                        db.Column('id',
                                  db.Integer(display_width=15),
                                  nullable=False),
                        db.Column('user_id',
                                  db.Integer(display_width=15),
                                  nullable=False),
                        db.Column('client_id',
                                  db.String(length=255),
                                  nullable=False),
                        db.Column('extra_data', db.JSON, nullable=True),
                        db.ForeignKeyConstraint(
                            ['user_id'],
                            ['user.id'],
                        ),
                        db.PrimaryKeyConstraint('id'),
                        db.UniqueConstraint('user_id', 'client_id'),
                        mysql_charset='utf8',
                        mysql_engine='MyISAM')
    else:
        warnings.warn("*** Creation of table 'remoteACCOUNT table skipped!'")

    if not op.has_table('remoteTOKEN'):
        op.create_table('remoteTOKEN',
                        db.Column('id_remote_account',
                                  db.Integer(display_width=15),
                                  nullable=False),
                        db.Column('token_type',
                                  db.String(length=40),
                                  nullable=False),
                        db.Column('access_token', db.Text(), nullable=False),
                        db.Column('secret', db.Text(), nullable=False),
                        db.ForeignKeyConstraint(
                            ['id_remote_account'],
                            ['remoteACCOUNT.id'],
                        ),
                        db.PrimaryKeyConstraint('id_remote_account',
                                                'token_type'),
                        mysql_charset='utf8',
                        mysql_engine='MyISAM')
    else:
        warnings.warn("*** Creation of table 'remoteTOKEN' skipped!'")
Esempio n. 12
0
def do_upgrade():
    """Carry out the upgrade."""
    op.alter_column(table_name='oauth2TOKEN',
                    column_name='client_id',
                    type_=db.String(255),
                    existing_nullable=False)
    op.alter_column(table_name='oauth2TOKEN',
                    column_name='user_id',
                    type_=db.Integer(15, unsigned=True),
                    existing_nullable=False)
Esempio n. 13
0
class Fieldname(db.Model):
    """Represent a Fieldname record."""

    __tablename__ = 'fieldname'
    id_field = db.Column(db.MediumInteger(9, unsigned=True),
                         db.ForeignKey(Field.id),
                         primary_key=True)
    ln = db.Column(db.Char(5), primary_key=True, server_default='')
    type = db.Column(db.Char(3), primary_key=True, server_default='sn')
    value = db.Column(db.String(255), nullable=False)
    field = db.relationship(Field, backref='names')
Esempio n. 14
0
class Fieldvalue(db.Model):
    """Represent a Fieldvalue record."""
    def __init__(self):
        """Init."""
        pass

    __tablename__ = 'fieldvalue'
    id = db.Column(db.MediumInteger(9, unsigned=True),
                   primary_key=True,
                   autoincrement=True)
    name = db.Column(db.String(255), nullable=False)
    value = db.Column(db.Text, nullable=False)
Esempio n. 15
0
class Goto(db.Model):

    """Represents a Goto record."""

    __tablename__ = 'goto'
    label = db.Column(db.String(150), primary_key=True)
    plugin = db.Column(db.String(150), nullable=False)
    _parameters = db.Column(db.JSON, nullable=False, default={},
                            name="parameters")
    creation_date = db.Column(db.DateTime, default=datetime.datetime.now,
                              nullable=False, index=True)
    modification_date = db.Column(db.DateTime, default=datetime.datetime.now,
                                  onupdate=datetime.datetime.now,
                                  nullable=False, index=True)

    @validates('plugin')
    def validate_plugin(self, key, plugin):
        """Validate plugin name."""
        if plugin not in redirect_methods:
            raise ValueError("%s plugin does not exist" % plugin)

        return plugin

    @db.hybrid_property
    def parameters(self):
        """Get parameters method."""
        return self._parameters

    @parameters.setter
    def parameters(self, value):
        """Set parameters method."""
        self._parameters = value or {}

    def to_dict(self):
        """Return a dict representation of Goto."""
        return {'label': self.label,
                'plugin': self.plugin,
                'parameters': self.parameters,
                'creation_date': self.creation_date,
                'modification_date': self.modification_date}
Esempio n. 16
0
class Field(db.Model):
    """Represent a Field record."""
    def __repr__(self):
        """Get repr."""
        return "%s(%s)" % (self.__class__.__name__, self.id)

    __tablename__ = 'field'
    id = db.Column(db.MediumInteger(9, unsigned=True), primary_key=True)
    name = db.Column(db.String(255), nullable=False)
    code = db.Column(db.String(255), unique=True, nullable=False)

    @property
    def name_ln(self):
        """Get name ln."""
        from .cache import get_field_i18nname
        return get_field_i18nname(self.name,
                                  getattr(g, 'ln', cfg['CFG_SITE_LANG']))
        # try:
        #    return db.object_session(self).query(Fieldname).\
        #        with_parent(self).filter(db.and_(Fieldname.ln==g.ln,
        #            Fieldname.type=='ln')).first().value
        # except Exception:
        #    return self.name

    @classmethod
    def get_field_name(cls, code):
        """Return field name for given code."""
        return cls.query.filter_by(code=code).value(cls.name)

    @classmethod
    def get_field_tags(cls, code, tagtype='marc'):
        """Yield tag values for given field code."""
        column = Tag.value if tagtype == 'marc' else Tag.recjson_value
        tags = cls.query.join(cls.tags).join(
            FieldTag.tag).filter(cls.code == code).values(column)
        for tag in tags:
            for value in tag[0].split(','):
                yield value.strip()
Esempio n. 17
0
class CheckerReporter(db.Model):
    """Represent reporters associated with a task.

    ..note::
        These entries are currently not meant to be associated with multiple
        tasks, because it is assumed that they may be deleted without affecting
        more than one tasks.
    """

    __tablename__ = 'checker_reporter'

    plugin = db.Column(db.String(127),
                       primary_key=True,
                       doc="Check associated with this reporter."
                       "")

    rule_name = db.Column(db.String(127),
                          db.ForeignKey('checker_rule.name',
                                        onupdate="CASCADE",
                                        ondelete="CASCADE"),
                          index=True,
                          nullable=False,
                          primary_key=True,
                          doc="Task associated with this reporter."
                          "")

    arguments = db.Column(
        JsonEncodedDict(1023),
        default={},
        doc="Arguments to be passed to this reporter.",
    )

    @db.hybrid_property
    def module(self):
        """Python module of the associated check."""
        return reporters_files[self.plugin]
def do_upgrade():
    """Add the table with facets configuration."""
    if not op.has_table('facet_collection'):
        op.create_table(
            'facet_collection',
            db.Column('id', mysql.INTEGER(), nullable=False),
            db.Column('id_collection', mysql.INTEGER(), nullable=False),
            db.Column('order', mysql.INTEGER(), nullable=False),
            db.Column('facet_name', db.String(length=80), nullable=False),
            db.ForeignKeyConstraint(['id_collection'], ['collection.id'], ),
            db.PrimaryKeyConstraint('id'),
            mysql_charset='utf8',
            mysql_engine='MyISAM'
        )
    else:
        warnings.warn("*** Creation of table 'facet_collection' skipped!")
Esempio n. 19
0
class Tag(db.Model):
    """Represent a Tag record."""

    __tablename__ = 'tag'
    id = db.Column(db.MediumInteger(9, unsigned=True), primary_key=True)
    name = db.Column(db.String(255), nullable=False)
    value = db.Column(db.Char(6), nullable=False, server_default='')
    recjson_value = db.Column(db.Text, nullable=False)

    def __init__(self, tup=None, *args, **kwargs):
        """Init."""
        if tup is not None and isinstance(tup, tuple):
            self.name, self.value = tup
            super(Tag, self).__init__(*args, **kwargs)
        else:
            if tup is None:
                super(Tag, self).__init__(*args, **kwargs)
            else:
                super(Tag, self).__init__(tup, *args, **kwargs)

    @property
    def as_tag(self):
        """Return tupple with name and value."""
        return self.name, self.value
Esempio n. 20
0
class KnwKBRVAL(db.Model):
    """Represent a KnwKBRVAL record."""

    __tablename__ = 'knwKBRVAL'
    m_key = db.Column(db.String(255),
                      nullable=False,
                      primary_key=True,
                      index=True)
    m_value = db.Column(db.Text().with_variant(mysql.TEXT(30), 'mysql'),
                        nullable=False)
    id_knwKB = db.Column(db.MediumInteger(8, unsigned=True),
                         db.ForeignKey(KnwKB.id),
                         nullable=False,
                         server_default='0',
                         primary_key=True)
    kb = db.relationship(
        KnwKB,
        backref=db.backref(
            'kbrvals',
            cascade="all, delete-orphan",
            collection_class=attribute_mapped_collection("m_key")))

    @staticmethod
    def query_kb_mappings(kbid, sortby="to", key="", value="", match_type="s"):
        """Return a list of all mappings from the given kb, ordered by key.

        If key given, give only those with left side (mapFrom) = key.
        If value given, give only those with right side (mapTo) = value.

        :param kb_name: knowledge base name. if "", return all
        :param sortby: the sorting criteria ('from' or 'to')
        :param key: return only entries where key matches this
        :param value: return only entries where value matches this
        :param match_type: s=substring, e=exact, sw=startswith
        """
        # query
        query = KnwKBRVAL.query.filter(KnwKBRVAL.id_knwKB == kbid)
        # filter
        if len(key) > 0:
            if match_type == "s":
                key = "%" + key + "%"
            elif match_type == "sw":
                key = key + "%"
        else:
            key = '%'
        if len(value) > 0:
            if match_type == "s":
                value = "%" + value + "%"
            elif match_type == "sw":
                value = value + "%"
        else:
            value = '%'
        query = query.filter(KnwKBRVAL.m_key.like(key),
                             KnwKBRVAL.m_value.like(value))
        # order by
        if sortby == "from":
            query = query.order_by(KnwKBRVAL.m_key)
        else:
            query = query.order_by(KnwKBRVAL.m_value)
        return query

    def to_dict(self):
        """Return a dict representation of KnwKBRVAL."""
        # FIXME remove 'id' dependency from invenio modules
        return {
            'id': self.m_key + "_" + str(self.id_knwKB),
            'key': self.m_key,
            'value': self.m_value,
            'kbid': self.kb.id if self.kb else None,
            'kbname': self.kb.name if self.kb else None
        }
Esempio n. 21
0
class CheckerRule(db.Model):
    """Represent runnable rules (also known as tasks)."""

    __tablename__ = 'checker_rule'

    name = db.Column(
        db.String(127),
        primary_key=True,
        doc="Name of the rule. Must be unique and user-friendly.",
    )

    plugin = db.Column(
        db.String(127),
        nullable=False,
        doc="Check to use. Must be importable string. Does not need to exist"
        " at task insertion time.",
    )

    arguments = db.Column(
        JsonEncodedDict(1023),
        default={},
        doc="Arguments to pass to the check.",
    )

    # XXX: Currently unsupported by search. Disabled elsewhere in the code.
    consider_deleted_records = db.Column(
        db.Boolean,
        nullable=True,
        default=False,
        doc="Whether to consider deleted records while filtering.",
    )

    filter_pattern = db.Column(
        db.String(255),
        nullable=True,
        doc="String pattern to search with to resolve records to check.",
    )

    filter_records = db.Column(
        IntBitSetType(1023),
        nullable=True,
        doc="Record IDs to run this task on.",
    )

    records = db.relationship(
        'CheckerRecord',
        backref='rule',
        cascade='all, delete-orphan',
        doc="Records which this rule has worked with in the past.",
    )

    reporters = db.relationship(
        'CheckerReporter',
        backref='rule',
        cascade='all, delete-orphan',
        doc="Reporters to be called while this task executes.",
    )

    executions = db.relationship(
        'CheckerRuleExecution',
        backref='rule',
        cascade='all, delete-orphan',
        doc="Past executions of this task. User should be free to clear them.",
    )

    last_scheduled_run = db.Column(
        db.DateTime(),
        nullable=True,
        doc="Last time this task was ran by the scheduler.",
    )

    schedule = db.Column(
        db.String(255),
        nullable=True,
        doc="Cron-style string that defines the schedule for this task.",
    )

    schedule_enabled = db.Column(
        db.Boolean,
        default=True,
        nullable=False,
        doc="Whether `schedule` is enabled.",
    )

    # TODO: You may use this column as a filter for tasks you don't want to see
    # by default in interfaces
    temporary = db.Column(
        db.Boolean,
        default=False,
        doc="Flag for tasks which will not be reused.",
    )

    force_run_on_unmodified_records = db.Column(
        db.Boolean,
        default=False,
        doc="Force a record-centric task to run on records it has checked"
        " before, even if they have already been checked in their current"
        " version.",
    )

    confirm_hash_on_commit = db.Column(
        db.Boolean,
        default=False,
        doc="Only commit recids whose hash has not changed between first"
        " requested modification and commit time.",
    )

    allow_chunking = db.Column(  # XXX unclear name (maybe "run_in_parallel")
        db.Boolean,
        default=True,
        doc="If the check is record-centric, allow checks to run in parallel.",
    )

    last_modification_date = db.Column(
        db.DateTime(),
        nullable=False,
        server_default='1900-01-01 00:00:00',
        doc="Last date on which this task was modified.",
    )

    owner_id = db.Column(
        db.Integer(15, unsigned=True),
        db.ForeignKey('user.id'),
        nullable=False,
        default=1,
    )
    owner = db.relationship(
        'User',
        doc="User that created this task. Used for scheduled tasks.",
    )

    @db.hybrid_property
    def filepath(self):
        """Resolve a the filepath of this rule's plugin/check file."""
        try:
            path = inspect.getfile(plugin_files[self.plugin])
        except KeyError:
            return None
        if path.endswith('.pyc'):
            path = path[:-1]
        return path

    @db.hybrid_property
    def modified_requested_recids(self):
        """Record IDs of records that match the filters of this task.

        This property takes (0) `requested_ids`, (1) `filter_pattern` and if
        `force_run_on_unmodified_records` is enabled (2)
        `CheckerRecord.last_run_version_id` into consideration to figure out
        which recids a record-centric task should run on.

        :rtype: intbitset
        """
        # Get all records that are already associated to this rule
        # If this is returning an empty set, you forgot to run bibindex
        try:
            associated_records = intbitset(
                zip(*db.session.query(CheckerRecord.rec_id).filter(
                    CheckerRecord.rule_name == self.name).all())[0])
        except IndexError:
            associated_records = intbitset()

        # Store requested records that were until now unknown to this rule
        requested_ids = self.requested_recids
        for requested_id in requested_ids - associated_records:
            new_record = CheckerRecord(rec_id=requested_id,
                                       rule_name=self.name)
            db.session.add(new_record)
        db.session.commit()

        # Figure out which records have been edited since the last time we ran
        # this rule
        try:
            recids = zip(*db.session.query(CheckerRecord.rec_id).outerjoin(
                RecordMetadata).filter(
                    CheckerRecord.rec_id.in_(requested_ids),
                    CheckerRecord.rule_name == self.name,
                    db.or_(
                        self.force_run_on_unmodified_records,
                        db.or_(
                            CheckerRecord.last_run_version_id == 1,
                            CheckerRecord.last_run_version_id <
                            RecordMetadata.version_id,
                        ),
                    )))[0]
        except IndexError:
            recids = set()
        return intbitset(recids)

    @session_manager
    def mark_recids_as_checked(self, recids):
        """Mark the given recids as checked by this task at their current `version_id`."""
        db.session.query(CheckerRecord).\
            filter(
                CheckerRecord.rec_id == RecordMetadata.id,
                CheckerRecord.rule_name == self.name,
                CheckerRecord.rec_id.in_(recids),
            ).\
            update({"last_run_version_id": RecordMetadata.version_id},
                   synchronize_session=False)

    @db.hybrid_property
    def requested_recids(self):
        """Search given `self.filter_pattern` and `self.filter_records`.

        :rtype: intbitset"""
        # TODO: Use self.option_consider_deleted_records when it's available
        pattern = self.filter_pattern or ''
        recids = Query(pattern).search().recids

        if self.filter_records is not None:
            recids &= self.filter_records

        return recids

    def __str__(self):
        name_len = len(self.name)
        trails = 61 - name_len
        return '\n'.join((
            '=== Checker Task: {} {}'.format(self.name, trails * '='),
            '* Name: {}'.format(self.name),
            '* Plugin: {}'.format(self.plugin),
            '* Arguments: {}'.format(self.arguments),
            '* Consider deleted records: {}'.format(
                self.consider_deleted_records),
            '* Filter Pattern: {}'.format(self.filter_pattern),
            '* Filter Records: {}'.format(ranges_str(self.filter_records)),
            '* Last scheduled run: {}'.format(self.last_scheduled_run),
            '* Schedule: {} [{}]'.format(
                self.schedule,
                'enabled' if self.schedule_enabled else 'disabled'),
            '* Temporary: {}'.format(self.temporary),
            '* Force-run on unmodified records: {}'.format(
                self.force_run_on_unmodified_records),
            '{}'.format(80 * '='),
        ))

    @staticmethod
    def update_time(mapper, connection, instance):
        """Update the `last_modification_date` to the current time."""
        instance.last_modification_date = datetime.now()
Esempio n. 22
0
class Record(db.Model):
    """Represent a record object inside the SQL database."""

    __tablename__ = 'bibrec'

    id = db.Column(db.MediumInteger(8, unsigned=True),
                   primary_key=True,
                   nullable=False,
                   autoincrement=True)
    creation_date = db.Column(db.DateTime,
                              nullable=False,
                              server_default='1900-01-01 00:00:00',
                              index=True)
    modification_date = db.Column(db.DateTime,
                                  nullable=False,
                                  server_default='1900-01-01 00:00:00',
                                  index=True)
    master_format = db.Column(db.String(16),
                              nullable=False,
                              server_default='marc')
    additional_info = db.Column(db.JSON)

    # FIXME: remove this from the model and add them to the record class, all?

    @property
    def deleted(self):
        """Return True if record is marked as deleted."""
        from .api import get_record
        dbcollids = [
            c.get('primary') for c in get_record(self.id).get('collections')
        ]

        # record exists; now check whether it isn't marked as deleted:
        return ("DELETED" in dbcollids) or \
               (current_app.config.get('CFG_CERN_SITE') and
                "DUMMY" in dbcollids)

    @staticmethod
    def _next_merged_recid(recid):
        """Return the ID of record merged with record with ID = recid."""
        from .api import get_record
        merged_recid = None
        # FIXME
        for val in get_record(recid).get("970__d", []):
            try:
                merged_recid = int(val)
                break
            except ValueError:
                pass

        if not merged_recid:
            return None
        else:
            return merged_recid

    @cached_property
    def merged_recid(self):
        """Return record object with which the given record has been merged.

        :param recID: deleted record recID
        :return: merged record recID
        """
        return Record._next_merged_recid(self.id)

    @property
    def merged_recid_final(self):
        """Return the last record from hierarchy merged with this one."""
        cur_id = self.id
        next_id = Record._next_merged_recid(cur_id)

        while next_id:
            cur_id = next_id
            next_id = Record._next_merged_recid(cur_id)

        return cur_id

    @classmethod
    def filter_time_interval(cls, datetext, column='c'):
        """Return filter based on date text and column type."""
        column = cls.creation_date if column == 'c' else cls.modification_date
        parts = datetext.split('->')
        where = []
        if len(parts) == 2:
            if parts[0] != '':
                where.append(column >= parts[0])
            if parts[1] != '':
                where.append(column <= parts[1])

        else:
            where.append(column.like(datetext + '%'))
        return where

    @classmethod
    def allids(cls):
        """Return all existing record ids."""
        return intbitset(db.session.query(cls.id).all())
def do_upgrade():
    """Implement your upgrades here."""
    op.add_column(
        'user', db.Column('family_name', db.String(length=255), nullable=True))
    op.add_column(
        'user', db.Column('given_names', db.String(length=255), nullable=True))
class User(db.Model):
    """Represents a User record."""
    def __str__(self):
        """Return string representation."""
        return "%s <%s>" % (self.nickname, self.email)

    __tablename__ = 'user'
    __mapper_args__ = {'confirm_deleted_rows': False}

    id = db.Column(db.Integer(15, unsigned=True),
                   primary_key=True,
                   autoincrement=True)
    email = db.Column(db.String(255),
                      nullable=False,
                      server_default='',
                      index=True)
    _password = db.Column(db.String(255), name="password", nullable=True)
    password_salt = db.Column(db.String(255))
    password_scheme = db.Column(db.String(50), nullable=False, index=True)

    _note = db.Column(db.String(255), name="note", nullable=True)
    given_names = db.Column(db.String(255), nullable=False, server_default='')
    family_name = db.Column(db.String(255), nullable=False, server_default='')
    settings = db.Column(db.MutableDict.as_mutable(
        db.MarshalBinary(default_value=get_default_user_preferences,
                         force_type=dict)),
                         nullable=True)
    nickname = db.Column(db.String(255),
                         nullable=False,
                         server_default='',
                         index=True)
    last_login = db.Column(db.DateTime,
                           nullable=False,
                           server_default='1900-01-01 00:00:00')

    PROFILE_FIELDS = ['nickname', 'email', 'family_name', 'given_names']
    """List of fields that can be updated with update_profile."""

    @staticmethod
    def check_nickname(nickname):
        """Check if it's a valid nickname."""
        re_invalid_nickname = re.compile(""".*[,'@]+.*""")
        return bool(nickname) and not nickname.startswith(' ') and \
            not nickname.endswith(' ') and \
            nickname.lower() != 'guest' and \
            not re_invalid_nickname.match(nickname)

    @staticmethod
    def check_email(email):
        """Check if it's a valid email."""
        r = re.compile(r'(.)+\@(.)+\.(.)+')
        return bool(email) and r.match(email) and not email.find(" ") > 0

    @hybrid_property
    def note(self):
        """Return the note."""
        return self._note

    @note.setter
    def note(self, note):
        """Set the note."""
        self._note = str(note)

    @hybrid_property
    def password(self):
        """Return the password."""
        return self._password

    @password.setter
    def password(self, password):
        """Set the password."""
        if password is None:
            # Unusable password.
            self._password = None
            self.password_scheme = ''
        else:
            self._password = password_context.encrypt(password)
            self.password_scheme = password_context.default_scheme()

        # Invenio legacy salt is stored in password_salt, and every new
        # password set will be migrated to new hash not relying on
        # password_salt, thus is force to empty value.
        self.password_salt = ""

    def verify_password(self, password, migrate=False):
        """Verify if password matches the stored password hash."""
        if self.password is None or password is None:
            return False

        # Invenio 1.x legacy needs externally store password salt to compute
        # hash.
        scheme_ctx = {} if \
            self.password_scheme != invenio_aes_encrypted_email.name else \
            {'user': self.password_salt}

        # Verify password
        if not password_context.verify(
                password, self.password, scheme=self.password_scheme, **
                scheme_ctx):
            return False

        # Migrate hash if needed.
        if migrate and password_context.needs_update(self.password):
            self.password = password
            try:
                db.session.commit()
            except Exception:
                db.session.rollback()
                raise

        return True

    def verify_email(self, force=False):
        """Verify email address."""
        if force or self.note == "2":
            if self.note != "2":
                self.note = 2
                try:
                    db.session.commit()
                except Exception:
                    db.session.rollback()
                    raise
            send_account_activation_email(self)
            return True
        return False

    def update_profile(self, data):
        """Update user profile.

        Sends signal to allow other modules to subscribe to changes.
        """
        changed_attrs = {}
        for field in self.PROFILE_FIELDS:
            if field in data and getattr(self, field) != data[field]:
                changed_attrs[field] = getattr(self, field)
                setattr(self, field, data[field])

        if 'email' in changed_attrs:
            self.verify_email(force=True)

        try:
            db.session.commit()
        except Exception:
            db.session.rollback()
            raise
        current_user.reload()
        profile_updated.send(sender=self.id,
                             user=self,
                             changed_attrs=changed_attrs)

        return changed_attrs

    @property
    def guest(self):
        """Return True if the user is a guest."""
        return False if self.email else True

    #
    # Basic functions for user authentification.
    #
    def get_id(self):
        """Return the id."""
        return self.id

    def is_confirmed(self):
        """Return true if accounts has been confirmed."""
        return self.note == "1"

    def is_guest(self):
        """Return if the user is a guest."""
        return self.guest

    def is_authenticated(self):
        """Return True if user is a authenticated user."""
        return True if self.email else False

    def is_active(self):
        """Return True if use is active."""
        return self.note != "0"
Esempio n. 25
0
class AccMAILCOOKIE(db.Model):

    """Represent an email cookie."""

    __tablename__ = 'accMAILCOOKIE'

    AUTHORIZATIONS_KIND = (
        'pw_reset', 'mail_activation', 'role', 'authorize_action',
        'comment_msg', 'generic'
    )

    id = db.Column(db.Integer(15, unsigned=True), primary_key=True,
                   autoincrement=True)
    _data = db.Column('data', db.iBinary, nullable=False)
    expiration = db.Column(db.DateTime, nullable=False,
                           server_default='9999-12-31 23:59:59', index=True)
    kind = db.Column(db.String(32), nullable=False)
    onetime = db.Column(db.TinyInteger(1), nullable=False, server_default='0')
    status = db.Column(db.Char(1), nullable=False, server_default='W')

    @validates('kind')
    def validate_kind(self, key, kind):
        """Validate cookie kind."""
        assert kind in self.AUTHORIZATIONS_KIND
        return kind

    @classmethod
    def get(cls, cookie, delete=False):
        """Get cookie if it is valid."""
        password = cookie[:16]+cookie[-16:]
        cookie_id = int(cookie[16:-16], 16)

        obj, data = db.session.query(
            cls,
            AccMAILCOOKIE._data
        ).filter_by(id=cookie_id).one()
        obj.data = loads(mysql_aes_decrypt(data, password))

        (kind_check, params, expiration, onetime_check) = obj.data
        assert obj.kind in cls.AUTHORIZATIONS_KIND

        if not (obj.kind == kind_check and obj.onetime == onetime_check):
            raise InvenioWebAccessMailCookieError("Cookie is corrupted")
        if obj.status == 'D':
            raise InvenioWebAccessMailCookieDeletedError(
                "Cookie has been deleted")
        if obj.onetime or delete:
            obj.status = 'D'
            db.session.merge(obj)
            db.session.commit()
        return obj

    @classmethod
    def create(cls, kind, params, cookie_timeout=timedelta(days=1),
               onetime=False):
        """Create cookie with given params."""
        expiration = datetime.today() + cookie_timeout
        data = (kind, params, expiration, onetime)
        password = md5(str(random())).hexdigest()
        cookie = cls(
            expiration=expiration,
            kind=kind,
            onetime=int(onetime),
        )
        cookie._data = mysql_aes_encrypt(dumps(data), password)
        db.session.add(cookie)
        db.session.commit()
        db.session.refresh(cookie)
        return password[:16]+hex(cookie.id)[2:-1]+password[-16:]

    @classmethod
    @session_manager
    def gc(cls):
        """Remove expired items."""
        return cls.query.filter(cls.expiration < db.func.now()).delete()
Esempio n. 26
0
class RemoteToken(db.Model):
    """Storage for the access tokens for linked accounts."""

    __tablename__ = 'remoteTOKEN'

    #
    # Fields
    #
    id_remote_account = db.Column(db.Integer(15, unsigned=True),
                                  db.ForeignKey(RemoteAccount.id),
                                  nullable=False,
                                  primary_key=True)
    """Foreign key to account."""

    token_type = db.Column(db.String(40),
                           default='',
                           nullable=False,
                           primary_key=True)
    """Type of token."""

    access_token = db.Column(TextEncryptedType(type_in=db.Text,
                                               key=secret_key),
                             nullable=False)
    """Access token to remote application."""

    secret = db.Column(db.Text(), default='', nullable=False)
    """Used only by OAuth 1."""
    def token(self):
        """Get token as expected by Flask-OAuthlib."""
        return (self.access_token, self.secret)

    def update_token(self, token, secret):
        """Update token with new values."""
        if self.access_token != token or self.secret != secret:
            self.access_token = token
            self.secret = secret
            db.session.commit()

    @classmethod
    def get(cls, user_id, client_id, token_type='', access_token=None):
        """Get RemoteToken for user."""
        args = [
            RemoteAccount.id == RemoteToken.id_remote_account,
            RemoteAccount.user_id == user_id,
            RemoteAccount.client_id == client_id,
            RemoteToken.token_type == token_type,
        ]

        if access_token:
            args.append(RemoteToken.access_token == access_token)

        return cls.query.options(
            db.joinedload('remote_account')).filter(*args).first()

    @classmethod
    def get_by_token(cls, client_id, access_token, token_type=''):
        """Get RemoteAccount object for token."""
        return cls.query.options(db.joinedload('remote_account')).filter(
            RemoteAccount.id == RemoteToken.id_remote_account,
            RemoteAccount.client_id == client_id,
            RemoteToken.token_type == token_type,
            RemoteToken.access_token == access_token,
        ).first()

    @classmethod
    def create(cls,
               user_id,
               client_id,
               token,
               secret,
               token_type='',
               extra_data=None):
        """Create a new access token.

        Creates RemoteAccount as well if it does not exists.
        """
        account = RemoteAccount.get(user_id, client_id)

        if account is None:
            account = RemoteAccount(
                user_id=user_id,
                client_id=client_id,
                extra_data=extra_data or dict(),
            )
            db.session.add(account)

        token = cls(
            token_type=token_type,
            remote_account=account,
            access_token=token,
            secret=secret,
        )
        db.session.add(token)
        db.session.commit()
        return token
Esempio n. 27
0
class RemoteAccount(db.Model):
    """Storage for remote linked accounts."""

    __tablename__ = 'remoteACCOUNT'

    __table_args__ = (db.UniqueConstraint('user_id', 'client_id'),
                      db.Model.__table_args__)

    #
    # Fields
    #
    id = db.Column(db.Integer(15, unsigned=True),
                   primary_key=True,
                   autoincrement=True)
    """Primary key."""

    user_id = db.Column(db.Integer(15, unsigned=True),
                        db.ForeignKey(User.id),
                        nullable=False)
    """Local user linked with a remote app via the access token."""

    client_id = db.Column(db.String(255), nullable=False)
    """Client ID of remote application (defined in OAUTHCLIENT_REMOTE_APPS)."""

    extra_data = db.Column(MutableDict.as_mutable(db.JSON), nullable=False)
    """Extra data associated with this linked account."""

    #
    # Relationships propoerties
    #
    user = db.relationship('User')
    """SQLAlchemy relationship to user."""

    tokens = db.relationship(
        "RemoteToken",
        backref="remote_account",
    )
    """SQLAlchemy relationship to RemoteToken objects."""
    @classmethod
    def get(cls, user_id, client_id):
        """Get RemoteAccount object for user.

        :param user_id: User id
        :param client_id: Client id.
        """
        return cls.query.filter_by(
            user_id=user_id,
            client_id=client_id,
        ).first()

    @classmethod
    def create(cls, user_id, client_id, extra_data):
        """Create new remote account for user.

        :param user_id: User id.
        :param client_id: Client id.
        :param extra_data: JSON-serializable dictionary of any extra data that
                           needs to be save together with this link.
        """
        account = cls(user_id=user_id,
                      client_id=client_id,
                      extra_data=extra_data or dict())
        db.session.add(account)
        db.session.commit()
        return account

    def delete(self):
        """Delete remote account together with all stored tokens."""
        RemoteToken.query.filter_by(id_remote_account=self.id).delete()
        db.session.delete(self)
        db.session.commit()
Esempio n. 28
0
class KnwKB(db.Model):
    """Represent a KnwKB record."""

    KNWKB_TYPES = {
        'written_as': 'w',
        'dynamic': 'd',
        'taxonomy': 't',
    }

    __tablename__ = 'knwKB'
    id = db.Column(db.MediumInteger(8, unsigned=True),
                   nullable=False,
                   primary_key=True,
                   autoincrement=True)
    _name = db.Column(db.String(255),
                      server_default='',
                      unique=True,
                      name="name")
    _description = db.Column(db.Text,
                             nullable=False,
                             name="description",
                             default="")
    _kbtype = db.Column(db.Char(1), nullable=True, default='w', name="kbtype")
    slug = db.Column(db.String(255), unique=True, nullable=False, default="")
    # Enable or disable the access from REST API
    is_api_accessible = db.Column(db.Boolean, default=True, nullable=False)

    @db.hybrid_property
    def name(self):
        """Get name."""
        return self._name

    @name.setter
    def name(self, value):
        """Set name and generate the slug."""
        self._name = value
        # generate slug
        if not self.slug:
            self.slug = KnwKB.generate_slug(value)

    @db.hybrid_property
    def description(self):
        """Get description."""
        return self._description

    @description.setter
    def description(self, value):
        """Set description."""
        # TEXT in mysql don't support default value
        # @see http://bugs.mysql.com/bug.php?id=21532
        self._description = value or ''

    @db.hybrid_property
    def kbtype(self):
        """Get kbtype."""
        return self._kbtype

    @kbtype.setter
    def kbtype(self, value):
        """Set kbtype."""
        if value is None:
            # set the default value
            return
        # or set one of the available values
        kbtype = value[0] if len(value) > 0 else 'w'
        if kbtype not in ['t', 'd', 'w']:
            raise ValueError('unknown type "{value}", please use one of \
                             following values: "taxonomy", "dynamic" or \
                             "written_as"'.format(value=value))
        self._kbtype = kbtype

    def is_dynamic(self):
        """Return true if the type is dynamic."""
        return self._kbtype == 'd'

    def to_dict(self):
        """Return a dict representation of KnwKB."""
        mydict = {
            'id': self.id,
            'name': self.name,
            'description': self.description,
            'kbtype': self.kbtype
        }
        if self.kbtype == 'd':
            mydict.update((self.kbdefs.to_dict() if self.kbdefs else {}) or {})

        return mydict

    def get_kbr_items(self, searchkey="", searchvalue="", searchtype='s'):
        """
        Return dicts of 'key' and 'value' from a knowledge base.

        :param kb_name the name of the knowledge base
        :param searchkey search using this key
        :param searchvalue search using this value
        :param searchtype s=substring, e=exact, sw=startswith
        :return a list of dictionaries [{'key'=>x, 'value'=>y},..]
        """
        import warnings
        warnings.warn("The function is deprecated. Please use the "
                      "`KnwKBRVAL.query_kb_mappings()` instead. "
                      "E.g. [kval.to_dict() for kval in "
                      "KnwKBRVAL.query_kb_mappings(kb_id).all()]")
        if searchtype == 's' and searchkey:
            searchkey = '%' + searchkey + '%'
        if searchtype == 's' and searchvalue:
            searchvalue = '%' + searchvalue + '%'
        if searchtype == 'sw' and searchvalue:  # startswith
            searchvalue = searchvalue + '%'
        if not searchvalue:
            searchvalue = '%'
        if not searchkey:
            searchkey = '%'

        kvals = KnwKBRVAL.query.filter(KnwKBRVAL.id_knwKB.like(self.id),
                                       KnwKBRVAL.m_value.like(searchvalue),
                                       KnwKBRVAL.m_key.like(searchkey)).all()
        return [kval.to_dict() for kval in kvals]

    def get_kbr_values(self, searchkey="", searchvalue="", searchtype='s'):
        """
        Return dicts of 'key' and 'value' from a knowledge base.

        :param kb_name the name of the knowledge base
        :param searchkey search using this key
        :param searchvalue search using this value
        :param searchtype s=substring, e=exact, sw=startswith
        :return a list of dictionaries [{'key'=>x, 'value'=>y},..]
        """
        import warnings
        warnings.warn("The function is deprecated. Please use the "
                      "`KnwKBRVAL.query_kb_mappings()` instead. "
                      "E.g. [(kval.m_value,) for kval in "
                      "KnwKBRVAL.query_kb_mappings(kb_id).all()]")
        # prepare filters
        if searchtype == 's':
            searchkey = '%' + searchkey + '%'
        if searchtype == 's' and searchvalue:
            searchvalue = '%' + searchvalue + '%'
        if searchtype == 'sw' and searchvalue:  # startswith
            searchvalue = searchvalue + '%'
        if not searchvalue:
            searchvalue = '%'
        # execute query
        return db.session.execute(
            db.select([KnwKBRVAL.m_value],
                      db.and_(KnwKBRVAL.id_knwKB.like(self.id),
                              KnwKBRVAL.m_value.like(searchvalue),
                              KnwKBRVAL.m_key.like(searchkey))))

    @session_manager
    def set_dyn_config(self, field, expression, collection=None):
        """Set dynamic configuration."""
        if self.kbdefs:
            # update
            self.kbdefs.output_tag = field
            self.kbdefs.search_expression = expression
            self.kbdefs.collection = collection
            db.session.merge(self.kbdefs)
        else:
            # insert
            self.kbdefs = KnwKBDDEF(output_tag=field,
                                    search_expression=expression,
                                    collection=collection)

    @staticmethod
    def generate_slug(name):
        """Generate a slug for the knowledge.

        :param name: text to slugify
        :return: slugified text
        """
        slug = slugify(name)

        i = KnwKB.query.filter(
            db.or_(
                KnwKB.slug.like(slug),
                KnwKB.slug.like(slug + '-%'),
            )).count()

        return slug + ('-{0}'.format(i) if i > 0 else '')

    @staticmethod
    def exists(kb_name):
        """Return True if a kb with the given name exists.

        :param kb_name: the name of the knowledge base
        :return: True if kb exists
        """
        return KnwKB.query_exists(KnwKB.name.like(kb_name))

    @staticmethod
    def query_exists(filters):
        """Return True if a kb with the given filters exists.

        E.g: KnwKB.query_exists(KnwKB.name.like('FAQ'))

        :param filters: filter for sqlalchemy
        :return: True if kb exists
        """
        return db.session.query(KnwKB.query.filter(filters).exists()).scalar()

    def get_filename(self):
        """Construct the file name for taxonomy knoledge."""
        return cfg['CFG_WEBDIR'] + "/kbfiles/" \
            + str(self.id) + ".rdf"
Esempio n. 29
0
class PersistentIdentifier(db.Model):

    """Store and register persistent identifiers.

    Assumptions:
      * Persistent identifiers can be represented as a string of max 255 chars.
      * An object has many persistent identifiers.
      * A persistent identifier has one and only one object.
    """

    __tablename__ = 'pidSTORE'
    __table_args__ = (
        db.Index('uidx_type_pid', 'pid_type', 'pid_value', unique=True),
        db.Index('idx_status', 'status'),
        db.Index('idx_object', 'object_type', 'object_value'),
    )

    id = db.Column(db.Integer(15, unsigned=True), primary_key=True)
    """Id of persistent identifier entry."""

    pid_type = db.Column(db.String(6), nullable=False)
    """Persistent Identifier Schema."""

    pid_value = db.Column(db.String(length=255), nullable=False)
    """Persistent Identifier."""

    pid_provider = db.Column(db.String(length=255), nullable=False)
    """Persistent Identifier Provider"""

    status = db.Column(db.CHAR(length=1), nullable=False)
    """Status of persistent identifier, e.g. registered, reserved, deleted."""

    object_type = db.Column(db.String(3), nullable=True)
    """Object Type - e.g. rec for record."""

    object_value = db.Column(db.String(length=255), nullable=True)
    """Object ID - e.g. a record id."""

    created = db.Column(db.DateTime(), nullable=False, default=datetime.now)
    """Creation datetime of entry."""

    last_modified = db.Column(
        db.DateTime(), nullable=False, default=datetime.now,
        onupdate=datetime.now
    )
    """Last modification datetime of entry."""

    #
    # Class methods
    #
    @classmethod
    def create(cls, pid_type, pid_value, pid_provider='', provider=None):
        """Internally reserve a new persistent identifier.

        A provider for the given persistent identifier type must exists. By
        default the system will choose a provider according to the pid
        type. If desired, the default system provider can be overridden via
        the provider keyword argument.

        Return PID object if successful otherwise None.
        """
        # Ensure provider exists
        if provider is None:
            provider = PidProvider.create(pid_type, pid_value, pid_provider)
            if not provider:
                raise Exception(
                    "No provider found for %s:%s (%s)" % (
                        pid_type, pid_value, pid_provider)
                )

        db.session.begin_nested()
        try:
            obj = cls(pid_type=provider.pid_type,
                      pid_value=provider.create_new_pid(pid_value),
                      pid_provider=pid_provider,
                      status=cfg['PIDSTORE_STATUS_NEW'])
            obj._provider = provider
            db.session.add(obj)
            obj.log("CREATE", "Created")
            return obj
        except SQLAlchemyError:
            db.session.rollback()
            obj.log("CREATE", "Failed to created. Already exists.")
            return None

    @classmethod
    def get(cls, pid_type, pid_value, pid_provider='', provider=None):
        """Get persistent identifier.

        Return None if not found.
        """
        pid_value = to_unicode(pid_value)
        obj = cls.query.filter_by(
            pid_type=pid_type, pid_value=pid_value, pid_provider=pid_provider
        ).first()
        if obj:
            obj._provider = provider
            return obj
        else:
            return None

    #
    # Instance methods
    #
    def has_object(self, object_type, object_value):
        """Determine if this PID is assigned to a specific object."""
        if object_type not in cfg['PIDSTORE_OBJECT_TYPES']:
            raise Exception("Invalid object type %s." % object_type)

        object_value = to_unicode(object_value)

        return self.object_type == object_type and \
            self.object_value == object_value

    def get_provider(self):
        """Get the provider for this type of persistent identifier."""
        if self._provider is None:
            self._provider = PidProvider.create(
                self.pid_type, self.pid_value, self.pid_provider
            )
        return self._provider

    def assign(self, object_type, object_value, overwrite=False):
        """Assign this persistent identifier to a given object.

        Note, the persistent identifier must first have been reserved. Also,
        if an exsiting object is already assigned to the pid, it will raise an
        exception unless overwrite=True.
        """
        if object_type not in cfg['PIDSTORE_OBJECT_TYPES']:
            raise Exception("Invalid object type %s." % object_type)
        object_value = to_unicode(object_value)

        if not self.id:
            raise Exception(
                "You must first create the persistent identifier before you "
                "can assign objects to it."
            )

        if self.is_deleted():
            raise Exception(
                "You cannot assign objects to a deleted persistent identifier."
            )

        with db.session.begin_nested():
            # Check for an existing object assigned to this pid
            existing_obj_id = self.get_assigned_object(object_type)

            if existing_obj_id and existing_obj_id != object_value:
                if not overwrite:
                    raise Exception(
                        "Persistent identifier is already assigned to another "
                        "object"
                    )
                else:
                    self.log(
                        "ASSIGN",
                        "Unassigned object %s:%s (overwrite requested)" % (
                            self.object_type, self.object_value)
                    )
                    self.object_type = None
                    self.object_value = None
            elif existing_obj_id and existing_obj_id == object_value:
                # The object is already assigned to this pid.
                return True

            self.object_type = object_type
            self.object_value = object_value
            db.session.commit()
            self.log("ASSIGN", "Assigned object {0}:{1}".format(
                self.object_type, self.object_value
            ))
            return True

    def update(self, with_deleted=False, *args, **kwargs):
        """Update the persistent identifier with the provider."""
        if self.is_new() or self.is_reserved():
            raise Exception(
                "Persistent identifier has not yet been registered."
            )

        if not with_deleted and self.is_deleted():
            raise Exception("Persistent identifier has been deleted.")

        with db.session.begin_nested():
            provider = self.get_provider()
            if provider is None:
                self.log("UPDATE", "No provider found.")
                raise Exception("No provider found.")

            if provider.update(self, *args, **kwargs):
                if with_deleted and self.is_deleted():
                    self.status = cfg['PIDSTORE_ST):TUS_REGISTERED']
                return True
        return False

    def reserve(self, *args, **kwargs):
        """Reserve the persistent identifier with the provider.

        Note, the reserve method may be called multiple times, even if it was
        already reserved.
        """
        if not (self.is_new() or self.is_reserved()):
            raise Exception(
                "Persistent identifier has already been registered."
            )

        with db.session.begin_nested():
            provider = self.get_provider()
            if provider is None:
                self.log("RESERVE", "No provider found.")
                raise Exception("No provider found.")

            if provider.reserve(self, *args, **kwargs):
                self.status = cfg['PIDSTORE_STATUS_RESERVED']
                return True
        return False

    def register(self, *args, **kwargs):
        """Register the persistent identifier with the provider."""
        if self.is_registered() or self.is_deleted():
            raise Exception(
                "Persistent identifier has already been registered."
            )

        with db.session.begin_nested():
            provider = self.get_provider()
            if provider is None:
                self.log("REGISTER", "No provider found.")
                raise Exception("No provider found.")

            if provider.register(self, *args, **kwargs):
                self.status = cfg['PIDSTORE_STATUS_REGISTERED']
                return True
        return False

    def delete(self, *args, **kwargs):
        """Delete the persistent identifier."""
        with db.session.begin_nested():
            if self.is_new():
                # New persistent identifier which haven't been registered yet.
                # Just delete it completely but keep log)
                # Remove links to log entries (leave the otherwise)
                PidLog.query.filter_by(id_pid=self.id).update({'id_pid': None})
                db.session.delete(self)
                self.log("DELETE", "Unregistered PID successfully deleted")
            else:
                provider = self.get_provider()
                if not provider.delete(self, *args, **kwargs):
                    return False
                self.status = cfg['PIDSTORE_STATUS_DELETED']
            return True

    def sync_status(self, *args, **kwargs):
        """Synchronize persistent identifier status.

        Used when the provider uses an external service, which might have been
        modified outside of our system.
        """
        with db.session.begin_nested():
            provider = self.get_provider()
            result = provider.sync_status(self, *args, **kwargs)
            return result

    def get_assigned_object(self, object_type=None):
        """Return an assigned object."""
        if object_type is not None and self.object_type == object_type:
            return self.object_value
        return None

    def is_registered(self):
        """Return true if the persistent identifier has been registered."""
        return self.status == cfg['PIDSTORE_STATUS_REGISTERED']

    def is_deleted(self):
        """Return true if the persistent identifier has been deleted."""
        return self.status == cfg['PIDSTORE_STATUS_DELETED']

    def is_new(self):
        """Return true if the PIDhas not yet been registered or reserved."""
        return self.status == cfg['PIDSTORE_STATUS_NEW']

    def is_reserved(self):
        """Return true if the PID has not yet been reserved."""
        return self.status == cfg['PIDSTORE_STATUS_RESERVED']

    def log(self, action, message):
        """Store action and message in log."""
        if self.pid_type and self.pid_value:
            message = "[%s:%s] %s" % (self.pid_type, self.pid_value, message)
        with db.session.begin_nested():
            p = PidLog(id_pid=self.id, action=action, message=message)
            db.session.add(p)
Esempio n. 30
0
class CheckerRuleExecution(db.Model):

    __tablename__ = 'checker_rule_execution'

    uuid = db.Column(
        db.String(36),
        primary_key=True,
        doc="UUID of the execution. Same with that of RedisMaster and logfile.",
    )

    owner_id = db.Column(
        db.Integer(15, unsigned=True),
        db.ForeignKey('user.id'),
        nullable=False,
        default=1,
    )
    owner = db.relationship(
        'User',
        doc="User who owns this execution. May be used by reporters.",
    )

    rule_name = db.Column(
        db.String(127),
        db.ForeignKey('checker_rule.name'),
        nullable=False,
        index=True,
        doc="Name of the associated task.",
    )

    _status = db.Column(
        ChoiceType(StatusMaster, impl=db.Integer()),
        default=StatusMaster.unknown,
    )

    status_update_date = db.Column(
        db.DateTime(),
        nullable=False,
        server_default='1900-01-01 00:00:00',
        doc="Last date the status was updated.",
    )

    start_date = db.Column(
        db.DateTime(),
        nullable=False,
        server_default='1900-01-01 00:00:00',
        doc="Date at which this task was started.",
    )

    dry_run = db.Column(
        db.Boolean,
        default=False,
        doc=
        "Whether this execution is a dry run. Note the `should_*` properties.")

    @db.hybrid_property
    def should_commit(self):
        """Whether this execution should commit record modifications."""
        return not self.dry_run

    @db.hybrid_property
    def should_report_logs(self):
        """Whether this execution should report logs to the reporters."""
        return not self.dry_run

    @db.hybrid_property
    def should_report_exceptions(self):
        """Whether this execution should report exceptions to the reporters."""
        return not self.dry_run

    @db.hybrid_property
    def master(self):
        """The master object of this execution.

        :rtype: `invenio_checker.clients.master.RedisMaster`
        """
        return RedisMaster(self.uuid)

    @db.hybrid_property
    def status(self):
        """The status of the execution.

        :rtype: `StatusMaster`
        """
        return self._status

    @status.setter
    @session_manager
    def status(self, new_status):
        """Status setter.

        :type new_status: `StatusMaster`
        """
        self._status = new_status
        self.status_update_date = datetime.now()

    def read_logs(self):
        """Stream user-friendly structured logs of this execution.

        First attempt to stream using `eliot-tree` which provides a text
        tree-like structure of the execution given its logs.

        If `eliot-tree` fails (which happens when there is an eliot Task
        serialization bug in our code, a warning is yielded, followed by the
        output of `eliot-prettyprint`.

        Therefore, the output of this function is not guaranteed to be
        machine-readable.

        ..note::
            This function may be called mid-run.
        """
        from glob import glob
        import subprocess
        from .config import get_eliot_log_file

        filenames = glob(get_eliot_log_file(master_id=self.uuid).name + "*")
        eliottree_subp = subprocess.Popen(['eliot-tree', '--field-limit', '0'],
                                          stdout=subprocess.PIPE,
                                          stdin=subprocess.PIPE)
        eliottree_failed = False
        with eliottree_subp.stdin:
            try:
                for filename in filenames:
                    with open(filename, 'r') as file_:
                        eliottree_subp.stdin.write(file_.read())
            except (IOError, MemoryError):
                eliottree_failed = True

        with eliottree_subp.stdout:
            for line in eliottree_subp.stdout:
                yield line

        if eliottree_failed or (eliottree_subp.wait() != 0):
            # eliot-tree can fail on unfinished logging. We still want output
            # for debugging, so we use the less structured eliot-prettyprint
            from eliot.prettyprint import pretty_format
            from eliot._bytesjson import loads
            yield '\n`eliot-tree` failed to format output. ' \
                'Retrying with eliot-prettyprint:\n'
            for filename in filenames:
                yield "{}:\n".format(filename)
                with open(filename, 'r') as file_:
                    for line in file_:
                        yield pretty_format(loads(line))