Ejemplo n.º 1
0
class ApiKey(ModelTimestampsMixin, BaseModel):
    org = peewee.ForeignKeyField(Organization)
    api_key = peewee.CharField(index=True, default=lambda: generate_token(40))
    active = peewee.BooleanField(default=True)
    object_type = peewee.CharField()
    object_id = peewee.IntegerField()
    object = GFKField('object_type', 'object_id')
    created_by = peewee.ForeignKeyField(User, null=True)

    class Meta:
        db_table = 'api_keys'
        indexes = (
            (('object_type', 'object_id'), False),
        )

    @classmethod
    def get_by_api_key(cls, api_key):
        return cls.get(cls.api_key==api_key, cls.active==True)

    @classmethod
    def get_by_object(cls, object):
        return cls.select().where(cls.object_type==object._meta.db_table, cls.object_id==object.id, cls.active==True).first()

    @classmethod
    def create_for_object(cls, object, user):
        return cls.create(org=user.org, object=object, created_by=user)
Ejemplo n.º 2
0
class ApiKey(TimestampMixin, GFKBase, db.Model):
    id = Column(db.Integer, primary_key=True)
    org_id = Column(db.Integer, db.ForeignKey("organizations.id"))
    org = db.relationship(Organization)
    api_key = Column(db.String(255),
                     index=True,
                     default=lambda: generate_token(40))
    active = Column(db.Boolean, default=True)
    # 'object' provided by GFKBase
    created_by_id = Column(db.Integer,
                           db.ForeignKey("users.id"),
                           nullable=True)
    created_by = db.relationship(User)

    __tablename__ = 'api_keys'
    __table_args__ = (db.Index('api_keys_object_type_object_id', 'object_type',
                               'object_id'), )

    @classmethod
    def get_by_api_key(cls, api_key):
        return cls.query.filter(cls.api_key == api_key,
                                cls.active == True).one()

    @classmethod
    def get_by_object(cls, object):
        return cls.query.filter(
            cls.object_type == object.__class__.__tablename__,
            cls.object_id == object.id, cls.active == True).first()

    @classmethod
    def create_for_object(cls, object, user):
        k = cls(org=user.org, object=object, created_by=user)
        db.session.add(k)
        return k
Ejemplo n.º 3
0
    def pre_save(self, created):
        super(User, self).pre_save(created)

        if not self.api_key:
            self.api_key = generate_token(40)
Ejemplo n.º 4
0
 def test_format(self):
     token = generate_token(40)
     self.assertRegexpMatches(token, r"[a-zA-Z0-9]{40}")
