Exemple #1
0
def save_token(token_data, request):
    requested_scopes = set(scope_to_list(token_data.get('scope', '')))
    application = OAuthApplication.query.filter_by(
        client_id=request.client.client_id).one()
    link = OAuthApplicationUserLink.query.with_parent(application).with_parent(
        request.user).first()

    if link is None:
        link = OAuthApplicationUserLink(application=application,
                                        user=request.user,
                                        scopes=requested_scopes)
    else:
        if not requested_scopes:
            # for already-authorized apps not specifying a scope uses all scopes the
            # user previously granted to the app
            requested_scopes = set(link.scopes)
            token_data['scope'] = list_to_scope(requested_scopes)
        new_scopes = requested_scopes - set(link.scopes)
        if new_scopes:
            logger.info('New scopes for %r: %s', link, new_scopes)
            link.update_scopes(new_scopes)

    link.tokens.append(
        OAuthToken(access_token=token_data['access_token'],
                   scopes=requested_scopes))

    # get rid of old tokens if there are too many
    q = (db.session.query(OAuthToken.id).with_parent(link).filter_by(
        _scopes=db.cast(sorted(requested_scopes), ARRAY(db.String))).order_by(
            OAuthToken.created_dt.desc()).offset(
                MAX_TOKENS_PER_SCOPE).scalar_subquery())
    OAuthToken.query.filter(
        OAuthToken.id.in_(q)).delete(synchronize_session='fetch')
Exemple #2
0
class Group(db.Model):
    __tablename__ = 'group'
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    name = db.Column(db.String(32), index=True, unique=True, nullable=False)
    capabilities = db.Column('capabilities',
                             ARRAY(db.Text),
                             nullable=False,
                             server_default='{}')
    private = db.Column(db.Boolean)

    @property
    def pending_group(self):
        return self.private and db.session.query(User).filter(
            User.login == self.name).first().pending

    @property
    def immutable(self):
        return self.private or self.name == "public"

    @property
    def user_logins(self):
        return [ug.login for ug in self.users]

    @staticmethod
    def public_group():
        return Group.get_by_name("public")

    @staticmethod
    def get_by_name(name):
        return db.session.query(Group).filter(Group.name == name).first()

    @staticmethod
    def all_access_groups():
        return db.session.query(Group) \
            .filter(Group.capabilities.contains([Capabilities.access_all_objects])).all()
Exemple #3
0
def test_delete_old_tokens(db, dummy_application, dummy_user):
    request = MagicMock(client=dummy_application, user=dummy_user)

    gen_hashes = {'foo': [], 'bar': []}
    for scope in ('foo', 'bar'):
        for i in range(MAX_TOKENS_PER_SCOPE + 2):
            token_string = generate_token(69)
            gen_hashes[scope].append(hashlib.sha256(token_string.encode()).hexdigest())
            save_token({'scope': scope, 'access_token': token_string}, request)
            num_tokens = OAuthToken.query.filter_by(_scopes=db.cast([scope], ARRAY(db.String))).count()
            assert num_tokens == min(i + 1, MAX_TOKENS_PER_SCOPE)

    # ensure we have the latest MAX_TOKENS_PER_SCOPE tokens in the DB
    for scope in ('foo', 'bar'):
        query = (db.session.query(OAuthToken.access_token_hash)
                 .filter_by(_scopes=db.cast([scope], ARRAY(db.String)))
                 .order_by(OAuthToken.created_dt))
        db_hashes = [x.access_token_hash for x in query]
        assert db_hashes == gen_hashes[scope][-MAX_TOKENS_PER_SCOPE:]
Exemple #4
0
class _BookMixin(object):
    author = Column(String(100))
    editor = Column(String(100))
    activities = Column(ArrayOfEnum(enums.activity_type))
    url = Column(String(255))
    isbn = Column(String(17))
    book_types = Column(ArrayOfEnum(enums.book_type))
    nb_pages = Column(SmallInteger)
    publication_date = Column(String(100))
    langs = Column(ARRAY(String(2)))
