Ejemplo n.º 1
0
class FeaturedCommunity(db.Model):
    """Featured community representation."""

    __tablename__ = 'communityFEATURED'

    id = db.Column(db.Integer(15, unsigned=True),
                   primary_key=True,
                   autoincrement=True)
    """Featured community identifier."""

    id_community = db.Column(db.String(100),
                             db.ForeignKey(Community.id),
                             nullable=False)
    """Specific community."""

    start_date = db.Column(db.DateTime(), nullable=False, default=datetime.now)
    """The date from which it should start to take effect."""

    community = db.relationship(Community, backref="featuredcommunity")
    """Relation to the community."""
    @classmethod
    def get_current(cls, start_date=None):
        """Get the latest featured community."""
        start_date = start_date or datetime.now()

        return cls.query.options(
            db.joinedload_all('community.collection')).filter(
                cls.start_date <= start_date).order_by(
                    cls.start_date.desc()).first()
Ejemplo n.º 2
0
class PidLog(db.Model):

    """Audit log of actions happening to persistent identifiers.

    This model is primarily used through PersistentIdentifier.log and rarely
    created manually.
    """

    __tablename__ = 'pidLOG'
    __table_args__ = (
        db.Index('idx_action', 'action'),
    )

    id = db.Column(db.Integer(15, unsigned=True), primary_key=True)
    """Id of persistent identifier entry."""

    id_pid = db.Column(
        db.Integer(15, unsigned=True), db.ForeignKey(PersistentIdentifier.id),
        nullable=True,
    )
    """PID."""

    timestamp = db.Column(db.DateTime(), nullable=False, default=datetime.now)
    """Creation datetime of entry."""

    action = db.Column(db.String(10), nullable=False)
    """Action identifier."""

    message = db.Column(db.Text(), nullable=False)
    """Log message."""

    # Relationship
    pid = db.relationship("PersistentIdentifier", backref="logs")
Ejemplo n.º 3
0
class Page(db.Model):
    """Represents a page."""
    __tablename__ = 'pages'

    id = db.Column(db.Integer(15, unsigned=True),
                   nullable=False,
                   primary_key=True,
                   autoincrement=True)
    url = db.Column(db.String(100), unique=True, nullable=False)
    title = db.Column(db.String(200), nullable=True)
    content = db.Column(db.TEXT(length=2**32 - 2), nullable=True)
    # Default is pages/templates/default.html
    template_name = db.Column(db.String(70), nullable=True)
    created = db.Column(db.DateTime(), nullable=False, default=datetime.now)
    last_modified = db.Column(db.DateTime(),
                              nullable=False,
                              default=datetime.now,
                              onupdate=datetime.now)
def do_upgrade():
    """ Implement your upgrades here  """
    m = db.MetaData(bind=db.engine)
    m.reflect()

    tpid = db.Table(
        'pid',
        m,
        db.Column('id', db.Integer(15, unsigned=True), primary_key=True, nullable=False),
        db.Column('type', db.String(length=6), nullable=False),
        db.Column('pid', db.String(length=255), nullable=False),
        db.Column('status', db.Char(length=1), nullable=False),
        db.Column('created', db.DateTime(), nullable=False),
        db.Column('last_modified', db.DateTime(), nullable=False),
        db.Index('uidx_type_pid', 'type', 'pid', unique=True),
        db.Index('idx_status', 'status'),
        mysql_engine='MyISAM',
    )

    tpidlog = db.Table(
        'pidLOG',
        m,
        db.Column('id', db.Integer(15, unsigned=True), primary_key=True, nullable=False),
        db.Column('id_pid', db.Integer(15, unsigned=True), ForeignKey('pid.id')),
        db.Column('timestamp', DateTime(), nullable=False),
        db.Column('action', db.String(length=10), nullable=False),
        db.Column('message', Text(), nullable=False),
        db.Index('idx_action', 'action'),
        mysql_engine='MyISAM',
    )

    tpidregistry = db.Table(
        'pidREGISTRY',
        m,
        db.Column('object_type', db.String(length=3), primary_key=True, nullable=False),
        db.Column('object_id', db.String(length=255), nullable=False),
        db.Column('id_pid', db.Integer(15, unsigned=True), ForeignKey('pid.id'), primary_key=True, nullable=False),
        db.Index('idx_type_id', 'object_type', 'object_id'),
        mysql_engine='MyISAM',
    )

    tpid.create()
    tpidlog.create()
    tpidregistry.create()
Ejemplo n.º 5
0
class Page(db.Model):
    """Represents a page."""
    __tablename__ = 'pages'

    id = db.Column(db.Integer(15, unsigned=True),
                   nullable=False,
                   primary_key=True,
                   autoincrement=True)
    """Page identifier."""

    url = db.Column(db.String(100), unique=True, nullable=False)
    """Page url."""

    title = db.Column(db.String(200), nullable=True)
    """Page title."""

    content = db.Column(db.Text().with_variant(db.Text(length=2**32 - 2),
                                               'mysql'),
                        nullable=True)
    """Page content. Default is pages/templates/default.html"""

    description = db.Column(db.String(200), nullable=True)
    """Page description."""

    template_name = db.Column(db.String(70), nullable=True)
    """Page template name. Default is cfg["PAGES_DEFAULT_TEMPLATE"]."""

    created = db.Column(db.DateTime(), nullable=False, default=datetime.now)
    """Page creation date."""

    last_modified = db.Column(db.DateTime(),
                              nullable=False,
                              default=datetime.now,
                              onupdate=datetime.now)
    """Page last modification date."""
    def __repr__(self):
        """Page representation.

        Used on Page admin view in inline model.

        :returns: unambiguous page representation.
        """
        return "URL: %s, title: %s" % (self.url, self.title)
def do_upgrade():
    """ Implement your upgrades here  """
    op.add_column(u'community', db.Column('fixed_points',
                  db.Integer(display_width=9), nullable=False))
    op.add_column(u'community',
                  db.Column('last_record_accepted', db.DateTime(),
                            nullable=False))
    op.add_column(u'community',
                  db.Column('ranking', db.Integer(display_width=9),
                            nullable=False))
def do_upgrade():
    """ Implement your upgrades here  """
    op.create_table('pages',
                    db.Column('id',
                              mysql.INTEGER(display_width=15),
                              nullable=False),
                    db.Column('url', db.String(length=100), nullable=False),
                    db.Column('title', db.String(length=200), nullable=True),
                    db.Column('content',
                              db.TEXT(length=4294967294),
                              nullable=True),
                    db.Column('template_name',
                              db.String(length=70),
                              nullable=True),
                    db.Column('created', db.DateTime(), nullable=False),
                    db.Column('last_modified', db.DateTime(), nullable=False),
                    db.PrimaryKeyConstraint('id'),
                    db.UniqueConstraint('url'),
                    mysql_charset='utf8',
                    mysql_engine='MyISAM')
Ejemplo n.º 8
0
def do_upgrade():
    """Implement your upgrades here."""
    op.create_table('communityFEATURED',
                    db.Column('id', db.Integer(), nullable=False),
                    db.Column('id_community', db.String(length=100),
                              nullable=False),
                    db.Column('start_date', db.DateTime(), nullable=False),
                    db.ForeignKeyConstraint(['id_community'],
                                            ['community.id'], ),
                    db.PrimaryKeyConstraint('id'),
                    mysql_charset='utf8',
                    mysql_engine='MyISAM'
                    )
