예제 #1
0
def fix_CollectionItem_v0_constraint(db_conn):
    """Add the forgotten Constraint on CollectionItem"""

    global collectionitem_unique_constraint_done
    if collectionitem_unique_constraint_done:
        # Reset it. Maybe the whole thing gets run again
        # For a different db?
        collectionitem_unique_constraint_done = False
        return

    metadata = MetaData(bind=db_conn.bind)

    CollectionItem_table = inspect_table(metadata, 'core__collection_items')

    constraint = UniqueConstraint('collection', 'media_entry',
        name='core__collection_items_collection_media_entry_key',
        table=CollectionItem_table)

    try:
        constraint.create()
    except ProgrammingError:
        # User probably has an install that was run since the
        # collection tables were added, so we don't need to run this migration.
        pass

    db_conn.commit()
예제 #2
0
def unique_collections_slug(db):
    """Add unique constraint to collection slug"""
    metadata = MetaData(bind=db.bind)
    collection_table = inspect_table(metadata, "core__collections")
    existing_slugs = {}
    slugs_to_change = []

    for row in db.execute(collection_table.select()):
        # if duplicate slug, generate a unique slug
        if row.creator in existing_slugs and row.slug in \
           existing_slugs[row.creator]:
            slugs_to_change.append(row.id)
        else:
            if not row.creator in existing_slugs:
                existing_slugs[row.creator] = [row.slug]
            else:
                existing_slugs[row.creator].append(row.slug)

    for row_id in slugs_to_change:
        new_slug = six.text_type(uuid.uuid4())
        db.execute(collection_table.update().
                   where(collection_table.c.id == row_id).
                   values(slug=new_slug))
    # sqlite does not like to change the schema when a transaction(update) is
    # not yet completed
    db.commit()

    constraint = UniqueConstraint('creator', 'slug',
                                  name='core__collection_creator_slug_key',
                                  table=collection_table)
    constraint.create()

    db.commit()
예제 #3
0
def pw_hash_nullable(db):
    """Make pw_hash column nullable"""
    metadata = MetaData(bind=db.bind)
    user_table = inspect_table(metadata, "core__users")

    user_table.c.pw_hash.alter(nullable=True)

    # sqlite+sqlalchemy seems to drop this constraint during the
    # migration, so we add it back here for now a bit manually.
    if db.bind.url.drivername == 'sqlite':
        constraint = UniqueConstraint('username', table=user_table)
        constraint.create()

    db.commit()
예제 #4
0
class Item(DB.Model):
    """
    This data model represents a single lendable item
    """
    __tablename__ = 'Item'

    id = DB.Column(DB.Integer, primary_key=True)
    name = DB.Column(DB.String(STD_STRING_SIZE))
    update_name_from_schema = DB.Column(DB.Boolean, default=True, nullable=False)
    type_id = DB.Column(DB.Integer, DB.ForeignKey('ItemType.id'))
    lending_id = DB.Column(DB.Integer, DB.ForeignKey('Lending.id'), default=None, nullable=True)
    lending_duration = DB.Column(DB.Integer, nullable=True)  # in seconds
    due = DB.Column(DB.Integer, default=-1) # unix time
    deleted_time = DB.Column(DB.Integer, default=None)
    visible_for = DB.Column(DB.String(STD_STRING_SIZE), nullable=True)

    type = DB.relationship('ItemType', lazy='joined')
    lending = DB.relationship('Lending', lazy='select',
                              backref=DB.backref('_items', lazy='select'))

    __table_args__ = (
        UniqueConstraint('name', 'type_id', name='_name_type_id_uc'),
    )

    def __init__(self, update_name_from_schema: bool, name: str, type_id: int, lending_duration: int = -1,
                 visible_for: str = ''):
        self.update_name_from_schema = update_name_from_schema

        self.name = name

        self.type_id = type_id

        if lending_duration >= 0:
            self.lending_duration = lending_duration

        if visible_for != '' and visible_for != None:
            self.visible_for = visible_for

    def update(self, update_name_from_schema: bool, name: str, type_id: int, lending_duration: int = -1,
               visible_for: str = ''):
        """
        Function to update the objects data
        """
        self.update_name_from_schema = update_name_from_schema

        if self.update_name_from_schema:
            self.name = self.name_schema_name
        else:
            self.name = name

        self.type_id = type_id
        self.lending_duration = lending_duration
        self.visible_for = visible_for

    @property
    def deleted(self):
        return self.deleted_time is not None

    @deleted.setter
    def deleted(self, value: bool):
        if value:
            self.deleted_time = int(time.time())
        else:
            self.deleted_time = None

    @property
    def is_currently_lent(self):
        """
        If the item is currently lent.
        """
        return self.lending is not None

    @property
    def parent(self):
        if self._parents:
            return self._parents[0].parent
        return None

    @property
    def effective_lending_duration(self):
        """
        The effective lending duration computed from item type, tags and item
        """
        if self.lending_duration and (self.lending_duration >= 0):
            return self.lending_duration

        tag_lending_duration = min((t.tag.lending_duration for t in self._tags if t.tag.lending_duration > 0), default=-1)

        if tag_lending_duration >= 0:
            return tag_lending_duration

        return self.type.lending_duration

    @property
    def name_schema_name(self):
        template = string.Template(self.type.name_schema)
        attributes = {}
        for attr in self._attributes:
            if attr.value:
                try:
                    attributes[attr.attribute_definition.name] = loads(attr.value)
                except:
                    pass
            else:
                attr_def = loads(attr.attribute_definition.jsonschema)
                attributes[attr.attribute_definition.name] = attr_def.get('default', '')

        parent = "".join(item.parent.name for item in ItemToItem.query.filter(ItemToItem.item_id == self.id).all())

        today = date.today()
        times = {
            'c_year': today.year,
            'c_month': today.month,
            'c_day': today.day,
            'c_date': today.strftime('%d.%b.%Y'),
            'c_date_iso': today.isoformat(),
        }
        return template.safe_substitute(attributes, type=self.type.name, parent=parent, **times)

    def delete(self):
        if self.is_currently_lent:
            return(400, "Requested item is currently lent!", False)
        self.deleted = True

        for element in self._attributes:
            element.delete()

    # Not intended -neumantm
    #    for element in self._contained_items:
    #        DB.session.delete(element)
    #    for element in self._tags:
    #        DB.session.delete(element)
        return(204, "", True)

    def get_attribute_changes(self, definition_ids, remove: bool = False):
        """
        Get a list of attributes to add, to delete and to undelete,
        considering all definition_ids in the list and whether to add or remove them.
        """
        attributes_to_add = []
        attributes_to_delete = []
        attributes_to_undelete = []
        for def_id in definition_ids:
            itads = self._attributes
            exists = False
            for itad in itads:
                if(itad.attribute_definition_id == def_id):
                    exists = True
                    if(remove):
                        # Check if multiple sources bring it, if yes don't delete it.
                        sources = 0
                        if(def_id in [ittad.attribute_definition_id for ittad in self.type._item_type_to_attribute_definitions if not ittad.attribute_definition.deleted]):
                            sources += 1
                        for tag in [itt.tag for itt in self._tags]:
                            if(def_id in [ttad.attribute_definition_id for ttad in tag._tag_to_attribute_definitions if not ttad.attribute_definition.deleted]):
                                sources += 1
                        if sources == 1:
                            attributes_to_delete.append(itad)
                    elif(itad.deleted):
                        attributes_to_undelete.append(itad)

            if not exists and not remove:
                attributes_to_add.append(ItemToAttributeDefinition(self.id,
                                                                   def_id,
                                                                   ""))  # TODO: Get default if possible.
        return attributes_to_add, attributes_to_delete, attributes_to_undelete

    def get_new_attributes_from_type(self, type_id: int):
        """
        Get a list of attributes to add to a new item which has the given type.
        """

        item_type_attribute_definitions = (itemType.ItemTypeToAttributeDefinition
                                           .query
                                           .filter(itemType.ItemTypeToAttributeDefinition.item_type_id == type_id)
                                           .all())
        attributes_to_add, _, _ = self.get_attribute_changes(
            [ittad.attribute_definition_id for ittad in item_type_attribute_definitions if not ittad.item_type.deleted], False)

        return attributes_to_add

    def get_attribute_changes_from_type_change(self, from_type_id: int, to_type_id: int):
        """
        Get a list of attributes to add, to delete and to undelete,
        when this item would now switch from the first to the second type.
        """
        old_item_type_attr_defs = (itemType.ItemTypeToAttributeDefinition
                                   .query
                                   .filter(itemType.ItemTypeToAttributeDefinition.item_type_id == from_type_id)
                                   .all())

        new_item_type_attr_defs = (itemType.ItemTypeToAttributeDefinition
                                   .query
                                   .filter(itemType.ItemTypeToAttributeDefinition.item_type_id == to_type_id)
                                   .all())

        old_attr_def_ids = [ittad.attribute_definition_id for ittad in old_item_type_attr_defs]
        new_attr_def_ids = [ittad.attribute_definition_id for ittad in new_item_type_attr_defs]

        added_attr_def_ids = [attr_def_id for attr_def_id in new_attr_def_ids if attr_def_id not in old_attr_def_ids]
        removed_attr_def_ids = [attr_def_id for attr_def_id in old_attr_def_ids if attr_def_id not in new_attr_def_ids]

        attributes_to_add, _, attributes_to_undelete = self.get_attribute_changes(added_attr_def_ids, False)
        _, attributes_to_delete, _ = self.get_attribute_changes(removed_attr_def_ids, True)

        return attributes_to_add, attributes_to_delete, attributes_to_undelete

    def get_attribute_changes_from_tag(self, tag_id: int, remove: bool = False):
        """
        Get a list of attributes to add, to delete and to undelete,
        when this item would now get that tag or loose that tag.
        """

        tag_attribute_definitions = (TagToAttributeDefinition
                                     .query
                                     .filter(TagToAttributeDefinition.tag_id == tag_id)
                                     .all())
        return self.get_attribute_changes([ttad.attribute_definition_id for ttad in tag_attribute_definitions], remove)
예제 #5
0
class TableColumn(Model, BaseColumn):
    """ORM object for table columns, each table can have multiple columns"""

    __tablename__ = 'table_columns'
    __table_args__ = (UniqueConstraint('table_id', 'column_name'), )
    table_id = Column(Integer, ForeignKey('tables.id'))
    table = relationship('SqlaTable',
                         backref=backref('columns',
                                         cascade='all, delete-orphan'),
                         foreign_keys=[table_id])
    is_dttm = Column(Boolean, default=False)
    expression = Column(Text, default='')
    python_date_format = Column(String(255))
    database_expression = Column(String(255))

    export_fields = (
        'table_id',
        'column_name',
        'verbose_name',
        'is_dttm',
        'is_active',
        'type',
        'groupby',
        'count_distinct',
        'sum',
        'avg',
        'max',
        'min',
        'filterable',
        'expression',
        'description',
        'python_date_format',
        'database_expression',
    )

    update_from_object_fields = [
        s for s in export_fields if s not in ('table_id', )
    ]
    export_parent = 'table'

    def get_sqla_col(self, label=None):
        db_engine_spec = self.table.database.db_engine_spec
        label = db_engine_spec.make_label_compatible(
            label if label else self.column_name)
        if not self.expression:
            col = column(self.column_name).label(label)
        else:
            col = literal_column(self.expression).label(label)
        return col

    @property
    def datasource(self):
        return self.table

    def get_time_filter(self, start_dttm, end_dttm):
        col = self.get_sqla_col(label='__time')
        l = []  # noqa: E741
        if start_dttm:
            l.append(col >= text(self.dttm_sql_literal(start_dttm)))
        if end_dttm:
            l.append(col <= text(self.dttm_sql_literal(end_dttm)))
        return and_(*l)

    def get_timestamp_expression(self, time_grain):
        """Getting the time component of the query"""
        pdf = self.python_date_format
        is_epoch = pdf in ('epoch_s', 'epoch_ms')
        if not self.expression and not time_grain and not is_epoch:
            return column(self.column_name, type_=DateTime).label(DTTM_ALIAS)

        expr = self.expression or self.column_name
        if is_epoch:
            # if epoch, translate to DATE using db specific conf
            db_spec = self.table.database.db_engine_spec
            if pdf == 'epoch_s':
                expr = db_spec.epoch_to_dttm().format(col=expr)
            elif pdf == 'epoch_ms':
                expr = db_spec.epoch_ms_to_dttm().format(col=expr)
        if time_grain:
            grain = self.table.database.grains_dict().get(time_grain)
            if grain:
                expr = grain.function.format(col=expr)
        return literal_column(expr, type_=DateTime).label(DTTM_ALIAS)

    @classmethod
    def import_obj(cls, i_column):
        def lookup_obj(lookup_column):
            return db.session.query(TableColumn).filter(
                TableColumn.table_id == lookup_column.table_id,
                TableColumn.column_name == lookup_column.column_name).first()

        return import_util.import_simple_obj(db.session, i_column, lookup_obj)

    def dttm_sql_literal(self, dttm):
        """Convert datetime object to a SQL expression string

        If database_expression is empty, the internal dttm
        will be parsed as the string with the pattern that
        the user inputted (python_date_format)
        If database_expression is not empty, the internal dttm
        will be parsed as the sql sentence for the database to convert
        """
        tf = self.python_date_format
        if self.database_expression:
            return self.database_expression.format(
                dttm.strftime('%Y-%m-%d %H:%M:%S'))
        elif tf:
            if tf == 'epoch_s':
                return str((dttm - datetime(1970, 1, 1)).total_seconds())
            elif tf == 'epoch_ms':
                return str(
                    (dttm - datetime(1970, 1, 1)).total_seconds() * 1000.0)
            return "'{}'".format(dttm.strftime(tf))
        else:
            s = self.table.database.db_engine_spec.convert_dttm(
                self.type or '', dttm)
            return s or "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S.%f'))

    def get_metrics(self):
        # TODO deprecate, this is not needed since MetricsControl
        metrics = []
        M = SqlMetric  # noqa
        quoted = self.column_name
        if self.sum:
            metrics.append(
                M(
                    metric_name='sum__' + self.column_name,
                    metric_type='sum',
                    expression='SUM({})'.format(quoted),
                ))
        if self.avg:
            metrics.append(
                M(
                    metric_name='avg__' + self.column_name,
                    metric_type='avg',
                    expression='AVG({})'.format(quoted),
                ))
        if self.max:
            metrics.append(
                M(
                    metric_name='max__' + self.column_name,
                    metric_type='max',
                    expression='MAX({})'.format(quoted),
                ))
        if self.min:
            metrics.append(
                M(
                    metric_name='min__' + self.column_name,
                    metric_type='min',
                    expression='MIN({})'.format(quoted),
                ))
        if self.count_distinct:
            metrics.append(
                M(
                    metric_name='count_distinct__' + self.column_name,
                    metric_type='count_distinct',
                    expression='COUNT(DISTINCT {})'.format(quoted),
                ))
        return {m.metric_name: m for m in metrics}
예제 #6
0
class Participation(Base):
    """Class to store a single participation of a user in a contest.

    """
    __tablename__ = 'participations'

    # Auto increment primary key.
    id = Column(Integer, primary_key=True)

    # The user can log in CWS only from this IP address or subnet.
    ip = Column(CastingArray(CIDR), nullable=True)

    # Starting time: for contests where every user has at most x hours
    # of the y > x hours totally available, this is the time the user
    # decided to start their time-frame.
    starting_time = Column(DateTime, nullable=True)

    # A shift in the time interval during which the user is allowed to
    # submit.
    delay_time = Column(Interval,
                        CheckConstraint("delay_time >= '0 seconds'"),
                        nullable=False,
                        default=timedelta())

    # An extra amount of time allocated for this user.
    extra_time = Column(Interval,
                        CheckConstraint("extra_time >= '0 seconds'"),
                        nullable=False,
                        default=timedelta())

    # Contest-specific password. If this password is not null then the
    # traditional user.password field will be "replaced" by this field's
    # value (only for this participation).
    password = Column(Unicode, nullable=True)

    # A hidden participation (e.g. does not appear in public rankings), can
    # also be used for debugging purposes.
    hidden = Column(Boolean, nullable=False, default=False)

    # An unrestricted participation (e.g. contest time,
    # maximum number of submissions, minimum interval between submissions,
    # maximum number of user tests, minimum interval between user tests),
    # can also be used for debugging purposes.
    unrestricted = Column(Boolean, nullable=False, default=False)

    # Contest (id and object) to which the user is participating.
    contest_id = Column(Integer,
                        ForeignKey(Contest.id,
                                   onupdate="CASCADE",
                                   ondelete="CASCADE"),
                        nullable=False,
                        index=True)
    contest = relationship(Contest,
                           backref=backref("participations",
                                           cascade="all, delete-orphan",
                                           passive_deletes=True))

    # User (id and object) which is participating.
    user_id = Column(Integer,
                     ForeignKey(User.id,
                                onupdate="CASCADE",
                                ondelete="CASCADE"),
                     nullable=False,
                     index=True)
    user = relationship(User,
                        backref=backref("participations",
                                        cascade="all, delete-orphan",
                                        passive_deletes=True))
    __table_args__ = (UniqueConstraint('contest_id', 'user_id'), )

    # Team (id and object) that the user is representing with this
    # participation.
    team_id = Column(Integer,
                     ForeignKey(Team.id,
                                onupdate="CASCADE",
                                ondelete="RESTRICT"),
                     nullable=True)
    team = relationship(Team,
                        backref=backref("participations",
                                        cascade="all, delete-orphan",
                                        passive_deletes=True))
예제 #7
0
    Column('name', String(255), unique=True),
    Column('manager', String(255), nullable=False),
    Column('default_version_id', String(36)),
)


datastore_versions = Table(
    'datastore_versions',
    meta,
    Column('id', String(36), primary_key=True, nullable=False),
    Column('datastore_id', String(36), ForeignKey('datastores.id')),
    Column('name', String(255), unique=True),
    Column('image_id', String(36), nullable=False),
    Column('packages', String(511)),
    Column('active', Boolean(), nullable=False),
    UniqueConstraint('datastore_id', 'name', name='ds_versions')
)


def upgrade(migrate_engine):
    meta.bind = migrate_engine
    create_tables([datastores, datastore_versions])
    instances = Table('instances', meta, autoload=True)
    datastore_version_id = Column('datastore_version_id', String(36),
                                  ForeignKey('datastore_versions.id'))
    instances.create_column(datastore_version_id)
    instances.drop_column('service_type')
    # Table 'service_images' is deprecated since this version.
    # Leave it for few releases.
    #drop_tables([service_images])
예제 #8
0
class Task(Base):
    """Class to store a task.

    """
    __tablename__ = 'tasks'
    __table_args__ = (
        UniqueConstraint('contest_id', 'num'),
        UniqueConstraint('contest_id', 'name'),
        ForeignKeyConstraint(
            ("id", "active_dataset_id"),
            ("datasets.task_id", "datasets.id"),
            onupdate="SET NULL", ondelete="SET NULL",
            # Use an ALTER query to set this foreign key after
            # both tables have been CREATEd, to avoid circular
            # dependencies.
            use_alter=True,
            name="fk_active_dataset_id"
        ),
        CheckConstraint("token_gen_initial <= token_gen_max"),
    )

    # Auto increment primary key.
    id = Column(
        Integer,
        primary_key=True,
        # Needed to enable autoincrement on integer primary keys that
        # are referenced by a foreign key defined on this table.
        autoincrement='ignore_fk')

    # Number of the task for sorting.
    num = Column(
        Integer,
        nullable=True)

    # Contest (id and object) owning the task.
    contest_id = Column(
        Integer,
        ForeignKey(Contest.id,
                   onupdate="CASCADE", ondelete="CASCADE"),
        nullable=True,
        index=True)
    contest = relationship(
        Contest,
        back_populates="tasks")

    # Short name and long human readable title of the task.
    name = Column(
        Unicode,
        CodenameConstraint("name"),
        nullable=False,
        unique=True)
    title = Column(
        Unicode,
        nullable=False)

    # The names of the files that the contestant needs to submit (with
    # language-specific extensions replaced by "%l").
    submission_format = Column(
        ARRAY(String),
        FilenameListConstraint("submission_format"),
        nullable=False,
        default=[])

    # The language codes of the statements that will be highlighted to
    # all users for this task.
    primary_statements = Column(
        ARRAY(String),
        nullable=False,
        default=[])

    # The parameters that control task-tokens follow. Note that their
    # effect during the contest depends on the interaction with the
    # parameters that control contest-tokens, defined on the Contest.

    # The "kind" of token rules that will be active during the contest.
    # - disabled: The user will never be able to use any token.
    # - finite: The user has a finite amount of tokens and can choose
    #   when to use them, subject to some limitations. Tokens may not
    #   be all available at start, but given periodically during the
    #   contest instead.
    # - infinite: The user will always be able to use a token.
    token_mode = Column(
        Enum(TOKEN_MODE_DISABLED, TOKEN_MODE_FINITE, TOKEN_MODE_INFINITE,
             name="token_mode"),
        nullable=False,
        default="disabled")

    # The maximum number of tokens a contestant is allowed to use
    # during the whole contest (on this tasks).
    token_max_number = Column(
        Integer,
        CheckConstraint("token_max_number > 0"),
        nullable=True)

    # The minimum interval between two successive uses of tokens for
    # the same user (on this task).
    token_min_interval = Column(
        Interval,
        CheckConstraint("token_min_interval >= '0 seconds'"),
        nullable=False,
        default=timedelta())

    # The parameters that control generation (if mode is "finite"):
    # the user starts with "initial" tokens and receives "number" more
    # every "interval", but their total number is capped to "max".
    token_gen_initial = Column(
        Integer,
        CheckConstraint("token_gen_initial >= 0"),
        nullable=False,
        default=2)
    token_gen_number = Column(
        Integer,
        CheckConstraint("token_gen_number >= 0"),
        nullable=False,
        default=2)
    token_gen_interval = Column(
        Interval,
        CheckConstraint("token_gen_interval > '0 seconds'"),
        nullable=False,
        default=timedelta(minutes=30))
    token_gen_max = Column(
        Integer,
        CheckConstraint("token_gen_max > 0"),
        nullable=True)

    # Maximum number of submissions or user_tests allowed for each user
    # on this task during the whole contest or None to not enforce
    # this limitation.
    max_submission_number = Column(
        Integer,
        CheckConstraint("max_submission_number >= 0"),
        nullable=True)
    max_user_test_number = Column(
        Integer,
        CheckConstraint("max_user_test_number >= 0"),
        nullable=True)

    # Minimum interval between two submissions or user_tests for this
    # task, or None to not enforce this limitation.
    min_submission_interval = Column(
        Interval,
        CheckConstraint("min_submission_interval > '0 seconds'"),
        nullable=True)
    min_user_test_interval = Column(
        Interval,
        CheckConstraint("min_user_test_interval > '0 seconds'"),
        nullable=True)

    # What information users can see about the evaluations of their
    # submissions. Offering full information might help some users to
    # reverse engineer task data.
    feedback_level = Column(
        Enum(FEEDBACK_LEVEL_FULL, FEEDBACK_LEVEL_RESTRICTED,
             name="feedback_level"),
        nullable=False,
        default=FEEDBACK_LEVEL_RESTRICTED)

    # The scores for this task will be rounded to this number of
    # decimal places.
    score_precision = Column(
        Integer,
        CheckConstraint("score_precision >= 0"),
        nullable=False,
        default=0)

    # Score mode for the task.
    score_mode = Column(
        Enum(SCORE_MODE_MAX_TOKENED_LAST, SCORE_MODE_MAX, SCORE_MODE_MAX_SUBTASK,
             name="score_mode"),
        nullable=False,
        default=SCORE_MODE_MAX_TOKENED_LAST)

    # Active Dataset (id and object) currently being used for scoring.
    # The ForeignKeyConstraint for this column is set at table-level.
    active_dataset_id = Column(
        Integer,
        nullable=True)
    active_dataset = relationship(
        'Dataset',
        foreign_keys=[active_dataset_id],
        # Use an UPDATE query *after* an INSERT query (and *before* a
        # DELETE query) to set (and unset) the column associated to
        # this relationship.
        post_update=True)

    # These one-to-many relationships are the reversed directions of
    # the ones defined in the "child" classes using foreign keys.

    statements = relationship(
        "Statement",
        collection_class=attribute_mapped_collection("language"),
        cascade="all, delete-orphan",
        passive_deletes=True,
        back_populates="task")

    attachments = relationship(
        "Attachment",
        collection_class=attribute_mapped_collection("filename"),
        cascade="all, delete-orphan",
        passive_deletes=True,
        back_populates="task")

    datasets = relationship(
        "Dataset",
        # Due to active_dataset_id, SQLAlchemy cannot unambiguously
        # figure out by itself which foreign key to use.
        foreign_keys="[Dataset.task_id]",
        cascade="all, delete-orphan",
        passive_deletes=True,
        back_populates="task")

    submissions = relationship(
        "Submission",
        cascade="all, delete-orphan",
        passive_deletes=True,
        back_populates="task")

    user_tests = relationship(
        "UserTest",
        cascade="all, delete-orphan",
        passive_deletes=True,
        back_populates="task")