Ejemplo n.º 5
0
class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin,
           PermissionsCheckMixin):
    id = Column(db.Integer, primary_key=True)
    org_id = Column(db.Integer, db.ForeignKey('organizations.id'))
    org = db.relationship("Organization",
                          backref=db.backref("users", lazy="dynamic"))
    name = Column(db.String(320))
    email = Column(EmailType)
    _profile_image_url = Column('profile_image_url',
                                db.String(320),
                                nullable=True)
    password_hash = Column(db.String(128), nullable=True)
    group_ids = Column('groups',
                       MutableList.as_mutable(postgresql.ARRAY(db.Integer)),
                       nullable=True)
    api_key = Column(db.String(40),
                     default=lambda: generate_token(40),
                     unique=True)

    disabled_at = Column(db.DateTime(True), default=None, nullable=True)
    details = Column(MutableDict.as_mutable(postgresql.JSON),
                     nullable=True,
                     server_default='{}',
                     default={})
    active_at = json_cast_property(db.DateTime(True),
                                   'details',
                                   'active_at',
                                   default=None)
    is_invitation_pending = json_cast_property(db.Boolean(True),
                                               'details',
                                               'is_invitation_pending',
                                               default=False)
    is_email_verified = json_cast_property(db.Boolean(True),
                                           'details',
                                           'is_email_verified',
                                           default=True)

    __tablename__ = 'users'
    __table_args__ = (db.Index('users_org_id_email',
                               'org_id',
                               'email',
                               unique=True), )

    def __str__(self):
        return '%s (%s)' % (self.name, self.email)

    def __init__(self, *args, **kwargs):
        if kwargs.get('email') is not None:
            kwargs['email'] = kwargs['email'].lower()
        super(User, self).__init__(*args, **kwargs)

    @property
    def is_disabled(self):
        return self.disabled_at is not None

    def disable(self):
        self.disabled_at = db.func.now()

    def enable(self):
        self.disabled_at = None

    def regenerate_api_key(self):
        self.api_key = generate_token(40)

    def to_dict(self, with_api_key=False):
        profile_image_url = self.profile_image_url
        if self.is_disabled:
            assets = app.extensions['webpack']['assets'] or {}
            path = 'images/avatar.svg'
            profile_image_url = url_for('static',
                                        filename=assets.get(path, path))

        d = {
            'id': self.id,
            'name': self.name,
            'email': self.email,
            'profile_image_url': profile_image_url,
            'groups': self.group_ids,
            'updated_at': self.updated_at,
            'created_at': self.created_at,
            'disabled_at': self.disabled_at,
            'is_disabled': self.is_disabled,
            'active_at': self.active_at,
            'is_invitation_pending': self.is_invitation_pending,
            'is_email_verified': self.is_email_verified,
        }

        if self.password_hash is None:
            d['auth_type'] = 'external'
        else:
            d['auth_type'] = 'password'

        if with_api_key:
            d['api_key'] = self.api_key

        return d

    def is_api_user(self):
        return False

    @property
    def profile_image_url(self):
        if self._profile_image_url is not None:
            return self._profile_image_url

        email_md5 = hashlib.md5(self.email.lower().encode()).hexdigest()
        return "https://www.gravatar.com/avatar/{}?s=40&d=identicon".format(
            email_md5)

    @property
    def permissions(self):
        # TODO: this should be cached.
        return list(
            itertools.chain(*[
                g.permissions
                for g in Group.query.filter(Group.id.in_(self.group_ids))
            ]))

    @classmethod
    def get_by_org(cls, org):
        return cls.query.filter(cls.org == org)

    @classmethod
    def get_by_id(cls, _id):
        return cls.query.filter(cls.id == _id).one()

    @classmethod
    def get_by_email_and_org(cls, email, org):
        return cls.get_by_org(org).filter(cls.email == email).one()

    @classmethod
    def get_by_api_key_and_org(cls, api_key, org):
        return cls.get_by_org(org).filter(cls.api_key == api_key).one()

    @classmethod
    def all(cls, org):
        return cls.get_by_org(org).filter(cls.disabled_at.is_(None))

    @classmethod
    def all_disabled(cls, org):
        return cls.get_by_org(org).filter(cls.disabled_at.isnot(None))

    @classmethod
    def search(cls, base_query, term):
        term = '%{}%'.format(term)
        search_filter = or_(cls.name.ilike(term), cls.email.like(term))

        return base_query.filter(search_filter)

    @classmethod
    def pending(cls, base_query, pending):
        if pending:
            return base_query.filter(cls.is_invitation_pending.is_(True))
        else:
            return base_query.filter(cls.is_invitation_pending.isnot(
                True))  # check for both `false`/`null`

    @classmethod
    def find_by_email(cls, email):
        return cls.query.filter(cls.email == email)

    def hash_password(self, password):
        self.password_hash = pwd_context.encrypt(password)

    def verify_password(self, password):
        return self.password_hash and pwd_context.verify(
            password, self.password_hash)

    def update_group_assignments(self, group_names):
        groups = Group.find_by_name(self.org, group_names)
        groups.append(self.org.default_group)
        self.group_ids = [g.id for g in groups]
        db.session.add(self)
        db.session.commit()

    def has_access(self, obj, access_type):
        return AccessPermission.exists(obj, access_type, grantee=self)

    def get_id(self):
        identity = hashlib.md5("{},{}".format(
            self.email, self.password_hash).encode()).hexdigest()
        return "{0}-{1}".format(self.id, identity)
Ejemplo n.º 6
0
 def regenerate_api_key(self):
     self.api_key = generate_token(40)