def do_upgrade():
    """ Implement your upgrades here  """
    m = db.MetaData(bind=db.engine)
    m.reflect()
    t = db.Table(
        'userCOLLECTION',
        m,
        db.Column('id',
                  db.String(length=100),
                  primary_key=True,
                  nullable=False),
        db.Column('id_user',
                  db.Integer(15, unsigned=True),
                  db.ForeignKey('user.id'),
                  nullable=False),
        db.Column('id_collection',
                  db.MediumInteger(9, unsigned=True),
                  db.ForeignKey('collection.id'),
                  nullable=True),
        db.Column('id_collection_provisional',
                  db.MediumInteger(9, unsigned=True),
                  db.ForeignKey('collection.id'),
                  nullable=True),
        db.Column('id_oairepository',
                  db.MediumInteger(9, unsigned=True),
                  db.ForeignKey('oaiREPOSITORY.id'),
                  nullable=True),
        db.Column('title', db.String(length=255), nullable=False),
        db.Column('description', db.Text(), nullable=False),
        db.Column('page', db.Text(), nullable=False),
        db.Column('curation_policy', db.Text(), nullable=False),
        db.Column('has_logo', db.Boolean(), nullable=False),
        db.Column('created', db.DateTime(), nullable=False),
        db.Column('last_modified', db.DateTime(), nullable=False),
        mysql_engine='MyISAM',
    )
    t.create()
Ejemplo n.º 10
0
def do_upgrade():
    """Implement your upgrades here."""
    if not op.has_table('oauth2CLIENT'):
        op.create_table(
            'oauth2CLIENT',
            db.Column('name', db.String(length=40), nullable=True),
            db.Column('description', db.Text(), nullable=True),
            db.Column('website', URLType(), nullable=True),
            db.Column('user_id', db.Integer(15, unsigned=True), nullable=True),
            db.Column('client_id', db.String(length=255), nullable=False),
            db.Column('client_secret', db.String(length=255), nullable=False),
            db.Column('is_confidential', db.Boolean(), nullable=True),
            db.Column('is_internal', db.Boolean(), nullable=True),
            db.Column('_redirect_uris', db.Text(), nullable=True),
            db.Column('_default_scopes', db.Text(), nullable=True),
            db.ForeignKeyConstraint(['user_id'], ['user.id'], ),
            db.PrimaryKeyConstraint('client_id'),
            mysql_charset='utf8',
            mysql_engine='MyISAM'
        )
    else:
        warnings.warn("*** Creation of table 'oauth2CLIENT' skipped!")

    if not op.has_table('oauth2TOKEN'):
        op.create_table(
            'oauth2TOKEN',
            db.Column('id', db.Integer(15, unsigned=True), autoincrement=True,
                      nullable=False),
            db.Column('client_id', db.String(length=40), nullable=False),
            db.Column('user_id', db.Integer(15, unsigned=True), nullable=True),
            db.Column('token_type', db.String(length=255), nullable=True),
            db.Column('access_token', db.String(length=255), nullable=True),
            db.Column('refresh_token', db.String(length=255), nullable=True),
            db.Column('expires', db.DateTime(), nullable=True),
            db.Column('_scopes', db.Text(), nullable=True),
            db.Column('is_personal', db.Boolean(), nullable=True),
            db.Column('is_internal', db.Boolean(), nullable=True),
            db.ForeignKeyConstraint(
                ['client_id'], ['oauth2CLIENT.client_id'],),
            db.ForeignKeyConstraint(['user_id'], ['user.id'], ),
            db.PrimaryKeyConstraint('id'),
            db.UniqueConstraint('access_token'),
            db.UniqueConstraint('refresh_token'),
            mysql_charset='utf8',
            mysql_engine='MyISAM'
        )
    else:
        warnings.warn("*** Creation of table 'oauth2TOKEN' skipped!")
Ejemplo n.º 11
0
def do_upgrade():
    """Implement your upgrades here."""
    op.create_table('quotaUSAGE',
    db.Column('id', db.Integer(display_width=15), nullable=False),
    db.Column('object_type', db.String(length=40), nullable=True),
    db.Column('object_id', db.String(length=250), nullable=True),
    db.Column('metric', db.String(length=40), nullable=True),
    db.Column('value', db.BigInteger(), nullable=False),
    db.Column('modified', db.DateTime(), nullable=False),
    db.PrimaryKeyConstraint('id'),
    db.UniqueConstraint('object_type', 'object_id', 'metric'),
    mysql_charset='utf8',
    mysql_engine='MyISAM'
    )
    op.create_index(op.f('ix_quotaUSAGE_object_id'), 'quotaUSAGE', ['object_id'], unique=False)
    op.create_index(op.f('ix_quotaUSAGE_object_type'), 'quotaUSAGE', ['object_type'], unique=False)