예제 #9
0
파일: build.py 프로젝트: rzachary/changes
class Build(db.Model):
    """
    Represents a collection of builds for a single target, as well as the sum
    of their results.

    Each Build contains many Jobs (usually linked to a JobPlan).
    """
    __tablename__ = 'build'
    __table_args__ = (
        Index('idx_buildfamily_project_id', 'project_id'),
        Index('idx_buildfamily_author_id', 'author_id'),
        Index('idx_buildfamily_source_id', 'source_id'),
        UniqueConstraint('project_id', 'number', name='unq_build_number'),
    )

    id = Column(GUID, primary_key=True, default=uuid.uuid4)
    number = Column(Integer)
    project_id = Column(GUID,
                        ForeignKey('project.id', ondelete="CASCADE"),
                        nullable=False)
    collection_id = Column(GUID)
    source_id = Column(GUID, ForeignKey('source.id', ondelete="CASCADE"))
    author_id = Column(GUID, ForeignKey('author.id', ondelete="CASCADE"))
    cause = Column(EnumType(Cause), nullable=False, default=Cause.unknown)
    label = Column(String(128), nullable=False)
    target = Column(String(128))
    tags = Column(ARRAY(String(16)), nullable=True)
    status = Column(EnumType(Status), nullable=False, default=Status.unknown)
    result = Column(EnumType(Result), nullable=False, default=Result.unknown)
    message = Column(Text)
    duration = Column(Integer)
    priority = Column(EnumType(BuildPriority),
                      nullable=False,
                      default=BuildPriority.default,
                      server_default='0')
    date_started = Column(DateTime)
    date_finished = Column(DateTime)
    date_created = Column(DateTime, default=datetime.utcnow)
    date_modified = Column(DateTime, default=datetime.utcnow)
    data = Column(JSONEncodedDict)

    project = relationship('Project', innerjoin=True)
    source = relationship('Source', innerjoin=True)
    author = relationship('Author')
    stats = relationship('ItemStat',
                         primaryjoin='Build.id == ItemStat.item_id',
                         foreign_keys=[id],
                         uselist=True)

    __repr__ = model_repr('label', 'target')

    def __init__(self, **kwargs):
        super(Build, self).__init__(**kwargs)
        if self.id is None:
            self.id = uuid.uuid4()
        if self.result is None:
            self.result = Result.unknown
        if self.status is None:
            self.status = Status.unknown
        if self.date_created is None:
            self.date_created = datetime.utcnow()
        if self.date_modified is None:
            self.date_modified = self.date_created
        if self.date_started and self.date_finished and not self.duration:
            self.duration = (self.date_finished -
                             self.date_started).total_seconds() * 1000
        if self.number is None and self.project:
            self.number = select([func.next_item_value(self.project.id.hex)])
예제 #10
0
class TableColumn(Model, BaseColumn):
    """ORM object for table columns, each table can have multiple columns"""

    __tablename__ = "table_columns"
    __table_args__ = (UniqueConstraint("table_id", "column_name"), )
    table_id = Column(Integer, ForeignKey("tables.id"))
    table = relationship(
        "SqlaTable",
        backref=backref("columns", cascade="all, delete-orphan"),
        foreign_keys=[table_id],
    )
    is_dttm = Column(Boolean, default=False)
    expression = Column(Text)
    python_date_format = Column(String(255))

    export_fields = [
        "table_id",
        "column_name",
        "verbose_name",
        "is_dttm",
        "is_active",
        "type",
        "groupby",
        "filterable",
        "expression",
        "description",
        "python_date_format",
    ]

    update_from_object_fields = [
        s for s in export_fields if s not in ("table_id", )
    ]
    export_parent = "table"

    @property
    def is_numeric(self) -> bool:
        db_engine_spec = self.table.database.db_engine_spec
        return db_engine_spec.is_db_column_type_match(
            self.type, utils.DbColumnType.NUMERIC)

    @property
    def is_string(self) -> bool:
        db_engine_spec = self.table.database.db_engine_spec
        return db_engine_spec.is_db_column_type_match(
            self.type, utils.DbColumnType.STRING)

    @property
    def is_temporal(self) -> bool:
        db_engine_spec = self.table.database.db_engine_spec
        return db_engine_spec.is_db_column_type_match(
            self.type, utils.DbColumnType.TEMPORAL)

    def get_sqla_col(self, label: Optional[str] = None) -> Column:
        label = label or self.column_name
        if self.expression:
            col = literal_column(self.expression)
        else:
            db_engine_spec = self.table.database.db_engine_spec
            type_ = db_engine_spec.get_sqla_column_type(self.type)
            col = column(self.column_name, type_=type_)
        col = self.table.make_sqla_column_compatible(col, label)
        return col

    @property
    def datasource(self) -> RelationshipProperty:
        return self.table

    def get_time_filter(
        self,
        start_dttm: DateTime,
        end_dttm: DateTime,
        time_range_endpoints: Optional[Tuple[utils.TimeRangeEndpoint,
                                             utils.TimeRangeEndpoint]],
    ) -> ColumnElement:
        col = self.get_sqla_col(label="__time")
        l = []
        if start_dttm:
            l.append(col >= text(
                self.dttm_sql_literal(start_dttm, time_range_endpoints)))
        if end_dttm:
            if (time_range_endpoints and time_range_endpoints[1]
                    == utils.TimeRangeEndpoint.EXCLUSIVE):
                l.append(col < text(
                    self.dttm_sql_literal(end_dttm, time_range_endpoints)))
            else:
                l.append(col <= text(self.dttm_sql_literal(end_dttm, None)))
        return and_(*l)

    def get_timestamp_expression(
            self,
            time_grain: Optional[str]) -> Union[TimestampExpression, Label]:
        """
        Return a SQLAlchemy Core element representation of self to be used in a query.

        :param time_grain: Optional time grain, e.g. P1Y
        :return: A TimeExpression object wrapped in a Label if supported by db
        """
        label = utils.DTTM_ALIAS

        db = self.table.database
        pdf = self.python_date_format
        is_epoch = pdf in ("epoch_s", "epoch_ms")
        if not self.expression and not time_grain and not is_epoch:
            sqla_col = column(self.column_name, type_=DateTime)
            return self.table.make_sqla_column_compatible(sqla_col, label)
        if self.expression:
            col = literal_column(self.expression)
        else:
            col = column(self.column_name)
        time_expr = db.db_engine_spec.get_timestamp_expr(
            col, pdf, time_grain, self.type)
        return self.table.make_sqla_column_compatible(time_expr, label)

    @classmethod
    def import_obj(cls, i_column):
        def lookup_obj(lookup_column):
            return (db.session.query(TableColumn).filter(
                TableColumn.table_id == lookup_column.table_id,
                TableColumn.column_name == lookup_column.column_name,
            ).first())

        return import_datasource.import_simple_obj(db.session, i_column,
                                                   lookup_obj)

    def dttm_sql_literal(
        self,
        dttm: DateTime,
        time_range_endpoints: Optional[Tuple[utils.TimeRangeEndpoint,
                                             utils.TimeRangeEndpoint]],
    ) -> str:
        """Convert datetime object to a SQL expression string"""
        sql = (self.table.database.db_engine_spec.convert_dttm(
            self.type, dttm) if self.type else None)

        if sql:
            return sql

        tf = self.python_date_format

        # Fallback to the default format (if defined) only if the SIP-15 time range
        # endpoints, i.e., [start, end) are enabled.
        if not tf and time_range_endpoints == (
                utils.TimeRangeEndpoint.INCLUSIVE,
                utils.TimeRangeEndpoint.EXCLUSIVE,
        ):
            tf = (self.table.database.get_extra().get(
                "python_date_format_by_column_name", {}).get(self.column_name))

        if tf:
            if tf in ["epoch_ms", "epoch_s"]:
                seconds_since_epoch = int(dttm.timestamp())
                if tf == "epoch_s":
                    return str(seconds_since_epoch)
                return str(seconds_since_epoch * 1000)
            return f"'{dttm.strftime(tf)}'"

        # TODO(john-bodley): SIP-15 will explicitly require a type conversion.
        return f"""'{dttm.strftime("%Y-%m-%d %H:%M:%S.%f")}'"""
예제 #11
0
파일: Task.py 프로젝트: pombredanne/cms-1
class Task(Base):
    """Class to store a task. Not to be used directly (import it from
    SQLAlchemyAll).

    """
    __tablename__ = 'tasks'
    __table_args__ = (
        UniqueConstraint('contest_id', 'num', name='cst_task_contest_id_num'),
        UniqueConstraint('contest_id', 'name',
                         name='cst_task_contest_id_name'),
        CheckConstraint("token_initial <= token_max"),
    )

    # Auto increment primary key.
    id = Column(Integer, primary_key=True)

    # Number of the task for sorting.
    num = Column(Integer, nullable=False)

    # Contest (id and object) owning the task.
    contest_id = Column(Integer,
                        ForeignKey(Contest.id,
                                   onupdate="CASCADE",
                                   ondelete="CASCADE"),
                        nullable=False,
                        index=True)
    contest = relationship(Contest,
                           backref=backref(
                               'tasks',
                               collection_class=ordering_list('num'),
                               order_by=[num],
                               cascade="all, delete-orphan",
                               passive_deletes=True))

    # Short name and long human readable title of the task.
    name = Column(String, nullable=False)
    title = Column(String, nullable=False)

    # A JSON-encoded lists of strings: the language codes of the
    # statments that will be highlighted to all users for this task.
    primary_statements = Column(String, nullable=False)

    # Time and memory limits for every testcase.
    time_limit = Column(Float, nullable=True)
    memory_limit = Column(Integer, nullable=True)

    # Name of the TaskType child class suited for the task.
    task_type = Column(String, nullable=False)

    # Parameters for the task type class, JSON encoded.
    task_type_parameters = Column(String, nullable=False)

    # Name of the ScoreType child class suited for the task.
    score_type = Column(String, nullable=False)

    # Parameters for the scorer class, JSON encoded.
    score_parameters = Column(String, nullable=False)

    # Parameter to define the token behaviour. See Contest.py for
    # details. The only change is that these parameters influence the
    # contest in a task-per-task behaviour. To play a token on a given
    # task, a user must satisfy the condition of the contest and the
    # one of the task.
    token_initial = Column(Integer,
                           CheckConstraint("token_initial >= 0"),
                           nullable=True)
    token_max = Column(Integer,
                       CheckConstraint("token_max > 0"),
                       nullable=True)
    token_total = Column(Integer,
                         CheckConstraint("token_total > 0"),
                         nullable=True)
    token_min_interval = Column(
        Interval,
        CheckConstraint("token_min_interval >= '0 seconds'"),
        nullable=False)
    token_gen_time = Column(Interval,
                            CheckConstraint("token_gen_time >= '0 seconds'"),
                            nullable=False)
    token_gen_number = Column(Integer,
                              CheckConstraint("token_gen_number >= 0"),
                              nullable=False)

    # Maximum number of submissions or user_tests allowed for each user
    # on this task during the whole contest or None to not enforce
    # this limitation.
    max_submission_number = Column(
        Integer, CheckConstraint("max_submission_number > 0"), nullable=True)
    max_user_test_number = Column(Integer,
                                  CheckConstraint("max_user_test_number > 0"),
                                  nullable=True)

    # Minimum interval between two submissions or user_tests for this
    # task, or None to not enforce this limitation.
    min_submission_interval = Column(
        Interval,
        CheckConstraint("min_submission_interval > '0 seconds'"),
        nullable=True)
    min_user_test_interval = Column(
        Interval,
        CheckConstraint("min_user_test_interval > '0 seconds'"),
        nullable=True)

    # Follows the description of the fields automatically added by
    # SQLAlchemy.
    # submission_format (list of SubmissionFormatElement objects)
    # testcases (list of Testcase objects)
    # attachments (dict of Attachment objects indexed by filename)
    # managers (dict of Manager objects indexed by filename)
    # statements (dict of Statement objects indexed by language code)
    # submissions (list of Submission objects)
    # user_tests (list of UserTest objects)

    # This object (independent from SQLAlchemy) is the instance of the
    # ScoreType class with the given parameters, taking care of
    # building the scores of the submissions.
    scorer = None

    def __init__(self,
                 name,
                 title,
                 statements,
                 attachments,
                 time_limit,
                 memory_limit,
                 primary_statements,
                 task_type,
                 task_type_parameters,
                 submission_format,
                 managers,
                 score_type,
                 score_parameters,
                 testcases,
                 token_initial=None,
                 token_max=None,
                 token_total=None,
                 token_min_interval=timedelta(),
                 token_gen_time=timedelta(),
                 token_gen_number=0,
                 max_submission_number=None,
                 max_user_test_number=None,
                 min_submission_interval=None,
                 min_user_test_interval=None,
                 contest=None,
                 num=0):
        for filename, attachment in attachments.iteritems():
            attachment.filename = filename
        for filename, manager in managers.iteritems():
            manager.filename = filename
        for language, statement in statements.iteritems():
            statement.language = language

        self.num = num
        self.name = name
        self.title = title
        self.statements = statements
        self.attachments = attachments
        self.time_limit = time_limit
        self.memory_limit = memory_limit
        self.primary_statements = primary_statements \
                                  if primary_statements is not None else "[]"
        self.task_type = task_type
        self.task_type_parameters = task_type_parameters
        self.submission_format = submission_format
        self.managers = managers
        self.score_type = score_type
        self.score_parameters = score_parameters
        self.testcases = testcases
        self.token_initial = token_initial
        self.token_max = token_max
        self.token_total = token_total
        self.token_min_interval = token_min_interval
        self.token_gen_time = token_gen_time
        self.token_gen_number = token_gen_number
        self.max_submission_number = max_submission_number
        self.max_user_test_number = max_user_test_number
        self.min_submission_interval = min_submission_interval
        self.min_user_test_interval = min_user_test_interval
        self.contest = contest

    def export_to_dict(self):
        """Return object data as a dictionary.

        """
        return {
            'name':
            self.name,
            'title':
            self.title,
            'num':
            self.num,
            'statements': [
                statement.export_to_dict()
                for statement in self.statements.itervalues()
            ],
            'attachments': [
                attachment.export_to_dict()
                for attachment in self.attachments.itervalues()
            ],
            'time_limit':
            self.time_limit,
            'memory_limit':
            self.memory_limit,
            'primary_statements':
            self.primary_statements,
            'task_type':
            self.task_type,
            'task_type_parameters':
            self.task_type_parameters,
            'submission_format':
            [element.export_to_dict() for element in self.submission_format],
            'managers': [
                manager.export_to_dict()
                for manager in self.managers.itervalues()
            ],
            'score_type':
            self.score_type,
            'score_parameters':
            self.score_parameters,
            'testcases':
            [testcase.export_to_dict() for testcase in self.testcases],
            'token_initial':
            self.token_initial,
            'token_max':
            self.token_max,
            'token_total':
            self.token_total,
            'token_min_interval':
            self.token_min_interval.total_seconds(),
            'token_gen_time':
            self.token_gen_time.total_seconds(),
            'token_gen_number':
            self.token_gen_number,
            'max_submission_number':
            self.max_submission_number,
            'max_user_test_number':
            self.max_user_test_number,
            'min_submission_interval':
            self.min_submission_interval.total_seconds()
            if self.min_submission_interval is not None else None,
            'min_user_test_interval':
            self.min_user_test_interval.total_seconds()
            if self.min_user_test_interval is not None else None,
        }

    @classmethod
    def import_from_dict(cls, data):
        """Build the object using data from a dictionary.

        """
        data['attachments'] = [
            Attachment.import_from_dict(attch_data)
            for attch_data in data['attachments']
        ]
        data['attachments'] = dict([(attachment.filename, attachment)
                                    for attachment in data['attachments']])
        data['submission_format'] = [
            SubmissionFormatElement.import_from_dict(sfe_data)
            for sfe_data in data['submission_format']
        ]
        data['managers'] = [
            Manager.import_from_dict(manager_data)
            for manager_data in data['managers']
        ]
        data['managers'] = dict([(manager.filename, manager)
                                 for manager in data['managers']])
        data['testcases'] = [
            Testcase.import_from_dict(testcase_data)
            for testcase_data in data['testcases']
        ]
        data['statements'] = [
            Statement.import_from_dict(statement_data)
            for statement_data in data['statements']
        ]
        data['statements'] = dict([(statement.language, statement)
                                   for statement in data['statements']])
        if 'token_min_interval' in data:
            data['token_min_interval'] = \
                timedelta(seconds=data['token_min_interval'])
        if 'token_gen_time' in data:
            data['token_gen_time'] = timedelta(seconds=data['token_gen_time'])
        if 'min_submission_interval' in data and \
                data['min_submission_interval'] is not None:
            data['min_submission_interval'] = \
                timedelta(seconds=data['min_submission_interval'])
        if 'min_user_test_interval' in data and \
                data['min_user_test_interval'] is not None:
            data['min_user_test_interval'] = \
                timedelta(seconds=data['min_user_test_interval'])
        return cls(**data)
예제 #12
0
class Participant(db.Model):
    __tablename__ = "participants"
    __table_args__ = (UniqueConstraint(
        "session_token", name="participants_session_token_key"), )

    id = Column(BigInteger,
                Sequence('participants_id_seq'),
                nullable=False,
                unique=True)
    username = Column(Text, nullable=False, primary_key=True)
    username_lower = Column(Text, nullable=False, unique=True)
    statement = Column(Text, default="", nullable=False)
    stripe_customer_id = Column(Text)
    last_bill_result = Column(Text)
    session_token = Column(Text)
    session_expires = Column(TIMESTAMP(timezone=True), default="now()")
    ctime = Column(TIMESTAMP(timezone=True), nullable=False, default="now()")
    claimed_time = Column(TIMESTAMP(timezone=True))
    is_admin = Column(Boolean, nullable=False, default=False)
    balance = Column(Numeric(precision=35, scale=2),
                     CheckConstraint("balance >= 0", name="min_balance"),
                     default=0.0,
                     nullable=False)
    pending = Column(Numeric(precision=35, scale=2), default=None)
    anonymous = Column(Boolean, default=False, nullable=False)
    goal = Column(Numeric(precision=35, scale=2), default=None)
    balanced_account_uri = Column(Text)
    last_ach_result = Column(Text)
    is_suspicious = Column(Boolean)
    type = Column(Enum('individual', 'group', 'open group', nullable=False))

    ### Relations ###
    accounts_elsewhere = relationship("Elsewhere",
                                      backref="participant_orm",
                                      lazy="dynamic")
    exchanges = relationship("Exchange", backref="participant_orm")

    # TODO: Once tippee/tipper are renamed to tippee_id/tipper_idd, we can go
    # ahead and drop the foreign_keys & rename backrefs to tipper/tippee

    _tips_giving = relationship("Tip",
                                backref="tipper_participant",
                                foreign_keys="Tip.tipper",
                                lazy="dynamic")
    _tips_receiving = relationship("Tip",
                                   backref="tippee_participant",
                                   foreign_keys="Tip.tippee",
                                   lazy="dynamic")

    transferer = relationship("Transfer",
                              backref="transferer",
                              foreign_keys="Transfer.tipper")
    transferee = relationship("Transfer",
                              backref="transferee",
                              foreign_keys="Transfer.tippee")

    def __eq__(self, other):
        return self.id == other.id

    def __ne__(self, other):
        return self.id != other.id

    # Class-specific exceptions
    class ProblemChangingUsername(Exception):
        pass

    class UsernameTooLong(ProblemChangingUsername):
        pass

    class UsernameContainsInvalidCharacters(ProblemChangingUsername):
        pass

    class UsernameIsRestricted(ProblemChangingUsername):
        pass

    class UsernameAlreadyTaken(ProblemChangingUsername):
        pass

    class UnknownPlatform(Exception):
        pass

    @property
    def IS_INDIVIDUAL(self):
        return self.type == 'individual'

    @property
    def IS_GROUP(self):
        return self.type == 'group'

    @property
    def IS_OPEN_GROUP(self):
        return self.type == 'open group'

    @property
    def tips_giving(self):
        return self._tips_giving.distinct("tips.tippee")\
                                .order_by("tips.tippee, tips.mtime DESC")

    @property
    def tips_receiving(self):
        return self._tips_receiving.distinct("tips.tipper")\
                                   .order_by("tips.tipper, tips.mtime DESC")

    @property
    def accepts_tips(self):
        return (self.goal is None) or (self.goal >= 0)

    @property
    def valid_tips_receiving(self):
        '''

      SELECT count(anon_1.amount) AS count_1
        FROM ( SELECT DISTINCT ON (tips.tipper)
                      tips.id AS id
                    , tips.ctime AS ctime
                    , tips.mtime AS mtime
                    , tips.tipper AS tipper
                    , tips.tippee AS tippee
                    , tips.amount AS amount
                 FROM tips
                 JOIN participants ON tips.tipper = participants.username
                WHERE %(param_1)s = tips.tippee
                  AND participants.is_suspicious IS NOT true
                  AND participants.last_bill_result = %(last_bill_result_1)s
             ORDER BY tips.tipper, tips.mtime DESC
              ) AS anon_1
       WHERE anon_1.amount > %(amount_1)s

        '''
        return self.tips_receiving \
                   .join( Participant
                        , Tip.tipper.op('=')(Participant.username)
                         ) \
                   .filter( 'participants.is_suspicious IS NOT true'
                          , Participant.last_bill_result == ''
                           )

    def resolve_unclaimed(self):
        if self.accounts_elsewhere:
            return self.accounts_elsewhere[0].resolve_unclaimed()
        else:
            return None

    def set_as_claimed(self, claimed_at=None):
        if claimed_at is None:
            claimed_at = datetime.datetime.now(pytz.utc)
        self.claimed_time = claimed_at
        db.session.add(self)
        db.session.commit()

    def change_username(self, desired_username):
        """Raise self.ProblemChangingUsername, or return None.

        We want to be pretty loose with usernames. Unicode is allowed--XXX
        aspen bug :(. So are spaces. Control characters aren't. We also limit
        to 32 characters in length.

        """
        for i, c in enumerate(desired_username):
            if i == 32:
                raise self.UsernameTooLong  # Request Entity Too Large (more or less)
            elif ord(c) < 128 and c not in ASCII_ALLOWED_IN_USERNAME:
                raise self.UsernameContainsInvalidCharacters  # Yeah, no.
            elif c not in ASCII_ALLOWED_IN_USERNAME:

                # XXX Burned by an Aspen bug. :`-(
                # https://github.com/gittip/aspen/issues/102

                raise self.UsernameContainsInvalidCharacters

        lowercased = desired_username.lower()

        if lowercased in gittip.RESTRICTED_USERNAMES:
            raise self.UsernameIsRestricted

        if desired_username != self.username:
            try:
                self.username = desired_username
                self.username_lower = lowercased
                db.session.add(self)
                db.session.commit()
                # Will raise sqlalchemy.exc.IntegrityError if the
                # desired_username is taken.
            except IntegrityError:
                db.session.rollback()
                raise self.UsernameAlreadyTaken

    def get_accounts_elsewhere(self):
        github_account = twitter_account = bitbucket_account = bountysource_account = None
        for account in self.accounts_elsewhere.all():
            if account.platform == "github":
                github_account = account
            elif account.platform == "twitter":
                twitter_account = account
            elif account.platform == "bitbucket":
                bitbucket_account = account
            elif account.platform == "bountysource":
                bountysource_account = account
            else:
                raise self.UnknownPlatform(account.platform)
        return (github_account, twitter_account, bitbucket_account,
                bountysource_account)

    def get_img_src(self, size=128):
        """Return a value for <img src="..." />.

        Until we have our own profile pics, delegate. XXX Is this an attack
        vector? Can someone inject this value? Don't think so, but if you make
        it happen, let me know, eh? Thanks. :)

            https://www.gittip.com/security.txt

        """
        typecheck(size, int)

        src = '/assets/%s/avatar-default.gif' % os.environ['__VERSION__']

        github, twitter, bitbucket, bountysource = self.get_accounts_elsewhere(
        )
        if github is not None:
            # GitHub -> Gravatar: http://en.gravatar.com/site/implement/images/
            if 'gravatar_id' in github.user_info:
                gravatar_hash = github.user_info['gravatar_id']
                src = "https://www.gravatar.com/avatar/%s.jpg?s=%s"
                src %= (gravatar_hash, size)

        elif twitter is not None:
            # https://dev.twitter.com/docs/api/1/get/users/profile_image/%3Ascreen_name
            if 'profile_image_url_https' in twitter.user_info:
                src = twitter.user_info['profile_image_url_https']

                # For Twitter, we don't have good control over size. We don't
                # want the original, cause that can be huge. The next option is
                # 73px(?!).
                src = src.replace('_normal.', '_bigger.')

        return src

    def get_tip_to(self, tippee):
        tip = self.tips_giving.filter_by(tippee=tippee).first()

        if tip:
            amount = tip.amount
        else:
            amount = Decimal('0.00')

        return amount

    def get_dollars_receiving(self):
        return sum(tip.amount
                   for tip in self.valid_tips_receiving) + Decimal('0.00')

    def get_number_of_backers(self):
        amount_column = self.valid_tips_receiving.subquery().columns.amount
        count = func.count(amount_column)
        nbackers = db.session.query(count).filter(amount_column > 0).one()[0]
        return nbackers

    def get_og_title(self):
        out = self.username
        receiving = self.get_dollars_receiving()
        giving = self.get_dollars_giving()
        if (giving > receiving) and not self.anonymous:
            out += " gives $%.2f/wk" % giving
        elif receiving > 0:
            out += " receives $%.2f/wk" % receiving
        else:
            out += " is"
        return out + " on Gittip"

    def get_age_in_seconds(self):
        out = -1
        if self.claimed_time is not None:
            now = datetime.datetime.now(self.claimed_time.tzinfo)
            out = (now - self.claimed_time).total_seconds()
        return out

    def allowed_to_vote_on(self, group):
        """Given a Participant object, return a boolean.
        """
        if self.ANON: return False
        if self.is_suspicious: return False
        if group == self: return True
        return self.username in group.get_voters()

    def get_voters(self):
        identifications = list(
            gittip.db.fetchall(
                """

            SELECT member
                 , weight
                 , identified_by
                 , mtime
              FROM current_identifications
             WHERE "group"=%s
               AND weight > 0
          ORDER BY mtime ASC

        """, (self.username, )))

        voters = set()
        vote_counts = defaultdict(int)

        class Break(Exception):
            pass

        while identifications:
            try:
                for i in range(len(identifications)):

                    identification = identifications[i]
                    identified_by = identification['identified_by']
                    member = identification['member']

                    def record():
                        voters.add(member)
                        identifications.pop(i)

                    # Group participant itself has chosen.
                    if identified_by == self.username:
                        record()
                        break

                    # Group member has voted.
                    elif identified_by in voters:
                        vote_counts[member] += 1
                        target = math.ceil(len(voters) * 0.667)
                        if vote_counts[member] == target:
                            record()
                            break

                else:
                    # This pass through didn't find any new voters.
                    raise Break
            except Break:
                break

        voters = list(voters)
        voters.sort()
        return voters

    def compute_split(self):
        if not self.IS_OPEN_GROUP:
            return [{"username": self.username, "weight": "1.0"}]

        split = []

        voters = self.get_voters()
        identifications = gittip.db.fetchall(
            """

            SELECT member
                 , weight
                 , identified_by
                 , mtime
              FROM current_identifications
             WHERE "group"=%s
               AND weight > 0
          ORDER BY mtime ASC

        """, (self.username, ))

        splitmap = defaultdict(int)
        total = 0
        for row in identifications:
            if row['identified_by'] not in voters:
                continue
            splitmap[row['member']] += row['weight']
            total += row['weight']

        total = Decimal(total)
        for username, weight in splitmap.items():
            split.append({
                "username": username,
                "weight": Decimal(weight) / total
            })

        split.sort(key=lambda r: r['weight'], reverse=True)

        return voters, split

    # TODO: Move these queries into this class.

    def set_tip_to(self, tippee, amount):
        return OldParticipant(self.username).set_tip_to(tippee, amount)

    def get_dollars_giving(self):
        return OldParticipant(self.username).get_dollars_giving()

    def get_tip_distribution(self):
        return OldParticipant(self.username).get_tip_distribution()

    def get_giving_for_profile(self, db=None):
        return OldParticipant(self.username).get_giving_for_profile(db)

    def get_tips_and_total(self, for_payday=False, db=None):
        return OldParticipant(self.username).get_tips_and_total(for_payday, db)

    def take_over(self, account_elsewhere, have_confirmation=False):
        OldParticipant(self.username).take_over(account_elsewhere,
                                                have_confirmation)