Exemple #5
0
class Group(db.Model):
    __tablename__ = "group"

    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    name = db.Column(db.String(32), index=True, unique=True, nullable=False)
    capabilities = db.Column("capabilities",
                             ARRAY(db.Text),
                             nullable=False,
                             server_default="{}")
    private = db.Column(db.Boolean, nullable=False, default=False)

    members = db.relationship("Member",
                              back_populates="group",
                              cascade="all, delete-orphan")
    users = association_proxy("members",
                              "user",
                              creator=lambda user: Member(user=user))

    PUBLIC_GROUP_NAME = "public"
    EVERYTHING_GROUP_NAME = "everything"

    @property
    def pending_group(self):
        from .user import User

        return (self.private and db.session.query(User).filter(
            User.login == self.name).first().pending)

    @property
    def immutable(self):
        return self.private or self.name == self.PUBLIC_GROUP_NAME

    @property
    def user_logins(self):
        return [ug.login for ug in self.users]

    @property
    def group_admins(self):
        return [
            member.user.login for member in self.members if member.group_admin
        ]

    @staticmethod
    def public_group():
        return Group.get_by_name(Group.PUBLIC_GROUP_NAME)

    @staticmethod
    def get_by_name(name):
        return db.session.query(Group).filter(Group.name == name).first()

    @staticmethod
    def all_access_groups():
        return (db.session.query(Group).filter(
            Group.capabilities.contains([Capabilities.access_all_objects
                                         ])).all())
Exemple #6
0
class Rule(Base):
    __tablename__ = 'rules'

    id = Column(Integer, primary_key=True)
    statistic = Column(Integer, nullable=False)
    sensor = Column(Integer, nullable=False)
    tags = Column(ARRAY(Integer))

    def __repr__(self):
        return """
            {{
                "id": {},
                "statistic": {},
                "sensor": {},
                "tags": {},
            }}
        """.format(self.id, self.statistic, self.sensor, self.tags)
Exemple #7
0
class TeamStats(Base):
    __tablename__ = 'team_stats'
    __table_args__ = (UniqueConstraint('match_id', 'team_name',
                                       name='uix_1'), )

    id = Column(Integer, primary_key=True)
    match_id = Column(Integer, index=True)
    league_id = Column(Integer, ForeignKey("league.league_id"))
    season = Column(String, index=True)
    team_name = Column(String, index=True)
    host_status = Column(Boolean)
    team_score = Column(Integer)
    scoring_minutes = Column(ARRAY(String))
    corners_total = Column(Integer)
    corners_chances_created = Column(Integer)
    corners_assists = Column(Integer)
    corners_failed = Column(Integer)
Exemple #8
0
class Tariff(Base):
    __tablename__ = 'tariffs'

    id = Column(Integer, primary_key=True)
    name = Column(String)
    tariff_type = Column(Integer)
    vals = Column(JSON)
    company = Column(Integer, ForeignKey('companies.id'))
    compatibility = Column(ARRAY(Integer))

    def __init__(self, name, val, company, id=None):
        if not isinstance(val, BaseTariffVal):
            raise ValueError()
        self.id = id
        self.name = name
        self.tariff_type = val.get_type()
        self.vals = val.get_value()
        self.company = company
        self.compatibility = val.get_compatibility()

    @property
    def tariff(self):
        if (self.tariff_type == 1):
            return MonoTariffVal.from_dict(self.val)
        raise ValueError()

    def __repr__(self):
        return """
            {{
                "id": {},
                "name": {},
                "tariff_type": {},
                "vals": {},
                "company": {},
                "compatibility": {}
            }}
        """.format(self.id,
                   self.name,
                   self.tariff_type,
                   self.vals,
                   self.company,
                   self.compatibility)