Ejemplo n.º 12
0
class PersistentIdentifier(db.Model):

    """Store and register persistent identifiers.

    Assumptions:
      * Persistent identifiers can be represented as a string of max 255 chars.
      * An object has many persistent identifiers.
      * A persistent identifier has one and only one object.
    """

    __tablename__ = 'pidSTORE'
    __table_args__ = (
        db.Index('uidx_type_pid', 'pid_type', 'pid_value', unique=True),
        db.Index('idx_status', 'status'),
        db.Index('idx_object', 'object_type', 'object_value'),
    )

    id = db.Column(db.Integer(15, unsigned=True), primary_key=True)
    """Id of persistent identifier entry."""

    pid_type = db.Column(db.String(6), nullable=False)
    """Persistent Identifier Schema."""

    pid_value = db.Column(db.String(length=255), nullable=False)
    """Persistent Identifier."""

    pid_provider = db.Column(db.String(length=255), nullable=False)
    """Persistent Identifier Provider"""

    status = db.Column(db.CHAR(length=1), nullable=False)
    """Status of persistent identifier, e.g. registered, reserved, deleted."""

    object_type = db.Column(db.String(3), nullable=True)
    """Object Type - e.g. rec for record."""

    object_value = db.Column(db.String(length=255), nullable=True)
    """Object ID - e.g. a record id."""

    created = db.Column(db.DateTime(), nullable=False, default=datetime.now)
    """Creation datetime of entry."""

    last_modified = db.Column(
        db.DateTime(), nullable=False, default=datetime.now,
        onupdate=datetime.now
    )
    """Last modification datetime of entry."""

    #
    # Class methods
    #
    @classmethod
    def create(cls, pid_type, pid_value, pid_provider='', provider=None):
        """Internally reserve a new persistent identifier.

        A provider for the given persistent identifier type must exists. By
        default the system will choose a provider according to the pid
        type. If desired, the default system provider can be overridden via
        the provider keyword argument.

        Return PID object if successful otherwise None.
        """
        # Ensure provider exists
        if provider is None:
            provider = PidProvider.create(pid_type, pid_value, pid_provider)
            if not provider:
                raise Exception(
                    "No provider found for %s:%s (%s)" % (
                        pid_type, pid_value, pid_provider)
                )

        try:
            obj = cls(pid_type=provider.pid_type,
                      pid_value=provider.create_new_pid(pid_value),
                      pid_provider=pid_provider,
                      status=cfg['PIDSTORE_STATUS_NEW'])
            obj._provider = provider
            db.session.add(obj)
            db.session.commit()
            obj.log("CREATE", "Created")
            return obj
        except SQLAlchemyError:
            db.session.rollback()
            obj.log("CREATE", "Failed to created. Already exists.")
            return None

    @classmethod
    def get(cls, pid_type, pid_value, pid_provider='', provider=None):
        """Get persistent identifier.

        Return None if not found.
        """
        pid_value = to_unicode(pid_value)
        obj = cls.query.filter_by(
            pid_type=pid_type, pid_value=pid_value, pid_provider=pid_provider
        ).first()
        if obj:
            obj._provider = provider
            return obj
        else:
            return None

    #
    # Instance methods
    #
    def has_object(self, object_type, object_value):
        """Determine if this PID is assigned to a specific object."""
        if object_type not in cfg['PIDSTORE_OBJECT_TYPES']:
            raise Exception("Invalid object type %s." % object_type)

        object_value = to_unicode(object_value)

        return self.object_type == object_type and \
            self.object_value == object_value

    def get_provider(self):
        """Get the provider for this type of persistent identifier."""
        if self._provider is None:
            self._provider = PidProvider.create(
                self.pid_type, self.pid_value, self.pid_provider
            )
        return self._provider

    def assign(self, object_type, object_value, overwrite=False):
        """Assign this persistent identifier to a given object.

        Note, the persistent identifier must first have been reserved. Also,
        if an exsiting object is already assigned to the pid, it will raise an
        exception unless overwrite=True.
        """
        if object_type not in cfg['PIDSTORE_OBJECT_TYPES']:
            raise Exception("Invalid object type %s." % object_type)
        object_value = to_unicode(object_value)

        if not self.id:
            raise Exception(
                "You must first create the persistent identifier before you "
                "can assign objects to it."
            )

        if self.is_deleted():
            raise Exception(
                "You cannot assign objects to a deleted persistent identifier."
            )

        # Check for an existing object assigned to this pid
        existing_obj_id = self.get_assigned_object(object_type)

        if existing_obj_id and existing_obj_id != object_value:
            if not overwrite:
                raise Exception(
                    "Persistent identifier is already assigned to another "
                    "object"
                )
            else:
                self.log(
                    "ASSIGN",
                    "Unassigned object %s:%s (overwrite requested)" % (
                        self.object_type, self.object_value)
                )
                self.object_type = None
                self.object_value = None
        elif existing_obj_id and existing_obj_id == object_value:
            # The object is already assigned to this pid.
            return True

        self.object_type = object_type
        self.object_value = object_value
        db.session.commit()
        self.log("ASSIGN", "Assigned object %s:%s" % (self.object_type,
                                                      self.object_value))
        return True

    def update(self, with_deleted=False, *args, **kwargs):
        """Update the persistent identifier with the provider.."""
        if self.is_new() or self.is_reserved():
            raise Exception(
                "Persistent identifier has not yet been registered."
            )

        if not with_deleted and self.is_deleted():
            raise Exception("Persistent identifier has been deleted.")

        provider = self.get_provider()
        if provider is None:
            self.log("UPDATE", "No provider found.")
            raise Exception("No provider found.")

        if provider.update(self, *args, **kwargs):
            if with_deleted and self.is_deleted():
                self.status = cfg['PIDSTORE_STATUS_REGISTERED']
                db.session.commit()
            return True
        return False

    def reserve(self, *args, **kwargs):
        """Reserve the persistent identifier with the provider.

        Note, the reserve method may be called multiple times, even if it was
        already reserved.
        """
        if not (self.is_new() or self.is_reserved()):
            raise Exception(
                "Persistent identifier has already been registered."
            )

        provider = self.get_provider()
        if provider is None:
            self.log("RESERVE", "No provider found.")
            raise Exception("No provider found.")

        if provider.reserve(self, *args, **kwargs):
            self.status = cfg['PIDSTORE_STATUS_RESERVED']
            db.session.commit()
            return True
        return False

    def register(self, *args, **kwargs):
        """Register the persistent identifier with the provider."""
        if self.is_registered() or self.is_deleted():
            raise Exception(
                "Persistent identifier has already been registered."
            )

        provider = self.get_provider()
        if provider is None:
            self.log("REGISTER", "No provider found.")
            raise Exception("No provider found.")

        if provider.register(self, *args, **kwargs):
            self.status = cfg['PIDSTORE_STATUS_REGISTERED']
            db.session.commit()
            return True
        return False

    def delete(self, *args, **kwargs):
        """Delete the persistent identifier."""
        if self.is_new():
            # New persistent identifier which haven't been registered yet. Just
            #  delete it completely but keep log)
            # Remove links to log entries (but otherwise leave the log entries)
            PidLog.query.filter_by(id_pid=self.id).update({'id_pid': None})
            db.session.delete(self)
            self.log("DELETE", "Unregistered PID successfully deleted")
        else:
            provider = self.get_provider()
            if not provider.delete(self, *args, **kwargs):
                return False
            self.status = cfg['PIDSTORE_STATUS_DELETED']
            db.session.commit()
        return True

    def sync_status(self, *args, **kwargs):
        """Synchronize persistent identifier status.

        Used when the provider uses an external service, which might have been
        modified outside of our system.
        """
        provider = self.get_provider()
        result = provider.sync_status(self, *args, **kwargs)
        db.session.commit()
        return result

    def get_assigned_object(self, object_type=None):
        """Return an assigned object."""
        if object_type is not None and self.object_type == object_type:
            return self.object_value
        return None

    def is_registered(self):
        """Return true if the persistent identifier has been registered."""
        return self.status == cfg['PIDSTORE_STATUS_REGISTERED']

    def is_deleted(self):
        """Return true if the persistent identifier has been deleted."""
        return self.status == cfg['PIDSTORE_STATUS_DELETED']

    def is_new(self):
        """Return true if the PIDhas not yet been registered or reserved."""
        return self.status == cfg['PIDSTORE_STATUS_NEW']

    def is_reserved(self):
        """Return true if the PID has not yet been reserved."""
        return self.status == cfg['PIDSTORE_STATUS_RESERVED']

    def log(self, action, message):
        """Store action and message in log."""
        if self.pid_type and self.pid_value:
            message = "[%s:%s] %s" % (self.pid_type, self.pid_value, message)
        p = PidLog(id_pid=self.id, action=action, message=message)
        db.session.add(p)
        db.session.commit()