예제 #13
0
class TableColumn(Model, BaseColumn):
    """ORM object for table columns, each table can have multiple columns"""

    __tablename__ = 'table_columns'
    __table_args__ = (UniqueConstraint('table_id', 'column_name'), )
    table_id = Column(Integer, ForeignKey('tables.id'))
    table = relationship('SqlaTable',
                         backref=backref('columns',
                                         cascade='all, delete-orphan'),
                         foreign_keys=[table_id])
    is_dttm = Column(Boolean, default=False)
    expression = Column(Text, default='')
    python_date_format = Column(String(255))
    database_expression = Column(String(255))

    export_fields = (
        'table_id',
        'column_name',
        'verbose_name',
        'is_dttm',
        'is_active',
        'type',
        'groupby',
        'count_distinct',
        'sum',
        'avg',
        'max',
        'min',
        'filterable',
        'expression',
        'description',
        'python_date_format',
        'database_expression',
    )
    export_parent = 'table'

    @property
    def sqla_col(self):
        name = self.column_name
        if not self.expression:
            col = column(self.column_name).label(name)
        else:
            col = literal_column(self.expression).label(name)
        return col

    def get_time_filter(self, start_dttm, end_dttm):
        col = self.sqla_col.label('__time')
        l = []  # noqa: E741
        if start_dttm:
            l.append(col >= text(self.dttm_sql_literal(start_dttm)))
        if end_dttm:
            l.append(col <= text(self.dttm_sql_literal(end_dttm)))
        return and_(*l)

    def get_timestamp_expression(self, time_grain):
        """Getting the time component of the query"""
        expr = self.expression or self.column_name
        if not self.expression and not time_grain:
            return column(expr, type_=DateTime).label(DTTM_ALIAS)
        if time_grain:
            pdf = self.python_date_format
            if pdf in ('epoch_s', 'epoch_ms'):
                # if epoch, translate to DATE using db specific conf
                db_spec = self.table.database.db_engine_spec
                if pdf == 'epoch_s':
                    expr = db_spec.epoch_to_dttm().format(col=expr)
                elif pdf == 'epoch_ms':
                    expr = db_spec.epoch_ms_to_dttm().format(col=expr)
            grain = self.table.database.grains_dict().get(time_grain, '{col}')
            expr = grain.function.format(col=expr)
        return literal_column(expr, type_=DateTime).label(DTTM_ALIAS)

    @classmethod
    def import_obj(cls, i_column):
        def lookup_obj(lookup_column):
            return db.session.query(TableColumn).filter(
                TableColumn.table_id == lookup_column.table_id,
                TableColumn.column_name == lookup_column.column_name).first()

        return import_util.import_simple_obj(db.session, i_column, lookup_obj)

    def dttm_sql_literal(self, dttm):
        """Convert datetime object to a SQL expression string

        If database_expression is empty, the internal dttm
        will be parsed as the string with the pattern that
        the user inputted (python_date_format)
        If database_expression is not empty, the internal dttm
        will be parsed as the sql sentence for the database to convert
        """
        tf = self.python_date_format
        if self.database_expression:
            return self.database_expression.format(
                dttm.strftime('%Y-%m-%d %H:%M:%S'))
        elif tf:
            if tf == 'epoch_s':
                return str((dttm - datetime(1970, 1, 1)).total_seconds())
            elif tf == 'epoch_ms':
                return str(
                    (dttm - datetime(1970, 1, 1)).total_seconds() * 1000.0)
            return "'{}'".format(dttm.strftime(tf))
        else:
            s = self.table.database.db_engine_spec.convert_dttm(
                self.type or '', dttm)
            return s or "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S.%f'))
예제 #14
0
class SqlaTable(Model, BaseDatasource):
    """An ORM object for SqlAlchemy table references"""

    type = 'table'
    query_language = 'sql'
    metric_class = SqlMetric
    column_class = TableColumn

    __tablename__ = 'tables'
    __table_args__ = (UniqueConstraint('database_id', 'table_name'), )

    table_name = Column(String(250))
    main_dttm_col = Column(String(250))
    database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
    fetch_values_predicate = Column(String(1000))
    user_id = Column(Integer, ForeignKey('ab_user.id'))
    owner = relationship(sm.user_model,
                         backref='tables',
                         foreign_keys=[user_id])
    database = relationship('Database',
                            backref=backref('tables',
                                            cascade='all, delete-orphan'),
                            foreign_keys=[database_id])
    schema = Column(String(255))
    sql = Column(Text)

    baselink = 'tablemodelview'

    export_fields = ('table_name', 'main_dttm_col', 'description',
                     'default_endpoint', 'database_id', 'offset',
                     'cache_timeout', 'schema', 'sql', 'params')
    export_parent = 'database'
    export_children = ['metrics', 'columns']

    def __repr__(self):
        return self.name

    @property
    def connection(self):
        return str(self.database)

    @property
    def description_markeddown(self):
        return utils.markdown(self.description)

    @property
    def link(self):
        name = escape(self.name)
        return Markup(
            '<a href="{self.explore_url}">{name}</a>'.format(**locals()))

    @property
    def schema_perm(self):
        """Returns schema permission if present, database one otherwise."""
        return utils.get_schema_perm(self.database, self.schema)

    def get_perm(self):
        return ('[{obj.database}].[{obj.table_name}]'
                '(id:{obj.id})').format(obj=self)

    @property
    def name(self):
        if not self.schema:
            return self.table_name
        return '{}.{}'.format(self.schema, self.table_name)

    @property
    def full_name(self):
        return utils.get_datasource_full_name(self.database,
                                              self.table_name,
                                              schema=self.schema)

    @property
    def dttm_cols(self):
        l = [c.column_name for c in self.columns if c.is_dttm]  # noqa: E741
        if self.main_dttm_col and self.main_dttm_col not in l:
            l.append(self.main_dttm_col)
        return l

    @property
    def num_cols(self):
        return [c.column_name for c in self.columns if c.is_num]

    @property
    def any_dttm_col(self):
        cols = self.dttm_cols
        if cols:
            return cols[0]

    @property
    def html(self):
        t = ((c.column_name, c.type) for c in self.columns)
        df = pd.DataFrame(t)
        df.columns = ['field', 'type']
        return df.to_html(
            index=False,
            classes=('dataframe table table-striped table-bordered '
                     'table-condensed'))

    @property
    def sql_url(self):
        return self.database.sql_url + '?table_name=' + str(self.table_name)

    @property
    def time_column_grains(self):
        return {
            'time_columns': self.dttm_cols,
            'time_grains': [grain.name for grain in self.database.grains()],
        }

    def get_col(self, col_name):
        columns = self.columns
        for col in columns:
            if col_name == col.column_name:
                return col

    @property
    def data(self):
        d = super(SqlaTable, self).data
        if self.type == 'table':
            grains = self.database.grains() or []
            if grains:
                grains = [(g.name, g.name) for g in grains]
            d['granularity_sqla'] = utils.choicify(self.dttm_cols)
            d['time_grain_sqla'] = grains
        return d

    def values_for_column(self, column_name, limit=10000):
        """Runs query against sqla to retrieve some
        sample values for the given column.
        """
        cols = {col.column_name: col for col in self.columns}
        target_col = cols[column_name]
        tp = self.get_template_processor()
        db_engine_spec = self.database.db_engine_spec

        qry = (select([target_col.sqla_col]).select_from(
            self.get_from_clause(tp, db_engine_spec)).distinct(column_name))
        if limit:
            qry = qry.limit(limit)

        if self.fetch_values_predicate:
            tp = self.get_template_processor()
            qry = qry.where(tp.process_template(self.fetch_values_predicate))

        engine = self.database.get_sqla_engine()
        sql = '{}'.format(
            qry.compile(
                engine,
                compile_kwargs={'literal_binds': True},
            ), )

        df = pd.read_sql_query(sql=sql, con=engine)
        return [row[0] for row in df.to_records(index=False)]

    def get_template_processor(self, **kwargs):
        return get_template_processor(table=self,
                                      database=self.database,
                                      **kwargs)

    def get_query_str(self, query_obj):
        engine = self.database.get_sqla_engine()
        qry = self.get_sqla_query(**query_obj)
        sql = six.text_type(
            qry.compile(
                engine,
                compile_kwargs={'literal_binds': True},
            ), )
        logging.info(sql)
        sql = sqlparse.format(sql, reindent=True)
        return sql

    def get_sqla_table(self):
        tbl = table(self.table_name)
        if self.schema:
            tbl.schema = self.schema
        return tbl

    def get_from_clause(self, template_processor=None, db_engine_spec=None):
        # Supporting arbitrary SQL statements in place of tables
        if self.sql:
            from_sql = self.sql
            if template_processor:
                from_sql = template_processor.process_template(from_sql)
            if db_engine_spec:
                from_sql = db_engine_spec.escape_sql(from_sql)
            return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
        return self.get_sqla_table()

    def get_sqla_query(  # sqla
            self,
            groupby,
            metrics,
            granularity,
            from_dttm,
            to_dttm,
            filter=None,  # noqa
            is_timeseries=True,
            timeseries_limit=15,
            timeseries_limit_metric=None,
            row_limit=None,
            inner_from_dttm=None,
            inner_to_dttm=None,
            orderby=None,
            extras=None,
            columns=None,
            form_data=None,
            order_desc=True):
        """Querying any sqla table from this common interface"""
        template_kwargs = {
            'from_dttm': from_dttm,
            'groupby': groupby,
            'metrics': metrics,
            'row_limit': row_limit,
            'to_dttm': to_dttm,
            'form_data': form_data,
        }
        template_processor = self.get_template_processor(**template_kwargs)
        db_engine_spec = self.database.db_engine_spec

        orderby = orderby or []

        # For backward compatibility
        if granularity not in self.dttm_cols:
            granularity = self.main_dttm_col

        # Database spec supports join-free timeslot grouping
        time_groupby_inline = db_engine_spec.time_groupby_inline

        cols = {col.column_name: col for col in self.columns}
        metrics_dict = {m.metric_name: m for m in self.metrics}

        if not granularity and is_timeseries:
            raise Exception(
                _('Datetime column not provided as part table configuration '
                  'and is required by this type of chart'))
        if not groupby and not metrics and not columns:
            raise Exception(_('Empty query?'))
        for m in metrics:
            if m not in metrics_dict:
                raise Exception(_("Metric '{}' is not valid".format(m)))
        metrics_exprs = [metrics_dict.get(m).sqla_col for m in metrics]
        if metrics_exprs:
            main_metric_expr = metrics_exprs[0]
        else:
            main_metric_expr = literal_column('COUNT(*)').label('ccount')

        select_exprs = []
        groupby_exprs = []

        if groupby:
            select_exprs = []
            inner_select_exprs = []
            inner_groupby_exprs = []
            for s in groupby:
                col = cols[s]
                outer = col.sqla_col
                inner = col.sqla_col.label(col.column_name + '__')

                groupby_exprs.append(outer)
                select_exprs.append(outer)
                inner_groupby_exprs.append(inner)
                inner_select_exprs.append(inner)
        elif columns:
            for s in columns:
                select_exprs.append(cols[s].sqla_col)
            metrics_exprs = []

        if granularity:
            dttm_col = cols[granularity]
            time_grain = extras.get('time_grain_sqla')
            time_filters = []

            if is_timeseries:
                timestamp = dttm_col.get_timestamp_expression(time_grain)
                select_exprs += [timestamp]
                groupby_exprs += [timestamp]

            # Use main dttm column to support index with secondary dttm columns
            if db_engine_spec.time_secondary_columns and \
                    self.main_dttm_col in self.dttm_cols and \
                    self.main_dttm_col != dttm_col.column_name:
                time_filters.append(cols[self.main_dttm_col].get_time_filter(
                    from_dttm, to_dttm))
            time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm))

        select_exprs += metrics_exprs
        qry = sa.select(select_exprs)

        tbl = self.get_from_clause(template_processor, db_engine_spec)

        if not columns:
            qry = qry.group_by(*groupby_exprs)

        where_clause_and = []
        having_clause_and = []
        for flt in filter:
            if not all([flt.get(s) for s in ['col', 'op', 'val']]):
                continue
            col = flt['col']
            op = flt['op']
            eq = flt['val']
            col_obj = cols.get(col)
            if col_obj:
                if op in ('in', 'not in'):
                    values = []
                    for v in eq:
                        # For backwards compatibility and edge cases
                        # where a column data type might have changed
                        if isinstance(v, basestring):
                            v = v.strip("'").strip('"')
                            if col_obj.is_num:
                                v = utils.string_to_num(v)

                        # Removing empty strings and non numeric values
                        # targeting numeric columns
                        if v is not None:
                            values.append(v)
                    cond = col_obj.sqla_col.in_(values)
                    if op == 'not in':
                        cond = ~cond
                    where_clause_and.append(cond)
                else:
                    if col_obj.is_num:
                        eq = utils.string_to_num(flt['val'])
                    if op == '==':
                        where_clause_and.append(col_obj.sqla_col == eq)
                    elif op == '!=':
                        where_clause_and.append(col_obj.sqla_col != eq)
                    elif op == '>':
                        where_clause_and.append(col_obj.sqla_col > eq)
                    elif op == '<':
                        where_clause_and.append(col_obj.sqla_col < eq)
                    elif op == '>=':
                        where_clause_and.append(col_obj.sqla_col >= eq)
                    elif op == '<=':
                        where_clause_and.append(col_obj.sqla_col <= eq)
                    elif op == 'LIKE':
                        where_clause_and.append(col_obj.sqla_col.like(eq))
        if extras:
            where = extras.get('where')
            if where:
                where = template_processor.process_template(where)
                where_clause_and += [sa.text('({})'.format(where))]
            having = extras.get('having')
            if having:
                having = template_processor.process_template(having)
                having_clause_and += [sa.text('({})'.format(having))]
        if granularity:
            qry = qry.where(and_(*(time_filters + where_clause_and)))
        else:
            qry = qry.where(and_(*where_clause_and))
        qry = qry.having(and_(*having_clause_and))

        if not orderby and not columns:
            orderby = [(main_metric_expr, not order_desc)]

        for col, ascending in orderby:
            direction = asc if ascending else desc
            qry = qry.order_by(direction(col))

        if row_limit:
            qry = qry.limit(row_limit)

        if is_timeseries and \
                timeseries_limit and groupby and not time_groupby_inline:
            # some sql dialects require for order by expressions
            # to also be in the select clause -- others, e.g. vertica,
            # require a unique inner alias
            inner_main_metric_expr = main_metric_expr.label('mme_inner__')
            inner_select_exprs += [inner_main_metric_expr]
            subq = select(inner_select_exprs)
            subq = subq.select_from(tbl)
            inner_time_filter = dttm_col.get_time_filter(
                inner_from_dttm or from_dttm,
                inner_to_dttm or to_dttm,
            )
            subq = subq.where(and_(*(where_clause_and + [inner_time_filter])))
            subq = subq.group_by(*inner_groupby_exprs)

            ob = inner_main_metric_expr
            if timeseries_limit_metric:
                timeseries_limit_metric = metrics_dict.get(
                    timeseries_limit_metric)
                ob = timeseries_limit_metric.sqla_col
            direction = desc if order_desc else asc
            subq = subq.order_by(direction(ob))
            subq = subq.limit(timeseries_limit)

            on_clause = []
            for i, gb in enumerate(groupby):
                on_clause.append(groupby_exprs[i] == column(gb + '__'))

            tbl = tbl.join(subq.alias(), and_(*on_clause))

        return qry.select_from(tbl)

    def query(self, query_obj):
        qry_start_dttm = datetime.now()
        sql = self.get_query_str(query_obj)
        status = QueryStatus.SUCCESS
        error_message = None
        df = None
        try:
            df = self.database.get_df(sql, self.schema)
        except Exception as e:
            status = QueryStatus.FAILED
            logging.exception(e)
            error_message = (
                self.database.db_engine_spec.extract_error_message(e))

        return QueryResult(status=status,
                           df=df,
                           duration=datetime.now() - qry_start_dttm,
                           query=sql,
                           error_message=error_message)

    def get_sqla_table_object(self):
        return self.database.get_table(self.table_name, schema=self.schema)

    def fetch_metadata(self):
        """Fetches the metadata for the table and merges it in"""
        try:
            table = self.get_sqla_table_object()
        except Exception:
            raise Exception(
                _("Table [{}] doesn't seem to exist in the specified database, "
                  "couldn't fetch column information").format(self.table_name))

        M = SqlMetric  # noqa
        metrics = []
        any_date_col = None
        db_dialect = self.database.get_dialect()
        dbcols = (db.session.query(TableColumn).filter(
            TableColumn.table == self).filter(
                or_(TableColumn.column_name == col.name
                    for col in table.columns)))
        dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}

        for col in table.columns:
            try:
                datatype = col.type.compile(dialect=db_dialect).upper()
            except Exception as e:
                datatype = 'UNKNOWN'
                logging.error('Unrecognized data type in {}.{}'.format(
                    table, col.name))
                logging.exception(e)
            dbcol = dbcols.get(col.name, None)
            if not dbcol:
                dbcol = TableColumn(column_name=col.name, type=datatype)
                dbcol.groupby = dbcol.is_string
                dbcol.filterable = dbcol.is_string
                dbcol.sum = dbcol.is_num
                dbcol.avg = dbcol.is_num
                dbcol.is_dttm = dbcol.is_time
            self.columns.append(dbcol)
            if not any_date_col and dbcol.is_time:
                any_date_col = col.name

            quoted = str(col.compile(dialect=db_dialect))
            if dbcol.sum:
                metrics.append(
                    M(
                        metric_name='sum__' + dbcol.column_name,
                        verbose_name='sum__' + dbcol.column_name,
                        metric_type='sum',
                        expression='SUM({})'.format(quoted),
                    ))
            if dbcol.avg:
                metrics.append(
                    M(
                        metric_name='avg__' + dbcol.column_name,
                        verbose_name='avg__' + dbcol.column_name,
                        metric_type='avg',
                        expression='AVG({})'.format(quoted),
                    ))
            if dbcol.max:
                metrics.append(
                    M(
                        metric_name='max__' + dbcol.column_name,
                        verbose_name='max__' + dbcol.column_name,
                        metric_type='max',
                        expression='MAX({})'.format(quoted),
                    ))
            if dbcol.min:
                metrics.append(
                    M(
                        metric_name='min__' + dbcol.column_name,
                        verbose_name='min__' + dbcol.column_name,
                        metric_type='min',
                        expression='MIN({})'.format(quoted),
                    ))
            if dbcol.count_distinct:
                metrics.append(
                    M(
                        metric_name='count_distinct__' + dbcol.column_name,
                        verbose_name='count_distinct__' + dbcol.column_name,
                        metric_type='count_distinct',
                        expression='COUNT(DISTINCT {})'.format(quoted),
                    ))
            dbcol.type = datatype

        metrics.append(
            M(
                metric_name='count',
                verbose_name='COUNT(*)',
                metric_type='count',
                expression='COUNT(*)',
            ))

        dbmetrics = db.session.query(M).filter(M.table_id == self.id).filter(
            or_(M.metric_name == metric.metric_name for metric in metrics))
        dbmetrics = {metric.metric_name: metric for metric in dbmetrics}
        for metric in metrics:
            metric.table_id = self.id
            if not dbmetrics.get(metric.metric_name, None):
                db.session.add(metric)
        if not self.main_dttm_col:
            self.main_dttm_col = any_date_col
        db.session.merge(self)
        db.session.commit()

    @classmethod
    def import_obj(cls, i_datasource, import_time=None):
        """Imports the datasource from the object to the database.

         Metrics and columns and datasource will be overrided if exists.
         This function can be used to import/export dashboards between multiple
         superset instances. Audit metadata isn't copies over.
        """
        def lookup_sqlatable(table):
            return db.session.query(SqlaTable).join(Database).filter(
                SqlaTable.table_name == table.table_name,
                SqlaTable.schema == table.schema,
                Database.id == table.database_id,
            ).first()

        def lookup_database(table):
            return db.session.query(Database).filter_by(
                database_name=table.params_dict['database_name']).one()

        return import_util.import_datasource(db.session, i_datasource,
                                             lookup_database, lookup_sqlatable,
                                             import_time)

    @classmethod
    def query_datasources_by_name(cls,
                                  session,
                                  database,
                                  datasource_name,
                                  schema=None):
        query = (session.query(cls).filter_by(
            database_id=database.id).filter_by(table_name=datasource_name))
        if schema:
            query = query.filter_by(schema=schema)
        return query.all()