Ejemplo n.º 7
0
class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
    id = Column(db.Integer, primary_key=True)
    version = Column(db.Integer, default=1)
    org_id = Column(db.Integer, db.ForeignKey('organizations.id'))
    org = db.relationship(Organization, backref="queries")
    data_source_id = Column(db.Integer,
                            db.ForeignKey("data_sources.id"),
                            nullable=True)
    data_source = db.relationship(DataSource, backref='queries')
    latest_query_data_id = Column(db.Integer,
                                  db.ForeignKey("query_results.id"),
                                  nullable=True)
    latest_query_data = db.relationship(QueryResult)
    name = Column(db.String(255))
    description = Column(db.String(4096), nullable=True)
    query_text = Column("query", db.Text)
    query_hash = Column(db.String(32))
    api_key = Column(db.String(40), default=lambda: generate_token(40))
    user_id = Column(db.Integer, db.ForeignKey("users.id"))
    user = db.relationship(User, foreign_keys=[user_id])
    last_modified_by_id = Column(db.Integer,
                                 db.ForeignKey('users.id'),
                                 nullable=True)
    last_modified_by = db.relationship(User,
                                       backref="modified_queries",
                                       foreign_keys=[last_modified_by_id])
    is_archived = Column(db.Boolean, default=False, index=True)
    is_draft = Column(db.Boolean, default=True, index=True)
    schedule = Column(MutableDict.as_mutable(PseudoJSON), nullable=True)
    schedule_failures = Column(db.Integer, default=0)
    visualizations = db.relationship("Visualization",
                                     cascade="all, delete-orphan")
    options = Column(MutableDict.as_mutable(PseudoJSON), default={})
    search_vector = Column(TSVectorType('id',
                                        'name',
                                        'description',
                                        'query',
                                        weights={
                                            'name': 'A',
                                            'id': 'B',
                                            'description': 'C',
                                            'query': 'D'
                                        }),
                           nullable=True)
    tags = Column('tags',
                  MutableList.as_mutable(postgresql.ARRAY(db.Unicode)),
                  nullable=True)

    query_class = SearchBaseQuery
    __tablename__ = 'queries'
    __mapper_args__ = {
        "version_id_col": version,
        'version_id_generator': False
    }

    def __str__(self):
        return text_type(self.id)

    def archive(self, user=None):
        db.session.add(self)
        self.is_archived = True
        self.schedule = None

        for vis in self.visualizations:
            for w in vis.widgets:
                db.session.delete(w)

        for a in self.alerts:
            db.session.delete(a)

        if user:
            self.record_changes(user)

    def regenerate_api_key(self):
        self.api_key = generate_token(40)

    @classmethod
    def create(cls, **kwargs):
        query = cls(**kwargs)
        db.session.add(
            Visualization(query_rel=query,
                          name="Table",
                          description='',
                          type="TABLE",
                          options="{}"))
        return query

    @classmethod
    def all_queries(cls,
                    group_ids,
                    user_id=None,
                    include_drafts=False,
                    include_archived=False):
        query_ids = (db.session.query(distinct(cls.id)).join(
            DataSourceGroup,
            Query.data_source_id == DataSourceGroup.data_source_id).filter(
                Query.is_archived.is_(include_archived)).filter(
                    DataSourceGroup.group_id.in_(group_ids)))
        queries = (
            cls.query.options(
                joinedload(Query.user),
                joinedload(Query.latest_query_data).load_only(
                    'runtime',
                    'retrieved_at',
                )).filter(cls.id.in_(query_ids))
            # Adding outer joins to be able to order by relationship
            .outerjoin(User, User.id == Query.user_id).outerjoin(
                QueryResult,
                QueryResult.id == Query.latest_query_data_id).options(
                    contains_eager(Query.user),
                    contains_eager(Query.latest_query_data),
                ))

        if not include_drafts:
            queries = queries.filter(
                or_(Query.is_draft.is_(False), Query.user_id == user_id))
        return queries

    @classmethod
    def favorites(cls, user, base_query=None):
        if base_query is None:
            base_query = cls.all_queries(user.group_ids,
                                         user.id,
                                         include_drafts=True)
        return base_query.join((Favorite,
                                and_(Favorite.object_type == u'Query',
                                     Favorite.object_id == Query.id))).filter(
                                         Favorite.user_id == user.id)

    @classmethod
    def all_tags(cls, user, include_drafts=False):
        queries = cls.all_queries(
            group_ids=user.group_ids,
            user_id=user.id,
            include_drafts=include_drafts,
        )

        tag_column = func.unnest(cls.tags).label('tag')
        usage_count = func.count(1).label('usage_count')

        query = (db.session.query(
            tag_column, usage_count).group_by(tag_column).filter(
                Query.id.in_(queries.options(load_only('id')))).order_by(
                    usage_count.desc()))
        return query

    @classmethod
    def by_user(cls, user):
        return cls.all_queries(user.group_ids,
                               user.id).filter(Query.user == user)

    @classmethod
    def by_api_key(cls, api_key):
        return cls.query.filter(cls.api_key == api_key).one()

    @classmethod
    def past_scheduled_queries(cls):
        now = utils.utcnow()
        queries = (Query.query.filter(Query.schedule.isnot(None)).order_by(
            Query.id))
        return filter(
            lambda x: x.schedule["until"] is not None and pytz.utc.localize(
                datetime.datetime.strptime(x.schedule['until'], '%Y-%m-%d')) <=
            now, queries)

    @classmethod
    def outdated_queries(cls):
        queries = (Query.query.options(
            joinedload(
                Query.latest_query_data).load_only('retrieved_at')).filter(
                    Query.schedule.isnot(None)).order_by(Query.id))

        now = utils.utcnow()
        outdated_queries = {}
        scheduled_queries_executions.refresh()

        for query in queries:
            if query.schedule['interval'] is None:
                continue

            if query.schedule['until'] is not None:
                schedule_until = pytz.utc.localize(
                    datetime.datetime.strptime(query.schedule['until'],
                                               '%Y-%m-%d'))

                if schedule_until <= now:
                    continue

            if query.latest_query_data:
                retrieved_at = query.latest_query_data.retrieved_at
            else:
                retrieved_at = now

            retrieved_at = scheduled_queries_executions.get(
                query.id) or retrieved_at

            if should_schedule_next(retrieved_at, now,
                                    query.schedule['interval'],
                                    query.schedule['time'],
                                    query.schedule['day_of_week'],
                                    query.schedule_failures):
                key = "{}:{}".format(query.query_hash, query.data_source_id)
                outdated_queries[key] = query

        return outdated_queries.values()

    @classmethod
    def search(cls,
               term,
               group_ids,
               user_id=None,
               include_drafts=False,
               limit=None,
               include_archived=False,
               multi_byte_search=False):
        all_queries = cls.all_queries(
            group_ids,
            user_id=user_id,
            include_drafts=include_drafts,
            include_archived=include_archived,
        )

        if multi_byte_search:
            # Since tsvector doesn't work well with CJK languages, use `ilike` too
            pattern = u'%{}%'.format(term)
            return all_queries.filter(
                or_(cls.name.ilike(pattern),
                    cls.description.ilike(pattern))).order_by(
                        Query.id).limit(limit)

        # sort the result using the weight as defined in the search vector column
        return all_queries.search(term, sort=True).limit(limit)

    @classmethod
    def search_by_user(cls, term, user, limit=None):
        return cls.by_user(user).search(term, sort=True).limit(limit)

    @classmethod
    def recent(cls, group_ids, user_id=None, limit=20):
        query = (cls.query.filter(
            Event.created_at > (db.func.current_date() - 7)).join(
                Event, Query.id == Event.object_id.cast(db.Integer)).join(
                    DataSourceGroup, Query.data_source_id ==
                    DataSourceGroup.data_source_id).filter(
                        Event.action.in_([
                            'edit', 'execute', 'edit_name', 'edit_description',
                            'view_source'
                        ]), Event.object_id != None,
                        Event.object_type == 'query',
                        DataSourceGroup.group_id.in_(group_ids),
                        or_(Query.is_draft == False, Query.user_id == user_id),
                        Query.is_archived == False).group_by(
                            Event.object_id,
                            Query.id).order_by(db.desc(db.func.count(0))))

        if user_id:
            query = query.filter(Event.user_id == user_id)

        query = query.limit(limit)

        return query

    @classmethod
    def get_by_id(cls, _id):
        return cls.query.filter(cls.id == _id).one()

    @classmethod
    def all_groups_for_query_ids(cls, query_ids):
        query = """SELECT group_id, view_only
                   FROM queries
                   JOIN data_source_groups ON queries.data_source_id = data_source_groups.data_source_id
                   WHERE queries.id in :ids"""

        return db.session.execute(query, {'ids': tuple(query_ids)}).fetchall()

    def fork(self, user):
        forked_list = [
            'org', 'data_source', 'latest_query_data', 'description',
            'query_text', 'query_hash', 'options'
        ]
        kwargs = {a: getattr(self, a) for a in forked_list}

        # Query.create will add default TABLE visualization, so use constructor to create bare copy of query
        forked_query = Query(name=u'Copy of (#{}) {}'.format(
            self.id, self.name),
                             user=user,
                             **kwargs)

        for v in self.visualizations:
            forked_v = v.copy()
            forked_v['query_rel'] = forked_query
            fv = Visualization(
                **forked_v
            )  # it will magically add it to `forked_query.visualizations`
            db.session.add(fv)

        db.session.add(forked_query)
        return forked_query

    @property
    def runtime(self):
        return self.latest_query_data.runtime

    @property
    def retrieved_at(self):
        return self.latest_query_data.retrieved_at

    @property
    def groups(self):
        if self.data_source is None:
            return {}

        return self.data_source.groups

    @hybrid_property
    def lowercase_name(self):
        "Optional property useful for sorting purposes."
        return self.name.lower()

    @lowercase_name.expression
    def lowercase_name(cls):
        "The SQLAlchemy expression for the property above."
        return func.lower(cls.name)

    @property
    def parameters(self):
        return self.options.get("parameters", [])

    @property
    def parameterized(self):
        return ParameterizedQuery(self.query_text, self.parameters)