Ejemplo n.º 13
0
class Community(db.Model):
    """Represents a Community.

    A layer around Invenio's collections and portalboxes.
    """

    __tablename__ = 'community'

    #
    # Fields
    #
    id = db.Column(db.String(100), primary_key=True)
    """
    Community identifier used to generate the real collection_identifier
    """

    id_user = db.Column(db.Integer(15, unsigned=True),
                        db.ForeignKey(User.id),
                        nullable=False)
    """ Owner of the community. """

    id_collection = db.Column(db.Integer(15, unsigned=True),
                              db.ForeignKey(Collection.id),
                              nullable=True,
                              default=None)
    """ Invenio collection generated from this community"""

    id_collection_provisional = db.Column(db.Integer(15, unsigned=True),
                                          db.ForeignKey(Collection.id),
                                          nullable=True,
                                          default=None)
    """ Invenio provisional collection generated from this community"""

    id_oairepository = db.Column(db.MediumInteger(9, unsigned=True),
                                 db.ForeignKey(OaiREPOSITORY.id),
                                 nullable=True,
                                 default=None)
    """ OAI Repository set specification """

    title = db.Column(db.String(length=255), nullable=False, default='')
    """ Title of community."""

    description = db.Column(db.Text(), nullable=False, default='')
    """ Short description of community, displayed in portal boxes. """

    page = db.Column(db.Text(), nullable=False, default='')
    """ Long description of community, displayed on an individual page. """

    curation_policy = db.Column(db.Text(), nullable=False, default='')
    """ Curation policy """

    logo_ext = db.Column(db.String(length=5))
    """Logo extension."""

    created = db.Column(db.DateTime(), nullable=False, default=datetime.now)
    """ Creation datetime """

    last_modified = db.Column(db.DateTime(),
                              nullable=False,
                              default=datetime.now,
                              onupdate=datetime.now)
    """ Last modification datetime """

    last_record_accepted = db.Column(db.DateTime(),
                                     nullable=False,
                                     default=datetime(2000, 1, 1, 0, 0, 0))
    """ Last record acceptance datetime"""

    ranking = db.Column(db.Integer(9), nullable=False, default=0)
    """ Ranking of community. Updated by ranking deamon"""

    fixed_points = db.Column(db.Integer(9), nullable=False, default=0)
    """ Points which will be always added to overall score of community"""
    #
    # Relation ships
    #
    owner = db.relationship(User,
                            backref='communities',
                            foreign_keys=[id_user])
    """ Relation to the owner (User) of the community"""

    collection = db.relationship(Collection,
                                 uselist=False,
                                 backref='community',
                                 foreign_keys=[id_collection])
    """ Relationship to collection. """

    collection_provisional = db.relationship(
        Collection,
        uselist=False,
        backref='community_provisional',
        foreign_keys=[id_collection_provisional])
    """Relationship to restricted collection containing uncurated records."""

    oai_set = db.relationship(OaiREPOSITORY,
                              uselist=False,
                              backref='community',
                              foreign_keys=[id_oairepository])
    """Relation to the owner (User) of the community."""

    #
    # Properties
    #
    @property
    def logo_url(self):
        """Get URL to collection logo."""
        if self.logo_ext:
            return url_for('static',
                           filename='media/communities' + '/' + self.id +
                           self.logo_ext)
        else:
            return None

    @property
    def oai_url(self):
        """Get link to OAI-PMH API for this community collection."""
        return "/oai2d?verb=ListRecords&metadataPrefix=oai_dc&set=%s" % (
            self.get_collection_name(), )

    @property
    def community_url(self):
        """Get URL to this community collection."""
        return "/collection/%s" % self.get_collection_name()

    @property
    def community_provisional_url(self):
        """Get URL to this provisional community collection."""
        return "/search?cc=%s" % self.get_collection_name(provisional=True)

    @property
    def upload_url(self):
        """Get direct upload URL."""
        return url_for('webdeposit.index', c=self.id)

    @classmethod
    def from_recid(cls, recid, provisional=False):
        """Get user communities specified in recid."""
        from invenio.legacy.search_engine import get_record
        rec = get_record(recid)
        prefix = "%s-" % (cfg['COMMUNITIES_ID_PREFIX_PROVISIONAL']
                          if provisional else cfg['COMMUNITIES_ID_PREFIX'])

        colls = rec.get('980', [])
        usercomm = []
        for c in colls:
            try:
                # We are only interested in subfield 'a'
                code, val = c[0][0]
                if code == 'a' and val.startswith(prefix):
                    val = val[len(prefix):]
                    u = cls.query.filter_by(id=val).first()
                    if u:
                        usercomm.append(u)
            except IndexError:
                pass
        return usercomm

    @classmethod
    def filter_communities(cls, p, so):
        """Search for communities.

        Helper function which takes from database only those communities which
        match search criteria. Uses parameter 'so' to set communities in the
        correct order.

        Parameter 'page' is introduced to restrict results and return only
        slice of them for the current page. If page == 0 function will return
        all communities that match the pattern.
        """
        query = cls.query
        if p:
            query = query.filter(
                db.or_(
                    cls.id.like("%" + p + "%"),
                    cls.title.like("%" + p + "%"),
                    cls.description.like("%" + p + "%"),
                ))
        if so in cfg['COMMUNITIES_SORTING_OPTIONS']:
            order = so == 'title' and db.asc or db.desc
            query = query.order_by(order(getattr(cls, so)))
        else:
            query = query.order_by(db.desc(cls.ranking))
        return query

    #
    # Utility methods
    #
    def get_collection_name(self, provisional=False):
        """Get a unique collection name identifier."""
        if provisional:
            return "%s-%s" % (cfg['COMMUNITIES_ID_PREFIX_PROVISIONAL'],
                              self.id)
        else:
            return "%s-%s" % (cfg['COMMUNITIES_ID_PREFIX'], self.id)

    def get_title(self, provisional=False):
        """Get collection title."""
        if provisional:
            return "Provisional: %s" % self.title
        else:
            return self.title

    def get_collection_dbquery(self, provisional=False):
        """Get collection query."""
        return "%s:%s" % self.get_query(provisional=provisional)

    def get_query(self, provisional=False):
        """Get tuple (field,value) for search engine query."""
        return ("980__a", self.get_collection_name(provisional=provisional))

    def render_portalbox_bodies(self, templates):
        """Get a list of rendered portal boxes for this user collection."""
        ctx = {
            'community': self,
        }

        return map(lambda t: render_template_to_string(t, **ctx), templates)

    #
    # Curation methods
    #
    def _modify_record(self,
                       recid,
                       test_func,
                       replace_func,
                       include_func,
                       append_colls=[],
                       replace_colls=[]):
        """Generate record a MARCXML file.

        @param test_func: Function to test if a collection id should be changed
        @param replace_func: Function to replace the collection id.
        @param include_func: Function to test if collection should be included
        """
        from invenio.legacy.search_engine import get_record
        rec = get_record(recid)
        newcolls = []
        dirty = False

        try:
            colls = rec['980']
            if replace_colls:
                for c in replace_colls:
                    newcolls.append([('a', c)])
                    dirty = True
            else:
                for c in colls:
                    try:
                        # We are only interested in subfield 'a'
                        code, val = c[0][0]
                        if test_func(code, val):
                            c[0][0] = replace_func(code, val)
                            dirty = True
                        if include_func(code, val):
                            newcolls.append(c[0])
                        else:
                            dirty = True
                    except IndexError:
                        pass
                for c in append_colls:
                    newcolls.append([('a', c)])
                    dirty = True
        except KeyError:
            return False

        if not dirty:
            return False

        rec = {}
        record_add_field(rec, '001', controlfield_value=str(recid))

        for subfields in newcolls:
            record_add_field(rec, '980', subfields=subfields)

        return rec

    def _upload_record(self, rec, pretend=False):
        """Bibupload one record."""
        from invenio.legacy.bibupload.utils import bibupload_record
        if rec is False:
            return None
        if not pretend:
            bibupload_record(
                record=rec,
                file_prefix='community',
                mode='-c',
                opts=[],
                alias="community",
            )
        return rec

    def _upload_collection(self, coll):
        """Bibupload many records."""
        from invenio.legacy.bibupload.utils import bibupload_record
        bibupload_record(
            collection=coll,
            file_prefix='community',
            mode='-c',
            opts=[],
            alias="community",
        )
        return True

    def accept_record(self, recid, pretend=False):
        """Accept a record for inclusion in a community.

        @param recid: Record ID
        """
        expected_id = self.get_collection_name(provisional=True)
        new_id = self.get_collection_name(provisional=False)

        append_colls, replace_colls = signalresult2list(
            pre_curation.send(self,
                              action='accept',
                              recid=recid,
                              pretend=pretend))

        def test_func(code, val):
            return code == 'a' and val == expected_id

        def replace_func(code, val):
            return (code, new_id)

        def include_func(code, val):
            return True

        rec = self._upload_record(self._modify_record(
            recid,
            test_func,
            replace_func,
            include_func,
            append_colls=append_colls,
            replace_colls=replace_colls),
                                  pretend=pretend)

        self.last_record_accepted = datetime.now()
        db.session.commit()
        post_curation.send(self,
                           action='accept',
                           recid=recid,
                           record=rec,
                           pretend=pretend)
        return rec

    def reject_record(self, recid, pretend=False):
        """Reject a record for inclusion in a community.

        @param recid: Record ID
        """
        expected_id = self.get_collection_name(provisional=True)
        new_id = self.get_collection_name(provisional=False)

        append_colls, replace_colls = signalresult2list(
            pre_curation.send(self,
                              action='reject',
                              recid=recid,
                              pretend=pretend))

        def test_func(code, val):
            return False

        def replace_func(code, val):
            return (code, val)

        def include_func(code, val):
            return not (code == 'a' and (val == expected_id or val == new_id))

        rec = self._upload_record(self._modify_record(
            recid,
            test_func,
            replace_func,
            include_func,
            append_colls=append_colls,
            replace_colls=replace_colls),
                                  pretend=pretend)

        post_curation.send(self,
                           action='reject',
                           recid=recid,
                           record=rec,
                           pretend=pretend)
        return rec

    #
    # Data persistence methods
    #
    def save_collectionname(self, collection, title):
        """Create or update Collectionname object."""
        if collection.id:
            c_name = Collectionname.query.filter_by(
                id_collection=collection.id, ln=CFG_SITE_LANG,
                type='ln').first()
            if c_name:
                update_changed_fields(c_name, dict(value=title))
                return c_name

        c_name = Collectionname(
            collection=collection,
            ln=CFG_SITE_LANG,
            type='ln',
            value=title,
        )
        db.session.add(c_name)
        return c_name

    def save_collectiondetailedrecordpagetabs(self, collection):
        """Create or update Collectiondetailedrecordpagetabs object."""
        if collection.id:
            c_tabs = Collectiondetailedrecordpagetabs.query.filter_by(
                id_collection=collection.id).first()
            if c_tabs:
                update_changed_fields(c_tabs,
                                      dict(tabs=cfg['COMMUNITIES_TABS']))
                return c_tabs

        c_tabs = Collectiondetailedrecordpagetabs(
            collection=collection,
            tabs=cfg['COMMUNITIES_TABS'],
        )
        db.session.add(c_tabs)
        return c_tabs

    def save_collectioncollection(self, collection, parent_name):
        """Create or update CollectionCollection object."""
        dad = Collection.query.filter_by(name=parent_name).first()

        if collection.id:
            c_tree = CollectionCollection.query.filter_by(
                id_dad=dad.id, id_son=collection.id).first()
            if c_tree:
                update_changed_fields(
                    c_tree,
                    dict(type=cfg['COMMUNITIES_COLLECTION_TYPE'],
                         score=cfg['COMMUNITIES_COLLECTION_SCORE']))
                return c_tree

        c_tree = CollectionCollection(
            dad=dad,
            son=collection,
            type=cfg['COMMUNITIES_COLLECTION_TYPE'],
            score=cfg['COMMUNITIES_COLLECTION_SCORE'],
        )
        db.session.add(c_tree)
        return c_tree

    def save_collectionformat(self, collection, fmt_str):
        """Create or update CollectionFormat object."""
        fmt = Format.query.filter_by(code=fmt_str).first()

        if collection.id:
            c_fmt = CollectionFormat.query.filter_by(
                id_collection=collection.id).first()
            if c_fmt:
                update_changed_fields(c_fmt, dict(id_format=fmt.id, score=1))
                return c_fmt

        c_fmt = CollectionFormat(
            collection=collection,
            id_format=fmt.id,
        )
        db.session.add(c_fmt)
        return c_fmt

    def save_collectionportalboxes(self, collection, templates):
        """Create or update Portalbox and CollectionPortalbox objects."""
        # Setup portal boxes
        bodies = self.render_portalbox_bodies(templates)
        bodies.reverse()  # Highest score is on the top, so we reverse the list

        objects = []
        if collection.id:
            c_pboxes = CollectionPortalbox.query.filter_by(
                id_collection=collection.id,
                ln=CFG_SITE_LANG,
            ).all()
            if len(c_pboxes) == len(bodies):
                for score, elem in enumerate(zip(c_pboxes, bodies)):
                    c_pbox, body = elem
                    pbox = c_pbox.portalbox
                    update_changed_fields(pbox, dict(body=body))
                    update_changed_fields(
                        c_pbox,
                        dict(score=score,
                             position=cfg['COMMUNITIES_PORTALBOX_POSITION']))
                    objects.append(c_pbox)
                return objects
            else:
                # Either templates where modified or collection portalboxes
                # where modified outside of the UserCollection. In either case,
                # remove existing portalboxes and add new ones.
                for c_pbox in c_pboxes:
                    db.session.delete(c_pbox.portalbox)
                    db.session.delete(c_pbox)

        for score, body in enumerate(bodies):
            p = Portalbox(title='', body=body)
            c_pbox = CollectionPortalbox()
            update_changed_fields(
                c_pbox,
                dict(
                    collection=collection,
                    portalbox=p,
                    ln=CFG_SITE_LANG,
                    position=cfg['COMMUNITIES_PORTALBOX_POSITION'],
                    score=score,
                ))
            db.session.add_all([p, c_pbox])
            objects.append(c_pbox)
        return objects

    def save_oairepository_set(self, provisional=False):
        """Create or update OAI Repository set."""
        collection_name = self.get_collection_name(provisional=provisional)
        (f1, p1) = self.get_query(provisional=provisional)
        fields = dict(setName='%s set' % collection_name,
                      setSpec=collection_name,
                      setDescription=self.description,
                      p1=p1,
                      f1=f1,
                      m1='e',
                      p2='',
                      f2='',
                      m2='',
                      p3='',
                      f3='',
                      m3='',
                      setDefinition='')

        if self.oai_set:
            update_changed_fields(self.oai_set, fields)
        else:
            self.oai_set = OaiREPOSITORY(**fields)
            db.session.add(self.oai_set)

    def save_acl(self, collection_id, collection_name):
        """Create or update authorization.

        Needed for user to view provisional collection.
        """
        # Role - use Community id, because role name is limited to 32 chars.
        role_name = 'coll_%s' % collection_id
        role = AccROLE.query.filter_by(name=role_name).first()
        if not role:
            role = AccROLE(
                name=role_name,
                description='Curators of Community {collection}'.format(
                    collection=collection_name))
            db.session.add(role)

        # Argument
        fields = dict(keyword='collection', value=collection_name)
        arg = AccARGUMENT.query.filter_by(**fields).first()
        if not arg:
            arg = AccARGUMENT(**fields)
            db.session.add(arg)

        # Action
        action = AccACTION.query.filter_by(name='viewrestrcoll').first()

        # User role
        alluserroles = UserAccROLE.query.filter_by(role=role).all()
        userrole = None
        if alluserroles:
            # Remove any user which is not the owner
            for ur in alluserroles:
                if ur.id_user == self.id_user:
                    db.session.delete(ur)
                else:
                    userrole = ur

        if not userrole:
            userrole = UserAccROLE(id_user=self.id_user, role=role)
            db.session.add(userrole)

        # Authorization
        auth = AccAuthorization.query.filter_by(role=role,
                                                action=action,
                                                argument=arg).first()
        if not auth:
            auth = AccAuthorization(role=role,
                                    action=action,
                                    argument=arg,
                                    argumentlistid=1)

    def save_collection(self, provisional=False):
        """Create or update a new collection.

        Including name, tabs, collection tree, collection output formats,
        portalboxes and OAI repository set.
        """
        # Setup collection
        collection_name = self.get_collection_name(provisional=provisional)
        c = Collection.query.filter_by(name=collection_name).first()
        fields = dict(
            name=collection_name,
            dbquery=self.get_collection_dbquery(provisional=provisional))

        if c:
            before_save_collection.send(self,
                                        is_new=True,
                                        provisional=provisional)
            update_changed_fields(c, fields)
        else:
            before_save_collection.send(self,
                                        is_new=False,
                                        provisional=provisional)
            c = Collection(**fields)
            db.session.add(c)
            db.session.commit()
        setattr(self,
                'collection_provisional' if provisional else 'collection', c)

        # Setup OAI Repository
        if provisional:
            self.save_acl(c.id, collection_name)
        else:
            self.save_oairepository_set(provisional=provisional)

        # Setup title, tabs and collection tree
        self.save_collectionname(c, self.get_title(provisional=provisional))
        self.save_collectiondetailedrecordpagetabs(c)
        self.save_collectioncollection(
            c, cfg['COMMUNITIES_PARENT_NAME_PROVISIONAL']
            if provisional else cfg['COMMUNITIES_PARENT_NAME'])

        # Setup collection format is needed
        if not provisional and cfg['COMMUNITIES_OUTPUTFORMAT']:
            self.save_collectionformat(c, cfg['COMMUNITIES_OUTPUTFORMAT'])
        elif provisional and cfg['COMMUNITIES_OUTPUTFORMAT_PROVISIONAL']:
            self.save_collectionformat(
                c, cfg['COMMUNITIES_OUTPUTFORMAT_PROVISIONAL'])

        # Setup portal boxes
        self.save_collectionportalboxes(
            c, cfg['COMMUNITIES_PORTALBOXES_PROVISIONAL']
            if provisional else cfg['COMMUNITIES_PORTALBOXES'])
        db.session.commit()
        after_save_collection.send(self, collection=c, provisional=provisional)

    def save_collections(self):
        """Create restricted and unrestricted collections."""
        before_save_collections.send(self)
        self.save_collection(provisional=False)
        self.save_collection(provisional=True)
        after_save_collections.send(self)

    def delete_record_collection_identifiers(self):
        """Remove collection identifiers from all records."""
        from invenio.legacy.search_engine import search_pattern
        provisional_id = self.get_collection_name(provisional=True)
        normal_id = self.get_collection_name(provisional=False)

        def test_func(code, val):
            return False

        def replace_func(code, val):
            return (code, val)

        def include_func(code, val):
            return not (code == 'a' and
                        (val == provisional_id or val == normal_id))

        coll = []
        for r in search_pattern(p="980__a:%s OR 980__a:%s" %
                                (normal_id, provisional_id)):
            coll.append(
                self._modify_record(r, test_func, replace_func, include_func))

        self._upload_collection(coll)

    def delete_collection(self, provisional=False):
        """Delete all objects related to a single collection."""
        # Most of the logic in this method ought to be moved to a
        # Collection.delete() method.
        c = getattr(self,
                    "collection_provisional" if provisional else "collection")
        collection_name = self.get_collection_name(provisional=provisional)

        before_delete_collection.send(self,
                                      collection=c,
                                      provisional=provisional)

        if c:
            # Delete portal boxes
            for c_pbox in c.portalboxes:
                if c_pbox.portalbox:
                    db.session.delete(c_pbox.portalbox)
                db.session.delete(c_pbox)
            db.session.commit()
            # Delete output formats:
            CollectionFormat.query.filter_by(id_collection=c.id).delete()

            # Delete title, tabs, collection tree
            Collectionname.query.filter_by(id_collection=c.id).delete()
            CollectionCollection.query.filter_by(id_son=c.id).delete()
            Collectiondetailedrecordpagetabs.query.filter_by(
                id_collection=c.id).delete()

        if provisional:
            # Delete ACLs
            AccARGUMENT.query.filter_by(keyword='collection',
                                        value=collection_name).delete()
            role = AccROLE.query.filter_by(name='coll_%s' % c.id).first()
            if role:
                UserAccROLE.query.filter_by(role=role).delete()
                AccAuthorization.query.filter_by(role=role).delete()
                db.session.delete(role)
        else:
            # Delete OAI repository
            if self.oai_set:
                db.session.delete(self.oai_set)

        # Delete collection
        if c:
            db.session.delete(c)
        db.session.commit()
        after_delete_collection.send(self, provisional=provisional)

    def delete_collections(self):
        """Delete collection and all associated objects."""
        before_delete_collections.send(self)
        self.delete_record_collection_identifiers()
        self.delete_collection(provisional=False)
        self.delete_collection(provisional=True)
        after_delete_collections.send(self)

    def __str__(self):
        """Return a string representation of an object."""
        return self.id