예제 #15
0
class Database(Model, AuditMixinNullable, ImportMixin):

    """An ORM object that stores Database related information"""

    __tablename__ = "dbs"
    type = "table"
    __table_args__ = (UniqueConstraint("database_name"),)

    id = Column(Integer, primary_key=True)
    verbose_name = Column(String(250), unique=True)
    # short unique name, used in permissions
    database_name = Column(String(250), unique=True, nullable=False)
    sqlalchemy_uri = Column(String(1024))
    password = Column(EncryptedType(String(1024), config.get("SECRET_KEY")))
    cache_timeout = Column(Integer)
    select_as_create_table_as = Column(Boolean, default=False)
    expose_in_sqllab = Column(Boolean, default=True)
    allow_run_async = Column(Boolean, default=False)
    allow_csv_upload = Column(Boolean, default=False)
    allow_ctas = Column(Boolean, default=False)
    allow_dml = Column(Boolean, default=False)
    force_ctas_schema = Column(String(250))
    allow_multi_schema_metadata_fetch = Column(Boolean, default=False)
    extra = Column(
        Text,
        default=textwrap.dedent(
            """\
    {
        "metadata_params": {},
        "engine_params": {},
        "metadata_cache_timeout": {},
        "schemas_allowed_for_csv_upload": []
    }
    """
        ),
    )
    perm = Column(String(1000))
    impersonate_user = Column(Boolean, default=False)
    export_fields = (
        "database_name",
        "sqlalchemy_uri",
        "cache_timeout",
        "expose_in_sqllab",
        "allow_run_async",
        "allow_ctas",
        "allow_csv_upload",
        "extra",
    )
    export_children = ["tables"]

    def __repr__(self):
        return self.name

    @property
    def name(self):
        return self.verbose_name if self.verbose_name else self.database_name

    @property
    def allows_subquery(self):
        return self.db_engine_spec.allows_subqueries

    @property
    def allows_cost_estimate(self) -> bool:
        extra = self.get_extra()
        database_version = extra.get("version")
        cost_estimate_enabled = extra.get("cost_estimate_enabled")
        return (
            self.db_engine_spec.get_allow_cost_estimate(database_version)
            and cost_estimate_enabled
        )

    @property
    def data(self):
        return {
            "id": self.id,
            "name": self.database_name,
            "backend": self.backend,
            "allow_multi_schema_metadata_fetch": self.allow_multi_schema_metadata_fetch,
            "allows_subquery": self.allows_subquery,
            "allows_cost_estimate": self.allows_cost_estimate,
        }

    @property
    def unique_name(self):
        return self.database_name

    @property
    def url_object(self):
        return make_url(self.sqlalchemy_uri_decrypted)

    @property
    def backend(self):
        url = make_url(self.sqlalchemy_uri_decrypted)
        return url.get_backend_name()

    @property
    def metadata_cache_timeout(self):
        return self.get_extra().get("metadata_cache_timeout", {})

    @property
    def schema_cache_enabled(self):
        return "schema_cache_timeout" in self.metadata_cache_timeout

    @property
    def schema_cache_timeout(self):
        return self.metadata_cache_timeout.get("schema_cache_timeout")

    @property
    def table_cache_enabled(self):
        return "table_cache_timeout" in self.metadata_cache_timeout

    @property
    def table_cache_timeout(self):
        return self.metadata_cache_timeout.get("table_cache_timeout")

    @property
    def default_schemas(self):
        return self.get_extra().get("default_schemas", [])

    @classmethod
    def get_password_masked_url_from_uri(cls, uri):
        url = make_url(uri)
        return cls.get_password_masked_url(url)

    @classmethod
    def get_password_masked_url(cls, url):
        url_copy = deepcopy(url)
        if url_copy.password is not None and url_copy.password != PASSWORD_MASK:
            url_copy.password = PASSWORD_MASK
        return url_copy

    def set_sqlalchemy_uri(self, uri):
        conn = sqla.engine.url.make_url(uri.strip())
        if conn.password != PASSWORD_MASK and not custom_password_store:
            # do not over-write the password with the password mask
            self.password = conn.password
        conn.password = PASSWORD_MASK if conn.password else None
        self.sqlalchemy_uri = str(conn)  # hides the password

    def get_effective_user(self, url, user_name=None):
        """
        Get the effective user, especially during impersonation.
        :param url: SQL Alchemy URL object
        :param user_name: Default username
        :return: The effective username
        """
        effective_username = None
        if self.impersonate_user:
            effective_username = url.username
            if user_name:
                effective_username = user_name
            elif (
                hasattr(g, "user")
                and hasattr(g.user, "username")
                and g.user.username is not None
            ):
                effective_username = g.user.username
        return effective_username

    @utils.memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted", "extra"))
    def get_sqla_engine(self, schema=None, nullpool=True, user_name=None, source=None):
        extra = self.get_extra()
        url = make_url(self.sqlalchemy_uri_decrypted)
        url = self.db_engine_spec.adjust_database_uri(url, schema)
        effective_username = self.get_effective_user(url, user_name)
        # If using MySQL or Presto for example, will set url.username
        # If using Hive, will not do anything yet since that relies on a
        # configuration parameter instead.
        self.db_engine_spec.modify_url_for_impersonation(
            url, self.impersonate_user, effective_username
        )

        masked_url = self.get_password_masked_url(url)
        logging.info("Database.get_sqla_engine(). Masked URL: {0}".format(masked_url))

        params = extra.get("engine_params", {})
        if nullpool:
            params["poolclass"] = NullPool

        # If using Hive, this will set hive.server2.proxy.user=$effective_username
        configuration = {}
        configuration.update(
            self.db_engine_spec.get_configuration_for_impersonation(
                str(url), self.impersonate_user, effective_username
            )
        )
        if configuration:
            d = params.get("connect_args", {})
            d["configuration"] = configuration
            params["connect_args"] = d

        DB_CONNECTION_MUTATOR = config.get("DB_CONNECTION_MUTATOR")
        if DB_CONNECTION_MUTATOR:
            url, params = DB_CONNECTION_MUTATOR(
                url, params, effective_username, security_manager, source
            )
        return create_engine(url, **params)

    def get_reserved_words(self):
        return self.get_dialect().preparer.reserved_words

    def get_quoter(self):
        return self.get_dialect().identifier_preparer.quote

    def get_df(self, sql, schema, mutator=None):
        sqls = [str(s).strip().strip(";") for s in sqlparse.parse(sql)]
        source_key = None
        if request and request.referrer:
            if "/superset/dashboard/" in request.referrer:
                source_key = "dashboard"
            elif "/superset/explore/" in request.referrer:
                source_key = "chart"
        engine = self.get_sqla_engine(
            schema=schema, source=utils.sources.get(source_key, None)
        )
        username = utils.get_username()

        def needs_conversion(df_series):
            if df_series.empty:
                return False
            if isinstance(df_series[0], (list, dict)):
                return True
            return False

        def _log_query(sql):
            if log_query:
                log_query(engine.url, sql, schema, username, __name__, security_manager)

        with closing(engine.raw_connection()) as conn:
            with closing(conn.cursor()) as cursor:
                for sql in sqls[:-1]:
                    _log_query(sql)
                    self.db_engine_spec.execute(cursor, sql)
                    cursor.fetchall()

                _log_query(sqls[-1])
                self.db_engine_spec.execute(cursor, sqls[-1])

                if cursor.description is not None:
                    columns = [col_desc[0] for col_desc in cursor.description]
                else:
                    columns = []

                df = pd.DataFrame.from_records(
                    data=list(cursor.fetchall()), columns=columns, coerce_float=True
                )

                if mutator:
                    df = mutator(df)

                for k, v in df.dtypes.items():
                    if v.type == numpy.object_ and needs_conversion(df[k]):
                        df[k] = df[k].apply(utils.json_dumps_w_dates)
                return df

    def compile_sqla_query(self, qry, schema=None):
        engine = self.get_sqla_engine(schema=schema)

        sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))

        if engine.dialect.identifier_preparer._double_percents:
            sql = sql.replace("%%", "%")

        return sql

    def select_star(
        self,
        table_name,
        schema=None,
        limit=100,
        show_cols=False,
        indent=True,
        latest_partition=False,
        cols=None,
    ):
        """Generates a ``select *`` statement in the proper dialect"""
        eng = self.get_sqla_engine(
            schema=schema, source=utils.sources.get("sql_lab", None)
        )
        return self.db_engine_spec.select_star(
            self,
            table_name,
            schema=schema,
            engine=eng,
            limit=limit,
            show_cols=show_cols,
            indent=indent,
            latest_partition=latest_partition,
            cols=cols,
        )

    def apply_limit_to_sql(self, sql, limit=1000):
        return self.db_engine_spec.apply_limit_to_sql(sql, limit, self)

    def safe_sqlalchemy_uri(self):
        return self.sqlalchemy_uri

    @property
    def inspector(self):
        engine = self.get_sqla_engine()
        return sqla.inspect(engine)

    @cache_util.memoized_func(
        key=lambda *args, **kwargs: "db:{}:schema:None:table_list",
        attribute_in_key="id",
    )
    def get_all_table_names_in_database(
        self, cache: bool = False, cache_timeout: bool = None, force=False
    ) -> List[utils.DatasourceName]:
        """Parameters need to be passed as keyword arguments."""
        if not self.allow_multi_schema_metadata_fetch:
            return []
        return self.db_engine_spec.get_all_datasource_names(self, "table")

    @cache_util.memoized_func(
        key=lambda *args, **kwargs: "db:{}:schema:None:view_list", attribute_in_key="id"
    )
    def get_all_view_names_in_database(
        self, cache: bool = False, cache_timeout: bool = None, force: bool = False
    ) -> List[utils.DatasourceName]:
        """Parameters need to be passed as keyword arguments."""
        if not self.allow_multi_schema_metadata_fetch:
            return []
        return self.db_engine_spec.get_all_datasource_names(self, "view")

    @cache_util.memoized_func(
        key=lambda *args, **kwargs: "db:{{}}:schema:{}:table_list".format(
            kwargs.get("schema")
        ),
        attribute_in_key="id",
    )
    def get_all_table_names_in_schema(
        self,
        schema: str,
        cache: bool = False,
        cache_timeout: int = None,
        force: bool = False,
    ):
        """Parameters need to be passed as keyword arguments.

        For unused parameters, they are referenced in
        cache_util.memoized_func decorator.

        :param schema: schema name
        :param cache: whether cache is enabled for the function
        :param cache_timeout: timeout in seconds for the cache
        :param force: whether to force refresh the cache
        :return: list of tables
        """
        try:
            tables = self.db_engine_spec.get_table_names(
                database=self, inspector=self.inspector, schema=schema
            )
            return [
                utils.DatasourceName(table=table, schema=schema) for table in tables
            ]
        except Exception as e:
            logging.exception(e)

    @cache_util.memoized_func(
        key=lambda *args, **kwargs: "db:{{}}:schema:{}:view_list".format(
            kwargs.get("schema")
        ),
        attribute_in_key="id",
    )
    def get_all_view_names_in_schema(
        self,
        schema: str,
        cache: bool = False,
        cache_timeout: int = None,
        force: bool = False,
    ):
        """Parameters need to be passed as keyword arguments.

        For unused parameters, they are referenced in
        cache_util.memoized_func decorator.

        :param schema: schema name
        :param cache: whether cache is enabled for the function
        :param cache_timeout: timeout in seconds for the cache
        :param force: whether to force refresh the cache
        :return: list of views
        """
        try:
            views = self.db_engine_spec.get_view_names(
                database=self, inspector=self.inspector, schema=schema
            )
            return [utils.DatasourceName(table=view, schema=schema) for view in views]
        except Exception as e:
            logging.exception(e)

    @cache_util.memoized_func(
        key=lambda *args, **kwargs: "db:{}:schema_list", attribute_in_key="id"
    )
    def get_all_schema_names(
        self, cache: bool = False, cache_timeout: int = None, force: bool = False
    ) -> List[str]:
        """Parameters need to be passed as keyword arguments.

        For unused parameters, they are referenced in
        cache_util.memoized_func decorator.

        :param cache: whether cache is enabled for the function
        :param cache_timeout: timeout in seconds for the cache
        :param force: whether to force refresh the cache
        :return: schema list
        """
        return self.db_engine_spec.get_schema_names(self.inspector)

    @property
    def db_engine_spec(self):
        return db_engine_specs.engines.get(self.backend, db_engine_specs.BaseEngineSpec)

    @classmethod
    def get_db_engine_spec_for_backend(cls, backend):
        return db_engine_specs.engines.get(backend, db_engine_specs.BaseEngineSpec)

    def grains(self):
        """Defines time granularity database-specific expressions.

        The idea here is to make it easy for users to change the time grain
        from a datetime (maybe the source grain is arbitrary timestamps, daily
        or 5 minutes increments) to another, "truncated" datetime. Since
        each database has slightly different but similar datetime functions,
        this allows a mapping between database engines and actual functions.
        """
        return self.db_engine_spec.get_time_grains()

    def get_extra(self):
        extra = {}
        if self.extra:
            try:
                extra = json.loads(self.extra)
            except Exception as e:
                logging.error(e)
                raise e
        return extra

    def get_table(self, table_name, schema=None):
        extra = self.get_extra()
        meta = MetaData(**extra.get("metadata_params", {}))
        return Table(
            table_name,
            meta,
            schema=schema or None,
            autoload=True,
            autoload_with=self.get_sqla_engine(),
        )

    def get_columns(self, table_name, schema=None):
        return self.db_engine_spec.get_columns(self.inspector, table_name, schema)

    def get_indexes(self, table_name, schema=None):
        return self.inspector.get_indexes(table_name, schema)

    def get_pk_constraint(self, table_name, schema=None):
        return self.inspector.get_pk_constraint(table_name, schema)

    def get_foreign_keys(self, table_name, schema=None):
        return self.inspector.get_foreign_keys(table_name, schema)

    def get_schema_access_for_csv_upload(self):
        return self.get_extra().get("schemas_allowed_for_csv_upload", [])

    @property
    def sqlalchemy_uri_decrypted(self):
        conn = sqla.engine.url.make_url(self.sqlalchemy_uri)
        if custom_password_store:
            conn.password = custom_password_store(conn)
        else:
            conn.password = self.password
        return str(conn)

    @property
    def sql_url(self):
        return "/superset/sql/{}/".format(self.id)

    def get_perm(self):
        return ("[{obj.database_name}].(id:{obj.id})").format(obj=self)

    def has_table(self, table):
        engine = self.get_sqla_engine()
        return engine.has_table(table.table_name, table.schema or None)

    def has_table_by_name(self, table_name, schema=None):
        engine = self.get_sqla_engine()
        return engine.has_table(table_name, schema)

    @utils.memoized
    def get_dialect(self):
        sqla_url = url.make_url(self.sqlalchemy_uri_decrypted)
        return sqla_url.get_dialect()()
예제 #16
0
        return "/superset/explore/?form_data=%7B%22slice_id%22%3A%20{0}%7D".format(
            self.id
        )


sqla.event.listen(Slice, "before_insert", set_related_perm)
sqla.event.listen(Slice, "before_update", set_related_perm)


dashboard_slices = Table(
    "dashboard_slices",
    metadata,
    Column("id", Integer, primary_key=True),
    Column("dashboard_id", Integer, ForeignKey("dashboards.id")),
    Column("slice_id", Integer, ForeignKey("slices.id")),
    UniqueConstraint("dashboard_id", "slice_id"),
)

dashboard_user = Table(
    "dashboard_user",
    metadata,
    Column("id", Integer, primary_key=True),
    Column("user_id", Integer, ForeignKey("ab_user.id")),
    Column("dashboard_id", Integer, ForeignKey("dashboards.id")),
)


class Dashboard(Model, AuditMixinNullable, ImportMixin):

    """The dashboard object!"""
예제 #17
0
class Label(MailSyncBase, UpdatedAtMixin, DeletedAtMixin):
    """ Labels from the remote account backend (Gmail). """
    # TOFIX this causes an import error due to circular dependencies
    # from inbox.models.account import Account
    # `use_alter` required here to avoid circular dependency w/Account
    account_id = Column(ForeignKey('account.id',
                                   use_alter=True,
                                   name='label_fk1',
                                   ondelete='CASCADE'),
                        nullable=False)
    account = relationship(
        'Account',
        backref=backref(
            'labels',
            # Don't load labels if the account is deleted,
            # (the labels will be deleted by the foreign key delete casade).
            passive_deletes=True),
        load_on_pending=True)

    name = Column(CategoryNameString(), nullable=False)
    canonical_name = Column(String(MAX_INDEXABLE_LENGTH),
                            nullable=False,
                            default='')

    category_id = Column(ForeignKey(Category.id, ondelete='CASCADE'))
    category = relationship(Category,
                            backref=backref('labels',
                                            cascade='all, delete-orphan'))

    @validates('name')
    def validate_name(self, key, name):
        sanitized_name = sanitize_name(name)
        if sanitized_name != name:
            log.warning("Truncating label name for account",
                        account_id=self.account_id,
                        name=name)
        return sanitized_name

    @classmethod
    def find_or_create(cls, session, account, name, role=None):
        q = session.query(cls).filter(cls.account_id == account.id)

        role = role or ''
        if role:
            q = q.filter(cls.canonical_name == role)
        else:
            q = q.filter(cls.name == name)

        obj = q.first()
        if obj is None:
            obj = cls(account=account, name=name, canonical_name=role)
            obj.category = Category.find_or_create(
                session,
                namespace_id=account.namespace.id,
                name=role,
                display_name=name,
                type_='label')
            session.add(obj)
        return obj

    __table_args__ = \
        (UniqueConstraint('account_id', 'name', 'canonical_name'),)