Exemple #9
0
class Group(db.Model):
    __tablename__ = "group"

    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    name = db.Column(db.String(32), index=True, unique=True, nullable=False)
    capabilities = db.Column(
        "capabilities", ARRAY(db.Text), nullable=False, server_default="{}"
    )
    # Group is user's private group
    private = db.Column(db.Boolean, nullable=False, default=False)
    # New users are automatically added to this group
    default = db.Column(db.Boolean, nullable=False, default=False)
    # Workspace groups have two traits:
    # - group members can list all the other group memebers
    # - they are candidates for sharing when upload_as:*
    workspace = db.Column(db.Boolean, nullable=False, default=True)

    members = db.relationship(
        "Member", back_populates="group", cascade="all, delete-orphan"
    )
    users = association_proxy("members", "user", creator=lambda user: Member(user=user))

    permissions = db.relationship(
        "ObjectPermission",
        back_populates="group",
        cascade="all, delete",
        passive_deletes=True,
    )

    attributes = db.relationship(
        "AttributePermission",
        back_populates="group",
        cascade="all, delete",
        passive_deletes=True,
    )

    PUBLIC_GROUP_NAME = "public"
    # These groups are just pre-created for convenience by 'mwdb-core configure'
    DEFAULT_EVERYTHING_GROUP_NAME = "everything"
    DEFAULT_REGISTERED_GROUP_NAME = "registered"

    @property
    def pending_group(self):
        from .user import User

        return (
            self.private
            and db.session.query(User).filter(User.login == self.name).first().pending
        )

    @property
    def immutable(self):
        """
        Immutable groups can't be renamed, joined and left.
        The only thing that can be changed are capabilities.
        """
        return self.private or self.name == self.PUBLIC_GROUP_NAME

    @property
    def user_logins(self):
        return [ug.login for ug in self.users]

    @property
    def group_admins(self):
        return [member.user.login for member in self.members if member.group_admin]

    def add_member(self, user):
        if user in self.users:
            return False

        db.session.begin_nested()
        try:
            self.users.append(user)
            db.session.commit()
        except (FlushError, IntegrityError):
            db.session.rollback()
            return False
        return True

    def remove_member(self, user):
        if user not in self.users:
            return False

        db.session.begin_nested()
        try:
            self.users.remove(user)
            db.session.commit()
        except (FlushError, IntegrityError):
            db.session.rollback()
            return False
        return True

    @staticmethod
    def public_group():
        return Group.get_by_name(Group.PUBLIC_GROUP_NAME)

    @staticmethod
    def get_by_name(name):
        return db.session.query(Group).filter(Group.name == name).first()

    @staticmethod
    def all_access_groups():
        return (
            db.session.query(Group)
            .filter(Group.capabilities.contains([Capabilities.access_all_objects]))
            .all()
        )

    @staticmethod
    def all_default_groups():
        """
        Return all default groups
        """
        return db.session.query(Group).filter(Group.default.is_(True)).all()
Exemple #10
0
    def _get_column_type(self, format_type):
        """Blatant ripoff from PG_Dialect.get_column_info"""
        # strip (*) from character varying(5), timestamp(5)
        # with time zone, geometry(POLYGON), etc.
        attype = re.sub(r'\(.*\)', '', format_type)

        # strip '[]' from integer[], etc.
        attype = re.sub(r'\[\]', '', attype)

        is_array = format_type.endswith('[]')
        charlen = re.search('\(([\d,]+)\)', format_type)
        if charlen:
            charlen = charlen.group(1)
        args = re.search('\((.*)\)', format_type)
        if args and args.group(1):
            args = tuple(re.split('\s*,\s*', args.group(1)))
        else:
            args = ()
        kwargs = {}

        if attype == 'numeric':
            if charlen:
                prec, scale = charlen.split(',')
                args = (int(prec), int(scale))
            else:
                args = ()
        elif attype == 'double precision':
            args = (53, )
        elif attype == 'integer':
            args = ()
        elif attype in ('timestamp with time zone', 'time with time zone'):
            kwargs['timezone'] = True
            if charlen:
                kwargs['precision'] = int(charlen)
            args = ()
        elif attype in ('timestamp without time zone',
                        'time without time zone', 'time'):
            kwargs['timezone'] = False
            if charlen:
                kwargs['precision'] = int(charlen)
            args = ()
        elif attype == 'bit varying':
            kwargs['varying'] = True
            if charlen:
                args = (int(charlen), )
            else:
                args = ()
        elif attype in ('interval', 'interval year to month',
                        'interval day to second'):
            if charlen:
                kwargs['precision'] = int(charlen)
            args = ()
        elif charlen:
            args = (int(charlen), )

        coltype = ischema_names.get(attype, None)
        if coltype:
            coltype = coltype(*args, **kwargs)
            if is_array:
                coltype = ARRAY(coltype)
        else:
            coltype = sqltypes.NULLTYPE
        return coltype