Ejemplo n.º 8
0
class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
    id = primary_key("Query")
    version = Column(db.Integer, default=1)
    org_id = Column(key_type("Organization"), db.ForeignKey("organizations.id"))
    org = db.relationship(Organization, backref="queries")
    data_source_id = Column(key_type("DataSource"), db.ForeignKey("data_sources.id"), nullable=True)
    data_source = db.relationship(DataSource, backref="queries")
    latest_query_data_id = Column(
        key_type("QueryResult"), db.ForeignKey("query_results.id"), nullable=True
    )
    latest_query_data = db.relationship(QueryResult)
    name = Column(db.String(255))
    description = Column(db.String(4096), nullable=True)
    query_text = Column("query", db.Text)
    query_hash = Column(db.String(32))
    api_key = Column(db.String(40), default=lambda: generate_token(40))
    user_id = Column(key_type("User"), db.ForeignKey("users.id"))
    user = db.relationship(User, foreign_keys=[user_id])
    last_modified_by_id = Column(key_type("User"), db.ForeignKey("users.id"), nullable=True)
    last_modified_by = db.relationship(
        User, backref="modified_queries", foreign_keys=[last_modified_by_id]
    )
    is_archived = Column(db.Boolean, default=False, index=True)
    is_draft = Column(db.Boolean, default=True, index=True)
    schedule = Column(MutableDict.as_mutable(PseudoJSON), nullable=True)
    interval = pseudo_json_cast_property(db.Integer, "schedule", "interval", default=0)
    schedule_failures = Column(db.Integer, default=0)
    visualizations = db.relationship("Visualization", cascade="all, delete-orphan")
    options = Column(MutableDict.as_mutable(PseudoJSON), default={})
    search_vector = Column(
        TSVectorType(
            "id",
            "name",
            "description",
            "query",
            weights={"name": "A", "id": "B", "description": "C", "query": "D"},
        ),
        nullable=True,
    )
    tags = Column(
        "tags", MutableList.as_mutable(postgresql.ARRAY(db.Unicode)), nullable=True
    )

    query_class = SearchBaseQuery
    __tablename__ = "queries"
    __mapper_args__ = {"version_id_col": version, "version_id_generator": False}

    def __str__(self):
        return str(self.id)

    def archive(self, user=None):
        db.session.add(self)
        self.is_archived = True
        self.schedule = None

        for vis in self.visualizations:
            for w in vis.widgets:
                db.session.delete(w)

        for a in self.alerts:
            db.session.delete(a)

        if user:
            self.record_changes(user)

    def regenerate_api_key(self):
        self.api_key = generate_token(40)

    @classmethod
    def create(cls, **kwargs):
        query = cls(**kwargs)
        db.session.add(
            Visualization(
                query_rel=query,
                name="Table",
                description="",
                type="TABLE",
                options="{}",
            )
        )
        return query

    @classmethod
    def all_queries(
        cls, group_ids, user_id=None, include_drafts=False, include_archived=False
    ):
        query_ids = (
            db.session.query(distinct(cls.id))
            .join(
                DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id
            )
            .filter(Query.is_archived.is_(include_archived))
            .filter(DataSourceGroup.group_id.in_(group_ids))
        )
        queries = (
            cls.query.options(
                joinedload(Query.user),
                joinedload(Query.latest_query_data).load_only(
                    "runtime", "retrieved_at"
                ),
            )
            .filter(cls.id.in_(query_ids))
            # Adding outer joins to be able to order by relationship
            .outerjoin(User, User.id == Query.user_id)
            .outerjoin(QueryResult, QueryResult.id == Query.latest_query_data_id)
            .options(
                contains_eager(Query.user), contains_eager(Query.latest_query_data)
            )
        )

        if not include_drafts:
            queries = queries.filter(
                or_(Query.is_draft.is_(False), Query.user_id == user_id)
            )
        return queries

    @classmethod
    def favorites(cls, user, base_query=None):
        if base_query is None:
            base_query = cls.all_queries(user.group_ids, user.id, include_drafts=True)
        return base_query.join(
            (
                Favorite,
                and_(Favorite.object_type == "Query", Favorite.object_id == Query.id),
            )
        ).filter(Favorite.user_id == user.id)

    @classmethod
    def all_tags(cls, user, include_drafts=False):
        queries = cls.all_queries(
            group_ids=user.group_ids, user_id=user.id, include_drafts=include_drafts
        )

        tag_column = func.unnest(cls.tags).label("tag")
        usage_count = func.count(1).label("usage_count")

        query = (
            db.session.query(tag_column, usage_count)
            .group_by(tag_column)
            .filter(Query.id.in_(queries.options(load_only("id"))))
            .order_by(usage_count.desc())
        )
        return query

    @classmethod
    def by_user(cls, user):
        return cls.all_queries(user.group_ids, user.id).filter(Query.user == user)

    @classmethod
    def by_api_key(cls, api_key):
        return cls.query.filter(cls.api_key == api_key).one()

    @classmethod
    def past_scheduled_queries(cls):
        now = utils.utcnow()
        queries = Query.query.filter(Query.schedule.isnot(None)).order_by(Query.id)
        return [
            query
            for query in queries
            if query.schedule["until"] is not None
            and pytz.utc.localize(
                datetime.datetime.strptime(query.schedule["until"], "%Y-%m-%d")
            )
            <= now
        ]

    @classmethod
    def outdated_queries(cls):
        queries = (
            Query.query.options(
                joinedload(Query.latest_query_data).load_only("retrieved_at")
            )
            .filter(Query.schedule.isnot(None))
            .order_by(Query.id)
            .all()
        )

        now = utils.utcnow()
        outdated_queries = {}
        scheduled_queries_executions.refresh()

        for query in queries:
            try:
                if query.schedule.get("disabled"):
                    continue

                if query.schedule["until"]:
                    schedule_until = pytz.utc.localize(
                        datetime.datetime.strptime(query.schedule["until"], "%Y-%m-%d")
                    )

                    if schedule_until <= now:
                        continue

                retrieved_at = scheduled_queries_executions.get(query.id) or (
                    query.latest_query_data and query.latest_query_data.retrieved_at
                )

                if should_schedule_next(
                    retrieved_at or now,
                    now,
                    query.schedule["interval"],
                    query.schedule["time"],
                    query.schedule["day_of_week"],
                    query.schedule_failures,
                ):
                    key = "{}:{}".format(query.query_hash, query.data_source_id)
                    outdated_queries[key] = query
            except Exception as e:
                query.schedule["disabled"] = True
                db.session.commit()

                message = (
                    "Could not determine if query %d is outdated due to %s. The schedule for this query has been disabled."
                    % (query.id, repr(e))
                )
                logging.info(message)
                sentry.capture_exception(
                    type(e)(message).with_traceback(e.__traceback__)
                )

        return list(outdated_queries.values())

    @classmethod
    def search(
        cls,
        term,
        group_ids,
        user_id=None,
        include_drafts=False,
        limit=None,
        include_archived=False,
        multi_byte_search=False,
    ):
        all_queries = cls.all_queries(
            group_ids,
            user_id=user_id,
            include_drafts=include_drafts,
            include_archived=include_archived,
        )

        if multi_byte_search:
            # Since tsvector doesn't work well with CJK languages, use `ilike` too
            pattern = "%{}%".format(term)
            return (
                all_queries.filter(
                    or_(cls.name.ilike(pattern), cls.description.ilike(pattern))
                )
                .order_by(Query.id)
                .limit(limit)
            )

        # sort the result using the weight as defined in the search vector column
        return all_queries.search(term, sort=True).limit(limit)

    @classmethod
    def search_by_user(cls, term, user, limit=None):
        return cls.by_user(user).search(term, sort=True).limit(limit)

    @classmethod
    def recent(cls, group_ids, user_id=None, limit=20):
        query = (
            cls.query.filter(Event.created_at > (db.func.current_date() - 7))
            .join(Event, Query.id == Event.object_id.cast(db.Integer))
            .join(
                DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id
            )
            .filter(
                Event.action.in_(
                    ["edit", "execute", "edit_name", "edit_description", "view_source"]
                ),
                Event.object_id != None,
                Event.object_type == "query",
                DataSourceGroup.group_id.in_(group_ids),
                or_(Query.is_draft == False, Query.user_id == user_id),
                Query.is_archived == False,
            )
            .group_by(Event.object_id, Query.id)
            .order_by(db.desc(db.func.count(0)))
        )

        if user_id:
            query = query.filter(Event.user_id == user_id)

        query = query.limit(limit)

        return query

    @classmethod
    def get_by_id(cls, _id):
        return cls.query.filter(cls.id == _id).one()

    @classmethod
    def all_groups_for_query_ids(cls, query_ids):
        query = """SELECT group_id, view_only
                   FROM queries
                   JOIN data_source_groups ON queries.data_source_id = data_source_groups.data_source_id
                   WHERE queries.id in :ids"""

        return db.session.execute(query, {"ids": tuple(query_ids)}).fetchall()

    @classmethod
    def update_latest_result(cls, query_result):
        # TODO: Investigate how big an impact this select-before-update makes.
        queries = Query.query.filter(
            Query.query_hash == query_result.query_hash,
            Query.data_source == query_result.data_source,
        )

        for q in queries:
            q.latest_query_data = query_result
            # don't auto-update the updated_at timestamp
            q.skip_updated_at = True
            db.session.add(q)

        query_ids = [q.id for q in queries]
        logging.info(
            "Updated %s queries with result (%s).",
            len(query_ids),
            query_result.query_hash,
        )

        return query_ids

    def fork(self, user):
        forked_list = [
            "org",
            "data_source",
            "latest_query_data",
            "description",
            "query_text",
            "query_hash",
            "options",
            "tags",
        ]
        kwargs = {a: getattr(self, a) for a in forked_list}

        # Query.create will add default TABLE visualization, so use constructor to create bare copy of query
        forked_query = Query(
            name="Copy of (#{}) {}".format(self.id, self.name), user=user, **kwargs
        )

        for v in sorted(self.visualizations, key=lambda v: v.id):
            forked_v = v.copy()
            forked_v["query_rel"] = forked_query
            fv = Visualization(
                **forked_v
            )  # it will magically add it to `forked_query.visualizations`
            db.session.add(fv)

        db.session.add(forked_query)
        return forked_query

    @property
    def runtime(self):
        return self.latest_query_data.runtime

    @property
    def retrieved_at(self):
        return self.latest_query_data.retrieved_at

    @property
    def groups(self):
        if self.data_source is None:
            return {}

        return self.data_source.groups

    @hybrid_property
    def lowercase_name(self):
        "Optional property useful for sorting purposes."
        return self.name.lower()

    @lowercase_name.expression
    def lowercase_name(cls):
        "The SQLAlchemy expression for the property above."
        return func.lower(cls.name)

    @property
    def parameters(self):
        return self.options.get("parameters", [])

    @property
    def parameterized(self):
        return ParameterizedQuery(self.query_text, self.parameters, self.org)

    @property
    def dashboard_api_keys(self):
        query = """SELECT api_keys.api_key
                   FROM api_keys
                   JOIN dashboards ON object_id = dashboards.id
                   JOIN widgets ON dashboards.id = widgets.dashboard_id
                   JOIN visualizations ON widgets.visualization_id = visualizations.id
                   WHERE object_type='dashboards'
                     AND active=true
                     AND visualizations.query_id = :id"""

        api_keys = db.session.execute(query, {"id": self.id}).fetchall()
        return [api_key[0] for api_key in api_keys]

    def update_query_hash(self):
        should_apply_auto_limit = self.options.get("apply_auto_limit", False) if self.options else False
        query_runner = self.data_source.query_runner if self.data_source else BaseQueryRunner({})
        self.query_hash = query_runner.gen_query_hash(self.query_text, should_apply_auto_limit)