예제 #18
0
class SqlaTable(Model, BaseDatasource):
    """An ORM object for SqlAlchemy table references"""

    type = "table"
    query_language = "sql"
    metric_class = SqlMetric
    column_class = TableColumn
    owner_class = security_manager.user_model

    __tablename__ = "tables"
    __table_args__ = (UniqueConstraint("database_id", "table_name"), )

    table_name = Column(String(250), nullable=False)
    main_dttm_col = Column(String(250))
    database_id = Column(Integer, ForeignKey("dbs.id"), nullable=False)
    fetch_values_predicate = Column(String(1000))
    owners = relationship(owner_class,
                          secondary=sqlatable_user,
                          backref="tables")
    database = relationship(
        "Database",
        backref=backref("tables", cascade="all, delete-orphan"),
        foreign_keys=[database_id],
    )
    schema = Column(String(255))
    sql = Column(Text)
    is_sqllab_view = Column(Boolean, default=False)
    template_params = Column(Text)

    baselink = "tablemodelview"

    export_fields = [
        "table_name",
        "main_dttm_col",
        "description",
        "default_endpoint",
        "database_id",
        "offset",
        "cache_timeout",
        "schema",
        "sql",
        "params",
        "template_params",
        "filter_select_enabled",
        "fetch_values_predicate",
    ]
    update_from_object_fields = [
        f for f in export_fields if f not in ("table_name", "database_id")
    ]
    export_parent = "database"
    export_children = ["metrics", "columns"]

    sqla_aggregations = {
        "COUNT_DISTINCT":
        lambda column_name: sa.func.COUNT(sa.distinct(column_name)),
        "COUNT":
        sa.func.COUNT,
        "SUM":
        sa.func.SUM,
        "AVG":
        sa.func.AVG,
        "MIN":
        sa.func.MIN,
        "MAX":
        sa.func.MAX,
    }

    def make_sqla_column_compatible(self,
                                    sqla_col: Column,
                                    label: Optional[str] = None) -> Column:
        """Takes a sqlalchemy column object and adds label info if supported by engine.
        :param sqla_col: sqlalchemy column instance
        :param label: alias/label that column is expected to have
        :return: either a sql alchemy column or label instance if supported by engine
        """
        label_expected = label or sqla_col.name
        db_engine_spec = self.database.db_engine_spec
        if db_engine_spec.allows_column_aliases:
            label = db_engine_spec.make_label_compatible(label_expected)
            sqla_col = sqla_col.label(label)
        sqla_col._df_label_expected = label_expected
        return sqla_col

    def __repr__(self):
        return self.name

    @property
    def changed_by_name(self) -> str:
        if not self.changed_by:
            return ""
        return str(self.changed_by)

    @property
    def changed_by_url(self) -> str:
        if not self.changed_by:
            return ""
        return f"/superset/profile/{self.changed_by.username}"

    @property
    def connection(self) -> str:
        return str(self.database)

    @property
    def description_markeddown(self) -> str:
        return utils.markdown(self.description)

    @property
    def datasource_name(self) -> str:
        return self.table_name

    @property
    def database_name(self) -> str:
        return self.database.name

    @classmethod
    def get_datasource_by_name(
        cls,
        session: Session,
        datasource_name: str,
        schema: Optional[str],
        database_name: str,
    ) -> Optional["SqlaTable"]:
        schema = schema or None
        query = (session.query(cls).join(Database).filter(
            cls.table_name == datasource_name).filter(
                Database.database_name == database_name))
        # Handling schema being '' or None, which is easier to handle
        # in python than in the SQLA query in a multi-dialect way
        for tbl in query.all():
            if schema == (tbl.schema or None):
                return tbl
        return None

    @property
    def link(self) -> Markup:
        name = escape(self.name)
        anchor = f'<a target="_blank" href="{self.explore_url}">{name}</a>'
        return Markup(anchor)

    def get_schema_perm(self) -> Optional[str]:
        """Returns schema permission if present, database one otherwise."""
        return security_manager.get_schema_perm(self.database, self.schema)

    def get_perm(self) -> str:
        return ("[{obj.database}].[{obj.table_name}]"
                "(id:{obj.id})").format(obj=self)

    @property
    def name(self) -> str:  # type: ignore
        if not self.schema:
            return self.table_name
        return "{}.{}".format(self.schema, self.table_name)

    @property
    def full_name(self) -> str:
        return utils.get_datasource_full_name(self.database,
                                              self.table_name,
                                              schema=self.schema)

    @property
    def dttm_cols(self) -> List:
        l = [c.column_name for c in self.columns if c.is_dttm]
        if self.main_dttm_col and self.main_dttm_col not in l:
            l.append(self.main_dttm_col)
        return l

    @property
    def num_cols(self) -> List:
        return [c.column_name for c in self.columns if c.is_numeric]

    @property
    def any_dttm_col(self) -> Optional[str]:
        cols = self.dttm_cols
        return cols[0] if cols else None

    @property
    def html(self) -> str:
        t = ((c.column_name, c.type) for c in self.columns)
        df = pd.DataFrame(t)
        df.columns = ["field", "type"]
        return df.to_html(
            index=False,
            classes=("dataframe table table-striped table-bordered "
                     "table-condensed"),
        )

    @property
    def sql_url(self) -> str:
        return self.database.sql_url + "?table_name=" + str(self.table_name)

    def external_metadata(self):
        cols = self.database.get_columns(self.table_name, schema=self.schema)
        for col in cols:
            try:
                col["type"] = str(col["type"])
            except CompileError:
                col["type"] = "UNKNOWN"
        return cols

    @property
    def time_column_grains(self) -> Dict[str, Any]:
        return {
            "time_columns": self.dttm_cols,
            "time_grains": [grain.name for grain in self.database.grains()],
        }

    @property
    def select_star(self) -> str:
        # show_cols and latest_partition set to false to avoid
        # the expensive cost of inspecting the DB
        return self.database.select_star(self.table_name,
                                         schema=self.schema,
                                         show_cols=False,
                                         latest_partition=False)

    @property
    def data(self) -> Dict:
        d = super().data
        if self.type == "table":
            grains = self.database.grains() or []
            if grains:
                grains = [(g.duration, g.name) for g in grains]
            d["granularity_sqla"] = utils.choicify(self.dttm_cols)
            d["time_grain_sqla"] = grains
            d["main_dttm_col"] = self.main_dttm_col
            d["fetch_values_predicate"] = self.fetch_values_predicate
            d["template_params"] = self.template_params
        return d

    def values_for_column(self, column_name: str, limit: int = 10000) -> List:
        """Runs query against sqla to retrieve some
        sample values for the given column.
        """
        cols = {col.column_name: col for col in self.columns}
        target_col = cols[column_name]
        tp = self.get_template_processor()

        qry = (select([target_col.get_sqla_col()
                       ]).select_from(self.get_from_clause(tp)).distinct())
        if limit:
            qry = qry.limit(limit)

        if self.fetch_values_predicate:
            tp = self.get_template_processor()
            qry = qry.where(tp.process_template(self.fetch_values_predicate))

        engine = self.database.get_sqla_engine()
        sql = "{}".format(
            qry.compile(engine, compile_kwargs={"literal_binds": True}))
        sql = self.mutate_query_from_config(sql)

        df = pd.read_sql_query(sql=sql, con=engine)
        return df[column_name].to_list()

    def mutate_query_from_config(self, sql: str) -> str:
        """Apply config's SQL_QUERY_MUTATOR

        Typically adds comments to the query with context"""
        SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"]
        if SQL_QUERY_MUTATOR:
            username = utils.get_username()
            sql = SQL_QUERY_MUTATOR(sql, username, security_manager,
                                    self.database)
        return sql

    def get_template_processor(self, **kwargs):
        return get_template_processor(table=self,
                                      database=self.database,
                                      **kwargs)

    def get_query_str_extended(
            self, query_obj: Dict[str, Any]) -> QueryStringExtended:
        sqlaq = self.get_sqla_query(**query_obj)
        sql = self.database.compile_sqla_query(sqlaq.sqla_query)
        logger.info(sql)
        sql = sqlparse.format(sql, reindent=True)
        sql = self.mutate_query_from_config(sql)
        return QueryStringExtended(labels_expected=sqlaq.labels_expected,
                                   sql=sql,
                                   prequeries=sqlaq.prequeries)

    def get_query_str(self, query_obj: Dict[str, Any]) -> str:
        query_str_ext = self.get_query_str_extended(query_obj)
        all_queries = query_str_ext.prequeries + [query_str_ext.sql]
        return ";\n\n".join(all_queries) + ";"

    def get_sqla_table(self):
        tbl = table(self.table_name)
        if self.schema:
            tbl.schema = self.schema
        return tbl

    def get_from_clause(self, template_processor=None):
        # Supporting arbitrary SQL statements in place of tables
        if self.sql:
            from_sql = self.sql
            if template_processor:
                from_sql = template_processor.process_template(from_sql)
            from_sql = sqlparse.format(from_sql, strip_comments=True)
            return TextAsFrom(sa.text(from_sql), []).alias("expr_qry")
        return self.get_sqla_table()

    def adhoc_metric_to_sqla(self, metric: Dict,
                             cols: Dict) -> Optional[Column]:
        """
        Turn an adhoc metric into a sqlalchemy column.

        :param dict metric: Adhoc metric definition
        :param dict cols: Columns for the current table
        :returns: The metric defined as a sqlalchemy column
        :rtype: sqlalchemy.sql.column
        """
        expression_type = metric.get("expressionType")
        label = utils.get_metric_name(metric)

        if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]:
            column_name = metric["column"].get("column_name")
            table_column = cols.get(column_name)
            if table_column:
                sqla_column = table_column.get_sqla_col()
            else:
                sqla_column = column(column_name)
            sqla_metric = self.sqla_aggregations[metric["aggregate"]](
                sqla_column)
        elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SQL"]:
            sqla_metric = literal_column(metric.get("sqlExpression"))
        else:
            return None

        return self.make_sqla_column_compatible(sqla_metric, label)

    def _get_sqla_row_level_filters(self, template_processor) -> List[str]:
        """
        Return the appropriate row level security filters for this table and the current user.

        :param BaseTemplateProcessor template_processor: The template processor to apply to the filters.
        :returns: A list of SQL clauses to be ANDed together.
        :rtype: List[str]
        """
        return [
            text("({})".format(template_processor.process_template(f.clause)))
            for f in security_manager.get_rls_filters(self)
        ]

    def get_sqla_query(  # sqla
        self,
        metrics,
        granularity,
        from_dttm,
        to_dttm,
        columns=None,
        groupby=None,
        filter=None,
        is_timeseries=True,
        timeseries_limit=15,
        timeseries_limit_metric=None,
        row_limit=None,
        inner_from_dttm=None,
        inner_to_dttm=None,
        orderby=None,
        extras=None,
        order_desc=True,
    ) -> SqlaQuery:
        """Querying any sqla table from this common interface"""
        template_kwargs = {
            "from_dttm": from_dttm,
            "groupby": groupby,
            "metrics": metrics,
            "row_limit": row_limit,
            "to_dttm": to_dttm,
            "filter": filter,
            "columns": {col.column_name: col
                        for col in self.columns},
        }
        is_sip_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE")
        template_kwargs.update(self.template_params_dict)
        extra_cache_keys: List[Any] = []
        template_kwargs["extra_cache_keys"] = extra_cache_keys
        template_processor = self.get_template_processor(**template_kwargs)
        db_engine_spec = self.database.db_engine_spec
        prequeries: List[str] = []

        orderby = orderby or []

        # For backward compatibility
        if granularity not in self.dttm_cols:
            granularity = self.main_dttm_col

        # Database spec supports join-free timeslot grouping
        time_groupby_inline = db_engine_spec.time_groupby_inline

        cols: Dict[str,
                   Column] = {col.column_name: col
                              for col in self.columns}
        metrics_dict: Dict[str, SqlMetric] = {
            m.metric_name: m
            for m in self.metrics
        }

        if not granularity and is_timeseries:
            raise Exception(
                _("Datetime column not provided as part table configuration "
                  "and is required by this type of chart"))
        if (not metrics and not columns
                and (is_sip_38 or (not is_sip_38 and not groupby))):
            raise Exception(_("Empty query?"))
        metrics_exprs: List[ColumnElement] = []
        for m in metrics:
            if utils.is_adhoc_metric(m):
                metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
            elif m in metrics_dict:
                metrics_exprs.append(metrics_dict[m].get_sqla_col())
            else:
                raise Exception(
                    _("Metric '%(metric)s' does not exist", metric=m))
        if metrics_exprs:
            main_metric_expr = metrics_exprs[0]
        else:
            main_metric_expr, label = literal_column("COUNT(*)"), "ccount"
            main_metric_expr = self.make_sqla_column_compatible(
                main_metric_expr, label)

        select_exprs: List[Column] = []
        groupby_exprs_sans_timestamp: OrderedDict = OrderedDict()

        if (is_sip_38 and metrics and columns) or (not is_sip_38 and groupby):
            # dedup columns while preserving order
            groupby = list(dict.fromkeys(columns if is_sip_38 else groupby))

            select_exprs = []
            for s in groupby:
                if s in cols:
                    outer = cols[s].get_sqla_col()
                else:
                    outer = literal_column(f"({s})")
                    outer = self.make_sqla_column_compatible(outer, s)

                groupby_exprs_sans_timestamp[outer.name] = outer
                select_exprs.append(outer)
        elif columns:
            for s in columns:
                select_exprs.append(
                    cols[s].get_sqla_col() if s in cols else self.
                    make_sqla_column_compatible(literal_column(s)))
            metrics_exprs = []

        time_range_endpoints = extras.get("time_range_endpoints")
        groupby_exprs_with_timestamp = OrderedDict(
            groupby_exprs_sans_timestamp.items())
        if granularity:
            dttm_col = cols[granularity]
            time_grain = extras.get("time_grain_sqla")
            time_filters = []

            if is_timeseries:
                timestamp = dttm_col.get_timestamp_expression(time_grain)
                select_exprs += [timestamp]
                groupby_exprs_with_timestamp[timestamp.name] = timestamp

            # Use main dttm column to support index with secondary dttm columns.
            if (db_engine_spec.time_secondary_columns
                    and self.main_dttm_col in self.dttm_cols
                    and self.main_dttm_col != dttm_col.column_name):
                time_filters.append(cols[self.main_dttm_col].get_time_filter(
                    from_dttm, to_dttm, time_range_endpoints))
            time_filters.append(
                dttm_col.get_time_filter(from_dttm, to_dttm,
                                         time_range_endpoints))

        select_exprs += metrics_exprs

        labels_expected = [c._df_label_expected for c in select_exprs]

        select_exprs = db_engine_spec.make_select_compatible(
            groupby_exprs_with_timestamp.values(), select_exprs)
        qry = sa.select(select_exprs)

        tbl = self.get_from_clause(template_processor)

        if (is_sip_38 and metrics) or (not is_sip_38 and not columns):
            qry = qry.group_by(*groupby_exprs_with_timestamp.values())

        where_clause_and = []
        having_clause_and: List = []
        for flt in filter:
            if not all([flt.get(s) for s in ["col", "op"]]):
                continue
            col = flt["col"]
            op = flt["op"].upper()
            col_obj = cols.get(col)
            if col_obj:
                is_list_target = op in (
                    utils.FilterOperator.IN.value,
                    utils.FilterOperator.NOT_IN.value,
                )
                eq = self.filter_values_handler(
                    values=flt.get("val"),
                    target_column_is_numeric=col_obj.is_numeric,
                    is_list_target=is_list_target,
                )
                if op in (
                        utils.FilterOperator.IN.value,
                        utils.FilterOperator.NOT_IN.value,
                ):
                    cond = col_obj.get_sqla_col().in_(eq)
                    if isinstance(eq, str) and NULL_STRING in eq:
                        cond = or_(cond, col_obj.get_sqla_col() is None)
                    if op == utils.FilterOperator.NOT_IN.value:
                        cond = ~cond
                    where_clause_and.append(cond)
                else:
                    if col_obj.is_numeric:
                        eq = utils.cast_to_num(flt["val"])
                    if op == utils.FilterOperator.EQUALS.value:
                        where_clause_and.append(col_obj.get_sqla_col() == eq)
                    elif op == utils.FilterOperator.NOT_EQUALS.value:
                        where_clause_and.append(col_obj.get_sqla_col() != eq)
                    elif op == utils.FilterOperator.GREATER_THAN.value:
                        where_clause_and.append(col_obj.get_sqla_col() > eq)
                    elif op == utils.FilterOperator.LESS_THAN.value:
                        where_clause_and.append(col_obj.get_sqla_col() < eq)
                    elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS.value:
                        where_clause_and.append(col_obj.get_sqla_col() >= eq)
                    elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value:
                        where_clause_and.append(col_obj.get_sqla_col() <= eq)
                    elif op == utils.FilterOperator.LIKE.value:
                        where_clause_and.append(
                            col_obj.get_sqla_col().like(eq))
                    elif op == utils.FilterOperator.IS_NULL.value:
                        where_clause_and.append(col_obj.get_sqla_col() == None)
                    elif op == utils.FilterOperator.IS_NOT_NULL.value:
                        where_clause_and.append(col_obj.get_sqla_col() != None)
                    else:
                        raise Exception(
                            _("Invalid filter operation type: %(op)s", op=op))

        where_clause_and += self._get_sqla_row_level_filters(
            template_processor)
        if extras:
            where = extras.get("where")
            if where:
                where = template_processor.process_template(where)
                where_clause_and += [sa.text("({})".format(where))]
            having = extras.get("having")
            if having:
                having = template_processor.process_template(having)
                having_clause_and += [sa.text("({})".format(having))]
        if granularity:
            qry = qry.where(and_(*(time_filters + where_clause_and)))
        else:
            qry = qry.where(and_(*where_clause_and))
        qry = qry.having(and_(*having_clause_and))

        if not orderby and ((is_sip_38 and metrics) or
                            (not is_sip_38 and not columns)):
            orderby = [(main_metric_expr, not order_desc)]

        # To ensure correct handling of the ORDER BY labeling we need to reference the
        # metric instance if defined in the SELECT clause.
        metrics_exprs_by_label = {m._label: m for m in metrics_exprs}

        for col, ascending in orderby:
            direction = asc if ascending else desc
            if utils.is_adhoc_metric(col):
                col = self.adhoc_metric_to_sqla(col, cols)
            elif col in cols:
                col = cols[col].get_sqla_col()

            if isinstance(col, Label) and col._label in metrics_exprs_by_label:
                col = metrics_exprs_by_label[col._label]

            qry = qry.order_by(direction(col))

        if row_limit:
            qry = qry.limit(row_limit)

        if (is_timeseries and timeseries_limit and not time_groupby_inline
                and ((is_sip_38 and columns) or (not is_sip_38 and groupby))):
            if self.database.db_engine_spec.allows_joins:
                # some sql dialects require for order by expressions
                # to also be in the select clause -- others, e.g. vertica,
                # require a unique inner alias
                inner_main_metric_expr = self.make_sqla_column_compatible(
                    main_metric_expr, "mme_inner__")
                inner_groupby_exprs = []
                inner_select_exprs = []
                for gby_name, gby_obj in groupby_exprs_sans_timestamp.items():
                    inner = self.make_sqla_column_compatible(
                        gby_obj, gby_name + "__")
                    inner_groupby_exprs.append(inner)
                    inner_select_exprs.append(inner)

                inner_select_exprs += [inner_main_metric_expr]
                subq = select(inner_select_exprs).select_from(tbl)
                inner_time_filter = dttm_col.get_time_filter(
                    inner_from_dttm or from_dttm,
                    inner_to_dttm or to_dttm,
                    time_range_endpoints,
                )
                subq = subq.where(
                    and_(*(where_clause_and + [inner_time_filter])))
                subq = subq.group_by(*inner_groupby_exprs)

                ob = inner_main_metric_expr
                if timeseries_limit_metric:
                    ob = self._get_timeseries_orderby(timeseries_limit_metric,
                                                      metrics_dict, cols)
                direction = desc if order_desc else asc
                subq = subq.order_by(direction(ob))
                subq = subq.limit(timeseries_limit)

                on_clause = []
                for gby_name, gby_obj in groupby_exprs_sans_timestamp.items():
                    # in this case the column name, not the alias, needs to be
                    # conditionally mutated, as it refers to the column alias in
                    # the inner query
                    col_name = db_engine_spec.make_label_compatible(gby_name +
                                                                    "__")
                    on_clause.append(gby_obj == column(col_name))

                tbl = tbl.join(subq.alias(), and_(*on_clause))
            else:
                if timeseries_limit_metric:
                    orderby = [(
                        self._get_timeseries_orderby(timeseries_limit_metric,
                                                     metrics_dict, cols),
                        False,
                    )]

                # run prequery to get top groups
                prequery_obj = {
                    "is_timeseries": False,
                    "row_limit": timeseries_limit,
                    "metrics": metrics,
                    "granularity": granularity,
                    "from_dttm": inner_from_dttm or from_dttm,
                    "to_dttm": inner_to_dttm or to_dttm,
                    "filter": filter,
                    "orderby": orderby,
                    "extras": extras,
                    "columns": columns,
                    "order_desc": True,
                }
                if not is_sip_38:
                    prequery_obj["groupby"] = groupby

                result = self.query(prequery_obj)
                prequeries.append(result.query)
                dimensions = [
                    c for c in result.df.columns
                    if c not in metrics and c in groupby_exprs_sans_timestamp
                ]
                top_groups = self._get_top_groups(
                    result.df, dimensions, groupby_exprs_sans_timestamp)
                qry = qry.where(top_groups)
        return SqlaQuery(
            extra_cache_keys=extra_cache_keys,
            labels_expected=labels_expected,
            sqla_query=qry.select_from(tbl),
            prequeries=prequeries,
        )

    def _get_timeseries_orderby(self, timeseries_limit_metric, metrics_dict,
                                cols):
        if utils.is_adhoc_metric(timeseries_limit_metric):
            ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols)
        elif timeseries_limit_metric in metrics_dict:
            timeseries_limit_metric = metrics_dict.get(timeseries_limit_metric)
            ob = timeseries_limit_metric.get_sqla_col()
        else:
            raise Exception(
                _("Metric '%(metric)s' does not exist",
                  metric=timeseries_limit_metric))

        return ob

    def _get_top_groups(self, df: pd.DataFrame, dimensions: List,
                        groupby_exprs: OrderedDict) -> ColumnElement:
        groups = []
        for unused, row in df.iterrows():
            group = []
            for dimension in dimensions:
                group.append(groupby_exprs[dimension] == row[dimension])
            groups.append(and_(*group))

        return or_(*groups)

    def query(self, query_obj: Dict[str, Any]) -> QueryResult:
        qry_start_dttm = datetime.now()
        query_str_ext = self.get_query_str_extended(query_obj)
        sql = query_str_ext.sql
        status = utils.QueryStatus.SUCCESS
        error_message = None

        def mutator(df: pd.DataFrame) -> None:
            """
            Some engines change the case or generate bespoke column names, either by
            default or due to lack of support for aliasing. This function ensures that
            the column names in the DataFrame correspond to what is expected by
            the viz components.

            :param df: Original DataFrame returned by the engine
            """

            labels_expected = query_str_ext.labels_expected
            if df is not None and not df.empty:
                if len(df.columns) != len(labels_expected):
                    raise Exception(f"For {sql}, df.columns: {df.columns}"
                                    f" differs from {labels_expected}")
                else:
                    df.columns = labels_expected

        try:
            df = self.database.get_df(sql, self.schema, mutator)
        except Exception as ex:
            df = pd.DataFrame()
            status = utils.QueryStatus.FAILED
            logger.exception(f"Query {sql} on schema {self.schema} failed")
            db_engine_spec = self.database.db_engine_spec
            error_message = db_engine_spec.extract_error_message(ex)

        return QueryResult(
            status=status,
            df=df,
            duration=datetime.now() - qry_start_dttm,
            query=sql,
            error_message=error_message,
        )

    def get_sqla_table_object(self) -> Table:
        return self.database.get_table(self.table_name, schema=self.schema)

    def fetch_metadata(self, commit=True) -> None:
        """Fetches the metadata for the table and merges it in"""
        try:
            table = self.get_sqla_table_object()
        except Exception as ex:
            logger.exception(ex)
            raise Exception(
                _("Table [{}] doesn't seem to exist in the specified database, "
                  "couldn't fetch column information").format(self.table_name))

        metrics = []
        any_date_col = None
        db_engine_spec = self.database.db_engine_spec
        db_dialect = self.database.get_dialect()
        dbcols = (db.session.query(TableColumn).filter(
            TableColumn.table == self).filter(
                or_(TableColumn.column_name == col.name
                    for col in table.columns)))
        dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}

        for col in table.columns:
            try:
                datatype = db_engine_spec.column_datatype_to_string(
                    col.type, db_dialect)
            except Exception as ex:
                datatype = "UNKNOWN"
                logger.error("Unrecognized data type in {}.{}".format(
                    table, col.name))
                logger.exception(ex)
            dbcol = dbcols.get(col.name, None)
            if not dbcol:
                dbcol = TableColumn(column_name=col.name,
                                    type=datatype,
                                    table=self)
                dbcol.sum = dbcol.is_numeric
                dbcol.avg = dbcol.is_numeric
                dbcol.is_dttm = dbcol.is_temporal
                db_engine_spec.alter_new_orm_column(dbcol)
            else:
                dbcol.type = datatype
            dbcol.groupby = True
            dbcol.filterable = True
            self.columns.append(dbcol)
            if not any_date_col and dbcol.is_temporal:
                any_date_col = col.name

        metrics.append(
            SqlMetric(
                metric_name="count",
                verbose_name="COUNT(*)",
                metric_type="count",
                expression="COUNT(*)",
            ))
        if not self.main_dttm_col:
            self.main_dttm_col = any_date_col
        self.add_missing_metrics(metrics)

        db.session.merge(self)
        if commit:
            db.session.commit()

    @classmethod
    def import_obj(cls, i_datasource, import_time=None) -> int:
        """Imports the datasource from the object to the database.

         Metrics and columns and datasource will be overrided if exists.
         This function can be used to import/export dashboards between multiple
         superset instances. Audit metadata isn't copies over.
        """
        def lookup_sqlatable(table):
            return (db.session.query(SqlaTable).join(Database).filter(
                SqlaTable.table_name == table.table_name,
                SqlaTable.schema == table.schema,
                Database.id == table.database_id,
            ).first())

        def lookup_database(table):
            try:
                return (db.session.query(Database).filter_by(
                    database_name=table.params_dict["database_name"]).one())
            except NoResultFound:
                raise DatabaseNotFound(
                    _(
                        "Database '%(name)s' is not found",
                        name=table.params_dict["database_name"],
                    ))

        return import_datasource.import_datasource(db.session, i_datasource,
                                                   lookup_database,
                                                   lookup_sqlatable,
                                                   import_time)

    @classmethod
    def query_datasources_by_name(cls,
                                  session: Session,
                                  database: Database,
                                  datasource_name: str,
                                  schema=None) -> List["SqlaTable"]:
        query = (session.query(cls).filter_by(
            database_id=database.id).filter_by(table_name=datasource_name))
        if schema:
            query = query.filter_by(schema=schema)
        return query.all()

    @staticmethod
    def default_query(qry) -> Query:
        return qry.filter_by(is_sqllab_view=False)

    def has_calls_to_cache_key_wrapper(self, query_obj: Dict[str,
                                                             Any]) -> bool:
        """
        Detects the presence of calls to `cache_key_wrapper` in items in query_obj that
        can be templated. If any are present, the query must be evaluated to extract
        additional keys for the cache key. This method is needed to avoid executing
        the template code unnecessarily, as it may contain expensive calls, e.g. to
        extract the latest partition of a database.

        :param query_obj: query object to analyze
        :return: True if at least one item calls `cache_key_wrapper`, otherwise False
        """
        regex = re.compile(r"\{\{.*cache_key_wrapper\(.*\).*\}\}")
        templatable_statements: List[str] = []
        if self.sql:
            templatable_statements.append(self.sql)
        if self.fetch_values_predicate:
            templatable_statements.append(self.fetch_values_predicate)
        extras = query_obj.get("extras", {})
        if "where" in extras:
            templatable_statements.append(extras["where"])
        if "having" in extras:
            templatable_statements.append(extras["having"])
        for statement in templatable_statements:
            if regex.search(statement):
                return True
        return False

    def get_extra_cache_keys(self, query_obj: Dict[str,
                                                   Any]) -> List[Hashable]:
        """
        The cache key of a SqlaTable needs to consider any keys added by the parent class
        and any keys added via `cache_key_wrapper`.

        :param query_obj: query object to analyze
        :return: True if at least one item calls `cache_key_wrapper`, otherwise False
        """
        extra_cache_keys = super().get_extra_cache_keys(query_obj)
        if self.has_calls_to_cache_key_wrapper(query_obj):
            sqla_query = self.get_sqla_query(**query_obj)
            extra_cache_keys += sqla_query.extra_cache_keys
        return extra_cache_keys