Exemple #11
0
from sqlalchemy.dialects.postgresql.array import ARRAY

# revision identifiers, used by Alembic.
revision = "f4ccb4be2170"
down_revision = "e304b81836b0"
branch_labels = None
depends_on = None

logger = logging.getLogger("alembic")

group_helper = sa.Table(
    "group",
    sa.MetaData(),
    sa.Column("id", sa.Integer()),
    sa.Column("name", sa.String(32)),
    sa.Column("capabilities", ARRAY(sa.Text())),
    sa.Column("private", sa.Boolean()),
    sa.Column("default", sa.Boolean()),
    sa.Column("workspace", sa.Boolean()),
)

user_helper = sa.Table(
    "user",
    sa.MetaData(),
    sa.Column("id", sa.Integer()),
)

member_helper = sa.Table(
    "member",
    sa.MetaData(),
    sa.Column("user_id", sa.Integer()),
Exemple #12
0
class File(Object):
    __tablename__ = "file"

    id = db.Column(db.Integer, db.ForeignKey("object.id"), primary_key=True)
    file_name = db.Column(db.String, nullable=False, index=True)
    file_size = db.Column(db.Integer, nullable=False, index=True)
    file_type = db.Column(db.Text, nullable=False, index=True)
    md5 = db.Column(db.String(32), nullable=False, index=True)
    crc32 = db.Column(db.String(8), nullable=False, index=True)
    sha1 = db.Column(db.String(40), nullable=False, index=True)
    sha256 = db.Column(db.String(64), nullable=False, index=True, unique=True)
    sha512 = db.Column(db.String(128), nullable=False, index=True)
    # ssdeep is nullable due to lack of support in earlier versions
    ssdeep = db.Column(db.String(255), nullable=True, index=True)
    alt_names = db.Column(
        MutableList.as_mutable(ARRAY(db.String)), nullable=False, server_default="{}"
    )

    __mapper_args__ = {
        "polymorphic_identity": __tablename__,
    }

    @classmethod
    def get(cls, identifier):
        identifier = identifier.lower()
        file = File.query.filter(File.dhash == identifier)
        if file.scalar():
            return file
        return File.query.filter(
            or_(
                File.sha1 == identifier,
                File.sha256 == identifier,
                File.sha512 == identifier,
                File.md5 == identifier,
            )
        )

    @property
    def upload_stream(self):
        """
        Stream with file contents if a file is uploaded in current request.

        In that case, we don't need to download it from object storage.
        """
        return getattr(self, "_upload_stream", None)

    @upload_stream.setter
    def upload_stream(self, stream):
        setattr(self, "_upload_stream", stream)

    @classmethod
    def get_or_create(
        cls,
        file_name,
        file_stream,
        parent=None,
        attributes=None,
        share_with=None,
        analysis_id=None,
        tags=None,
    ):
        file_stream.seek(0, os.SEEK_END)
        file_size = file_stream.tell()
        if file_size == 0:
            raise EmptyFileError

        sha256 = calc_hash(file_stream, hashlib.sha256(), lambda h: h.hexdigest())
        file_obj = File(
            dhash=sha256,
            file_name=secure_filename(file_name),
            file_size=file_size,
            file_type=calc_magic(file_stream),
            crc32=calc_crc32(file_stream),
            md5=calc_hash(file_stream, hashlib.md5(), lambda h: h.hexdigest()),
            sha1=calc_hash(file_stream, hashlib.sha1(), lambda h: h.hexdigest()),
            sha256=sha256,
            sha512=calc_hash(file_stream, hashlib.sha512(), lambda h: h.hexdigest()),
            ssdeep=calc_ssdeep(file_stream),
        )

        file_obj, is_new = cls._get_or_create(
            file_obj,
            parent=parent,
            attributes=attributes,
            share_with=share_with,
            analysis_id=analysis_id,
            tags=tags,
        )

        # Check if add new alternative file name
        if not is_new:
            original_filename = secure_filename(file_name)
            if (
                file_obj.file_name != original_filename
                and original_filename not in file_obj.alt_names
            ):
                file_obj.alt_names.append(original_filename)

        if is_new:
            file_stream.seek(0, os.SEEK_SET)
            if app_config.mwdb.storage_provider == StorageProviderType.S3:
                get_minio_client(
                    app_config.mwdb.s3_storage_endpoint,
                    app_config.mwdb.s3_storage_access_key,
                    app_config.mwdb.s3_storage_secret_key,
                    app_config.mwdb.s3_storage_region_name,
                    app_config.mwdb.s3_storage_secure,
                    app_config.mwdb.s3_storage_iam_auth,
                ).put_object(
                    app_config.mwdb.s3_storage_bucket_name,
                    file_obj._calculate_path(),
                    file_stream,
                    file_size,
                )
            else:
                with open(file_obj._calculate_path(), "wb") as f:
                    shutil.copyfileobj(file_stream, f)

        file_obj.upload_stream = file_stream
        return file_obj, is_new

    def _calculate_path(self):
        if app_config.mwdb.storage_provider == StorageProviderType.DISK:
            upload_path = app_config.mwdb.uploads_folder
        else:
            upload_path = ""

        sample_sha256 = self.sha256.lower()

        if app_config.mwdb.hash_pathing:
            # example: uploads/9/f/8/6/9f86d0818...
            upload_path = os.path.join(upload_path, *list(sample_sha256)[0:4])

        if app_config.mwdb.storage_provider == StorageProviderType.DISK:
            upload_path = os.path.abspath(upload_path)
            os.makedirs(upload_path, mode=0o755, exist_ok=True)
        return os.path.join(upload_path, sample_sha256)

    def get_path(self):
        """
        Legacy method used to retrieve the path to the file contents.

        Creates NamedTemporaryFile if mwdb-core uses different type of
        storage than DISK and file size is too small to be written to
        disk by Werkzeug.

        Deprecated, use File.open() to get the stream with contents.
        """
        if app_config.mwdb.storage_provider == StorageProviderType.DISK:
            # Just return path of file stored in local file-system
            return self._calculate_path()

        if not self.upload_stream:
            raise ValueError("Can't retrieve local path for this file")

        if isinstance(self.upload_stream.name, str) or isinstance(
            self.upload_stream, bytes
        ):
            return self.upload_stream.name

        fd_path = get_fd_path(self.upload_stream)
        if fd_path:
            return fd_path

        # If not a file (BytesIO), copy contents to the named temporary file
        tmpfile = tempfile.NamedTemporaryFile()
        self.upload_stream.seek(0, os.SEEK_SET)
        shutil.copyfileobj(self.upload_stream, tmpfile)
        self.upload_stream.close()
        self.upload_stream = tmpfile
        return self.upload_stream.name

    def open(self):
        """
        Opens the file stream with contents.

        File stream must be closed using File.close.
        """
        if self.upload_stream is not None:
            # If file contents are uploaded in this request,
            # try to reuse the existing file instead of downloading it from Minio.
            if isinstance(self.upload_stream, io.BytesIO):
                return io.BytesIO(self.upload_stream.getbuffer())
            else:
                dupfd = os.dup(self.upload_stream.fileno())
                stream = os.fdopen(dupfd, "rb")
                stream.seek(0, os.SEEK_SET)
                return stream
        if app_config.mwdb.storage_provider == StorageProviderType.S3:
            return get_minio_client(
                app_config.mwdb.s3_storage_endpoint,
                app_config.mwdb.s3_storage_access_key,
                app_config.mwdb.s3_storage_secret_key,
                app_config.mwdb.s3_storage_region_name,
                app_config.mwdb.s3_storage_secure,
                app_config.mwdb.s3_storage_iam_auth,
            ).get_object(app_config.mwdb.s3_storage_bucket_name, self._calculate_path())
        elif app_config.mwdb.storage_provider == StorageProviderType.DISK:
            return open(self._calculate_path(), "rb")
        else:
            raise RuntimeError(
                f"StorageProvider {app_config.mwdb.storage_provider} is not supported"
            )

    def read(self):
        """
        Reads all bytes from the file
        """
        fh = self.open()
        try:
            return fh.read()
        finally:
            File.close(fh)

    def iterate(self, chunk_size=1024 * 256):
        """
        Iterates over bytes in the file contents
        """
        fh = self.open()
        try:
            if hasattr(fh, "stream"):
                yield from fh.stream(chunk_size)
            else:
                while True:
                    chunk = fh.read(chunk_size)
                    if chunk:
                        yield chunk
                    else:
                        return
        finally:
            File.close(fh)

    @staticmethod
    def close(fh):
        """
        Closes file stream opened by File.open
        """
        fh.close()
        if hasattr(fh, "release_conn"):
            fh.release_conn()

    def release_after_upload(self):
        """
        Release additional resources used by uploaded file.
        e.g. NamedTemporaryFile opened by get_path()
        """
        if self.upload_stream:
            self.upload_stream.close()
            self.upload_stream = None

    def generate_download_token(self):
        return generate_token(
            {"identifier": self.sha256},
            scope=AuthScope.download_file,
            expiration=60,
        )

    @staticmethod
    def get_by_download_token(download_token):
        download_req = verify_token(download_token, scope=AuthScope.download_file)
        if not download_req:
            return None
        return File.get(download_req["identifier"]).first()

    def _send_to_karton(self):
        return send_file_to_karton(self)
Exemple #13
0
class DocumentChange(Base):
    """This table contains "changes" that are shown in the homepage feed and
    the user profile.
    For example if a user creates a document or uploads images, an entry is
    added to this table.
    """
    __tablename__ = 'feed_document_changes'

    change_id = Column(Integer, primary_key=True)

    time = Column(DateTime(timezone=True), default=func.now(), nullable=False)

    # the actor: who did the change?
    user_id = Column(Integer,
                     ForeignKey(users_schema + '.user.id'),
                     nullable=False)
    user = relationship(User, primaryjoin=user_id == User.id)

    # the action type: what did the user do? e.g. create or update a document
    change_type = Column(feed_change_type, nullable=False)

    # the object: what document did the user change?
    document_id = Column(Integer,
                         ForeignKey(schema + '.documents.document_id'),
                         nullable=False)
    document = relationship(Document,
                            primaryjoin=document_id == Document.document_id)

    document_type = Column(String(1), nullable=False)

    # activities related to the document
    activities = Column(ArrayOfEnum(enums.activity_type),
                        nullable=False,
                        server_default='{}')

    # langs of the document locales
    langs = Column(ArrayOfEnum(enums.lang),
                   nullable=False,
                   server_default='{}')

    # For performance reasons, areas and users are referenced in simple integer
    # arrays in 'feed_document_changes', no PK-FK relations are set up.
    # To prevent inconsistencies, triggers are used.

    # ids of the areas where this change happened
    area_ids = Column(ARRAY(Integer), nullable=False, server_default='{}')

    # ids of the users that were involved in this change (e.g. the user that
    # created a document, but also the participants of an outing)
    user_ids = Column(ARRAY(Integer), nullable=False, server_default='{}')

    # images
    image1_id = Column(Integer, ForeignKey(schema + '.images.document_id'))
    image1 = relationship(Image, primaryjoin=image1_id == Image.document_id)
    image2_id = Column(Integer, ForeignKey(schema + '.images.document_id'))
    image2 = relationship(Image, primaryjoin=image2_id == Image.document_id)
    image3_id = Column(Integer, ForeignKey(schema + '.images.document_id'))
    image3 = relationship(Image, primaryjoin=image3_id == Image.document_id)
    more_images = Column(Boolean, server_default='FALSE', nullable=False)

    __table_args__ = (
        # the queries on the feed table always order by time (desc) and
        # change_id, therefore create an index for these two columns.
        Index('ix_guidebook_feed_document_changes_time_and_change_id',
              time.desc(),
              change_id,
              postgresql_using='btree'),
        Base.__table_args__)

    def copy(self):
        copy = DocumentChange()
        copy.document_id = self.document_id
        copy.document_type = self.document_type
        copy.change_type = self.change_type
        copy.activities = self.activities
        copy.langs = self.langs
        copy.area_ids = self.area_ids
        if copy.document_type == OUTING_TYPE:
            copy.user_ids = self.user_ids
        else:
            copy.user_ids = []
        return copy