Ejemplo n.º 9
0
 def regenerate_api_key(self):
     self.api_key = generate_token(40)
Ejemplo n.º 10
0
 def test_format(self):
     token = generate_token(40)
     self.assertRegexpMatches(token, r"[a-zA-Z0-9]{40}")
Ejemplo n.º 11
0
class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin,
           PermissionsCheckMixin):
    id = primary_key("User")
    org_id = Column(key_type("Organization"),
                    db.ForeignKey("organizations.id"))
    org = db.relationship("Organization",
                          backref=db.backref("users", lazy="dynamic"))
    name = Column(db.String(320))
    email = Column(EmailType)
    password_hash = Column(db.String(128), nullable=True)
    group_ids = Column("groups",
                       MutableList.as_mutable(
                           postgresql.ARRAY(key_type("Group"))),
                       nullable=True)
    api_key = Column(db.String(40),
                     default=lambda: generate_token(40),
                     unique=True)

    disabled_at = Column(db.DateTime(True), default=None, nullable=True)
    details = Column(
        MutableDict.as_mutable(postgresql.JSONB),
        nullable=True,
        server_default="{}",
        default={},
    )
    active_at = json_cast_property(db.DateTime(True),
                                   "details",
                                   "active_at",
                                   default=None)
    _profile_image_url = json_cast_property(db.Text(),
                                            "details",
                                            "profile_image_url",
                                            default=None)
    is_invitation_pending = json_cast_property(db.Boolean(True),
                                               "details",
                                               "is_invitation_pending",
                                               default=False)
    is_email_verified = json_cast_property(db.Boolean(True),
                                           "details",
                                           "is_email_verified",
                                           default=True)

    __tablename__ = "users"
    __table_args__ = (db.Index("users_org_id_email",
                               "org_id",
                               "email",
                               unique=True), )

    def __str__(self):
        return "%s (%s)" % (self.name, self.email)

    def __init__(self, *args, **kwargs):
        if kwargs.get("email") is not None:
            kwargs["email"] = kwargs["email"].lower()
        super(User, self).__init__(*args, **kwargs)

    @property
    def is_disabled(self):
        return self.disabled_at is not None

    def disable(self):
        self.disabled_at = db.func.now()

    def enable(self):
        self.disabled_at = None

    def regenerate_api_key(self):
        self.api_key = generate_token(40)

    def to_dict(self, with_api_key=False):
        profile_image_url = self.profile_image_url
        if self.is_disabled:
            assets = app.extensions["webpack"]["assets"] or {}
            path = "images/avatar.svg"
            profile_image_url = url_for("static",
                                        filename=assets.get(path, path))

        d = {
            "id": self.id,
            "name": self.name,
            "email": self.email,
            "profile_image_url": profile_image_url,
            "groups": self.group_ids,
            "updated_at": self.updated_at,
            "created_at": self.created_at,
            "disabled_at": self.disabled_at,
            "is_disabled": self.is_disabled,
            "active_at": self.active_at,
            "is_invitation_pending": self.is_invitation_pending,
            "is_email_verified": self.is_email_verified,
        }

        if self.password_hash is None:
            d["auth_type"] = "external"
        else:
            d["auth_type"] = "password"

        if with_api_key:
            d["api_key"] = self.api_key

        return d

    def is_api_user(self):
        return False

    @property
    def profile_image_url(self):
        if self._profile_image_url is not None:
            return self._profile_image_url

        email_md5 = hashlib.md5(self.email.lower().encode()).hexdigest()
        return "//sdn.geekzu.org/avatar/{}?s=40&d=identicon".format(email_md5)

    @property
    def permissions(self):
        # TODO: this should be cached.
        return list(
            itertools.chain(*[
                g.permissions
                for g in Group.query.filter(Group.id.in_(self.group_ids))
            ]))

    @classmethod
    def get_by_org(cls, org):
        return cls.query.filter(cls.org == org)

    @classmethod
    def get_by_id(cls, _id):
        return cls.query.filter(cls.id == _id).one()

    @classmethod
    def get_by_email_and_org(cls, email, org):
        return cls.get_by_org(org).filter(cls.email == email).one()

    @classmethod
    def get_by_api_key_and_org(cls, api_key, org):
        return cls.get_by_org(org).filter(cls.api_key == api_key).one()

    @classmethod
    def all(cls, org):
        return cls.get_by_org(org).filter(cls.disabled_at.is_(None))

    @classmethod
    def all_disabled(cls, org):
        return cls.get_by_org(org).filter(cls.disabled_at.isnot(None))

    @classmethod
    def search(cls, base_query, term):
        term = "%{}%".format(term)
        search_filter = or_(cls.name.ilike(term), cls.email.like(term))

        return base_query.filter(search_filter)

    @classmethod
    def pending(cls, base_query, pending):
        if pending:
            return base_query.filter(cls.is_invitation_pending.is_(True))
        else:
            return base_query.filter(cls.is_invitation_pending.isnot(
                True))  # check for both `false`/`null`

    @classmethod
    def find_by_email(cls, email):
        return cls.query.filter(cls.email == email)

    def hash_password(self, password):
        self.password_hash = pwd_context.encrypt(password)

    def verify_password(self, password):
        return self.password_hash and pwd_context.verify(
            password, self.password_hash)

    def update_group_assignments(self, group_names):
        groups = Group.find_by_name(self.org, group_names)
        groups.append(self.org.default_group)
        self.group_ids = [g.id for g in groups]
        db.session.add(self)
        db.session.commit()

    def has_access(self, obj, access_type):
        return AccessPermission.exists(obj, access_type, grantee=self)

    def get_id(self):
        identity = hashlib.md5("{},{}".format(
            self.email, self.password_hash).encode()).hexdigest()
        return "{0}-{1}".format(self.id, identity)