예제 #19
0
class Dataset(Base):
    """Class to store the information about a data set.

    """
    __tablename__ = 'datasets'
    __table_args__ = (
        UniqueConstraint('task_id', 'description'),
        # Useless, in theory, because 'id' is already unique. Yet, we
        # need this because it's a target of a foreign key.
        UniqueConstraint('id', 'task_id'),
    )

    # Auto increment primary key.
    id = Column(
        Integer,
        primary_key=True)

    # Task (id and object) owning the dataset.
    task_id = Column(
        Integer,
        ForeignKey(Task.id,
                   onupdate="CASCADE", ondelete="CASCADE"),
        nullable=False)
    task = relationship(
        Task,
        foreign_keys=[task_id],
        back_populates="datasets")

    # A human-readable text describing the dataset.
    description = Column(
        Unicode,
        CodenameConstraint("description"),
        nullable=False)

    # Whether this dataset will be automatically judged by ES and SS
    # "in background", together with the active dataset of each task.
    autojudge = Column(
        Boolean,
        nullable=False,
        default=False)

    # Time and memory limits for every testcase.
    time_limit = Column(
        Float,
        CheckConstraint("time_limit > 0"),
        nullable=True)
    memory_limit = Column(
        BigInteger,
        CheckConstraint("memory_limit > 0"),
        nullable=True)

    # Name of the TaskType child class suited for the task.
    task_type = Column(
        String,
        nullable=False)

    # Parameters for the task type class.
    task_type_parameters = Column(
        JSONB,
        nullable=False)

    # Name of the ScoreType child class suited for the task.
    score_type = Column(
        String,
        nullable=False)

    # Parameters for the score type class.
    score_type_parameters = Column(
        JSONB,
        nullable=False)

    # These one-to-many relationships are the reversed directions of
    # the ones defined in the "child" classes using foreign keys.

    managers = relationship(
        "Manager",
        collection_class=attribute_mapped_collection("filename"),
        cascade="all, delete-orphan",
        passive_deletes=True,
        back_populates="dataset")

    testcases = relationship(
        "Testcase",
        collection_class=attribute_mapped_collection("codename"),
        cascade="all, delete-orphan",
        passive_deletes=True,
        back_populates="dataset")

    @property
    def active(self):
        """Shorthand for detecting if the dataset is active.

        return (bool): True if this dataset is the active one for its
            task.

        """
        return self is self.task.active_dataset

    @property
    def task_type_object(self):
        if not hasattr(self, "_cached_task_type_object") \
                or self.task_type != self._cached_task_type \
                or self.task_type_parameters \
                   != self._cached_task_type_parameters:
            # Import late to avoid a circular dependency.
            from cms.grading.tasktypes import get_task_type
            # This can raise.
            self._cached_task_type_object = get_task_type(
                self.task_type, self.task_type_parameters)
            # If an exception is raised these updates don't take place:
            # that way, next time this property is accessed, we get a
            # cache miss again and the same exception is raised again.
            self._cached_task_type = self.task_type
            self._cached_task_type_parameters = \
                copy.deepcopy(self.task_type_parameters)
        return self._cached_task_type_object

    @property
    def score_type_object(self):
        public_testcases = {k: tc.public for k, tc in iteritems(self.testcases)}
        if not hasattr(self, "_cached_score_type_object") \
                or self.score_type != self._cached_score_type \
                or self.score_type_parameters \
                   != self._cached_score_type_parameters \
                or public_testcases != self._cached_public_testcases:
            # Import late to avoid a circular dependency.
            from cms.grading.scoretypes import get_score_type
            # This can raise.
            self._cached_score_type_object = get_score_type(
                self.score_type, self.score_type_parameters, public_testcases)
            # If an exception is raised these updates don't take place:
            # that way, next time this property is accessed, we get a
            # cache miss again and the same exception is raised again.
            self._cached_score_type = self.score_type
            self._cached_score_type_parameters = \
                copy.deepcopy(self.score_type_parameters)
            self._cached_public_testcases = public_testcases
        return self._cached_score_type_object

    def clone_from(self, old_dataset, clone_managers=True,
                   clone_testcases=True, clone_results=False):
        """Overwrite the data with that in dataset.

        old_dataset (Dataset): original dataset to copy from.
        clone_managers (bool): copy dataset managers.
        clone_testcases (bool): copy dataset testcases.
        clone_results (bool): copy submission results (will also copy
            managers and testcases).

        """
        new_testcases = dict()
        if clone_testcases or clone_results:
            for old_t in itervalues(old_dataset.testcases):
                new_t = old_t.clone()
                new_t.dataset = self
                new_testcases[new_t.codename] = new_t

        if clone_managers or clone_results:
            for old_m in itervalues(old_dataset.managers):
                new_m = old_m.clone()
                new_m.dataset = self

        # TODO: why is this needed?
        self.sa_session.flush()

        if clone_results:
            old_results = self.get_submission_results(old_dataset)

            for old_sr in old_results:
                # Create the submission result.
                new_sr = old_sr.clone()
                new_sr.submission = old_sr.submission
                new_sr.dataset = self

                # Create executables.
                for old_e in itervalues(old_sr.executables):
                    new_e = old_e.clone()
                    new_e.submission_result = new_sr

                # Create evaluations.
                for old_e in old_sr.evaluations:
                    new_e = old_e.clone()
                    new_e.submission_result = new_sr
                    new_e.testcase = new_testcases[old_e.codename]

        self.sa_session.flush()
예제 #20
0
class Contact(
        MailSyncBase,
        HasRevisions,
        HasPublicID,
        HasEmailAddress,
        UpdatedAtMixin,
        DeletedAtMixin,
):
    """Data for a user's contact."""

    API_OBJECT_NAME = "contact"

    namespace_id = Column(BigInteger, nullable=False, index=True)
    namespace = relationship(
        Namespace,
        primaryjoin="foreign(Contact.namespace_id) == remote(Namespace.id)",
        load_on_pending=True,
    )

    # A server-provided unique ID.
    # NB: We specify the collation here so that the test DB gets setup correctly.
    uid = Column(String(64, collation="utf8mb4_bin"), nullable=False)
    # A constant, unique identifier for the remote backend this contact came
    # from. E.g., 'google', 'eas', 'inbox'
    provider_name = Column(String(64))

    name = Column(Text)

    raw_data = Column(Text)

    # A score to use for ranking contact search results. This should be
    # precomputed to facilitate performant search.
    score = Column(Integer)

    # Flag to set if the contact is deleted in a remote backend.
    # (This is an unmapped attribute, i.e., it does not correspond to a
    # database column.)
    deleted = False

    __table_args__ = (
        UniqueConstraint("uid", "namespace_id", "provider_name"),
        Index("idx_namespace_created", "namespace_id", "created_at"),
        Index("ix_contact_ns_uid_provider_name", "namespace_id", "uid",
              "provider_name"),
    )

    @validates("raw_data")
    def validate_text_column_length(self, key, value):
        if value is None:
            return None
        return unicode_safe_truncate(value, MAX_TEXT_CHARS)

    @property
    def versioned_relationships(self):
        return ["phone_numbers"]

    def merge_from(self, new_contact):
        # This must be updated when new fields are added to the class.
        merge_attrs = ["name", "email_address", "raw_data"]
        for attr in merge_attrs:
            if getattr(self, attr) != getattr(new_contact, attr):
                setattr(self, attr, getattr(new_contact, attr))
예제 #21
0
class Paging(Base):

    __tablename__ = 'paging'
    __table_args__ = (
        PrimaryKeyConstraint('id'),
        UniqueConstraint('number'),
    )

    id = Column(Integer, nullable=False)
    tenant_uuid = Column(String(36), ForeignKey('tenant.uuid', ondelete='CASCADE'), nullable=False)
    number = Column(String(32))
    name = Column(String(128))
    duplex = Column(Integer, nullable=False, server_default='0')
    ignore = Column(Integer, nullable=False, server_default='0')
    record = Column(Integer, nullable=False, server_default='0')
    quiet = Column(Integer, nullable=False, server_default='0')
    timeout = Column(Integer, nullable=False, server_default='30')
    announcement_file = Column(String(64))
    announcement_play = Column(Integer, nullable=False, server_default='0')
    announcement_caller = Column(Integer, nullable=False, server_default='0')
    commented = Column(Integer, nullable=False, server_default='0')
    description = Column(Text)

    paging_members = relationship(
        'PagingUser',
        primaryjoin="""and_(
            PagingUser.pagingid == Paging.id,
            PagingUser.caller == 0
        )""",
        cascade='all, delete-orphan',
    )

    users_member = association_proxy(
        'paging_members', 'user',
        creator=lambda _user: PagingUser(user=_user, caller=0),
    )

    paging_callers = relationship(
        'PagingUser',
        primaryjoin="""and_(
            PagingUser.pagingid == Paging.id,
            PagingUser.caller == 1
        )""",
        cascade='all, delete-orphan')

    users_caller = association_proxy(
        'paging_callers', 'user',
        creator=lambda _user: PagingUser(user=_user, caller=1),
    )

    func_keys = relationship('FuncKeyDestPaging', cascade='all, delete-orphan')

    @hybrid_property
    def enabled(self):
        return self.commented == 0

    @enabled.expression
    def enabled(cls):
        return not_(cast(cls.commented, Boolean))

    @enabled.setter
    def enabled(self, value):
        self.commented = int(value is False)

    @hybrid_property
    def duplex_bool(self):
        return self.duplex == 1

    @duplex_bool.expression
    def duplex_bool(cls):
        return cast(cls.duplex, Boolean)

    @duplex_bool.setter
    def duplex_bool(self, value):
        self.duplex = int(value)

    @hybrid_property
    def record_bool(self):
        return self.record == 1

    @record_bool.expression
    def record_bool(cls):
        return cast(cls.record, Boolean)

    @record_bool.setter
    def record_bool(self, value):
        self.record = int(value)

    @hybrid_property
    def ignore_forward(self):
        return self.ignore == 1

    @ignore_forward.expression
    def ignore_forward(cls):
        return cast(cls.ignore, Boolean)

    @ignore_forward.setter
    def ignore_forward(self, value):
        self.ignore = int(value)

    @hybrid_property
    def caller_notification(self):
        return self.quiet == 0

    @caller_notification.expression
    def caller_notification(cls):
        return not_(cast(cls.quiet, Boolean))

    @caller_notification.setter
    def caller_notification(self, value):
        self.quiet = int(value == 0)

    @hybrid_property
    def announce_caller(self):
        return self.announcement_caller == 0

    @announce_caller.expression
    def announce_caller(cls):
        return not_(cast(cls.announcement_caller, Boolean))

    @announce_caller.setter
    def announce_caller(self, value):
        self.announcement_caller = int(value == 0)

    @hybrid_property
    def announce_sound(self):
        return self.announcement_file

    @announce_sound.setter
    def announce_sound(self, value):
        self.announcement_play = int(value is not None)
        self.announcement_file = value
예제 #22
0
class Database(Model, AuditMixinNullable, ImportMixin):

    """An ORM object that stores Database related information"""

    __tablename__ = 'dbs'
    type = 'table'
    __table_args__ = (UniqueConstraint('database_name'),)

    id = Column(Integer, primary_key=True)
    verbose_name = Column(String(250), unique=True)
    # short unique name, used in permissions
    database_name = Column(String(250), unique=True)
    sqlalchemy_uri = Column(String(1024))
    password = Column(EncryptedType(String(1024), config.get('SECRET_KEY')))
    cache_timeout = Column(Integer)
    select_as_create_table_as = Column(Boolean, default=False)
    expose_in_sqllab = Column(Boolean, default=False)
    allow_run_sync = Column(Boolean, default=True)
    allow_run_async = Column(Boolean, default=False)
    allow_csv_upload = Column(Boolean, default=False)
    allow_ctas = Column(Boolean, default=False)
    allow_dml = Column(Boolean, default=False)
    force_ctas_schema = Column(String(250))
    allow_multi_schema_metadata_fetch = Column(Boolean, default=True)
    extra = Column(Text, default=textwrap.dedent("""\
    {
        "metadata_params": {},
        "engine_params": {},
        "metadata_cache_timeout": {},
        "schemas_allowed_for_csv_upload": []
    }
    """))
    perm = Column(String(1000))
    impersonate_user = Column(Boolean, default=False)
    export_fields = ('database_name', 'sqlalchemy_uri', 'cache_timeout',
                     'expose_in_sqllab', 'allow_run_sync', 'allow_run_async',
                     'allow_ctas', 'allow_csv_upload', 'extra')
    export_children = ['tables']

    def __repr__(self):
        return self.verbose_name if self.verbose_name else self.database_name

    @property
    def name(self):
        return self.verbose_name if self.verbose_name else self.database_name

    @property
    def allows_subquery(self):
        return self.db_engine_spec.allows_subquery

    @property
    def data(self):
        return {
            'id': self.id,
            'name': self.database_name,
            'backend': self.backend,
            'allow_multi_schema_metadata_fetch':
                self.allow_multi_schema_metadata_fetch,
            'allows_subquery': self.allows_subquery,
        }

    @property
    def unique_name(self):
        return self.database_name

    @property
    def url_object(self):
        return make_url(self.sqlalchemy_uri_decrypted)

    @property
    def backend(self):
        url = make_url(self.sqlalchemy_uri_decrypted)
        return url.get_backend_name()

    @classmethod
    def get_password_masked_url_from_uri(cls, uri):
        url = make_url(uri)
        return cls.get_password_masked_url(url)

    @classmethod
    def get_password_masked_url(cls, url):
        url_copy = deepcopy(url)
        if url_copy.password is not None and url_copy.password != PASSWORD_MASK:
            url_copy.password = PASSWORD_MASK
        return url_copy

    def set_sqlalchemy_uri(self, uri):
        conn = sqla.engine.url.make_url(uri.strip())
        if conn.password != PASSWORD_MASK and not custom_password_store:
            # do not over-write the password with the password mask
            self.password = conn.password
        conn.password = PASSWORD_MASK if conn.password else None
        self.sqlalchemy_uri = str(conn)  # hides the password

    def get_effective_user(self, url, user_name=None):
        """
        Get the effective user, especially during impersonation.
        :param url: SQL Alchemy URL object
        :param user_name: Default username
        :return: The effective username
        """
        effective_username = None
        if self.impersonate_user:
            effective_username = url.username
            if user_name:
                effective_username = user_name
            elif (
                hasattr(g, 'user') and hasattr(g.user, 'username') and
                g.user.username is not None
            ):
                effective_username = g.user.username
        return effective_username

    @utils.memoized(
        watch=('impersonate_user', 'sqlalchemy_uri_decrypted', 'extra'))
    def get_sqla_engine(self, schema=None, nullpool=True, user_name=None):
        extra = self.get_extra()
        url = make_url(self.sqlalchemy_uri_decrypted)
        url = self.db_engine_spec.adjust_database_uri(url, schema)
        effective_username = self.get_effective_user(url, user_name)
        # If using MySQL or Presto for example, will set url.username
        # If using Hive, will not do anything yet since that relies on a
        # configuration parameter instead.
        self.db_engine_spec.modify_url_for_impersonation(
            url,
            self.impersonate_user,
            effective_username)

        masked_url = self.get_password_masked_url(url)
        logging.info('Database.get_sqla_engine(). Masked URL: {0}'.format(masked_url))

        params = extra.get('engine_params', {})
        if nullpool:
            params['poolclass'] = NullPool

        # If using Hive, this will set hive.server2.proxy.user=$effective_username
        configuration = {}
        configuration.update(
            self.db_engine_spec.get_configuration_for_impersonation(
                str(url),
                self.impersonate_user,
                effective_username))
        if configuration:
            params['connect_args'] = {'configuration': configuration}

        DB_CONNECTION_MUTATOR = config.get('DB_CONNECTION_MUTATOR')
        if DB_CONNECTION_MUTATOR:
            url, params = DB_CONNECTION_MUTATOR(
                url, params, effective_username, security_manager)
        return create_engine(url, **params)

    def get_reserved_words(self):
        return self.get_dialect().preparer.reserved_words

    def get_quoter(self):
        return self.get_dialect().identifier_preparer.quote

    def get_df(self, sql, schema):
        sqls = [str(s).strip().strip(';') for s in sqlparse.parse(sql)]
        engine = self.get_sqla_engine(schema=schema)

        def needs_conversion(df_series):
            if df_series.empty:
                return False
            if isinstance(df_series[0], (list, dict)):
                return True
            return False

        with closing(engine.raw_connection()) as conn:
            with closing(conn.cursor()) as cursor:
                for sql in sqls[:-1]:
                    self.db_engine_spec.execute(cursor, sql)
                    cursor.fetchall()

                self.db_engine_spec.execute(cursor, sqls[-1])

                if cursor.description is not None:
                    columns = [col_desc[0] for col_desc in cursor.description]
                else:
                    columns = []

                df = pd.DataFrame.from_records(
                    data=list(cursor.fetchall()),
                    columns=columns,
                    coerce_float=True,
                )

                for k, v in df.dtypes.items():
                    if v.type == numpy.object_ and needs_conversion(df[k]):
                        df[k] = df[k].apply(utils.json_dumps_w_dates)
                return df

    def compile_sqla_query(self, qry, schema=None):
        engine = self.get_sqla_engine(schema=schema)

        sql = str(
            qry.compile(
                engine,
                compile_kwargs={'literal_binds': True},
            ),
        )

        if engine.dialect.identifier_preparer._double_percents:
            sql = sql.replace('%%', '%')

        return sql

    def select_star(
            self, table_name, schema=None, limit=100, show_cols=False,
            indent=True, latest_partition=False, cols=None):
        """Generates a ``select *`` statement in the proper dialect"""
        eng = self.get_sqla_engine(schema=schema)
        return self.db_engine_spec.select_star(
            self, table_name, schema=schema, engine=eng,
            limit=limit, show_cols=show_cols,
            indent=indent, latest_partition=latest_partition, cols=cols)

    def apply_limit_to_sql(self, sql, limit=1000):
        return self.db_engine_spec.apply_limit_to_sql(sql, limit, self)

    def safe_sqlalchemy_uri(self):
        return self.sqlalchemy_uri

    @property
    def inspector(self):
        engine = self.get_sqla_engine()
        return sqla.inspect(engine)

    def all_table_names(self, schema=None, force=False):
        if not schema:
            if not self.allow_multi_schema_metadata_fetch:
                return []
            tables_dict = self.db_engine_spec.fetch_result_sets(
                self, 'table', force=force)
            return tables_dict.get('', [])

        extra = self.get_extra()
        medatada_cache_timeout = extra.get('metadata_cache_timeout', {})
        table_cache_timeout = medatada_cache_timeout.get('table_cache_timeout')
        enable_cache = 'table_cache_timeout' in medatada_cache_timeout
        return sorted(self.db_engine_spec.get_table_names(
            inspector=self.inspector,
            db_id=self.id,
            schema=schema,
            enable_cache=enable_cache,
            cache_timeout=table_cache_timeout,
            force=force))

    def all_view_names(self, schema=None, force=False):
        if not schema:
            if not self.allow_multi_schema_metadata_fetch:
                return []
            views_dict = self.db_engine_spec.fetch_result_sets(
                self, 'view', force=force)
            return views_dict.get('', [])
        views = []
        try:
            extra = self.get_extra()
            medatada_cache_timeout = extra.get('metadata_cache_timeout', {})
            table_cache_timeout = medatada_cache_timeout.get('table_cache_timeout')
            enable_cache = 'table_cache_timeout' in medatada_cache_timeout
            views = self.db_engine_spec.get_view_names(
                inspector=self.inspector,
                db_id=self.id,
                schema=schema,
                enable_cache=enable_cache,
                cache_timeout=table_cache_timeout,
                force=force)
        except Exception:
            pass
        return views

    def all_schema_names(self, force_refresh=False):
        extra = self.get_extra()
        medatada_cache_timeout = extra.get('metadata_cache_timeout', {})
        schema_cache_timeout = medatada_cache_timeout.get('schema_cache_timeout')
        enable_cache = 'schema_cache_timeout' in medatada_cache_timeout
        return sorted(self.db_engine_spec.get_schema_names(
            inspector=self.inspector,
            enable_cache=enable_cache,
            cache_timeout=schema_cache_timeout,
            db_id=self.id,
            force=force_refresh))

    @property
    def db_engine_spec(self):
        return db_engine_specs.engines.get(
            self.backend, db_engine_specs.BaseEngineSpec)

    @classmethod
    def get_db_engine_spec_for_backend(cls, backend):
        return db_engine_specs.engines.get(backend, db_engine_specs.BaseEngineSpec)

    def grains(self):
        """Defines time granularity database-specific expressions.

        The idea here is to make it easy for users to change the time grain
        form a datetime (maybe the source grain is arbitrary timestamps, daily
        or 5 minutes increments) to another, "truncated" datetime. Since
        each database has slightly different but similar datetime functions,
        this allows a mapping between database engines and actual functions.
        """
        return self.db_engine_spec.get_time_grains()

    def grains_dict(self):
        """Allowing to lookup grain by either label or duration

        For backward compatibility"""
        d = {grain.duration: grain for grain in self.grains()}
        d.update({grain.label: grain for grain in self.grains()})
        return d

    def get_extra(self):
        extra = {}
        if self.extra:
            try:
                extra = json.loads(self.extra)
            except Exception as e:
                logging.error(e)
                raise e
        return extra

    def get_table(self, table_name, schema=None):
        extra = self.get_extra()
        meta = MetaData(**extra.get('metadata_params', {}))
        return Table(
            table_name, meta,
            schema=schema or None,
            autoload=True,
            autoload_with=self.get_sqla_engine())

    def get_columns(self, table_name, schema=None):
        return self.inspector.get_columns(table_name, schema)

    def get_indexes(self, table_name, schema=None):
        return self.inspector.get_indexes(table_name, schema)

    def get_pk_constraint(self, table_name, schema=None):
        return self.inspector.get_pk_constraint(table_name, schema)

    def get_foreign_keys(self, table_name, schema=None):
        return self.inspector.get_foreign_keys(table_name, schema)

    def get_schema_access_for_csv_upload(self):
        return self.get_extra().get('schemas_allowed_for_csv_upload', [])

    @property
    def sqlalchemy_uri_decrypted(self):
        conn = sqla.engine.url.make_url(self.sqlalchemy_uri)
        if custom_password_store:
            conn.password = custom_password_store(conn)
        else:
            conn.password = self.password
        return str(conn)

    @property
    def sql_url(self):
        return '/superset/sql/{}/'.format(self.id)

    def get_perm(self):
        return (
            '[{obj.database_name}].(id:{obj.id})').format(obj=self)

    def has_table(self, table):
        engine = self.get_sqla_engine()
        return engine.has_table(
            table.table_name, table.schema or None)

    @utils.memoized
    def get_dialect(self):
        sqla_url = url.make_url(self.sqlalchemy_uri_decrypted)
        return sqla_url.get_dialect()()
예제 #23
0
class UserCustom(Base):

    __tablename__ = 'usercustom'
    __table_args__ = (
        PrimaryKeyConstraint('id'),
        UniqueConstraint('interface', 'intfsuffix', 'category'),
        Index('usercustom__idx__category', 'category'),
        Index('usercustom__idx__context', 'context'),
        Index('usercustom__idx__name', 'name'),
    )

    id = Column(Integer, nullable=False)
    tenant_uuid = Column(String(36),
                         ForeignKey('tenant.uuid', ondelete='CASCADE'),
                         nullable=False)
    name = Column(String(40))
    context = Column(String(39))
    interface = Column(String(128), nullable=False)
    intfsuffix = Column(String(32), nullable=False, server_default='')
    commented = Column(Integer, nullable=False, server_default='0')
    protocol = Column(enum.trunk_protocol,
                      nullable=False,
                      server_default='custom')
    category = Column(Enum('user',
                           'trunk',
                           name='usercustom_category',
                           metadata=Base.metadata),
                      nullable=False)

    line = relationship('LineFeatures', uselist=False, viewonly=True)
    trunk = relationship('TrunkFeatures', uselist=False, viewonly=True)

    @hybrid_property
    def enabled(self):
        return not bool(self.commented)

    @enabled.expression
    def enabled(cls):
        return not_(cast(cls.commented, Boolean))

    @enabled.setter
    def enabled(self, value):
        if value is None:
            self.commented = None
        else:
            self.commented = int(value is False)

    def endpoint_protocol(self):
        return 'custom'

    def same_protocol(self, protocol, protocolid):
        return protocol == 'custom' and self.id == int(protocolid)

    @hybrid_property
    def interface_suffix(self):
        if self.intfsuffix == '':
            return None
        return self.intfsuffix

    @interface_suffix.expression
    def interface_suffix(cls):
        return func.nullif(cls.intfsuffix, '')

    @interface_suffix.setter
    def interface_suffix(self, value):
        if value is None:
            self.intfsuffix = ''
        else:
            self.intfsuffix = value
예제 #24
0
파일: user.py 프로젝트: shaunren/cms
class User(Base):
    """Class to store a 'user participating in a contest'.

    """
    # TODO: we really need to split this as a user (as in: not paired
    # with a contest) and a participation.
    __tablename__ = 'users'
    __table_args__ = (UniqueConstraint('contest_id', 'username'), )

    # Auto increment primary key.
    id = Column(Integer, primary_key=True)

    # Real name (human readable) of the user.
    first_name = Column(Unicode, nullable=False)
    last_name = Column(Unicode, nullable=False)

    # Username and password to log in the CWS.
    username = Column(Unicode, nullable=False)
    password = Column(Unicode,
                      nullable=False,
                      default=generate_random_password)

    # Email for any communications in case of remote contest.
    email = Column(Unicode, nullable=True)

    # User can log in CWS only from this ip.
    ip = Column(Unicode, nullable=True)

    # A hidden user is used only for debugging purpose.
    hidden = Column(Boolean, nullable=False, default=False)

    # Contest (id and object) to which the user is participating.
    contest_id = Column(Integer,
                        ForeignKey(Contest.id,
                                   onupdate="CASCADE",
                                   ondelete="CASCADE"),
                        nullable=False,
                        index=True)
    contest = relationship(Contest,
                           backref=backref("users",
                                           cascade="all, delete-orphan",
                                           passive_deletes=True))

    # A JSON-encoded dictionary of lists of strings: statements["a"]
    # contains the language codes of the statements that will be
    # highlighted to this user for task "a".
    primary_statements = Column(String, nullable=False, default="{}")

    # Timezone for the user. All timestamps in CWS will be shown using
    # the timezone associated to the logged-in user or (if it's None
    # or an invalid string) the timezone associated to the contest or
    # (if it's None or an invalid string) the local timezone of the
    # server. This value has to be a string like "Europe/Rome",
    # "Australia/Sydney", "America/New_York", etc.
    timezone = Column(Unicode, nullable=True)

    # Starting time: for contests where every user has at most x hours
    # of the y > x hours totally available, this is the time the user
    # decided to start his/her time-frame.
    starting_time = Column(DateTime, nullable=True)

    # An extra amount of time allocated for this user.
    extra_time = Column(Interval, nullable=False, default=timedelta())