Ejemplo n.º 14
0
def do_upgrade():
    """Perfrom upgrade."""
    if op.has_table('usergroup'):
        op.rename_table(old_table_name='usergroup', new_table_name='group')
        with op.batch_alter_table("group") as batch_op:
            batch_op.drop_index('ix_usergroup_name')
            batch_op.create_index('ix_group_name', ['name'], unique=True)
            batch_op.alter_column('name', server_default=None)
            batch_op.add_column(
                db.Column('is_managed',
                          db.Boolean(),
                          nullable=False,
                          default=False))
            batch_op.alter_column(column_name='join_policy',
                                  new_column_name='privacy_policy',
                                  type_=db.String(length=1),
                                  nullable=False)
            batch_op.drop_index('login_method_name')
            batch_op.alter_column(column_name='login_method',
                                  new_column_name='subscription_policy',
                                  type_=db.String(length=1),
                                  nullable=False,
                                  server_default=None)
            batch_op.add_column(
                db.Column('created',
                          db.DateTime(),
                          nullable=False,
                          default=datetime.now))
            batch_op.add_column(
                db.Column('modified',
                          db.DateTime(),
                          nullable=False,
                          default=datetime.now,
                          onupdate=datetime.now))
    else:
        op.create_table('group',
                        db.Column('id',
                                  db.Integer(15, unsigned=True),
                                  nullable=False,
                                  autoincrement=True),
                        db.Column('name',
                                  db.String(length=255),
                                  nullable=False,
                                  unique=True,
                                  index=True),
                        db.Column('description',
                                  db.Text,
                                  nullable=True,
                                  default=''),
                        db.Column('is_managed',
                                  db.Boolean(),
                                  default=False,
                                  nullable=False),
                        db.Column('privacy_policy',
                                  db.String(length=1),
                                  nullable=False),
                        db.Column('subscription_policy',
                                  db.String(length=1),
                                  nullable=False),
                        db.Column('created',
                                  db.DateTime,
                                  nullable=False,
                                  default=datetime.now),
                        db.Column('modified',
                                  db.DateTime,
                                  nullable=False,
                                  default=datetime.now,
                                  onupdate=datetime.now),
                        db.PrimaryKeyConstraint('id'),
                        mysql_charset='utf8',
                        mysql_engine='MyISAM')

    if op.has_table('user_usergroup'):
        op.rename_table(old_table_name='user_usergroup',
                        new_table_name='groupMEMBER')
        with op.batch_alter_table("groupMEMBER") as batch_op:
            batch_op.drop_index('id_usergroup')
            batch_op.alter_column('id_user', server_default=None)
            batch_op.alter_column(column_name='id_usergroup',
                                  new_column_name='id_group',
                                  existing_type=db.Integer(15, unsigned=True),
                                  nullable=False)
            batch_op.create_index('id_group', ['id_group'])
            batch_op.alter_column(column_name='user_status',
                                  new_column_name='state',
                                  type_=db.String(length=1),
                                  nullable=False)
            batch_op.drop_column('user_status_date')
            batch_op.add_column(
                db.Column('modified',
                          db.DateTime(),
                          nullable=False,
                          default=datetime.now,
                          onupdate=datetime.now))
            batch_op.add_column(
                db.Column('created',
                          db.DateTime(),
                          nullable=False,
                          default=datetime.now))
    else:
        op.create_table('groupMEMBER',
                        db.Column('id_user',
                                  db.Integer(15, unsigned=True),
                                  nullable=False),
                        db.Column('id_group', db.Integer(15, unsigned=True)),
                        db.Column('state', db.String(length=1),
                                  nullable=False),
                        db.Column('created',
                                  db.DateTime(),
                                  nullable=False,
                                  default=datetime.now),
                        db.Column('modified',
                                  db.DateTime(),
                                  nullable=False,
                                  default=datetime.now,
                                  onupdate=datetime.now),
                        db.ForeignKeyConstraint(
                            ['id_group'],
                            [u'group.id'],
                        ),
                        db.ForeignKeyConstraint(
                            ['id_user'],
                            [u'user.id'],
                        ),
                        db.PrimaryKeyConstraint('id_user', 'id_group'),
                        mysql_charset='utf8',
                        mysql_engine='MyISAM')

    op.create_table('groupADMIN',
                    db.Column('id',
                              db.Integer(15, unsigned=True),
                              nullable=False,
                              autoincrement=True),
                    db.Column('group_id', db.Integer(15, unsigned=True)),
                    db.Column('admin_type', db.Unicode(255)),
                    db.Column('admin_id', db.Integer),
                    db.ForeignKeyConstraint(
                        ['group_id'],
                        [u'group.id'],
                    ),
                    db.PrimaryKeyConstraint('id', 'group_id'),
                    mysql_charset='utf8',
                    mysql_engine='MyISAM')
