Exemplo n.º 1
0
class LegacyRecordsMirror(db.Model):
    __tablename__ = 'legacy_records_mirror'

    __table_args__ = (
        db.Index('ix_legacy_records_mirror_valid_collection', 'valid', 'collection'),
    )

    recid = db.Column(db.Integer, primary_key=True)
    last_updated = db.Column(db.DateTime, default=datetime.utcnow, nullable=False, index=True)
    _marcxml = db.Column('marcxml', db.LargeBinary, nullable=False)
    valid = db.Column(db.Boolean, default=None, nullable=True)
    _errors = db.Column('errors', db.Text(), nullable=True)
    collection = db.Column(db.Text(), default='')

    re_recid = re.compile('<controlfield.*?tag=.001.*?>(?P<recid>\d+)</controlfield>')

    @hybrid_property
    def marcxml(self):
        """marcxml column wrapper to compress/decompress on the fly."""
        try:
            return decompress(self._marcxml)
        except error:
            # Legacy uncompress data?
            return self._marcxml

    @marcxml.setter
    def marcxml(self, value):
        self._marcxml = compress(value)

    @hybrid_property
    def error(self):
        return self._errors

    @error.setter
    def error(self, value):
        """Errors column setter that stores an Exception and sets the ``valid`` flag."""
        self.valid = False
        self.collection = get_collection_from_marcxml(self.marcxml)
        self._errors = u'{}: {}'.format(type(value).__name__, value)

    @classmethod
    def from_marcxml(cls, raw_record):
        """Create an instance from a MARCXML record.

        The record must have a ``001`` tag containing the recid, otherwise it raises a ValueError.
        """
        try:
            recid = int(cls.re_recid.search(raw_record).group('recid'))
        except AttributeError:
            raise ValueError('The MARCXML record contains no recid or recid is malformed')
        # FIXME also get last_updated from marcxml
        record = cls(recid=recid)
        record.marcxml = raw_record
        record.valid = None
        return record
class SIPFile(db.Model, Timestamp):
    """Extra SIP info regarding files."""

    __tablename__ = 'sipstore_sipfile'

    sip_id = db.Column(UUIDType, db.ForeignKey(SIP.id))
    """Id of SIP."""

    filepath = db.Column(db.Text().with_variant(mysql.VARCHAR(255), 'mysql'),
                         nullable=False)
    """Filepath of submitted file within the SIP record."""

    file_id = db.Column(UUIDType,
                        db.ForeignKey(FileInstance.id, ondelete='RESTRICT'),
                        primary_key=True,
                        nullable=False)
    """Id of the FileInstance."""
    @validates('filepath')
    def validate_key(self, filepath, filepath_):
        """Validate key."""
        if len(filepath_) > current_app.config['SIPSTORE_FILEPATH_MAX_LEN']:
            raise ValueError('Filepath too long ({0}).'.format(len(filepath_)))
        return filepath_

    #
    # Relations
    #
    sip = db.relationship(SIP, backref='sip_files', foreign_keys=[sip_id])
    """Relation to the SIP along which given file was submitted."""

    file = db.relationship(FileInstance,
                           backref='sip_files',
                           foreign_keys=[file_id])
    """Relation to the SIP along which given file was submitted."""
Exemplo n.º 3
0
def do_upgrade():
    """Implement your upgrades here."""
    if not op.has_table('remoteACCOUNT'):
        op.create_table('remoteACCOUNT',
                        db.Column('id',
                                  db.Integer(display_width=15),
                                  nullable=False),
                        db.Column('user_id',
                                  db.Integer(display_width=15),
                                  nullable=False),
                        db.Column('client_id',
                                  db.String(length=255),
                                  nullable=False),
                        db.Column('extra_data', db.JSON, nullable=True),
                        db.ForeignKeyConstraint(
                            ['user_id'],
                            ['user.id'],
                        ),
                        db.PrimaryKeyConstraint('id'),
                        db.UniqueConstraint('user_id', 'client_id'),
                        mysql_charset='utf8',
                        mysql_engine='MyISAM')
    else:
        warnings.warn("*** Creation of table 'remoteACCOUNT table skipped!'")

    if not op.has_table('remoteTOKEN'):
        op.create_table('remoteTOKEN',
                        db.Column('id_remote_account',
                                  db.Integer(display_width=15),
                                  nullable=False),
                        db.Column('token_type',
                                  db.String(length=40),
                                  nullable=False),
                        db.Column('access_token', db.Text(), nullable=False),
                        db.Column('secret', db.Text(), nullable=False),
                        db.ForeignKeyConstraint(
                            ['id_remote_account'],
                            ['remoteACCOUNT.id'],
                        ),
                        db.PrimaryKeyConstraint('id_remote_account',
                                                'token_type'),
                        mysql_charset='utf8',
                        mysql_engine='MyISAM')
    else:
        warnings.warn("*** Creation of table 'remoteTOKEN' skipped!'")
Exemplo n.º 4
0
class Page(db.Model, Timestamp):
    """Represents a page."""

    __versioned__ = {}

    __tablename__ = 'pages_page'

    id = db.Column(db.Integer,
                   nullable=False,
                   primary_key=True,
                   autoincrement=True)
    """Page identifier."""

    url = db.Column(db.String(100), unique=True, nullable=False)
    """Page url."""

    title = db.Column(db.String(200), nullable=False, default='')
    """Page title."""

    content = db.Column(db.Text(), nullable=False, default='')
    """Page content. Default is pages/templates/default.html"""

    description = db.Column(db.String(200), nullable=False, default='')
    """Page description."""

    template_name = db.Column(db.String(70), nullable=False)
    """Page template name."""
    @classmethod
    def get_by_url(self, url):
        """Get a page by URL.

        :param url: The page URL.
        :returns: A :class:`invenio_pages.models.Page` instance.
        """
        return Page.query.filter_by(url=url).one()

    @validates('template_name')
    def validate_template_name(self, key, value):
        """Validate template name.

        :param key: The template path.
        :param value: The template name.
        :raises ValueError: If template name is wrong.
        """
        if value not in dict(current_app.config['PAGES_TEMPLATES']):
            raise ValueError('Template "{0}" does not exist.'.format(value))
        return value

    def __repr__(self):
        """Page representation.

        Used on Page admin view in inline model.
        :returns: unambiguous page representation.
        """
        return "URL: %s, title: %s" % (self.url, self.title)
Exemplo n.º 5
0
class FileRecordModelMixin:
    """Base class for a record file, storing its state and metadata."""

    __record_model_cls__ = None
    """Record model to be for the ``record_id`` foreign key."""

    key = db.Column(
        db.Text().with_variant(mysql.VARCHAR(255), 'mysql'),
        nullable=False,
    )
    """Filename key (can be path-like also)."""
    @declared_attr
    def record_id(cls):
        """Record ID foreign key."""
        return db.Column(
            UUIDType,
            db.ForeignKey(cls.__record_model_cls__.id, ondelete='RESTRICT'),
            nullable=False,
        )

    @declared_attr
    def record(cls):
        """Record the file belnogs to."""
        return db.relationship(cls.__record_model_cls__)

    @declared_attr
    def object_version_id(cls):
        """Object version ID foreign key."""
        return db.Column(
            UUIDType,
            db.ForeignKey(ObjectVersion.version_id, ondelete='RESTRICT'),
            nullable=True,
        )

    @declared_attr
    def object_version(cls):
        """Object version connected to the record file."""
        return db.relationship(ObjectVersion)

    @declared_attr
    def __table_args__(cls):
        """Table args."""
        return (db.Index(
            f'uidx_{cls.__tablename__}_id_key',
            'id',
            'key',
            unique=True,
        ), )