예제 #25
0
class SqlaTable(Model, BaseDatasource):
    """An ORM object for SqlAlchemy table references"""

    type = 'table'
    query_language = 'sql'
    metric_class = SqlMetric
    column_class = TableColumn

    __tablename__ = 'tables'
    __table_args__ = (UniqueConstraint('database_id', 'table_name'), )

    table_name = Column(String(250))
    main_dttm_col = Column(String(250))
    database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
    fetch_values_predicate = Column(String(1000))
    user_id = Column(Integer, ForeignKey('ab_user.id'))
    owner = relationship(security_manager.user_model,
                         backref='tables',
                         foreign_keys=[user_id])
    database = relationship('Database',
                            backref=backref('tables',
                                            cascade='all, delete-orphan'),
                            foreign_keys=[database_id])
    schema = Column(String(255))
    sql = Column(Text)
    is_sqllab_view = Column(Boolean, default=False)
    template_params = Column(Text)

    baselink = 'tablemodelview'

    export_fields = ('table_name', 'main_dttm_col', 'description',
                     'default_endpoint', 'database_id', 'offset',
                     'cache_timeout', 'schema', 'sql', 'params',
                     'template_params', 'filter_select_enabled')
    update_from_object_fields = [
        f for f in export_fields if f not in ('table_name', 'database_id')
    ]
    export_parent = 'database'
    export_children = ['metrics', 'columns']

    sqla_aggregations = {
        'COUNT_DISTINCT':
        lambda column_name: sa.func.COUNT(sa.distinct(column_name)),
        'COUNT':
        sa.func.COUNT,
        'SUM':
        sa.func.SUM,
        'AVG':
        sa.func.AVG,
        'MIN':
        sa.func.MIN,
        'MAX':
        sa.func.MAX,
    }

    def __repr__(self):
        return self.name

    @property
    def connection(self):
        return str(self.database)

    @property
    def description_markeddown(self):
        return utils.markdown(self.description)

    @property
    def datasource_name(self):
        return self.table_name

    @property
    def database_name(self):
        return self.database.name

    @property
    def link(self):
        name = escape(self.name)
        anchor = '<a target="_blank" href="{self.explore_url}">{name}</a>'
        return Markup(anchor.format(**locals()))

    @property
    def schema_perm(self):
        """Returns schema permission if present, database one otherwise."""
        return security_manager.get_schema_perm(self.database, self.schema)

    def get_perm(self):
        return ('[{obj.database}].[{obj.table_name}]'
                '(id:{obj.id})').format(obj=self)

    @property
    def name(self):
        if not self.schema:
            return self.table_name
        return '{}.{}'.format(self.schema, self.table_name)

    @property
    def full_name(self):
        return utils.get_datasource_full_name(self.database,
                                              self.table_name,
                                              schema=self.schema)

    @property
    def dttm_cols(self):
        l = [c.column_name for c in self.columns if c.is_dttm]  # noqa: E741
        if self.main_dttm_col and self.main_dttm_col not in l:
            l.append(self.main_dttm_col)
        return l

    @property
    def num_cols(self):
        return [c.column_name for c in self.columns if c.is_num]

    @property
    def any_dttm_col(self):
        cols = self.dttm_cols
        if cols:
            return cols[0]

    @property
    def html(self):
        t = ((c.column_name, c.type) for c in self.columns)
        df = pd.DataFrame(t)
        df.columns = ['field', 'type']
        return df.to_html(
            index=False,
            classes=('dataframe table table-striped table-bordered '
                     'table-condensed'))

    @property
    def sql_url(self):
        return self.database.sql_url + '?table_name=' + str(self.table_name)

    def external_metadata(self):
        cols = self.database.get_columns(self.table_name, schema=self.schema)
        for col in cols:
            col['type'] = '{}'.format(col['type'])
        return cols

    @property
    def time_column_grains(self):
        return {
            'time_columns': self.dttm_cols,
            'time_grains': [grain.name for grain in self.database.grains()],
        }

    @property
    def select_star(self):
        # show_cols and latest_partition set to false to avoid
        # the expensive cost of inspecting the DB
        return self.database.select_star(self.name,
                                         show_cols=False,
                                         latest_partition=False)

    def get_col(self, col_name):
        columns = self.columns
        for col in columns:
            if col_name == col.column_name:
                return col

    @property
    def data(self):
        d = super(SqlaTable, self).data
        if self.type == 'table':
            grains = self.database.grains() or []
            if grains:
                grains = [(g.duration, g.name) for g in grains]
            d['granularity_sqla'] = utils.choicify(self.dttm_cols)
            d['time_grain_sqla'] = grains
            d['main_dttm_col'] = self.main_dttm_col
        return d

    def values_for_column(self, column_name, limit=10000):
        """Runs query against sqla to retrieve some
        sample values for the given column.
        """
        cols = {col.column_name: col for col in self.columns}
        target_col = cols[column_name]
        tp = self.get_template_processor()

        qry = (select([target_col.get_sqla_col()
                       ]).select_from(self.get_from_clause(tp)).distinct())
        if limit:
            qry = qry.limit(limit)

        if self.fetch_values_predicate:
            tp = self.get_template_processor()
            qry = qry.where(tp.process_template(self.fetch_values_predicate))

        engine = self.database.get_sqla_engine()
        sql = '{}'.format(
            qry.compile(engine, compile_kwargs={'literal_binds': True}), )
        sql = self.mutate_query_from_config(sql)

        df = pd.read_sql_query(sql=sql, con=engine)
        return [row[0] for row in df.to_records(index=False)]

    def mutate_query_from_config(self, sql):
        """Apply config's SQL_QUERY_MUTATOR

        Typically adds comments to the query with context"""
        SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR')
        if SQL_QUERY_MUTATOR:
            username = utils.get_username()
            sql = SQL_QUERY_MUTATOR(sql, username, security_manager,
                                    self.database)
        return sql

    def get_template_processor(self, **kwargs):
        return get_template_processor(table=self,
                                      database=self.database,
                                      **kwargs)

    def get_query_str(self, query_obj):
        qry = self.get_sqla_query(**query_obj)
        sql = self.database.compile_sqla_query(qry)
        logging.info(sql)
        sql = sqlparse.format(sql, reindent=True)
        if query_obj['is_prequery']:
            query_obj['prequeries'].append(sql)
        sql = self.mutate_query_from_config(sql)
        return sql

    def get_sqla_table(self):
        tbl = table(self.table_name)
        if self.schema:
            tbl.schema = self.schema
        return tbl

    def get_from_clause(self, template_processor=None):
        # Supporting arbitrary SQL statements in place of tables
        if self.sql:
            from_sql = self.sql
            if template_processor:
                from_sql = template_processor.process_template(from_sql)
            from_sql = sqlparse.format(from_sql, strip_comments=True)
            return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
        return self.get_sqla_table()

    def adhoc_metric_to_sqla(self, metric, cols):
        """
        Turn an adhoc metric into a sqlalchemy column.

        :param dict metric: Adhoc metric definition
        :param dict cols: Columns for the current table
        :returns: The metric defined as a sqlalchemy column
        :rtype: sqlalchemy.sql.column
        """
        expression_type = metric.get('expressionType')
        db_engine_spec = self.database.db_engine_spec
        label = db_engine_spec.make_label_compatible(metric.get('label'))

        if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
            column_name = metric.get('column').get('column_name')
            sqla_column = column(column_name)
            table_column = cols.get(column_name)

            if table_column:
                sqla_column = table_column.get_sqla_col()

            sqla_metric = self.sqla_aggregations[metric.get('aggregate')](
                sqla_column)
            sqla_metric = sqla_metric.label(label)
            return sqla_metric
        elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
            sqla_metric = literal_column(metric.get('sqlExpression'))
            sqla_metric = sqla_metric.label(label)
            return sqla_metric
        else:
            return None

    def get_sqla_query(  # sqla
        self,
        groupby,
        metrics,
        granularity,
        from_dttm,
        to_dttm,
        filter=None,  # noqa
        is_timeseries=True,
        timeseries_limit=15,
        timeseries_limit_metric=None,
        row_limit=None,
        inner_from_dttm=None,
        inner_to_dttm=None,
        orderby=None,
        extras=None,
        columns=None,
        order_desc=True,
        prequeries=None,
        is_prequery=False,
    ):
        """Querying any sqla table from this common interface"""
        template_kwargs = {
            'from_dttm': from_dttm,
            'groupby': groupby,
            'metrics': metrics,
            'row_limit': row_limit,
            'to_dttm': to_dttm,
            'filter': filter,
            'columns': {col.column_name: col
                        for col in self.columns},
        }
        template_kwargs.update(self.template_params_dict)
        template_processor = self.get_template_processor(**template_kwargs)
        db_engine_spec = self.database.db_engine_spec

        orderby = orderby or []

        # For backward compatibility
        if granularity not in self.dttm_cols:
            granularity = self.main_dttm_col

        # Database spec supports join-free timeslot grouping
        time_groupby_inline = db_engine_spec.time_groupby_inline

        cols = {col.column_name: col for col in self.columns}
        metrics_dict = {m.metric_name: m for m in self.metrics}

        if not granularity and is_timeseries:
            raise Exception(
                _('Datetime column not provided as part table configuration '
                  'and is required by this type of chart'))
        if not groupby and not metrics and not columns:
            raise Exception(_('Empty query?'))
        metrics_exprs = []
        for m in metrics:
            if utils.is_adhoc_metric(m):
                metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
            elif m in metrics_dict:
                metrics_exprs.append(metrics_dict.get(m).get_sqla_col())
            else:
                raise Exception(_("Metric '{}' is not valid".format(m)))
        if metrics_exprs:
            main_metric_expr = metrics_exprs[0]
        else:
            main_metric_expr = literal_column('COUNT(*)').label(
                db_engine_spec.make_label_compatible('count'))

        select_exprs = []
        groupby_exprs = []

        if groupby:
            select_exprs = []
            inner_select_exprs = []
            inner_groupby_exprs = []
            for s in groupby:
                col = cols[s]
                outer = col.get_sqla_col()
                inner = col.get_sqla_col(col.column_name + '__')

                groupby_exprs.append(outer)
                select_exprs.append(outer)
                inner_groupby_exprs.append(inner)
                inner_select_exprs.append(inner)
        elif columns:
            for s in columns:
                select_exprs.append(cols[s].get_sqla_col())
            metrics_exprs = []

        if granularity:
            dttm_col = cols[granularity]
            time_grain = extras.get('time_grain_sqla')
            time_filters = []

            if is_timeseries:
                timestamp = dttm_col.get_timestamp_expression(time_grain)
                select_exprs += [timestamp]
                groupby_exprs += [timestamp]

            # Use main dttm column to support index with secondary dttm columns
            if db_engine_spec.time_secondary_columns and \
                    self.main_dttm_col in self.dttm_cols and \
                    self.main_dttm_col != dttm_col.column_name:
                time_filters.append(cols[self.main_dttm_col].get_time_filter(
                    from_dttm, to_dttm))
            time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm))

        select_exprs += metrics_exprs
        qry = sa.select(select_exprs)

        tbl = self.get_from_clause(template_processor)

        if not columns:
            qry = qry.group_by(*groupby_exprs)

        where_clause_and = []
        having_clause_and = []
        for flt in filter:
            if not all([flt.get(s) for s in ['col', 'op']]):
                continue
            col = flt['col']
            op = flt['op']
            col_obj = cols.get(col)
            if col_obj:
                is_list_target = op in ('in', 'not in')
                eq = self.filter_values_handler(
                    flt.get('val'),
                    target_column_is_numeric=col_obj.is_num,
                    is_list_target=is_list_target)
                if op in ('in', 'not in'):
                    cond = col_obj.get_sqla_col().in_(eq)
                    if '<NULL>' in eq:
                        cond = or_(cond,
                                   col_obj.get_sqla_col() == None)  # noqa
                    if op == 'not in':
                        cond = ~cond
                    where_clause_and.append(cond)
                else:
                    if col_obj.is_num:
                        eq = utils.string_to_num(flt['val'])
                    if op == '==':
                        where_clause_and.append(col_obj.get_sqla_col() == eq)
                    elif op == '!=':
                        where_clause_and.append(col_obj.get_sqla_col() != eq)
                    elif op == '>':
                        where_clause_and.append(col_obj.get_sqla_col() > eq)
                    elif op == '<':
                        where_clause_and.append(col_obj.get_sqla_col() < eq)
                    elif op == '>=':
                        where_clause_and.append(col_obj.get_sqla_col() >= eq)
                    elif op == '<=':
                        where_clause_and.append(col_obj.get_sqla_col() <= eq)
                    elif op == 'LIKE':
                        where_clause_and.append(
                            col_obj.get_sqla_col().like(eq))
                    elif op == 'IS NULL':
                        where_clause_and.append(
                            col_obj.get_sqla_col() == None)  # noqa
                    elif op == 'IS NOT NULL':
                        where_clause_and.append(
                            col_obj.get_sqla_col() != None)  # noqa
        if extras:
            where = extras.get('where')
            if where:
                where = template_processor.process_template(where)
                where_clause_and += [sa.text('({})'.format(where))]
            having = extras.get('having')
            if having:
                having = template_processor.process_template(having)
                having_clause_and += [sa.text('({})'.format(having))]
        if granularity:
            qry = qry.where(and_(*(time_filters + where_clause_and)))
        else:
            qry = qry.where(and_(*where_clause_and))
        qry = qry.having(and_(*having_clause_and))

        if not orderby and not columns:
            orderby = [(main_metric_expr, not order_desc)]

        for col, ascending in orderby:
            direction = asc if ascending else desc
            if utils.is_adhoc_metric(col):
                col = self.adhoc_metric_to_sqla(col, cols)
            qry = qry.order_by(direction(col))

        if row_limit:
            qry = qry.limit(row_limit)

        if is_timeseries and \
                timeseries_limit and groupby and not time_groupby_inline:
            if self.database.db_engine_spec.inner_joins:
                # some sql dialects require for order by expressions
                # to also be in the select clause -- others, e.g. vertica,
                # require a unique inner alias
                inner_main_metric_expr = main_metric_expr.label('mme_inner__')
                inner_select_exprs += [inner_main_metric_expr]
                subq = select(inner_select_exprs)
                subq = subq.select_from(tbl)
                inner_time_filter = dttm_col.get_time_filter(
                    inner_from_dttm or from_dttm,
                    inner_to_dttm or to_dttm,
                )
                subq = subq.where(
                    and_(*(where_clause_and + [inner_time_filter])))
                subq = subq.group_by(*inner_groupby_exprs)

                ob = inner_main_metric_expr
                if timeseries_limit_metric:
                    if utils.is_adhoc_metric(timeseries_limit_metric):
                        ob = self.adhoc_metric_to_sqla(timeseries_limit_metric,
                                                       cols)
                    elif timeseries_limit_metric in metrics_dict:
                        timeseries_limit_metric = metrics_dict.get(
                            timeseries_limit_metric, )
                        ob = timeseries_limit_metric.get_sqla_col()
                    else:
                        raise Exception(_(
                            "Metric '{}' is not valid".format(m)))
                direction = desc if order_desc else asc
                subq = subq.order_by(direction(ob))
                subq = subq.limit(timeseries_limit)

                on_clause = []
                for i, gb in enumerate(groupby):
                    on_clause.append(groupby_exprs[i] == column(gb + '__'))

                tbl = tbl.join(subq.alias(), and_(*on_clause))
            else:
                # run subquery to get top groups
                subquery_obj = {
                    'prequeries': prequeries,
                    'is_prequery': True,
                    'is_timeseries': False,
                    'row_limit': timeseries_limit,
                    'groupby': groupby,
                    'metrics': metrics,
                    'granularity': granularity,
                    'from_dttm': inner_from_dttm or from_dttm,
                    'to_dttm': inner_to_dttm or to_dttm,
                    'filter': filter,
                    'orderby': orderby,
                    'extras': extras,
                    'columns': columns,
                    'order_desc': True,
                }
                result = self.query(subquery_obj)
                cols = {col.column_name: col for col in self.columns}
                dimensions = [
                    c for c in result.df.columns
                    if c not in metrics and c in cols
                ]
                top_groups = self._get_top_groups(result.df, dimensions)
                qry = qry.where(top_groups)

        return qry.select_from(tbl)

    def _get_top_groups(self, df, dimensions):
        cols = {col.column_name: col for col in self.columns}
        groups = []
        for unused, row in df.iterrows():
            group = []
            for dimension in dimensions:
                col_obj = cols.get(dimension)
                group.append(col_obj.get_sqla_col() == row[dimension])
            groups.append(and_(*group))

        return or_(*groups)

    def query(self, query_obj):
        qry_start_dttm = datetime.now()
        sql = self.get_query_str(query_obj)
        status = QueryStatus.SUCCESS
        error_message = None
        df = None
        try:
            df = self.database.get_df(sql, self.schema)
        except Exception as e:
            status = QueryStatus.FAILED
            logging.exception(e)
            error_message = (
                self.database.db_engine_spec.extract_error_message(e))

        # if this is a main query with prequeries, combine them together
        if not query_obj['is_prequery']:
            query_obj['prequeries'].append(sql)
            sql = ';\n\n'.join(query_obj['prequeries'])
        sql += ';'

        return QueryResult(status=status,
                           df=df,
                           duration=datetime.now() - qry_start_dttm,
                           query=sql,
                           error_message=error_message)

    def get_sqla_table_object(self):
        return self.database.get_table(self.table_name, schema=self.schema)

    def fetch_metadata(self):
        """Fetches the metadata for the table and merges it in"""
        try:
            table = self.get_sqla_table_object()
        except Exception as e:
            logging.exception(e)
            raise Exception(
                _("Table [{}] doesn't seem to exist in the specified database, "
                  "couldn't fetch column information").format(self.table_name))

        M = SqlMetric  # noqa
        metrics = []
        any_date_col = None
        db_dialect = self.database.get_dialect()
        dbcols = (db.session.query(TableColumn).filter(
            TableColumn.table == self).filter(
                or_(TableColumn.column_name == col.name
                    for col in table.columns)))
        dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
        db_engine_spec = self.database.db_engine_spec

        for col in table.columns:
            try:
                datatype = col.type.compile(dialect=db_dialect).upper()
            except Exception as e:
                datatype = 'UNKNOWN'
                logging.error('Unrecognized data type in {}.{}'.format(
                    table, col.name))
                logging.exception(e)
            dbcol = dbcols.get(col.name, None)
            if not dbcol:
                dbcol = TableColumn(column_name=col.name, type=datatype)
                dbcol.groupby = dbcol.is_string
                dbcol.filterable = dbcol.is_string
                dbcol.sum = dbcol.is_num
                dbcol.avg = dbcol.is_num
                dbcol.is_dttm = dbcol.is_time
            else:
                dbcol.type = datatype
            self.columns.append(dbcol)
            if not any_date_col and dbcol.is_time:
                any_date_col = col.name
            metrics += dbcol.get_metrics().values()

        metrics.append(
            M(
                metric_name='count',
                verbose_name='COUNT(*)',
                metric_type='count',
                expression='COUNT(*)',
            ))
        if not self.main_dttm_col:
            self.main_dttm_col = any_date_col
        for metric in metrics:
            metric.metric_name = db_engine_spec.mutate_expression_label(
                metric.metric_name)
        self.add_missing_metrics(metrics)
        db.session.merge(self)
        db.session.commit()

    @classmethod
    def import_obj(cls, i_datasource, import_time=None):
        """Imports the datasource from the object to the database.

         Metrics and columns and datasource will be overrided if exists.
         This function can be used to import/export dashboards between multiple
         superset instances. Audit metadata isn't copies over.
        """
        def lookup_sqlatable(table):
            return db.session.query(SqlaTable).join(Database).filter(
                SqlaTable.table_name == table.table_name,
                SqlaTable.schema == table.schema,
                Database.id == table.database_id,
            ).first()

        def lookup_database(table):
            return db.session.query(Database).filter_by(
                database_name=table.params_dict['database_name']).one()

        return import_util.import_datasource(db.session, i_datasource,
                                             lookup_database, lookup_sqlatable,
                                             import_time)

    @classmethod
    def query_datasources_by_name(cls,
                                  session,
                                  database,
                                  datasource_name,
                                  schema=None):
        query = (session.query(cls).filter_by(
            database_id=database.id).filter_by(table_name=datasource_name))
        if schema:
            query = query.filter_by(schema=schema)
        return query.all()

    @staticmethod
    def default_query(qry):
        return qry.filter_by(is_sqllab_view=False)
예제 #26
0
파일: db.py 프로젝트: tunghoang/srm
class Attachment(__db.Base):
    __tablename__ = "attachment"
    idAttachment = Column(Integer, primary_key=True)
    title = Column(String(255))
    uuid = Column(String(512))
    idProject = Column(Integer, ForeignKey('project.idProject'))
    idOwner = Column(Integer, ForeignKey('student.idStudent'))
    advisorApproved = Column(Boolean)
    uploadDate = Column(DateTime)

    constraints = list()
    constraints.append(UniqueConstraint('idProject', 'idOwner'))
    if len(constraints) > 0:
        __table_args__ = tuple(constraints)

    def __init__(self, dictModel):
        if ("idAttachment"
                in dictModel) and (dictModel["idAttachment"] != None):
            self.idAttachment = dictModel["idAttachment"]
        if ("title" in dictModel) and (dictModel["title"] != None):
            self.title = dictModel["title"]
        if ("uuid" in dictModel) and (dictModel["uuid"] != None):
            self.uuid = dictModel["uuid"]
        if ("idProject" in dictModel) and (dictModel["idProject"] != None):
            self.idProject = dictModel["idProject"]
        if ("idOwner" in dictModel) and (dictModel["idOwner"] != None):
            self.idOwner = dictModel["idOwner"]
        if ("advisorApproved"
                in dictModel) and (dictModel["advisorApproved"] != None):
            self.advisorApproved = dictModel["advisorApproved"]
        if ("uploadDate" in dictModel) and (dictModel["uploadDate"] != None):
            self.uploadDate = dictModel["uploadDate"]

    def __repr__(self):
        return '<Attachment idAttachment={} title={} uuid={} idProject={} idOwner={} advisorApproved={} uploadDate={} >'.format(
            self.idAttachment,
            self.title,
            self.uuid,
            self.idProject,
            self.idOwner,
            self.advisorApproved,
            self.uploadDate,
        )

    def json(self):
        return {
            "idAttachment": self.idAttachment,
            "title": self.title,
            "uuid": self.uuid,
            "idProject": self.idProject,
            "idOwner": self.idOwner,
            "advisorApproved": self.advisorApproved,
            "uploadDate": self.uploadDate,
        }

    def update(self, dictModel):
        if ("idAttachment"
                in dictModel) and (dictModel["idAttachment"] != None):
            self.idAttachment = dictModel["idAttachment"]
        if ("title" in dictModel) and (dictModel["title"] != None):
            self.title = dictModel["title"]
        if ("uuid" in dictModel) and (dictModel["uuid"] != None):
            self.uuid = dictModel["uuid"]
        if ("idProject" in dictModel) and (dictModel["idProject"] != None):
            self.idProject = dictModel["idProject"]
        if ("idOwner" in dictModel) and (dictModel["idOwner"] != None):
            self.idOwner = dictModel["idOwner"]
        if ("advisorApproved"
                in dictModel) and (dictModel["advisorApproved"] != None):
            self.advisorApproved = dictModel["advisorApproved"]
        if ("uploadDate" in dictModel) and (dictModel["uploadDate"] != None):
            self.uploadDate = dictModel["uploadDate"]