Ejemplo n.º 15
0
class Project(db.Model):
    """
    Represents a project
    """
    __tablename__ = 'project'

    id = db.Column(db.Integer(15, unsigned=True), primary_key=True)
    """Project id"""

    title = db.Column(db.String(length=255), nullable=False, default='')
    """ Project title """

    description = db.Column(db.Text(), nullable=False, default='')
    """ Project short description """

    creation_date = db.Column(db.DateTime(),
                              nullable=False,
                              default=datetime.now)
    """ creation date of the project"""

    modification_date = db.Column(db.DateTime(),
                                  nullable=False,
                                  default=datetime.now)
    """ date of last modification"""

    # collection
    id_collection = db.Column(db.Integer(15, unsigned=True),
                              db.ForeignKey(Collection.id),
                              nullable=True,
                              default=None)
    collection = db.relationship(Collection,
                                 uselist=False,
                                 backref='project',
                                 foreign_keys=[id_collection])

    # owner
    id_user = db.Column(db.Integer(15, unsigned=True),
                        db.ForeignKey(User.id),
                        nullable=False)
    owner = db.relationship(User, backref='projects', foreign_keys=[id_user])

    eresable = db.Column(db.Boolean, nullable=False, default=True)

    # group
    id_group = db.Column(db.Integer(15, unsigned=True),
                         db.ForeignKey(Group.id),
                         nullable=True,
                         default=None)
    group = db.relationship(Group, backref='projects', foreign_keys=[id_group])

    #
    # Collection management
    #
    def get_collection_name(self):
        return '%s-%s' % (cfg['PROJECTS_COLLECTION_PREFIX'], self.id)

    def get_collection_dbquery(self):
        return '%s:%s' % ("980__a", self.get_collection_name())

    def get_project_records(self, record_types=[], public=None, curated=None):
        """ Return all records of this project"""
        from invenio.legacy.search_engine import search_pattern_parenthesised
        from invenio.modules.records.models import Record
        q = ['980__:%s' % self.get_collection_name()]
        if record_types:
            qtypes = ['980__:%s' % t for t in record_types]
            if len(qtypes) > 1:
                q.append('(%s)' % ' OR '.join(qtypes))
            else:
                q.extend(qtypes)
        if public is not None:
            q.append('983__b:%s' % public)
        if curated is not None:
            q.append('983__a:%s' % curated)
        p = (' AND '.join(q))
        recids = search_pattern_parenthesised(p=p)
        records = Record.query.filter(Record.id.in_(recids))
        return records

    def save_collectionname(self, collection, title):
        if collection.id:
            c_name = Collectionname.query.filter_by(
                id_collection=collection.id, ln=CFG_SITE_LANG,
                type='ln').first()
            if c_name:
                update_changed_fields(c_name, dict(value=title))
                return c_name

        c_name = Collectionname(
            collection=collection,
            ln=CFG_SITE_LANG,
            type='ln',
            value=title,
        )
        db.session.add(c_name)
        return c_name

    def save_collectioncollection(self, collection):
        """Create or update CollectionCollection object."""
        dad = Collection.query.filter_by(
            name=cfg['PROJECTS_PARENT_NAME']).first()

        if collection.id:
            c_tree = CollectionCollection.query.filter_by(
                id_dad=dad.id, id_son=collection.id).first()
            if c_tree:
                update_changed_fields(
                    c_tree,
                    dict(type=cfg['PROJECTS_COLLECTION_TYPE'],
                         score=cfg['PROJECTS_COLLECTION_SCORE']))
                return c_tree

        c_tree = CollectionCollection(
            dad=dad,
            son=collection,
            type=cfg['PROJECTS_COLLECTION_TYPE'],
            score=cfg['PROJECTS_COLLECTION_SCORE'],
        )
        db.session.add(c_tree)
        return c_tree

    def save_collectionformat(self, collection):
        """Create or update CollectionFormat object."""
        fmt = Format.query.filter_by(code=cfg['PROJECTS_OUTPUTFORMAT']).first()

        if collection.id:
            c_fmt = CollectionFormat.query.filter_by(
                id_collection=collection.id).first()
            if c_fmt:
                update_changed_fields(c_fmt, dict(id_format=fmt.id, score=1))
                return c_fmt

        c_fmt = CollectionFormat(
            collection=collection,
            id_format=fmt.id,
        )
        db.session.add(c_fmt)
        return c_fmt

    def save_acl(self, c):
        # Role - use Community id, because role name is limited to 32 chars.
        role_name = 'project_role_%s' % self.id
        role = AccROLE.query.filter_by(name=role_name).first()
        if not role:
            rule = 'allow group "%s"\ndeny any' % self.get_group_name()
            role = AccROLE(name=role_name,
                           description='Owner of project %s' % self.title,
                           firerole_def_ser=serialize(
                               compile_role_definition(rule)),
                           firerole_def_src=rule)
            db.session.add(role)

        # Argument
        fields = dict(keyword='collection', value=c.name)
        arg = AccARGUMENT.query.filter_by(**fields).first()
        if not arg:
            arg = AccARGUMENT(**fields)
            db.session.add(arg)

        # Action
        action = AccACTION.query.filter_by(name='viewrestrcoll').first()

        # User role
        alluserroles = UserAccROLE.query.filter_by(role=role).all()
        userrole = None
        if alluserroles:
            # Remove any user which is not the owner
            for ur in alluserroles:
                if ur.id_user == self.id_user:
                    db.session.delete(ur)
                else:
                    userrole = ur

        if not userrole:
            userrole = UserAccROLE(id_user=self.id_user, role=role)
            db.session.add(userrole)

        # Authorization
        auth = AccAuthorization.query.filter_by(role=role,
                                                action=action,
                                                argument=arg).first()
        if not auth:
            auth = AccAuthorization(role=role,
                                    action=action,
                                    argument=arg,
                                    argumentlistid=1)

    def save_collection(self):
        collection_name = self.get_collection_name()
        c = Collection.query.filter_by(name=collection_name).first()
        fields = dict(name=collection_name,
                      dbquery=self.get_collection_dbquery())
        if c:
            update_changed_fields(c, fields)
        else:
            c = Collection(**fields)
            db.session.add(c)
            db.session.commit()
        self.collection = c
        self.save_collectionname(c, self.title)
        self.save_collectioncollection(c)
        self.save_collectionformat(c)
        self.save_acl(c)
        db.session.commit()

    def delete_collection(self):
        if self.collection:
            CollectionFormat.query.filter_by(
                id_collection=self.collection.id).delete()
            Collectionname.query.filter_by(
                id_collection=self.collection.id).delete()
            CollectionCollection.query.filter_by(
                id_son=self.collection.id).delete()
            db.session.delete(self.collection)
            db.session.commit()

    def get_group_name(self):
        return 'project-group-%d' % self.id

    def save_group(self):
        g = self.group
        if not g:
            g = Group.create(self.get_group_name(),
                             description='Group for project %s' % self.id,
                             privacy_policy=PrivacyPolicy.MEMBERS,
                             subscription_policy=SubscriptionPolicy.APPROVAL,
                             is_managed=False,
                             admins=[self.owner])
            g.add_member(self.owner)
            self.group = g
            db.session.commit()

    def is_user_allowed(self, user=None):
        if not user:
            from flask_login import current_user
            user = current_user
        uid = user.get_id()
        groups = user.get('group', [])
        return self.id_user == uid or self.group.name in groups

    def is_empty(self):
        if self.eresable:
            # Ensure project has not records.
            from invenio.legacy.search_engine import search_pattern
            q = '980__:%s' % self.get_collection_name()
            recids = search_pattern(p=q)
            if len(recids) != 0:
                self.eresable = False
                db.session.commit()
                return False
            else:
                return True
        return False

    @classmethod
    def get_project(cls, id):
        try:
            return cls.query.get(int(id))
        except ValueError:
            return None

    @classmethod
    def filter_projects(cls, p, so):
        """Search for projects.

        Helper function which takes from database only those projects which
        match search criteria. Uses parameter 'so' to set projects in the
        correct order.

        Parameter 'page' is introduced to restrict results and return only
        slice of them for the current page. If page == 0 function will return
        all projects that match the pattern.
        """
        query = cls.query
        if p:
            query = query.filter(
                db.or_(
                    cls.id.like("%" + p + "%"),
                    cls.title.like("%" + p + "%"),
                    cls.description.like("%" + p + "%"),
                ))
        if so in cfg['PROJECTS_SORTING_OPTIONS']:
            order = so == 'title' and db.asc or db.desc
            query = query.order_by(order(getattr(cls, so)))
        return query

    @classmethod
    def get_user_projects(cls, user):
        gids = [g.id for g in Group.query_by_uid(user.get_id())]
        return Project.query.filter(Project.id_group.in_(gids))