Exemplo n.º 6
0
class OAISync(db.Model):
    __tablename__ = "oarepo_oai_sync"
    id = db.Column(db.Integer, primary_key=True)
    provider_id = db.Column(db.Integer, ForeignKey('oarepo_oai_provider.id'))
    sync_start = db.Column(db.TIMESTAMP)
    sync_end = db.Column(db.TIMESTAMP)
    status = db.Column(db.String(32))
    logs = db.Column(db.Text())

    # number of created, modified and deleted records for statistics
    rec_created = db.Column(db.Integer)
    rec_modified = db.Column(db.Integer)
    rec_deleted = db.Column(db.Integer)
    provider = relationship("OAIProvider", backref=backref("synchronizations"))
    traceback = relationship("OAIRecordExc",
                             backref=backref("synchronizations"))
Exemplo n.º 7
0
class OAISync(db.Model):
    __tablename__ = "oarepo_oai_sync"
    id = db.Column(db.Integer, primary_key=True)
    provider_code = db.Column(db.String, nullable=False)
    synchronizer_code = db.Column(db.String)
    purpose = db.Column(db.String)
    sync_start = db.Column(db.TIMESTAMP)
    sync_end = db.Column(db.TIMESTAMP)
    status = db.Column(db.String(32))
    logs = db.Column(db.Text())

    # number of created, modified and deleted records for statistics
    records_created = db.Column(db.Integer)
    records_modified = db.Column(db.Integer)
    records_deleted = db.Column(db.Integer)
    tracebacks = relationship("OAIRecordExc",
                              backref=backref("synchronization"))
Exemplo n.º 8
0
class InspireProdRecords(db.Model):
    __tablename__ = 'inspire_prod_records'

    recid = db.Column(db.Integer, primary_key=True, index=True)
    last_updated = db.Column(db.DateTime, default=datetime.now, nullable=False, index=True)
    _marcxml = db.Column('marcxml', db.LargeBinary, nullable=False)
    valid = db.Column(db.Boolean, default=None, nullable=True, index=True)
    errors = db.Column(db.Text(), nullable=True)

    @hybrid_property
    def marcxml(self):
        """marcxml column wrapper to compress/decompress on the fly."""
        try:
            return decompress(self._marcxml)
        except error:
            # Legacy uncompress data?
            return self._marcxml

    @marcxml.setter
    def marcxml(self, value):
        self._marcxml = compress(value)
Exemplo n.º 9
0
class MementoArchives(db.Model):
    """Relationship between Memento and Buckets."""

    __tablename__ = 'memento_archives'

    archived = db.Column(
        db.DateTime,
        primary_key=True,
    )
    """The archivation date and time."""

    key = db.Column(
        db.Text().with_variant(mysql.VARCHAR(255), 'mysql'),
        primary_key=True,
    )
    """Key identifying the archived object."""

    bucket_id = db.Column(
        UUIDType,
        db.ForeignKey(Bucket.id),
        nullable=False,
    )
    """The bucket with archived files related to the ``key``.

    .. note:: There must be a ``ObjectVersion`` with same key.
    """

    bucket = db.relationship(Bucket)

    def __repr__(self):
        """Return representation of Memento."""
        return '{0.archived}/{0.key}:{0.bucket_id}'.format(self)

    @validates('archived')
    def validate_archived(self, key, value):
        """Remove microseconds from the value."""
        return value.replace(microsecond=0) if value else value