예제 #27
0
class SystemBaseline(db.Model):
    __tablename__ = "system_baselines"
    # do not allow two records in the same account to have the same display name
    __table_args__ = (UniqueConstraint("account",
                                       "display_name",
                                       name="_account_display_name_uc"), )

    id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
    account = db.Column(db.String(10), nullable=False)
    display_name = db.Column(db.String(200), nullable=False)
    created_on = db.Column(db.DateTime,
                           default=datetime.utcnow,
                           nullable=False)
    modified_on = db.Column(db.DateTime,
                            default=datetime.utcnow,
                            onupdate=datetime.utcnow,
                            nullable=False)
    baseline_facts = db.Column(JSONB)
    mapped_systems = relationship(
        "SystemBaselineMappedSystem",
        cascade="all, delete, delete-orphan",
    )

    @property
    def fact_count(self):
        return len(self.baseline_facts)

    @validates("baseline_facts")
    def validate_facts(self, key, value):
        validators.check_facts_length(value)
        validators.check_for_duplicate_names(value)
        return value

    def mapped_system_ids(self):
        mapped_system_ids = []
        for mapped_system in self.mapped_systems:
            mapped_system_ids.append(str(mapped_system.system_id))
        return mapped_system_ids

    def to_json(self, withhold_facts=False, withhold_system_ids=True):
        json_dict = {}
        json_dict["id"] = str(self.id)
        json_dict["account"] = self.account
        json_dict["display_name"] = self.display_name
        json_dict["fact_count"] = self.fact_count
        json_dict["created"] = self.created_on.isoformat() + "Z"
        json_dict["updated"] = self.modified_on.isoformat() + "Z"
        if not withhold_facts:
            json_dict["baseline_facts"] = self.baseline_facts
        if not withhold_system_ids:
            json_dict["system_ids"] = self.mapped_system_ids()
        return json_dict

    def add_mapped_system(self, system_id):
        new_mapped_system = SystemBaselineMappedSystem(system_id=system_id,
                                                       account=self.account)
        self.mapped_systems.append(new_mapped_system)
        db.session.add(new_mapped_system)

    def remove_mapped_system(self, system_id):
        system_id_removed = False
        for mapped_system in self.mapped_systems:
            if str(mapped_system.system_id) == str(system_id):
                self.mapped_systems.remove(mapped_system)
                system_id_removed = True
                break
        if not system_id_removed:
            # do we want to raise exception here?
            raise ValueError(
                "Failed to remove system id %s from mapped systems - not in list"
                % system_id)
예제 #28
0
class UserTestResult(Base):
    """Class to store the execution results of a user_test.

    """
    # Possible statuses of a user test result. COMPILING and
    # EVALUATING do not necessarily imply we are going to schedule
    # compilation and run for these user test results: for
    # example, they might be for datasets not scheduled for
    # evaluation, or they might have passed the maximum number of
    # tries. If a user test result does not exists for a pair
    # (user test, dataset), its status can be implicitly assumed to
    # be COMPILING.
    COMPILING = 1
    COMPILATION_FAILED = 2
    EVALUATING = 3
    EVALUATED = 4

    __tablename__ = 'user_test_results'
    __table_args__ = (UniqueConstraint('user_test_id', 'dataset_id'), )

    # Primary key is (user_test_id, dataset_id).
    user_test_id = Column(Integer,
                          ForeignKey(UserTest.id,
                                     onupdate="CASCADE",
                                     ondelete="CASCADE"),
                          primary_key=True)
    user_test = relationship(UserTest, back_populates="results")

    dataset_id = Column(Integer,
                        ForeignKey(Dataset.id,
                                   onupdate="CASCADE",
                                   ondelete="CASCADE"),
                        primary_key=True)
    dataset = relationship(Dataset)

    # Now below follow the actual result fields.

    # Output file's digest for this test
    output = Column(Digest, nullable=True)

    # Compilation outcome (can be None = yet to compile, "ok" =
    # compilation successful and we can evaluate, "fail" =
    # compilation unsuccessful, throw it away).
    compilation_outcome = Column(String, nullable=True)

    # The output from the sandbox (to allow localization the first item
    # of the list is a format string, possibly containing some "%s",
    # that will be filled in using the remaining items of the list).
    compilation_text = Column(ARRAY(String), nullable=False, default=[])

    # Number of attempts of compilation.
    compilation_tries = Column(Integer, nullable=False, default=0)

    # The compiler stdout and stderr.
    compilation_stdout = Column(Unicode, nullable=True)
    compilation_stderr = Column(Unicode, nullable=True)

    # Other information about the compilation.
    compilation_time = Column(Float, nullable=True)
    compilation_wall_clock_time = Column(Float, nullable=True)
    compilation_memory = Column(BigInteger, nullable=True)

    # Worker shard and sandbox where the compilation was performed.
    compilation_shard = Column(Integer, nullable=True)
    compilation_sandbox = Column(String, nullable=True)

    # Evaluation outcome (can be None = yet to evaluate, "ok" =
    # evaluation successful).
    evaluation_outcome = Column(String, nullable=True)

    # The output from the grader, usually "Correct", "Time limit", ...
    # (to allow localization the first item of the list is a format
    # string, possibly containing some "%s", that will be filled in
    # using the remaining items of the list).
    evaluation_text = Column(ARRAY(String), nullable=False, default=[])

    # Number of attempts of evaluation.
    evaluation_tries = Column(Integer, nullable=False, default=0)

    # Other information about the execution.
    execution_time = Column(Float, nullable=True)
    execution_wall_clock_time = Column(Float, nullable=True)
    execution_memory = Column(BigInteger, nullable=True)

    # Worker shard and sandbox where the evaluation was performed.
    evaluation_shard = Column(Integer, nullable=True)
    evaluation_sandbox = Column(String, nullable=True)

    # These one-to-many relationships are the reversed directions of
    # the ones defined in the "child" classes using foreign keys.

    executables = relationship(
        "UserTestExecutable",
        collection_class=attribute_mapped_collection("filename"),
        cascade="all, delete-orphan",
        passive_deletes=True,
        back_populates="user_test_result")

    def get_status(self):
        """Return the status of this object.

        """
        if not self.compiled():
            return UserTestResult.COMPILING
        elif self.compilation_failed():
            return UserTestResult.COMPILATION_FAILED
        elif not self.evaluated():
            return UserTestResult.EVALUATING
        else:
            return UserTestResult.EVALUATED

    def compiled(self):
        """Return whether the user test result has been compiled.

        return (bool): True if compiled, False otherwise.

        """
        return self.compilation_outcome is not None

    @staticmethod
    def filter_compiled():
        """Return a filtering expression for compiled user test results.

        """
        return UserTestResult.compilation_outcome.isnot(None)

    def compilation_failed(self):
        """Return whether the user test result did not compile.

        return (bool): True if the compilation failed (in the sense
            that there is a problem in the user's source), False if
            not yet compiled or compilation was successful.

        """
        return self.compilation_outcome == "fail"

    @staticmethod
    def filter_compilation_failed():
        """Return a filtering expression for user test results failing
        compilation.

        """
        return UserTestResult.compilation_outcome == "fail"

    def compilation_succeeded(self):
        """Return whether the user test compiled.

        return (bool): True if the compilation succeeded (in the sense
            that an executable was created), False if not yet compiled
            or compilation was unsuccessful.

        """
        return self.compilation_outcome == "ok"

    @staticmethod
    def filter_compilation_succeeded():
        """Return a filtering expression for user test results failing
        compilation.

        """
        return UserTestResult.compilation_outcome == "ok"

    def evaluated(self):
        """Return whether the user test result has been evaluated.

        return (bool): True if evaluated, False otherwise.

        """
        return self.evaluation_outcome is not None

    @staticmethod
    def filter_evaluated():
        """Return a filtering lambda for evaluated user test results.

        """
        return UserTestResult.evaluation_outcome.isnot(None)

    def invalidate_compilation(self):
        """Blank all compilation and evaluation outcomes.

        """
        self.invalidate_evaluation()
        self.compilation_outcome = None
        self.compilation_text = []
        self.compilation_tries = 0
        self.compilation_time = None
        self.compilation_wall_clock_time = None
        self.compilation_memory = None
        self.compilation_shard = None
        self.compilation_sandbox = None
        self.executables = {}

    def invalidate_evaluation(self):
        """Blank the evaluation outcome.

        """
        self.evaluation_outcome = None
        self.evaluation_text = []
        self.evaluation_tries = 0
        self.execution_time = None
        self.execution_wall_clock_time = None
        self.execution_memory = None
        self.evaluation_shard = None
        self.evaluation_sandbox = None
        self.output = None

    def set_compilation_outcome(self, success):
        """Set the compilation outcome based on the success.

        success (bool): if the compilation was successful.

        """
        self.compilation_outcome = "ok" if success else "fail"

    def set_evaluation_outcome(self):
        """Set the evaluation outcome (always ok now).

        """
        self.evaluation_outcome = "ok"
예제 #29
0
파일: auth.py 프로젝트: ehsanhm/stalker
class Permission(Base):
    """A class to hold permissions.

    Permissions in Stalker defines what one can do or do not. A Permission
    instance is composed by three attributes; access, action and class_name.

    Permissions for all the classes in SOM are generally created by Stalker
    when initializing the database.

    If you created any custom classes to extend SOM you are also responsible to
    create the Permissions for it by calling :meth:`stalker.db.register` and
    passing your class to it. See the :mod:`stalker.db` documentation for
    details.

    :param str access: An Enum value which can have the one of the values of
      ``Allow`` or ``Deny``.

    :param str action: An Enum value from the list ['Create', 'Read', 'Update',
      'Delete', 'List']. Can not be None. The list can be changed from
      stalker.config.Config.default_actions.

    :param str class_name: The name of the class that this action is applied
      to. Can not be None or an empty string.

    Example: Let say that you want to create a Permission specifying a Group of
    Users are allowed to create Projects::

      from stalker import db
      from stalker import db
      from stalker.models.auth import User, Group, Permission

      # first setup the db with the default database
      #
      # stalker.db.init() will create all the Actions possible with the
      # SOM classes automatically
      #
      # What is left to you is to create the permissions
      db.setup()

      user1 = User(
          name='Test User',
          login='******',
          password='******',
          email='*****@*****.**'
      )
      user2 = User(
          name='Test User',
          login='******',
          password='******',
          email='*****@*****.**'
      )

      group1 = Group(name='users')
      group1.users = [user1, user2]

      # get the permissions for the Project class
      project_permissions = Permission.query\
          .filter(Permission.access='Allow')\
          .filter(Permission.action='Create')\
          .filter(Permission.class_name='Project')\
          .first()

      # now we have the permission specifying the allowance of creating a
      # Project

      # to make group1 users able to create a Project we simply add this
      # Permission to the groups permission attribute
      group1.permissions.append(permission)

      # and persist this information in the database
      DBSession.add(group)
      DBSession.commit()
    """
    __tablename__ = 'Permissions'
    __table_args__ = (UniqueConstraint('access', 'action', 'class_name'), {
        "extend_existing": True
    })

    id = Column(Integer, primary_key=True)
    _access = Column('access', Enum('Allow', 'Deny', name='AccessNames'))
    _action = Column('action', Enum(*defaults.actions, name='ActionNames'))
    _class_name = Column('class_name', String(32))

    def __init__(self, access, action, class_name):
        self._access = self._validate_access(access)
        self._action = self._validate_action(action)
        self._class_name = self._validate_class_name(class_name)

    def _validate_access(self, access):
        """validates the given access value
        """
        from stalker import __string_types__
        if not isinstance(access, __string_types__):
            raise TypeError(
                '%s.access should be an instance of str not %s' %
                (self.__class__.__name__, access.__class__.__name__))

        if access not in ['Allow', 'Deny']:
            raise ValueError('%s.access should be "Allow" or "Deny" not %s' %
                             (self.__class__.__name__, access))

        return access

    def _access_getter(self):
        """returns the _access value
        """
        return self._access

    access = synonym('_access', descriptor=property(_access_getter))

    def _validate_class_name(self, class_name):
        """validates the given class_name value
        """
        from stalker import __string_types__
        if not isinstance(class_name, __string_types__):
            raise TypeError(
                '%s.class_name should be an instance of str not %s' %
                (self.__class__.__name__, class_name.__class__.__name__))

        return class_name

    def _class_name_getter(self):
        """returns the _class_name attribute value
        """
        return self._class_name

    class_name = synonym('_class_name',
                         descriptor=property(_class_name_getter))

    def _validate_action(self, action):
        """validates the given action value
        """
        from stalker import __string_types__
        if not isinstance(action, __string_types__):
            raise TypeError(
                '%s.action should be an instance of str not %s' %
                (self.__class__.__name__, action.__class__.__name__))

        if action not in defaults.actions:
            raise ValueError(
                '%s.action should be one of the values of %s not %s' %
                (self.__class__.__name__, defaults.actions, action))

        return action

    def _action_getter(self):
        """returns the _action value
        """
        return self._action

    action = synonym('_action', descriptor=property(_action_getter))

    def __eq__(self, other):
        """the equality of two Permissions
        """
        return isinstance(other, Permission) \
            and other.access == self.access \
            and other.action == self.action \
            and other.class_name == self.class_name
예제 #30
0
    Column('description', String(length=255)),
    Column('tenant_id', String(length=64), nullable=True),
    Column('datastore_id', String(length=64), nullable=True),
    Column('datastore_version_id', String(length=64), nullable=True),
    Column('auto_apply', Boolean(), default=0, nullable=False),
    Column('visible', Boolean(), default=1, nullable=False),
    Column('live_update', Boolean(), default=0, nullable=False),
    Column('md5', String(length=32), nullable=False),
    Column('created', DateTime(), nullable=False),
    Column('updated', DateTime(), nullable=False),
    Column('deleted', Boolean(), default=0, nullable=False),
    Column('deleted_at', DateTime()),
    UniqueConstraint('type',
                     'tenant_id',
                     'datastore_id',
                     'datastore_version_id',
                     'name',
                     'deleted_at',
                     name='UQ_type_tenant_datastore_datastore_version_name'),
)

instance_modules = Table(
    'instance_modules',
    meta,
    Column('id', String(length=64), primary_key=True, nullable=False),
    Column('instance_id',
           String(length=64),
           ForeignKey('instances.id', ondelete="CASCADE", onupdate="CASCADE"),
           nullable=False),
    Column('module_id',
           String(length=64),
예제 #31
0
class TestCase(db.Model):
    """
    A single run of a single test, together with any captured output, retry-count
    and its return value.

    Every test that gets run ever has a row in this table.

    At the time this was written, it seems to have 400-500M rows

    (how is this still surviving?)

    NOTE: DO NOT MODIFY THIS TABLE! Running migration on this table has caused
    unavailability in the past. If you need to add a new column, consider doing
    that on a new table and linking it back to tests via the ID.
    """
    __tablename__ = 'test'
    __table_args__ = (
        UniqueConstraint('job_id', 'label_sha', name='unq_test_name'),
        Index('idx_test_step_id', 'step_id'),
        Index('idx_test_project_key', 'project_id', 'label_sha'),
        Index('idx_task_date_created', 'date_created'),
        Index('idx_test_project_key_date', 'project_id', 'label_sha',
              'date_created'),
    )

    id = Column(GUID, nullable=False, primary_key=True, default=uuid.uuid4)
    job_id = Column(GUID,
                    ForeignKey('job.id', ondelete="CASCADE"),
                    nullable=False)
    project_id = Column(GUID,
                        ForeignKey('project.id', ondelete="CASCADE"),
                        nullable=False)
    step_id = Column(GUID, ForeignKey('jobstep.id', ondelete="CASCADE"))
    name_sha = Column('label_sha', String(40), nullable=False)
    name = Column(Text, nullable=False)
    _package = Column('package', Text, nullable=True)
    result = Column(Enum(Result), default=Result.unknown, nullable=False)
    duration = Column(Integer, default=0)
    message = deferred(Column(Text))
    date_created = Column(DateTime, default=datetime.utcnow, nullable=False)
    reruns = Column(Integer)

    # owner should be considered an unstructured string field. It may contain
    # email address ("Foo <*****@*****.**>", a username ("foo"), or something
    # else. This field is not used directly by Changes, so
    # providers + consumers on either side of Changes should be sure they know
    # what they're doing.
    owner = Column(Text)

    job = relationship('Job')
    step = relationship('JobStep')
    project = relationship('Project')

    __repr__ = model_repr('name', '_package', 'result')

    def __init__(self, **kwargs):
        super(TestCase, self).__init__(**kwargs)
        if self.id is None:
            self.id = uuid.uuid4()
        if self.result is None:
            self.result = Result.unknown
        if self.date_created is None:
            self.date_created = datetime.utcnow()

    @classmethod
    def calculate_name_sha(self, name):
        if name:
            return sha1(name).hexdigest()
        raise ValueError

    @property
    def sep(self):
        name = (self._package or self.name)
        # handle the case where it might begin with some special character
        if not re.match(r'^[a-zA-Z0-9]', name):
            return '/'
        elif '/' in name:
            return '/'
        return '.'

    def _get_package(self):
        if not self._package:
            try:
                package, _ = self.name.rsplit(self.sep, 1)
            except ValueError:
                package = None
        else:
            package = self._package
        return package

    def _set_package(self, value):
        self._package = value

    package = property(_get_package, _set_package)

    @property
    def short_name(self):
        name, package = self.name, self.package
        if package and name.startswith(package) and name != package:
            return name[len(package) + 1:]
        return name
예제 #32
0
class Category(MailSyncBase, HasRevisions, HasPublicID, UpdatedAtMixin,
               DeletedAtMixin):
    @property
    def API_OBJECT_NAME(self):
        return self.type_

    # Override the default `deleted_at` column with one that is NOT NULL --
    # Category.deleted_at is needed in a UniqueConstraint.
    # Set the default Category.deleted_at = EPOCH instead.
    deleted_at = Column(DateTime,
                        index=True,
                        nullable=False,
                        default='1970-01-01 00:00:00')

    # Need `use_alter` here to avoid circular dependencies
    namespace_id = Column(ForeignKey('namespace.id',
                                     use_alter=True,
                                     name='category_fk1',
                                     ondelete='CASCADE'),
                          nullable=False)
    namespace = relationship('Namespace', load_on_pending=True)

    # STOPSHIP(emfree): need to index properly for API filtering performance.
    name = Column(String(MAX_INDEXABLE_LENGTH), nullable=False, default='')
    display_name = Column(CategoryNameString(), nullable=False)

    type_ = Column(Enum('folder', 'label'), nullable=False, default='folder')

    @validates('display_name')
    def validate_display_name(self, key, display_name):
        sanitized_name = sanitize_name(display_name)
        if sanitized_name != display_name:
            log.warning("Truncating category display_name",
                        type_=self.type_,
                        original=display_name)
        return sanitized_name

    @classmethod
    def find_or_create(cls, session, namespace_id, name, display_name, type_):
        name = name or ''

        objects = session.query(cls).filter(
            cls.namespace_id == namespace_id,
            cls.display_name == display_name).all()

        if not objects:
            obj = cls(namespace_id=namespace_id,
                      name=name,
                      display_name=display_name,
                      type_=type_,
                      deleted_at=EPOCH)
            session.add(obj)
        elif len(objects) == 1:
            obj = objects[0]
            if not obj.name:
                # There is an existing category with this `display_name` and no
                # `name`, so update it's `name` as needed.
                # This is needed because the first time we sync generic IMAP
                # folders, they may initially have `name` == '' but later they may
                # get a `name`. At this point, it *is* the same folder so we
                # merely want to update its `name`, not create a new one.
                obj.name = name
        else:
            log.error('Duplicate category rows for namespace_id {}, '
                      'name {}, display_name: {}'.format(
                          namespace_id, name, display_name))
            raise MultipleResultsFound(
                'Duplicate category rows for namespace_id {}, name {}, '
                'display_name: {}'.format(namespace_id, name, display_name))

        return obj

    @classmethod
    def create(cls, session, namespace_id, name, display_name, type_):
        name = name or ''
        obj = cls(namespace_id=namespace_id,
                  name=name,
                  display_name=display_name,
                  type_=type_,
                  deleted_at=EPOCH)
        session.add(obj)
        return obj

    @property
    def account(self):
        return self.namespace.account

    @property
    def type(self):
        return self.account.category_type

    @hybrid_property
    def lowercase_name(self):
        return self.display_name.lower()

    @lowercase_name.comparator
    def lowercase_name(cls):
        return CaseInsensitiveComparator(cls.display_name)

    @property
    def api_display_name(self):
        if self.namespace.account.provider == 'gmail':
            if self.display_name.startswith('[Gmail]/'):
                return self.display_name[8:]
            elif self.display_name.startswith('[Google Mail]/'):
                return self.display_name[14:]

        if self.namespace.account.provider not in ['gmail', 'eas']:
            return fs_folder_path(
                self.display_name,
                separator=self.namespace.account.folder_separator,
                prefix=self.namespace.account.folder_prefix)

        return self.display_name

    @property
    def is_deleted(self):
        return self.deleted_at > EPOCH

    __table_args__ = (UniqueConstraint('namespace_id', 'name', 'display_name',
                                       'deleted_at'),
                      UniqueConstraint('namespace_id', 'public_id'))
예제 #33
0
def revert_username_index(db):
    """
    Revert the stuff we did in migration 22 above.

    There were a couple of problems with what we did:
     - There was never a need for this migration!  The unique
       constraint had an implicit b-tree index, so it wasn't really
       needed.  (This is my (Chris Webber's) fault for suggesting it
       needed to happen without knowing what's going on... my bad!)
     - On top of that, databases created after the models.py was
       changed weren't the same as those that had been run through
       migration 22 above.

    As such, we're setting things back to the way they were before,
    but as it turns out, that's tricky to do!
    """
    metadata = MetaData(bind=db.bind)
    user_table = inspect_table(metadata, "core__users")
    indexes = dict(
        [(index.name, index) for index in user_table.indexes])

    # index from unnecessary migration
    users_uploader_index = indexes.get(u'ix_core__users_uploader')
    # index created from models.py after (unique=True, index=True)
    # was set in models.py
    users_username_index = indexes.get(u'ix_core__users_username')

    if users_uploader_index is None and users_username_index is None:
        # We don't need to do anything.
        # The database isn't in a state where it needs fixing
        #
        # (ie, either went through the previous borked migration or
        #  was initialized with a models.py where core__users was both
        #  unique=True and index=True)
        return

    if db.bind.url.drivername == 'sqlite':
        # Again, sqlite has problems.  So this is tricky.

        # Yes, this is correct to use User_vR1!  Nothing has changed
        # between the *correct* version of this table and migration 18.
        User_vR1.__table__.create(db.bind)
        db.commit()
        new_user_table = inspect_table(metadata, 'rename__users')
        replace_table_hack(db, user_table, new_user_table)

    else:
        # If the db is not run using SQLite, we don't need to do crazy
        # table copying.

        # Remove whichever of the not-used indexes are in place
        if users_uploader_index is not None:
            users_uploader_index.drop()
        if users_username_index is not None:
            users_username_index.drop()

        # Given we're removing indexes then adding a unique constraint
        # which *we know might fail*, thus probably rolling back the
        # session, let's commit here.
        db.commit()

        try:
            # Add the unique constraint
            constraint = UniqueConstraint(
                'username', table=user_table)
            constraint.create()
        except ProgrammingError:
            # constraint already exists, no need to add
            db.rollback()

    db.commit()
예제 #34
0
class JobPhase(db.Model):
    """A JobPhase is a grouping of one or more JobSteps performing the same basic task.
    The phases of a Job are intended to be executed sequentially, though that isn't necessarily
    enforced.

    One example of phase usage: a Job may have a test collection phase and a test execution phase,
    with a single JobStep collecting tests in the first phase and an arbitrary number
    of JobSteps executing shards of the collected tests in the second phase.
    By using two phases, the types of JobSteps can be tracked and managed independently.

    Though JobPhases are typically created to group newly created JobSteps, they
    can also be constructed retroactively once a JobStep has finished based on
    phased artifacts. This is convenient but a little confusing, and perhaps
    should be handled by another mechanism.
    """
    # TODO(dcramer): add order column rather than implicity date_started ordering
    # TODO(dcramer): make duration a column
    __tablename__ = 'jobphase'
    __table_args__ = (
        UniqueConstraint('job_id', 'label', name='unq_jobphase_key'),
    )

    id = Column(GUID, nullable=False, primary_key=True, default=uuid.uuid4)
    job_id = Column(GUID, ForeignKey('job.id', ondelete="CASCADE"), nullable=False)
    project_id = Column(GUID, ForeignKey('project.id', ondelete="CASCADE"), nullable=False)
    label = Column(String(128), nullable=False)
    status = Column(Enum(Status), nullable=False, default=Status.unknown)
    result = Column(Enum(Result), nullable=False, default=Result.unknown)
    date_started = Column(DateTime)
    date_finished = Column(DateTime)
    date_created = Column(DateTime, default=datetime.utcnow)

    job = relationship('Job', backref=backref('phases', order_by='JobPhase.date_started'))
    project = relationship('Project')

    def __init__(self, **kwargs):
        super(JobPhase, self).__init__(**kwargs)
        if self.id is None:
            self.id = uuid.uuid4()
        if self.result is None:
            self.result = Result.unknown
        if self.status is None:
            self.status = Status.unknown
        if self.date_created is None:
            self.date_created = datetime.utcnow()

    @property
    def duration(self):
        """
        Return the duration (in milliseconds) that this item was in-progress.
        """
        if self.date_started and self.date_finished:
            duration = (self.date_finished - self.date_started).total_seconds() * 1000
        else:
            duration = None
        return duration

    @property
    def current_steps(self):
        """
        Return only steps from this phase that have not been replaced.
        """
        # note that the self.steps property exists because of a backref in JobStep
        return [s for s in self.steps if s.replacement_id is None]