Ejemplo n.º 16
0
def do_upgrade():
    op.create_table('accreqREQUEST',
                    db.Column('id',
                              db.Integer(display_width=15),
                              nullable=False),
                    db.Column('status', db.String(length=1), nullable=False),
                    db.Column('receiver_user_id',
                              db.Integer(display_width=15),
                              nullable=False),
                    db.Column('sender_user_id',
                              db.Integer(display_width=15),
                              nullable=True),
                    db.Column('sender_full_name',
                              db.String(length=255),
                              nullable=False),
                    db.Column('sender_email',
                              db.String(length=255),
                              nullable=False),
                    db.Column('recid',
                              db.Integer(display_width=15),
                              nullable=False),
                    db.Column('created', db.DateTime(), nullable=False),
                    db.Column('modified', db.DateTime(), nullable=False),
                    db.Column('justification', db.Text(), nullable=False),
                    db.Column('message', db.Text(), nullable=False),
                    db.Column('link_id',
                              db.Integer(display_width=15),
                              nullable=True),
                    db.ForeignKeyConstraint(
                        ['receiver_user_id'],
                        [u'user.id'],
                    ),
                    db.ForeignKeyConstraint(
                        ['sender_user_id'],
                        [u'user.id'],
                    ),
                    db.ForeignKeyConstraint(
                        ['link_id'],
                        [u'accreqLINK.id'],
                    ),
                    db.PrimaryKeyConstraint('id'),
                    mysql_charset='utf8',
                    mysql_engine='MyISAM')
    op.create_index(op.f('ix_accreqREQUEST_created'),
                    'accreqREQUEST', ['created'],
                    unique=False)
    op.create_index(op.f('ix_accreqREQUEST_recid'),
                    'accreqREQUEST', ['recid'],
                    unique=False)
    op.create_index(op.f('ix_accreqREQUEST_status'),
                    'accreqREQUEST', ['status'],
                    unique=False)

    op.create_table('accreqLINK',
                    db.Column('id',
                              db.Integer(display_width=15),
                              nullable=False),
                    db.Column('token', db.Text(), nullable=False),
                    db.Column('owner_user_id',
                              db.Integer(display_width=15),
                              nullable=False),
                    db.Column('created', db.DateTime(), nullable=False),
                    db.Column('expires_at', db.DateTime(), nullable=True),
                    db.Column('revoked_at', db.DateTime(), nullable=True),
                    db.Column('title', db.String(length=255), nullable=False),
                    db.Column('description', db.Text(), nullable=False),
                    db.ForeignKeyConstraint(
                        ['owner_user_id'],
                        [u'user.id'],
                    ),
                    db.PrimaryKeyConstraint('id'),
                    mysql_charset='utf8',
                    mysql_engine='MyISAM')
    op.create_index(op.f('ix_accreqLINK_created'),
                    'accreqLINK', ['created'],
                    unique=False)
    op.create_index(op.f('ix_accreqLINK_revoked_at'),
                    'accreqLINK', ['revoked_at'],
                    unique=False)