Exemplo n.º 10
0
class OAIRecordExc(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    oai_identifier = db.Column(db.String, nullable=False)
    traceback = db.Column(db.Text(), nullable=True)
    oai_sync_id = db.Column(db.Integer, ForeignKey('oarepo_oai_sync.id'))
Exemplo n.º 11
0
class ObjectVersion(db.Model, Timestamp):
    """Model for storing versions of objects.

    A bucket stores one or more objects identified by a key. Each object is
    versioned where each version is represented by an ``ObjectVersion``.

    An object version can either be 1) a *normal version* which is linked to
    a file instance, or 2) a *delete marker*, which is *not* linked to a file
    instance.

    An normal object version is linked to a physical file on disk via a file
    instance. This allows for multiple object versions to point to the same
    file on disk, to optimize storage efficiency (e.g. useful for snapshotting
    an entire bucket without duplicating the files).

    A delete marker object version represents that the object at hand was
    deleted.

    The latest version of an object is marked using the ``is_head`` property.
    If the latest object version is a delete marker the object will not be
    shown in the bucket.
    """

    __tablename__ = 'files_object'

    bucket_id = db.Column(
        UUIDType,
        db.ForeignKey(Bucket.id, ondelete='RESTRICT'),
        default=uuid.uuid4,
        primary_key=True,
    )
    """Bucket identifier."""

    key = db.Column(
        db.Text().with_variant(mysql.VARCHAR(255), 'mysql'),
        primary_key=True,
    )
    """Key identifying the object."""

    version_id = db.Column(
        UUIDType,
        primary_key=True,
        default=uuid.uuid4,
    )
    """Identifier for the specific version of an object."""

    file_id = db.Column(UUIDType,
                        db.ForeignKey(FileInstance.id, ondelete='RESTRICT'),
                        nullable=True)
    """File instance for this object version.

    A null value in this column defines that the object has been deleted.
    """

    _mimetype = db.Column(
        db.String(255),
        index=True,
        nullable=True,
    )
    """MIME type of the object."""

    is_head = db.Column(db.Boolean, nullable=False, default=True)
    """Defines if object is the latest version."""

    # Relationships definitions
    bucket = db.relationship(Bucket, backref='objects')
    """Relationship to buckets."""

    file = db.relationship(FileInstance, backref='objects')
    """Relationship to file instance."""
    @validates('key')
    def validate_key(self, key, key_):
        """Validate key."""
        return validate_key(key_)

    def __repr__(self):
        """Return representation of location."""
        return '{0}:{2}:{1}'.format(self.bucket_id, self.key, self.version_id)

    @hybrid_property
    def mimetype(self):
        """Get MIME type of object."""
        if self._mimetype:
            m = self._mimetype
        elif self.key:
            m = mimetypes.guess_type(self.key)[0]
        return m or 'application/octet-stream'

    @mimetype.setter
    def mimetype(self, value):
        """Setter for MIME type."""
        self._mimetype = value

    @property
    def basename(self):
        """Determine if object version is a delete marker."""
        return basename(self.key)

    @property
    def deleted(self):
        """Determine if object version is a delete marker."""
        return self.file_id is None

    @ensure_no_file()
    @update_bucket_size
    def set_contents(self,
                     stream,
                     chunk_size=None,
                     size=None,
                     size_limit=None,
                     progress_callback=None):
        """Save contents of stream to file instance.

        If a file instance has already been set, this methods raises an
        ``FileInstanceAlreadySetError`` exception.

        :param stream: File-like stream.
        :param size: Size of stream if known.
        :param chunk_size: Desired chunk size to read stream in. It is up to
            the storage interface if it respects this value.
        """
        if size_limit is None:
            size_limit = self.bucket.size_limit

        self.file = FileInstance.create()
        self.file.set_contents(
            stream,
            size_limit=size_limit,
            size=size,
            chunk_size=chunk_size,
            progress_callback=progress_callback,
            default_location=self.bucket.location.uri,
            default_storage_class=self.bucket.default_storage_class,
        )

        return self

    @ensure_no_file()
    @update_bucket_size
    def set_location(self, uri, size, checksum, storage_class=None):
        """Set only URI location of for object.

        Useful to link files on externally controlled storage. If a file
        instance has already been set, this methods raises an
        ``FileInstanceAlreadySetError`` exception.

        :param uri: Full URI to object (which can be interpreted by the storage
            interface).
        :param size: Size of file.
        :param checksum: Checksum of file.
        :param storage_class: Storage class where file is stored ()
        """
        self.file = FileInstance()
        self.file.set_uri(uri, size, checksum, storage_class=storage_class)
        db.session.add(self.file)
        return self

    @ensure_no_file()
    @update_bucket_size
    def set_file(self, fileinstance):
        """Set a file instance."""
        self.file = fileinstance
        return self

    def send_file(self, restricted=True, **kwargs):
        """Wrapper around FileInstance's send file."""
        return self.file.send_file(self.basename,
                                   restricted=restricted,
                                   mimetype=self.mimetype,
                                   **kwargs)

    @ensure_is_previous_version()
    def restore(self):
        """Restore this object version to become the latest version.

        Raises an exception if the object is the latest version.
        """
        # Note, copy calls create which will fail if bucket is locked.
        return self.copy()

    @ensure_not_deleted(msg='Cannot copy a delete marker.')
    def copy(self, bucket=None, key=None):
        """Copy an object version to a given bucket + object key.

        The copy operation is handled completely at the metadata level. The
        actual data on disk is not copied. Instead, the two object versions
        will point to the same physical file (via the same FileInstance).

        .. warning::

           If the destination object exists, it will be replaced by  the new
           object version which will become the latest version.

        :param bucket: The bucket (instance or id) to copy the object to.
            Default: current bucket.
        :param key: Key name of destination object.
            Default: current object key.
        :returns: The copied object version.
        """
        return ObjectVersion.create(
            self.bucket if bucket is None else as_bucket(bucket),
            key or self.key,
            _file_id=self.file_id)

    @ensure_unlocked(getter=lambda o: not o.bucket.locked)
    def remove(self):
        """Permanently remove a specific object version from the database.

        .. note::

           This overrides the normal versioning and should only be used when
           you want to permanently delete an object.

        :returns: ``self``.
        """
        with db.session.begin_nested():
            db.session.delete(self)
        return self

    @classmethod
    def create(cls,
               bucket,
               key,
               _file_id=None,
               stream=None,
               mimetype=None,
               **kwargs):
        """Create a new object in a bucket.

        The created object is by default created as a delete marker. You must
        use ``set_contents()`` or ``set_location()`` in order to change this.

        :param bucket: The bucket (instance or id) to create the object in.
        :param key: Key of object.
        :param _file_id: For internal use.
        :param stream: File-like stream object. Used to set content of object
            immediately after being created.
        :param mimetype: MIME type of the file object if it is known.
        :param kwargs: Keyword arguments passed to ``Object.set_contents()``.
        """
        bucket = as_bucket(bucket)

        if bucket.locked:
            raise BucketLockedError()

        with db.session.begin_nested():
            latest_obj = cls.get(bucket.id, key)
            if latest_obj is not None:
                latest_obj.is_head = False
                db.session.add(latest_obj)

            # By default objects are created in a deleted state (i.e.
            # file_id is null).
            obj = cls(
                bucket=bucket,
                key=key,
                version_id=uuid.uuid4(),
                is_head=True,
                mimetype=mimetype,
            )
            if _file_id:
                file_ = _file_id if isinstance(_file_id, FileInstance) else \
                    FileInstance.get(_file_id)
                obj.set_file(file_)
            db.session.add(obj)
        if stream:
            obj.set_contents(stream, **kwargs)
        return obj

    @classmethod
    def get(cls, bucket, key, version_id=None):
        """Fetch a specific object.

        By default the latest object version is returned, if
        ``version_id`` is not set.

        :param bucket: The bucket (instance or id) to get the object from.
        :param key: Key of object.
        :param version_id: Specific version of an object.
        """
        filters = [
            cls.bucket_id == as_bucket_id(bucket),
            cls.key == key,
        ]

        if version_id:
            filters.append(cls.version_id == version_id)
        else:
            filters.append(cls.is_head.is_(True))
            filters.append(cls.file_id.isnot(None))

        return cls.query.filter(*filters).one_or_none()

    @classmethod
    def get_versions(cls, bucket, key):
        """Fetch all versions of a specific object.

        :param bucket: The bucket (instance or id) to get the object from.
        :param key: Key of object.
        """
        filters = [
            cls.bucket_id == as_bucket_id(bucket),
            cls.key == key,
        ]

        return cls.query.filter(*filters).order_by(cls.key, cls.created.desc())

    @classmethod
    def delete(cls, bucket, key):
        """Delete an object.

        Technically works by creating a new version which works as a delete
        marker.

        :param bucket: The bucket (instance or id) to delete the object from.
        :param key: Key of object.
        :param version_id: Specific version to delete.
        :returns: Created delete marker object if key exists else ``None``.
        """
        bucket_id = as_bucket_id(bucket)

        obj = cls.get(bucket_id, key)
        if obj:
            return cls.create(as_bucket(bucket), key)
        return None

    @classmethod
    def get_by_bucket(cls, bucket, versions=False):
        """Return query that fetches all the objects in a bucket."""
        bucket_id = bucket.id if isinstance(bucket, Bucket) else bucket

        filters = [
            cls.bucket_id == bucket_id,
        ]

        if not versions:
            filters.append(cls.file_id.isnot(None))
            filters.append(cls.is_head.is_(True))

        return cls.query.filter(*filters).order_by(cls.key, cls.created.desc())

    @classmethod
    def relink_all(cls, old_file, new_file):
        """Relink all object versions (for a given file) to a new file.

        .. warning::

           Use this method with great care.
        """
        assert old_file.checksum == new_file.checksum
        assert old_file.id
        assert new_file.id

        with db.session.begin_nested():
            ObjectVersion.query.filter_by(file_id=str(old_file.id)).update(
                {ObjectVersion.file_id: str(new_file.id)})
Exemplo n.º 12
0
class FileInstance(db.Model, Timestamp):
    """Model for storing files.

    A file instance represents a file on disk. A file instance may be linked
    from many objects, while an object can have one and only one file instance.

    A file instance also records the storage class, size and checksum of the
    file on disk.

    Additionally, a file instance can be read only in case the storage layer
    is not capable of writing to the file (e.g. can typically be used to
    link to files on externally controlled storage).
    """

    __tablename__ = 'files_files'

    id = db.Column(
        UUIDType,
        primary_key=True,
        default=uuid.uuid4,
    )
    """Identifier of file."""

    uri = db.Column(db.Text().with_variant(mysql.VARCHAR(255), 'mysql'),
                    unique=True,
                    nullable=True)
    """Location of file."""

    storage_class = db.Column(db.String(1), nullable=True)
    """Storage class of file."""

    size = db.Column(db.BigInteger, default=0, nullable=True)
    """Size of file."""

    checksum = db.Column(db.String(255), nullable=True)
    """String representing the checksum of the object."""

    readable = db.Column(db.Boolean, default=True, nullable=False)
    """Defines if the file is read only."""

    writable = db.Column(db.Boolean, default=True, nullable=False)
    """Defines if file is writable.

    This property is used to create a file instance prior to having the actual
    file at the given URI. This is useful when e.g. copying a file instance.
    """

    last_check_at = db.Column(db.DateTime, nullable=True)
    """Timestamp of last fixity check."""

    last_check = db.Column(db.Boolean, default=True, nullable=False)
    """Result of last fixity check."""
    @validates('uri')
    def validate_uri(self, key, uri):
        """Validate uri."""
        if len(uri) > current_app.config['FILES_REST_FILE_URI_MAX_LEN']:
            raise ValueError('FileInstance URI too long ({0}).'.format(
                len(uri)))
        return uri

    @classmethod
    def get(cls, file_id):
        """Get a file instance."""
        return cls.query.filter_by(id=file_id).one_or_none()

    @classmethod
    def get_by_uri(cls, uri):
        """Get a file instance by URI."""
        assert uri is not None
        return cls.query.filter_by(uri=uri).one_or_none()

    @classmethod
    def create(cls):
        """Create a file instance.

        Note, object is only added to the database session.
        """
        obj = cls(
            id=uuid.uuid4(),
            writable=True,
            readable=False,
            size=0,
        )
        db.session.add(obj)
        return obj

    def delete(self):
        """Delete a file instance.

        The file instance can be deleted if it has no references from other
        objects. The caller is responsible to test if the file instance is
        writable and that the disk file can actually be removed.

        .. note::

           Normally you should use the Celery task to delete a file instance,
           as this method will not remove the file on disk.
        """
        self.query.filter_by(id=self.id).delete()
        return self

    def storage(self, **kwargs):
        """Get storage interface for object.

        Uses the applications storage factory to create a storage interface
        that can be used for this particular file instance.

        :returns: Storage interface.
        """
        return current_files_rest.storage_factory(fileinstance=self, **kwargs)

    @ensure_readable()
    def update_checksum(self, progress_callback=None, **kwargs):
        """Update checksum based on file."""
        self.checksum = self.storage(**kwargs).checksum(
            progress_callback=progress_callback)

    def verify_checksum(self, progress_callback=None, **kwargs):
        """Verify checksum of file instance."""
        real_checksum = self.storage(**kwargs).checksum(
            progress_callback=progress_callback)
        with db.session.begin_nested():
            self.last_check = (self.checksum == real_checksum)
            self.last_check_at = datetime.utcnow()
        return self.last_check

    @ensure_writable()
    def init_contents(self, size=0, **kwargs):
        """Initialize file."""
        self.set_uri(*self.storage(**kwargs).initialize(size=size),
                     readable=False,
                     writable=True)

    @ensure_writable()
    def update_contents(self,
                        stream,
                        seek=0,
                        size=None,
                        chunk_size=None,
                        progress_callback=None,
                        **kwargs):
        """Save contents of stream to this file.

        :param obj: ObjectVersion instance from where this file is accessed
            from.
        :param stream: File-like stream.
        """
        self.checksum = None
        return self.storage(**kwargs).update(
            stream,
            seek=seek,
            size=size,
            chunk_size=None,
            progress_callback=progress_callback)

    @ensure_writable()
    def set_contents(self,
                     stream,
                     chunk_size=None,
                     size=None,
                     size_limit=None,
                     progress_callback=None,
                     **kwargs):
        """Save contents of stream to this file.

        :param obj: ObjectVersion instance from where this file is accessed
            from.
        :param stream: File-like stream.
        """
        self.set_uri(*self.storage(
            **kwargs).save(stream,
                           chunk_size=chunk_size,
                           size=size,
                           size_limit=size_limit,
                           progress_callback=progress_callback))

    @ensure_writable()
    def copy_contents(self,
                      fileinstance,
                      progress_callback=None,
                      chunk_size=None,
                      **kwargs):
        """Copy this file instance into another file instance."""
        if not fileinstance.readable:
            raise ValueError('Source file instance is not readable.')
        if not self.size == 0:
            raise ValueError('File instance has data.')

        self.set_uri(*self.storage(
            **kwargs).copy(fileinstance.storage(**kwargs),
                           chunk_size=None,
                           progress_callback=progress_callback))

    @ensure_readable()
    def send_file(self, filename, restricted=True, mimetype=None, **kwargs):
        """Send file to client."""
        return self.storage(**kwargs).send_file(
            filename,
            mimetype=mimetype,
            restricted=restricted,
            checksum=self.checksum,
        )

    def set_uri(self,
                uri,
                size,
                checksum,
                readable=True,
                writable=False,
                storage_class=None):
        """Set a location of a file."""
        self.uri = uri
        self.size = size
        self.checksum = checksum
        self.writable = writable
        self.readable = readable
        self.storage_class = \
            current_app.config['FILES_REST_DEFAULT_STORAGE_CLASS'] \
            if storage_class is None else \
            storage_class
        return self
Exemplo n.º 13
0
class MultipartObject(db.Model, Timestamp):
    """Model for storing files in chunks.

    A multipart object belongs to a specific bucket and key and is identified
    by an upload id. You can have multiple multipart uploads for the same
    bucket and key. Once all parts of a multipart object is uploaded, the state
    is changed to ``completed``. Afterwards it is not possible to upload
    new parts. Once completed, the multipart object is merged, and added as
    a new version in the current object/bucket.

    All parts for a multipart upload must be of the same size, except for the
    last part.
    """

    __tablename__ = 'files_multipartobject'

    __table_args__ = (db.UniqueConstraint('upload_id',
                                          'bucket_id',
                                          'key',
                                          name='uix_item'), )

    upload_id = db.Column(
        UUIDType,
        default=uuid.uuid4,
        primary_key=True,
    )
    """Identifier for the specific version of an object."""

    bucket_id = db.Column(
        UUIDType,
        db.ForeignKey(Bucket.id, ondelete='RESTRICT'),
    )
    """Bucket identifier."""

    key = db.Column(db.Text().with_variant(mysql.VARCHAR(255), 'mysql'), )
    """Key identifying the object."""

    file_id = db.Column(UUIDType,
                        db.ForeignKey(FileInstance.id, ondelete='RESTRICT'),
                        nullable=False)
    """File instance for this multipart object."""

    chunk_size = db.Column(db.Integer, nullable=True)
    """Size of chunks for file."""

    size = db.Column(db.BigInteger, nullable=True)
    """Size of file."""

    completed = db.Column(db.Boolean, nullable=False, default=False)
    """Defines if object is the completed."""

    # Relationships definitions
    bucket = db.relationship(Bucket, backref='multipart_objects')
    """Relationship to buckets."""

    file = db.relationship(FileInstance, backref='multipart_objects')
    """Relationship to buckets."""
    def __repr__(self):
        """Return representation of the multipart object."""
        return "{0}:{2}:{1}".format(self.bucket_id, self.key, self.upload_id)

    @property
    def last_part_number(self):
        """Get last part number."""
        return int(self.size / self.chunk_size)

    @property
    def last_part_size(self):
        """Get size of last part."""
        return self.size % self.chunk_size

    @validates('key')
    def validate_key(self, key, key_):
        """Validate key."""
        return validate_key(key_)

    @staticmethod
    def is_valid_chunksize(chunk_size):
        """Check if size is valid."""
        min_csize = current_app.config['FILES_REST_MULTIPART_CHUNKSIZE_MIN']
        max_csize = current_app.config['FILES_REST_MULTIPART_CHUNKSIZE_MAX']
        return chunk_size >= min_csize and chunk_size <= max_csize

    @staticmethod
    def is_valid_size(size, chunk_size):
        """Validate max theoretical size."""
        min_csize = current_app.config['FILES_REST_MULTIPART_CHUNKSIZE_MIN']
        max_size = \
            chunk_size * current_app.config['FILES_REST_MULTIPART_MAX_PARTS']
        return size > min_csize and size <= max_size

    def expected_part_size(self, part_number):
        """Get expected part size for a particular part number."""
        last_part = self.multipart.last_part_number

        if part_number == last_part:
            return self.multipart.last_part_size
        elif part_number >= 0 and part_number < last_part:
            return self.multipart.chunk_size
        else:
            raise MultipartInvalidPartNumber()

    @ensure_uncompleted()
    def complete(self):
        """Mark a multipart object as complete."""
        if Part.count(self) != self.last_part_number + 1:
            raise MultipartMissingParts()

        with db.session.begin_nested():
            self.completed = True
            self.file.readable = True
            self.file.writable = False
        return self

    @ensure_completed()
    def merge_parts(self, **kwargs):
        """Merge parts into object version."""
        self.file.update_checksum(**kwargs)
        with db.session.begin_nested():
            obj = ObjectVersion.create(self.bucket,
                                       self.key,
                                       _file_id=self.file_id)
            self.delete()
        return obj

    def delete(self):
        """Delete a multipart object."""
        # Update bucket size.
        self.bucket.size -= self.size
        # Remove parts
        Part.query_by_multipart(self).delete()
        # Remove self
        self.query.filter_by(upload_id=self.upload_id).delete()

    @classmethod
    def create(cls, bucket, key, size, chunk_size):
        """Create a new object in a bucket."""
        bucket = as_bucket(bucket)

        if bucket.locked:
            raise BucketLockedError()

        # Validate chunk size.
        if not cls.is_valid_chunksize(chunk_size):
            raise MultipartInvalidChunkSize()

        # Validate max theoretical size.
        if not cls.is_valid_size(size, chunk_size):
            raise MultipartInvalidSize()

        # Validate max bucket size.
        bucket_limit = bucket.size_limit
        if bucket_limit and size > bucket_limit:
            desc = 'File size limit exceeded.' \
                if isinstance(bucket_limit, int) else bucket_limit.reason
            raise FileSizeError(description=desc)

        with db.session.begin_nested():
            file_ = FileInstance.create()
            file_.size = size
            obj = cls(
                upload_id=uuid.uuid4(),
                bucket=bucket,
                key=key,
                chunk_size=chunk_size,
                size=size,
                completed=False,
                file=file_,
            )
            bucket.size += size
            db.session.add(obj)
        file_.init_contents(
            size=size,
            default_location=bucket.location.uri,
            default_storage_class=bucket.default_storage_class,
        )
        return obj

    @classmethod
    def get(cls, bucket, key, upload_id, with_completed=False):
        """Fetch a specific multipart object."""
        q = cls.query.filter_by(
            upload_id=upload_id,
            bucket_id=as_bucket_id(bucket),
            key=key,
        )
        if not with_completed:
            q = q.filter(cls.completed.is_(False))

        return q.one_or_none()

    @classmethod
    def query_expired(cls, dt, bucket=None):
        """Query all uncompleted multipart uploads."""
        q = cls.query.filter(cls.created < dt).filter_by(completed=True)
        if bucket:
            q = q.filter(cls.bucket_id == as_bucket_id(bucket))
        return q

    @classmethod
    def query_by_bucket(cls, bucket):
        """Query all uncompleted multipart uploads."""
        return cls.query.filter(cls.bucket_id == as_bucket_id(bucket))
Exemplo n.º 14
0
class Client(db.Model):
    """A client is the app which want to use the resource of a user.

    It is suggested that the client is registered by a user on your site, but
    it is not required.

    The client should contain at least these information:

        client_id: A random string
        client_secret: A random string
        client_type: A string represents if it is confidential
        redirect_uris: A list of redirect uris
        default_redirect_uri: One of the redirect uris
        default_scopes: Default scopes of the client

    But it could be better, if you implemented:

        allowed_grant_types: A list of grant types
        allowed_response_types: A list of response types
        validate_scopes: A function to validate scopes
    """

    __tablename__ = 'oauth2CLIENT'

    name = db.Column(
        db.String(40),
        info=dict(label=_('Name'),
                  description=_('Name of application (displayed to users).'),
                  validators=[validators.DataRequired()]))
    """Human readable name of the application."""

    description = db.Column(db.Text(),
                            default=u'',
                            info=dict(
                                label=_('Description'),
                                description=_(
                                    'Optional. Description of the application'
                                    ' (displayed to users).'),
                            ))
    """Human readable description."""

    website = db.Column(
        URLType(),
        info=dict(
            label=_('Website URL'),
            description=_('URL of your application (displayed to users).'),
        ),
        default=u'',
    )

    user_id = db.Column(db.ForeignKey(User.id), nullable=True)
    """Creator of the client application."""

    client_id = db.Column(db.String(255), primary_key=True)
    """Client application ID."""

    client_secret = db.Column(db.String(255),
                              unique=True,
                              index=True,
                              nullable=False)
    """Client application secret."""

    is_confidential = db.Column(db.Boolean, default=True)
    """Determine if client application is public or not."""

    is_internal = db.Column(db.Boolean, default=False)
    """Determins if client application is an internal application."""

    _redirect_uris = db.Column(db.Text)
    """A newline-separated list of redirect URIs. First is the default URI."""

    _default_scopes = db.Column(db.Text)
    """A space-separated list of default scopes of the client.

    The value of the scope parameter is expressed as a list of space-delimited,
    case-sensitive strings.
    """

    user = db.relationship(User,
                           backref=db.backref(
                               "oauth2clients",
                               cascade="all, delete-orphan",
                           ))
    """Relationship to user."""
    @property
    def allowed_grant_types(self):
        """Return allowed grant types."""
        return current_app.config['OAUTH2_ALLOWED_GRANT_TYPES']

    @property
    def allowed_response_types(self):
        """Return allowed response types."""
        return current_app.config['OAUTH2_ALLOWED_RESPONSE_TYPES']

    # def validate_scopes(self, scopes):
    #     return self._validate_scopes

    @property
    def client_type(self):
        """Return client type."""
        if self.is_confidential:
            return 'confidential'
        return 'public'

    @property
    def redirect_uris(self):
        """Return redirect uris."""
        if self._redirect_uris:
            return self._redirect_uris.splitlines()
        return []

    @redirect_uris.setter
    def redirect_uris(self, value):
        """Validate and store redirect URIs for client."""
        if isinstance(value, six.text_type):
            value = value.split("\n")

        value = [v.strip() for v in value]

        for v in value:
            validate_redirect_uri(v)

        self._redirect_uris = "\n".join(value) or ""

    @property
    def default_redirect_uri(self):
        """Return default redirect uri."""
        try:
            return self.redirect_uris[0]
        except IndexError:
            pass

    @property
    def default_scopes(self):
        """List of default scopes for client."""
        if self._default_scopes:
            return self._default_scopes.split(" ")
        return []

    @default_scopes.setter
    def default_scopes(self, scopes):
        """Set default scopes for client."""
        validate_scopes(scopes)
        self._default_scopes = " ".join(set(scopes)) if scopes else ""

    def validate_scopes(self, scopes):
        """Validate if client is allowed to access scopes."""
        try:
            validate_scopes(scopes)
            return True
        except ScopeDoesNotExists:
            return False

    def gen_salt(self):
        """Generate salt."""
        self.reset_client_id()
        self.reset_client_secret()

    def reset_client_id(self):
        """Reset client id."""
        self.client_id = gen_salt(
            current_app.config.get('OAUTH2_CLIENT_ID_SALT_LEN'))

    def reset_client_secret(self):
        """Reset client secret."""
        self.client_secret = gen_salt(
            current_app.config.get('OAUTH2_CLIENT_SECRET_SALT_LEN'))
Exemplo n.º 15
0
class LegacyRecordsMirror(db.Model):
    __tablename__ = "legacy_records_mirror"

    __table_args__ = (db.Index("ix_legacy_records_mirror_valid_collection",
                               "valid", "collection"), )

    recid = db.Column(db.Integer, primary_key=True)
    last_updated = db.Column(db.DateTime,
                             default=datetime.utcnow,
                             nullable=False,
                             index=True)
    _marcxml = db.Column("marcxml", db.LargeBinary, nullable=False)
    valid = db.Column(db.Boolean, default=None, nullable=True)
    _errors = db.Column("errors", db.Text(), nullable=True)
    collection = db.Column(db.Text(), default="")

    re_recid = re.compile(
        r"<controlfield.*?tag=.001.*?>(?P<recid>\d+)</controlfield>")

    @hybrid_property
    def marcxml(self):
        """marcxml column wrapper to compress/decompress on the fly."""
        try:
            return decompress(self._marcxml)
        except error:
            # Legacy uncompress data?
            return self._marcxml

    @marcxml.setter
    def marcxml(self, value):
        if isinstance(value, str):
            self._marcxml = compress(bytes(value, "utf8"))
        else:
            self._marcxml = compress(value)

    @hybrid_property
    def error(self):
        return self._errors

    @error.setter
    def error(self, value):
        """Errors column setter that stores an Exception and sets the ``valid`` flag."""
        self.valid = False
        self.collection = get_collection_from_marcxml(self.marcxml)
        value_type = type(value).__name__
        self._errors = f"{value_type}: {value}"

    @classmethod
    def from_marcxml(cls, raw_record):
        """Create an instance from a MARCXML record.
        The record must have a ``001`` tag containing the recid, otherwise it raises a ValueError.
        """
        try:
            recid = int(cls.re_recid.search(str(raw_record)).group("recid"))
        except AttributeError:
            raise ValueError(
                "The MARCXML record contains no recid or recid is malformed")
        # FIXME also get last_updated from marcxml
        record = cls(recid=recid)
        record.marcxml = raw_record
        record.valid = None
        return record
Exemplo n.º 16
0
class Community(db.Model, Timestamp):
    """Represent a community."""

    __tablename__ = 'communities_community'

    id = db.Column(db.String(100), primary_key=True)
    """Id of the community."""

    id_user = db.Column(
        db.Integer,
        db.ForeignKey(User.id),
        nullable=False
    )
    """Owner of the community."""

    title = db.Column(db.String(length=255), nullable=False, default='')
    """Title of the community."""

    description = db.Column(db.Text, nullable=False, default='')
    """Short description of community, displayed in portal boxes."""

    page = db.Column(db.Text, nullable=False, default='')
    """Long description of community, displayed on an individual page."""

    curation_policy = db.Column(db.Text(), nullable=False, default='')
    """Community curation policy."""

    community_header = db.Column(db.Text, nullable=False, default='')
    """Header design of community, displayed in portal boxes."""

    community_footer = db.Column(db.Text, nullable=False, default='')
    """Footer design of community, displayed in portal boxes."""

    last_record_accepted = db.Column(
        db.DateTime(), nullable=False, default=datetime(2000, 1, 1, 0, 0, 0))
    """Last record acceptance datetime."""

    logo_ext = db.Column(db.String(length=4), nullable=True, default=None)
    """Extension of the logo."""

    ranking = db.Column(db.Integer, nullable=False, default=0)
    """Ranking of community. Updated by ranking deamon."""

    fixed_points = db.Column(db.Integer, nullable=False, default=0)
    """Points which will be always added to overall score of community."""

    deleted_at = db.Column(db.DateTime, nullable=True, default=None)
    """Time at which the community was soft-deleted."""

    # root_node_id = db.Column(db.Text, nullable=False, default='')

    root_node_id = db.Column(
        db.BigInteger,
        db.ForeignKey(Index.id),
        nullable=False
    )

    """Id of Root Node"""

    #
    # Relationships
    #
    owner = db.relationship(User, backref='communities',
                            foreign_keys=[id_user])
    """Relation to the owner (User) of the community."""

    index = db.relationship(Index, backref='index', foreign_keys=[root_node_id])
    """Relation to the owner (Index) of the community."""

    def __repr__(self):
        """String representation of the community object."""
        return '<Community, ID: {}>'.format(self.id)

    @classmethod
    def create(cls, community_id, user_id, root_node_id, **data):
        """Get a community."""
        with db.session.begin_nested():
            obj = cls(id=community_id, id_user=user_id, root_node_id=root_node_id, **data)
            db.session.add(obj)
        return obj

    def save_logo(self, stream, filename):
        """Get a community."""
        logo_ext = save_and_validate_logo(stream, filename, self.id)
        if logo_ext:
            self.logo_ext = logo_ext
            return True
        return False

    @classmethod
    def get(cls, community_id, with_deleted=False):
        """Get a community."""
        q = cls.query.filter_by(id=community_id)
        if not with_deleted:
            q = q.filter(cls.deleted_at.is_(None))
        return q.one_or_none()

    @classmethod
    def get_by_user(cls, user_id, with_deleted=False):
        """Get a community."""
        query = cls.query.filter_by(
            id_user=user_id
        )
        if not with_deleted:
            query = query.filter(cls.deleted_at.is_(None))

        return query.order_by(db.asc(Community.title))

    @classmethod
    def filter_communities(cls, p, so, with_deleted=False):
        """Search for communities.

        Helper function which takes from database only those communities which
        match search criteria. Uses parameter 'so' to set communities in the
        correct order.

        Parameter 'page' is introduced to restrict results and return only
        slice of them for the current page. If page == 0 function will return
        all communities that match the pattern.
        """
        query = cls.query if with_deleted else \
            cls.query.filter(cls.deleted_at.is_(None))

        if p:
            p = p.replace(' ', '%')
            query = query.filter(db.or_(
                cls.id.ilike('%' + p + '%'),
                cls.title.ilike('%' + p + '%'),
                cls.description.ilike('%' + p + '%'),
            ))

        if so in current_app.config['COMMUNITIES_SORTING_OPTIONS']:
            order = so == 'title' and db.asc or db.desc
            query = query.order_by(order(getattr(cls, so)))
        else:
            query = query.order_by(db.desc(cls.ranking))
        return query

    def add_record(self, record):
        """Add a record to the community.

        :param record: Record object.
        :type record: `invenio_records.api.Record`
        """
        key = current_app.config['COMMUNITIES_RECORD_KEY']
        record.setdefault(key, [])

        if self.has_record(record):
            current_app.logger.warning(
                'Community addition: record {uuid} is already in community '
                '"{comm}"'.format(uuid=record.id, comm=self.id))
        else:
            record[key].append(self.id)
            record[key] = sorted(record[key])
        if current_app.config['COMMUNITIES_OAI_ENABLED']:
            if not self.oaiset.has_record(record):
                self.oaiset.add_record(record)

    def remove_record(self, record):
        """Remove an already accepted record from the community.

        :param record: Record object.
        :type record: `invenio_records.api.Record`
        """
        if not self.has_record(record):
            current_app.logger.warning(
                'Community removal: record {uuid} was not in community '
                '"{comm}"'.format(uuid=record.id, comm=self.id))
        else:
            key = current_app.config['COMMUNITIES_RECORD_KEY']
            record[key] = [c for c in record[key] if c != self.id]

        if current_app.config['COMMUNITIES_OAI_ENABLED']:
            if self.oaiset.has_record(record):
                self.oaiset.remove_record(record)

    def has_record(self, record):
        """Check if record is in community."""
        return self.id in \
            record.get(current_app.config['COMMUNITIES_RECORD_KEY'], [])

    def accept_record(self, record):
        """Accept a record for inclusion in the community.

        :param record: Record object.
        """
        with db.session.begin_nested():
            req = InclusionRequest.get(self.id, record.id)
            if req is None:
                raise InclusionRequestMissingError(community=self,
                                                   record=record)
            req.delete()
            self.add_record(record)
            self.last_record_accepted = datetime.utcnow()

    def reject_record(self, record):
        """Reject a record for inclusion in the community.

        :param record: Record object.
        """
        with db.session.begin_nested():
            req = InclusionRequest.get(self.id, record.id)
            if req is None:
                raise InclusionRequestMissingError(community=self,
                                                   record=record)
            req.delete()

    def delete(self):
        """Mark the community for deletion.

        :param delete_time: DateTime after which to delete the community.
        :type delete_time: datetime.datetime
        :raises: CommunitiesError
        """
        if self.deleted_at is not None:
            raise CommunitiesError(community=self)
        else:
            self.deleted_at = datetime.utcnow()

    def undelete(self):
        """Remove the community marking for deletion."""
        if self.deleted_at is None:
            raise CommunitiesError(community=self)
        else:
            self.deleted_at = None

    @property
    def is_deleted(self):
        """Return whether given community is marked for deletion."""
        return self.deleted_at is not None

    @property
    def logo_url(self):
        """Get URL to collection logo.

        :returns: Path to community logo.
        :rtype: str
        """
        if self.logo_ext:
            return '/api/files/{bucket}/{key}'.format(
                bucket=current_app.config['COMMUNITIES_BUCKET_UUID'],
                key='{0}/logo.{1}'.format(self.id, self.logo_ext),
            )
        return None

    @property
    def community_url(self):
        """Get provisional URL."""
        return url_for(
            'invenio_communities.detail', community_id=self.id, _external=True)

    @property
    def community_provisional_url(self):
        """Get provisional URL."""
        return url_for(
            'invenio_communities.curate', community_id=self.id, _external=True)

    @property
    def upload_url(self):
        """Get provisional URL."""
        return url_for('invenio_deposit_ui.new', c=self.id, _external=True)

    @property
    def oaiset_spec(self):
        """Return the OAISet 'spec' name for given community.

        :returns: name of corresponding OAISet ('spec').
        :rtype: str
        """
        return current_app.config['COMMUNITIES_OAI_FORMAT'].format(
            community_id=self.id)

    @property
    def oaiset(self):
        """Return the corresponding OAISet for given community.

        If OAIServer is not installed this property will return None.

        :returns: returns OAISet object corresponding to this community.
        :rtype: `invenio_oaiserver.models.OAISet` or None
        """
        if current_app.config['COMMUNITIES_OAI_ENABLED']:
            from invenio_oaiserver.models import OAISet
            return OAISet.query.filter_by(spec=self.oaiset_spec).one()
        else:
            return None

    @property
    def oaiset_url(self):
        """Return the OAISet URL for given community.

        :returns: URL of corresponding OAISet.
        :rtype: str
        """
        return url_for(
            'invenio_oaiserver.response',
            verb='ListRecords',
            metadataPrefix='oai_dc', set=self.oaiset_spec, _external=True)

    @property
    def version_id(self):
        """Return the version of the community.

        :returns: hash which encodes the community id and its las update.
        :rtype: str
        """
        return hashlib.sha1('{0}__{1}'.format(
            self.id, self.updated).encode('utf-8')).hexdigest()
class RemoteToken(db.Model):
    """Storage for the access tokens for linked accounts."""

    __tablename__ = 'oauthclient_remotetoken'

    #
    # Fields
    #
    id_remote_account = db.Column(db.Integer,
                                  db.ForeignKey(RemoteAccount.id),
                                  nullable=False,
                                  primary_key=True)
    """Foreign key to account."""

    token_type = db.Column(db.String(40),
                           default='',
                           nullable=False,
                           primary_key=True)
    """Type of token."""

    access_token = db.Column(TextEncryptedType(type_in=db.Text,
                                               key=_secret_key),
                             nullable=False)
    """Access token to remote application."""

    secret = db.Column(db.Text(), default='', nullable=False)
    """Used only by OAuth 1."""
    def token(self):
        """Get token as expected by Flask-OAuthlib."""
        return (self.access_token, self.secret)

    def update_token(self, token, secret):
        """Update token with new values."""
        if self.access_token != token or self.secret != secret:
            with db.session.begin_nested():
                self.access_token = token
                self.secret = secret
                db.session.add(self)

    @classmethod
    def get(cls, user_id, client_id, token_type='', access_token=None):
        """Get RemoteToken for user."""
        args = [
            RemoteAccount.id == RemoteToken.id_remote_account,
            RemoteAccount.user_id == user_id,
            RemoteAccount.client_id == client_id,
            RemoteToken.token_type == token_type,
        ]

        if access_token:
            args.append(RemoteToken.access_token == access_token)

        return cls.query.options(
            db.joinedload('remote_account')).filter(*args).first()

    @classmethod
    def get_by_token(cls, client_id, access_token, token_type=''):
        """Get RemoteAccount object for token."""
        return cls.query.options(db.joinedload('remote_account')).filter(
            RemoteAccount.id == RemoteToken.id_remote_account,
            RemoteAccount.client_id == client_id,
            RemoteToken.token_type == token_type,
            RemoteToken.access_token == access_token,
        ).first()

    @classmethod
    def create(cls,
               user_id,
               client_id,
               token,
               secret,
               token_type='',
               extra_data=None):
        """Create a new access token.

        Creates RemoteAccount as well if it does not exists.
        """
        account = RemoteAccount.get(user_id, client_id)

        with db.session.begin_nested():
            if account is None:
                account = RemoteAccount(
                    user_id=user_id,
                    client_id=client_id,
                    extra_data=extra_data or dict(),
                )
                db.session.add(account)

            token = cls(
                token_type=token_type,
                remote_account=account,
                access_token=token,
                secret=secret,
            )
            db.session.add(token)
        return token
Exemplo n.º 18
0
class RemoteToken(db.Model):
    """Storage for the access tokens for linked accounts."""

    __tablename__ = 'oauthclient_remotetoken'

    #
    # Fields
    #
    id_remote_account = db.Column(
        db.Integer,
        db.ForeignKey(RemoteAccount.id,
                      name='fk_oauthclient_remote_token_remote_account'),
        nullable=False,
        primary_key=True)
    """Foreign key to account."""

    token_type = db.Column(db.String(40),
                           default='',
                           nullable=False,
                           primary_key=True)
    """Type of token."""

    access_token = db.Column(EncryptedType(type_in=db.Text, key=_secret_key),
                             nullable=False)
    """Access token to remote application."""

    secret = db.Column(db.Text(), default='', nullable=False)
    """Used only by OAuth 1."""

    #
    # Relationships properties
    #
    remote_account = db.relationship(RemoteAccount,
                                     backref=backref(
                                         'remote_tokens',
                                         cascade='all, delete-orphan'))
    """SQLAlchemy relationship to RemoteAccount objects."""
    def __repr__(self):
        """String representation for model."""
        return ('Remote Token <token_type={0.token_type} '
                'access_token={0.access_token}>'.format(self))

    def token(self):
        """Get token as expected by Flask-OAuthlib."""
        return (self.access_token, self.secret)

    def update_token(self, token, secret):
        """Update token with new values.

        :param token: The token value.
        :param secret: The secret key.
        """
        if self.access_token != token or self.secret != secret:
            with db.session.begin_nested():
                self.access_token = token
                self.secret = secret
                db.session.add(self)

    @classmethod
    def get(cls, user_id, client_id, token_type='', access_token=None):
        """Get RemoteToken for user.

        :param user_id: The user id.
        :param client_id: The client id.
        :param token_type: The token type. (Default: ``''``)
        :param access_token: If set, will filter also by access token.
            (Default: ``None``)
        :returns: A :class:`invenio_oauthclient.models.RemoteToken` instance.
        """
        args = [
            RemoteAccount.id == RemoteToken.id_remote_account,
            RemoteAccount.user_id == user_id,
            RemoteAccount.client_id == client_id,
            RemoteToken.token_type == token_type,
        ]

        if access_token:
            args.append(RemoteToken.access_token == access_token)

        return cls.query.options(
            db.joinedload('remote_account')).filter(*args).first()

    @classmethod
    def get_by_token(cls, client_id, access_token, token_type=''):
        """Get RemoteAccount object for token.

        :param client_id: The client id.
        :param access_token: The access token.
        :param token_type: The token type. (Default: ``''``)
        :returns: A :class:`invenio_oauthclient.models.RemoteToken` instance.
        """
        return cls.query.options(db.joinedload('remote_account')).filter(
            RemoteAccount.id == RemoteToken.id_remote_account,
            RemoteAccount.client_id == client_id,
            RemoteToken.token_type == token_type,
            RemoteToken.access_token == access_token,
        ).first()

    @classmethod
    def create(cls,
               user_id,
               client_id,
               token,
               secret,
               token_type='',
               extra_data=None):
        """Create a new access token.

        .. note:: Creates RemoteAccount as well if it does not exists.

        :param user_id: The user id.
        :param client_id: The client id.
        :param token: The token.
        :param secret: The secret key.
        :param token_type: The token type. (Default: ``''``)
        :param extra_data: Extra data to set in the remote account if the
            remote account doesn't exists. (Default: ``None``)
        :returns: A :class:`invenio_oauthclient.models.RemoteToken` instance.

        """
        account = RemoteAccount.get(user_id, client_id)

        with db.session.begin_nested():
            if account is None:
                account = RemoteAccount(
                    user_id=user_id,
                    client_id=client_id,
                    extra_data=extra_data or dict(),
                )
                db.session.add(account)

            token = cls(
                token_type=token_type,
                remote_account=account,
                access_token=token,
                secret=secret,
            )
            db.session.add(token)
        return token
Exemplo n.º 19
0
class WidgetDesignPage(db.Model):
    """Database for menu pages."""

    __tablename__ = 'widget_design_page'

    id = db.Column(db.Integer, primary_key=True, nullable=False)

    title = db.Column(db.String(100), nullable=True)

    repository_id = db.Column(db.String(100), nullable=False)

    url = db.Column(db.String(100), nullable=False, unique=True)

    template_name = db.Column(  # May be used in the future
        db.String(100), nullable=True)

    content = db.Column(db.Text(), nullable=True, default='')

    settings = db.Column(db.JSON().with_variant(
        postgresql.JSONB(none_as_null=True),
        'postgresql',
    ).with_variant(
        JSONType(),
        'sqlite',
    ).with_variant(
        JSONType(),
        'mysql',
    ),
                         default=lambda: dict(),
                         nullable=True)

    is_main_layout = db.Column(db.Boolean('is_main_layout'), nullable=True)

    multi_lang_data = db.relationship(
        'WidgetDesignPageMultiLangData',
        backref='widget_design_page',
        cascade='all, delete-orphan',
        collection_class=attribute_mapped_collection('lang_code'))

    @classmethod
    def create_or_update(cls,
                         repository_id,
                         title,
                         url,
                         content,
                         page_id=0,
                         settings=None,
                         multi_lang_data={},
                         is_main_layout=False):
        """Insert new widget design page.

        :param repository_id: Identifier of the repository
        :param title: Page title
        :param url: Page URL
        :param content: HTML content
        :param page_id: Page identifier
        :param settings: Page widget setting data
        :param multi_lang_data: Multi language data
        :param is_main_layout: Main layout flash
        :return: True if successful, otherwise False
        """
        try:
            prev = cls.query.filter_by(id=int(page_id)).one_or_none()
            if prev:
                repository_id = prev.repository_id
            page = prev or WidgetDesignPage()

            if not repository_id or not url:
                return False

            with db.session.begin_nested():
                page.repository_id = repository_id
                page.title = title
                page.url = url
                page.content = content
                page.settings = settings
                page.is_main_layout = is_main_layout
                for lang in multi_lang_data:
                    page.multi_lang_data[lang] = \
                        WidgetDesignPageMultiLangData(
                            lang, multi_lang_data[lang])
                db.session.merge(page)
            db.session.commit()
            return True
        except Exception as ex:
            db.session.rollback()
            current_app.logger.debug(ex)
            raise ex

    @classmethod
    def delete(cls, page_id):
        """Delete widget design page.

        :param page_id: Page model's id
        :return: True if successful or False
        """
        if page_id:
            try:
                with db.session.begin_nested():
                    cls.query.filter_by(id=int(page_id)).delete()
                db.session.commit()
                return True
            except BaseException as ex:
                db.session.rollback()
                current_app.logger.debug(ex)
                raise ex
        return False

    @classmethod
    def update_settings(cls, page_id, settings=None):
        """Update design page setting.

        :param page_id: Identifier of the page.
        :param settings: Page widget setting data.
        :return: True if successful, otherwise False.
        """
        try:
            page = cls.query.filter_by(id=int(page_id)).one_or_none()
            if page:
                with db.session.begin_nested():
                    page.settings = settings
                    db.session.merge(page)
                db.session.commit()
                return True
        except Exception as ex:
            db.session.rollback()
            current_app.logger.debug(ex)
            raise ex
        return False

    @classmethod
    def update_settings_by_repository_id(cls, repository_id, settings=None):
        """Update design page setting by repository id.

        Note: ALL pages belonging to repository will have the same settings.
        Could be used to make all pages uniform in design.
        :param repository_id: Repository id.
        :param settings: Page widget setting data.
        :return: True if successful, otherwise False.
        """
        try:
            pages = cls.query.filter_by(repository_id=int(repository_id)).all()
            for page in pages:
                with db.session.begin_nested():
                    page.settings = settings
                    db.session.merge(page)
                db.session.commit()
            return True
        except Exception as ex:
            db.session.rollback()
            current_app.logger.debug(ex)
            return False

    @classmethod
    def get_all(cls):
        """
        Get all pages.

        :return: List of all pages.
        """
        return db.session.query(cls).all()

    @classmethod
    def get_all_valid(cls):
        """
        Get all pages with widget settings.

        :return: List of all pages.
        """
        return db.session.query(cls).filter(cls.settings is not None).all()

    @classmethod
    def get_by_id(cls, id):
        """
        Get widget page by id.

        Raises error if not found etc.
        :return: Single page object or exception raised.
        """
        return db.session.query(cls).filter_by(id=int(id)).one()

    @classmethod
    def get_by_id_or_none(cls, id):
        """
        Get widget page by id without raising exception.

        :return: Single page object or none.
        """
        return db.session.query(cls).filter_by(id=int(id)).one_or_none()

    @classmethod
    def get_by_url(cls, url):
        """
        Get widget page by url.

        :return: Single page objects or none.
        """
        return db.session.query(cls).filter_by(url=url).one()

    @classmethod
    def get_by_repository_id(cls, repository_id):
        """
        Get widget pages for community/repo.

        :return: Multiple page objects or empty list.
        """
        return db.session.query(cls).filter_by(
            repository_id=repository_id).all()