def fix_CollectionItem_v0_constraint(db_conn): """Add the forgotten Constraint on CollectionItem""" global collectionitem_unique_constraint_done if collectionitem_unique_constraint_done: # Reset it. Maybe the whole thing gets run again # For a different db? collectionitem_unique_constraint_done = False return metadata = MetaData(bind=db_conn.bind) CollectionItem_table = inspect_table(metadata, 'core__collection_items') constraint = UniqueConstraint('collection', 'media_entry', name='core__collection_items_collection_media_entry_key', table=CollectionItem_table) try: constraint.create() except ProgrammingError: # User probably has an install that was run since the # collection tables were added, so we don't need to run this migration. pass db_conn.commit()
def unique_collections_slug(db): """Add unique constraint to collection slug""" metadata = MetaData(bind=db.bind) collection_table = inspect_table(metadata, "core__collections") existing_slugs = {} slugs_to_change = [] for row in db.execute(collection_table.select()): # if duplicate slug, generate a unique slug if row.creator in existing_slugs and row.slug in \ existing_slugs[row.creator]: slugs_to_change.append(row.id) else: if not row.creator in existing_slugs: existing_slugs[row.creator] = [row.slug] else: existing_slugs[row.creator].append(row.slug) for row_id in slugs_to_change: new_slug = six.text_type(uuid.uuid4()) db.execute(collection_table.update(). where(collection_table.c.id == row_id). values(slug=new_slug)) # sqlite does not like to change the schema when a transaction(update) is # not yet completed db.commit() constraint = UniqueConstraint('creator', 'slug', name='core__collection_creator_slug_key', table=collection_table) constraint.create() db.commit()
def pw_hash_nullable(db): """Make pw_hash column nullable""" metadata = MetaData(bind=db.bind) user_table = inspect_table(metadata, "core__users") user_table.c.pw_hash.alter(nullable=True) # sqlite+sqlalchemy seems to drop this constraint during the # migration, so we add it back here for now a bit manually. if db.bind.url.drivername == 'sqlite': constraint = UniqueConstraint('username', table=user_table) constraint.create() db.commit()
class Item(DB.Model): """ This data model represents a single lendable item """ __tablename__ = 'Item' id = DB.Column(DB.Integer, primary_key=True) name = DB.Column(DB.String(STD_STRING_SIZE)) update_name_from_schema = DB.Column(DB.Boolean, default=True, nullable=False) type_id = DB.Column(DB.Integer, DB.ForeignKey('ItemType.id')) lending_id = DB.Column(DB.Integer, DB.ForeignKey('Lending.id'), default=None, nullable=True) lending_duration = DB.Column(DB.Integer, nullable=True) # in seconds due = DB.Column(DB.Integer, default=-1) # unix time deleted_time = DB.Column(DB.Integer, default=None) visible_for = DB.Column(DB.String(STD_STRING_SIZE), nullable=True) type = DB.relationship('ItemType', lazy='joined') lending = DB.relationship('Lending', lazy='select', backref=DB.backref('_items', lazy='select')) __table_args__ = ( UniqueConstraint('name', 'type_id', name='_name_type_id_uc'), ) def __init__(self, update_name_from_schema: bool, name: str, type_id: int, lending_duration: int = -1, visible_for: str = ''): self.update_name_from_schema = update_name_from_schema self.name = name self.type_id = type_id if lending_duration >= 0: self.lending_duration = lending_duration if visible_for != '' and visible_for != None: self.visible_for = visible_for def update(self, update_name_from_schema: bool, name: str, type_id: int, lending_duration: int = -1, visible_for: str = ''): """ Function to update the objects data """ self.update_name_from_schema = update_name_from_schema if self.update_name_from_schema: self.name = self.name_schema_name else: self.name = name self.type_id = type_id self.lending_duration = lending_duration self.visible_for = visible_for @property def deleted(self): return self.deleted_time is not None @deleted.setter def deleted(self, value: bool): if value: self.deleted_time = int(time.time()) else: self.deleted_time = None @property def is_currently_lent(self): """ If the item is currently lent. """ return self.lending is not None @property def parent(self): if self._parents: return self._parents[0].parent return None @property def effective_lending_duration(self): """ The effective lending duration computed from item type, tags and item """ if self.lending_duration and (self.lending_duration >= 0): return self.lending_duration tag_lending_duration = min((t.tag.lending_duration for t in self._tags if t.tag.lending_duration > 0), default=-1) if tag_lending_duration >= 0: return tag_lending_duration return self.type.lending_duration @property def name_schema_name(self): template = string.Template(self.type.name_schema) attributes = {} for attr in self._attributes: if attr.value: try: attributes[attr.attribute_definition.name] = loads(attr.value) except: pass else: attr_def = loads(attr.attribute_definition.jsonschema) attributes[attr.attribute_definition.name] = attr_def.get('default', '') parent = "".join(item.parent.name for item in ItemToItem.query.filter(ItemToItem.item_id == self.id).all()) today = date.today() times = { 'c_year': today.year, 'c_month': today.month, 'c_day': today.day, 'c_date': today.strftime('%d.%b.%Y'), 'c_date_iso': today.isoformat(), } return template.safe_substitute(attributes, type=self.type.name, parent=parent, **times) def delete(self): if self.is_currently_lent: return(400, "Requested item is currently lent!", False) self.deleted = True for element in self._attributes: element.delete() # Not intended -neumantm # for element in self._contained_items: # DB.session.delete(element) # for element in self._tags: # DB.session.delete(element) return(204, "", True) def get_attribute_changes(self, definition_ids, remove: bool = False): """ Get a list of attributes to add, to delete and to undelete, considering all definition_ids in the list and whether to add or remove them. """ attributes_to_add = [] attributes_to_delete = [] attributes_to_undelete = [] for def_id in definition_ids: itads = self._attributes exists = False for itad in itads: if(itad.attribute_definition_id == def_id): exists = True if(remove): # Check if multiple sources bring it, if yes don't delete it. sources = 0 if(def_id in [ittad.attribute_definition_id for ittad in self.type._item_type_to_attribute_definitions if not ittad.attribute_definition.deleted]): sources += 1 for tag in [itt.tag for itt in self._tags]: if(def_id in [ttad.attribute_definition_id for ttad in tag._tag_to_attribute_definitions if not ttad.attribute_definition.deleted]): sources += 1 if sources == 1: attributes_to_delete.append(itad) elif(itad.deleted): attributes_to_undelete.append(itad) if not exists and not remove: attributes_to_add.append(ItemToAttributeDefinition(self.id, def_id, "")) # TODO: Get default if possible. return attributes_to_add, attributes_to_delete, attributes_to_undelete def get_new_attributes_from_type(self, type_id: int): """ Get a list of attributes to add to a new item which has the given type. """ item_type_attribute_definitions = (itemType.ItemTypeToAttributeDefinition .query .filter(itemType.ItemTypeToAttributeDefinition.item_type_id == type_id) .all()) attributes_to_add, _, _ = self.get_attribute_changes( [ittad.attribute_definition_id for ittad in item_type_attribute_definitions if not ittad.item_type.deleted], False) return attributes_to_add def get_attribute_changes_from_type_change(self, from_type_id: int, to_type_id: int): """ Get a list of attributes to add, to delete and to undelete, when this item would now switch from the first to the second type. """ old_item_type_attr_defs = (itemType.ItemTypeToAttributeDefinition .query .filter(itemType.ItemTypeToAttributeDefinition.item_type_id == from_type_id) .all()) new_item_type_attr_defs = (itemType.ItemTypeToAttributeDefinition .query .filter(itemType.ItemTypeToAttributeDefinition.item_type_id == to_type_id) .all()) old_attr_def_ids = [ittad.attribute_definition_id for ittad in old_item_type_attr_defs] new_attr_def_ids = [ittad.attribute_definition_id for ittad in new_item_type_attr_defs] added_attr_def_ids = [attr_def_id for attr_def_id in new_attr_def_ids if attr_def_id not in old_attr_def_ids] removed_attr_def_ids = [attr_def_id for attr_def_id in old_attr_def_ids if attr_def_id not in new_attr_def_ids] attributes_to_add, _, attributes_to_undelete = self.get_attribute_changes(added_attr_def_ids, False) _, attributes_to_delete, _ = self.get_attribute_changes(removed_attr_def_ids, True) return attributes_to_add, attributes_to_delete, attributes_to_undelete def get_attribute_changes_from_tag(self, tag_id: int, remove: bool = False): """ Get a list of attributes to add, to delete and to undelete, when this item would now get that tag or loose that tag. """ tag_attribute_definitions = (TagToAttributeDefinition .query .filter(TagToAttributeDefinition.tag_id == tag_id) .all()) return self.get_attribute_changes([ttad.attribute_definition_id for ttad in tag_attribute_definitions], remove)
class TableColumn(Model, BaseColumn): """ORM object for table columns, each table can have multiple columns""" __tablename__ = 'table_columns' __table_args__ = (UniqueConstraint('table_id', 'column_name'), ) table_id = Column(Integer, ForeignKey('tables.id')) table = relationship('SqlaTable', backref=backref('columns', cascade='all, delete-orphan'), foreign_keys=[table_id]) is_dttm = Column(Boolean, default=False) expression = Column(Text, default='') python_date_format = Column(String(255)) database_expression = Column(String(255)) export_fields = ( 'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active', 'type', 'groupby', 'count_distinct', 'sum', 'avg', 'max', 'min', 'filterable', 'expression', 'description', 'python_date_format', 'database_expression', ) update_from_object_fields = [ s for s in export_fields if s not in ('table_id', ) ] export_parent = 'table' def get_sqla_col(self, label=None): db_engine_spec = self.table.database.db_engine_spec label = db_engine_spec.make_label_compatible( label if label else self.column_name) if not self.expression: col = column(self.column_name).label(label) else: col = literal_column(self.expression).label(label) return col @property def datasource(self): return self.table def get_time_filter(self, start_dttm, end_dttm): col = self.get_sqla_col(label='__time') l = [] # noqa: E741 if start_dttm: l.append(col >= text(self.dttm_sql_literal(start_dttm))) if end_dttm: l.append(col <= text(self.dttm_sql_literal(end_dttm))) return and_(*l) def get_timestamp_expression(self, time_grain): """Getting the time component of the query""" pdf = self.python_date_format is_epoch = pdf in ('epoch_s', 'epoch_ms') if not self.expression and not time_grain and not is_epoch: return column(self.column_name, type_=DateTime).label(DTTM_ALIAS) expr = self.expression or self.column_name if is_epoch: # if epoch, translate to DATE using db specific conf db_spec = self.table.database.db_engine_spec if pdf == 'epoch_s': expr = db_spec.epoch_to_dttm().format(col=expr) elif pdf == 'epoch_ms': expr = db_spec.epoch_ms_to_dttm().format(col=expr) if time_grain: grain = self.table.database.grains_dict().get(time_grain) if grain: expr = grain.function.format(col=expr) return literal_column(expr, type_=DateTime).label(DTTM_ALIAS) @classmethod def import_obj(cls, i_column): def lookup_obj(lookup_column): return db.session.query(TableColumn).filter( TableColumn.table_id == lookup_column.table_id, TableColumn.column_name == lookup_column.column_name).first() return import_util.import_simple_obj(db.session, i_column, lookup_obj) def dttm_sql_literal(self, dttm): """Convert datetime object to a SQL expression string If database_expression is empty, the internal dttm will be parsed as the string with the pattern that the user inputted (python_date_format) If database_expression is not empty, the internal dttm will be parsed as the sql sentence for the database to convert """ tf = self.python_date_format if self.database_expression: return self.database_expression.format( dttm.strftime('%Y-%m-%d %H:%M:%S')) elif tf: if tf == 'epoch_s': return str((dttm - datetime(1970, 1, 1)).total_seconds()) elif tf == 'epoch_ms': return str( (dttm - datetime(1970, 1, 1)).total_seconds() * 1000.0) return "'{}'".format(dttm.strftime(tf)) else: s = self.table.database.db_engine_spec.convert_dttm( self.type or '', dttm) return s or "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S.%f')) def get_metrics(self): # TODO deprecate, this is not needed since MetricsControl metrics = [] M = SqlMetric # noqa quoted = self.column_name if self.sum: metrics.append( M( metric_name='sum__' + self.column_name, metric_type='sum', expression='SUM({})'.format(quoted), )) if self.avg: metrics.append( M( metric_name='avg__' + self.column_name, metric_type='avg', expression='AVG({})'.format(quoted), )) if self.max: metrics.append( M( metric_name='max__' + self.column_name, metric_type='max', expression='MAX({})'.format(quoted), )) if self.min: metrics.append( M( metric_name='min__' + self.column_name, metric_type='min', expression='MIN({})'.format(quoted), )) if self.count_distinct: metrics.append( M( metric_name='count_distinct__' + self.column_name, metric_type='count_distinct', expression='COUNT(DISTINCT {})'.format(quoted), )) return {m.metric_name: m for m in metrics}
class Participation(Base): """Class to store a single participation of a user in a contest. """ __tablename__ = 'participations' # Auto increment primary key. id = Column(Integer, primary_key=True) # The user can log in CWS only from this IP address or subnet. ip = Column(CastingArray(CIDR), nullable=True) # Starting time: for contests where every user has at most x hours # of the y > x hours totally available, this is the time the user # decided to start their time-frame. starting_time = Column(DateTime, nullable=True) # A shift in the time interval during which the user is allowed to # submit. delay_time = Column(Interval, CheckConstraint("delay_time >= '0 seconds'"), nullable=False, default=timedelta()) # An extra amount of time allocated for this user. extra_time = Column(Interval, CheckConstraint("extra_time >= '0 seconds'"), nullable=False, default=timedelta()) # Contest-specific password. If this password is not null then the # traditional user.password field will be "replaced" by this field's # value (only for this participation). password = Column(Unicode, nullable=True) # A hidden participation (e.g. does not appear in public rankings), can # also be used for debugging purposes. hidden = Column(Boolean, nullable=False, default=False) # An unrestricted participation (e.g. contest time, # maximum number of submissions, minimum interval between submissions, # maximum number of user tests, minimum interval between user tests), # can also be used for debugging purposes. unrestricted = Column(Boolean, nullable=False, default=False) # Contest (id and object) to which the user is participating. contest_id = Column(Integer, ForeignKey(Contest.id, onupdate="CASCADE", ondelete="CASCADE"), nullable=False, index=True) contest = relationship(Contest, backref=backref("participations", cascade="all, delete-orphan", passive_deletes=True)) # User (id and object) which is participating. user_id = Column(Integer, ForeignKey(User.id, onupdate="CASCADE", ondelete="CASCADE"), nullable=False, index=True) user = relationship(User, backref=backref("participations", cascade="all, delete-orphan", passive_deletes=True)) __table_args__ = (UniqueConstraint('contest_id', 'user_id'), ) # Team (id and object) that the user is representing with this # participation. team_id = Column(Integer, ForeignKey(Team.id, onupdate="CASCADE", ondelete="RESTRICT"), nullable=True) team = relationship(Team, backref=backref("participations", cascade="all, delete-orphan", passive_deletes=True))
Column('name', String(255), unique=True), Column('manager', String(255), nullable=False), Column('default_version_id', String(36)), ) datastore_versions = Table( 'datastore_versions', meta, Column('id', String(36), primary_key=True, nullable=False), Column('datastore_id', String(36), ForeignKey('datastores.id')), Column('name', String(255), unique=True), Column('image_id', String(36), nullable=False), Column('packages', String(511)), Column('active', Boolean(), nullable=False), UniqueConstraint('datastore_id', 'name', name='ds_versions') ) def upgrade(migrate_engine): meta.bind = migrate_engine create_tables([datastores, datastore_versions]) instances = Table('instances', meta, autoload=True) datastore_version_id = Column('datastore_version_id', String(36), ForeignKey('datastore_versions.id')) instances.create_column(datastore_version_id) instances.drop_column('service_type') # Table 'service_images' is deprecated since this version. # Leave it for few releases. #drop_tables([service_images])
class Task(Base): """Class to store a task. """ __tablename__ = 'tasks' __table_args__ = ( UniqueConstraint('contest_id', 'num'), UniqueConstraint('contest_id', 'name'), ForeignKeyConstraint( ("id", "active_dataset_id"), ("datasets.task_id", "datasets.id"), onupdate="SET NULL", ondelete="SET NULL", # Use an ALTER query to set this foreign key after # both tables have been CREATEd, to avoid circular # dependencies. use_alter=True, name="fk_active_dataset_id" ), CheckConstraint("token_gen_initial <= token_gen_max"), ) # Auto increment primary key. id = Column( Integer, primary_key=True, # Needed to enable autoincrement on integer primary keys that # are referenced by a foreign key defined on this table. autoincrement='ignore_fk') # Number of the task for sorting. num = Column( Integer, nullable=True) # Contest (id and object) owning the task. contest_id = Column( Integer, ForeignKey(Contest.id, onupdate="CASCADE", ondelete="CASCADE"), nullable=True, index=True) contest = relationship( Contest, back_populates="tasks") # Short name and long human readable title of the task. name = Column( Unicode, CodenameConstraint("name"), nullable=False, unique=True) title = Column( Unicode, nullable=False) # The names of the files that the contestant needs to submit (with # language-specific extensions replaced by "%l"). submission_format = Column( ARRAY(String), FilenameListConstraint("submission_format"), nullable=False, default=[]) # The language codes of the statements that will be highlighted to # all users for this task. primary_statements = Column( ARRAY(String), nullable=False, default=[]) # The parameters that control task-tokens follow. Note that their # effect during the contest depends on the interaction with the # parameters that control contest-tokens, defined on the Contest. # The "kind" of token rules that will be active during the contest. # - disabled: The user will never be able to use any token. # - finite: The user has a finite amount of tokens and can choose # when to use them, subject to some limitations. Tokens may not # be all available at start, but given periodically during the # contest instead. # - infinite: The user will always be able to use a token. token_mode = Column( Enum(TOKEN_MODE_DISABLED, TOKEN_MODE_FINITE, TOKEN_MODE_INFINITE, name="token_mode"), nullable=False, default="disabled") # The maximum number of tokens a contestant is allowed to use # during the whole contest (on this tasks). token_max_number = Column( Integer, CheckConstraint("token_max_number > 0"), nullable=True) # The minimum interval between two successive uses of tokens for # the same user (on this task). token_min_interval = Column( Interval, CheckConstraint("token_min_interval >= '0 seconds'"), nullable=False, default=timedelta()) # The parameters that control generation (if mode is "finite"): # the user starts with "initial" tokens and receives "number" more # every "interval", but their total number is capped to "max". token_gen_initial = Column( Integer, CheckConstraint("token_gen_initial >= 0"), nullable=False, default=2) token_gen_number = Column( Integer, CheckConstraint("token_gen_number >= 0"), nullable=False, default=2) token_gen_interval = Column( Interval, CheckConstraint("token_gen_interval > '0 seconds'"), nullable=False, default=timedelta(minutes=30)) token_gen_max = Column( Integer, CheckConstraint("token_gen_max > 0"), nullable=True) # Maximum number of submissions or user_tests allowed for each user # on this task during the whole contest or None to not enforce # this limitation. max_submission_number = Column( Integer, CheckConstraint("max_submission_number >= 0"), nullable=True) max_user_test_number = Column( Integer, CheckConstraint("max_user_test_number >= 0"), nullable=True) # Minimum interval between two submissions or user_tests for this # task, or None to not enforce this limitation. min_submission_interval = Column( Interval, CheckConstraint("min_submission_interval > '0 seconds'"), nullable=True) min_user_test_interval = Column( Interval, CheckConstraint("min_user_test_interval > '0 seconds'"), nullable=True) # What information users can see about the evaluations of their # submissions. Offering full information might help some users to # reverse engineer task data. feedback_level = Column( Enum(FEEDBACK_LEVEL_FULL, FEEDBACK_LEVEL_RESTRICTED, name="feedback_level"), nullable=False, default=FEEDBACK_LEVEL_RESTRICTED) # The scores for this task will be rounded to this number of # decimal places. score_precision = Column( Integer, CheckConstraint("score_precision >= 0"), nullable=False, default=0) # Score mode for the task. score_mode = Column( Enum(SCORE_MODE_MAX_TOKENED_LAST, SCORE_MODE_MAX, SCORE_MODE_MAX_SUBTASK, name="score_mode"), nullable=False, default=SCORE_MODE_MAX_TOKENED_LAST) # Active Dataset (id and object) currently being used for scoring. # The ForeignKeyConstraint for this column is set at table-level. active_dataset_id = Column( Integer, nullable=True) active_dataset = relationship( 'Dataset', foreign_keys=[active_dataset_id], # Use an UPDATE query *after* an INSERT query (and *before* a # DELETE query) to set (and unset) the column associated to # this relationship. post_update=True) # These one-to-many relationships are the reversed directions of # the ones defined in the "child" classes using foreign keys. statements = relationship( "Statement", collection_class=attribute_mapped_collection("language"), cascade="all, delete-orphan", passive_deletes=True, back_populates="task") attachments = relationship( "Attachment", collection_class=attribute_mapped_collection("filename"), cascade="all, delete-orphan", passive_deletes=True, back_populates="task") datasets = relationship( "Dataset", # Due to active_dataset_id, SQLAlchemy cannot unambiguously # figure out by itself which foreign key to use. foreign_keys="[Dataset.task_id]", cascade="all, delete-orphan", passive_deletes=True, back_populates="task") submissions = relationship( "Submission", cascade="all, delete-orphan", passive_deletes=True, back_populates="task") user_tests = relationship( "UserTest", cascade="all, delete-orphan", passive_deletes=True, back_populates="task")
class Build(db.Model): """ Represents a collection of builds for a single target, as well as the sum of their results. Each Build contains many Jobs (usually linked to a JobPlan). """ __tablename__ = 'build' __table_args__ = ( Index('idx_buildfamily_project_id', 'project_id'), Index('idx_buildfamily_author_id', 'author_id'), Index('idx_buildfamily_source_id', 'source_id'), UniqueConstraint('project_id', 'number', name='unq_build_number'), ) id = Column(GUID, primary_key=True, default=uuid.uuid4) number = Column(Integer) project_id = Column(GUID, ForeignKey('project.id', ondelete="CASCADE"), nullable=False) collection_id = Column(GUID) source_id = Column(GUID, ForeignKey('source.id', ondelete="CASCADE")) author_id = Column(GUID, ForeignKey('author.id', ondelete="CASCADE")) cause = Column(EnumType(Cause), nullable=False, default=Cause.unknown) label = Column(String(128), nullable=False) target = Column(String(128)) tags = Column(ARRAY(String(16)), nullable=True) status = Column(EnumType(Status), nullable=False, default=Status.unknown) result = Column(EnumType(Result), nullable=False, default=Result.unknown) message = Column(Text) duration = Column(Integer) priority = Column(EnumType(BuildPriority), nullable=False, default=BuildPriority.default, server_default='0') date_started = Column(DateTime) date_finished = Column(DateTime) date_created = Column(DateTime, default=datetime.utcnow) date_modified = Column(DateTime, default=datetime.utcnow) data = Column(JSONEncodedDict) project = relationship('Project', innerjoin=True) source = relationship('Source', innerjoin=True) author = relationship('Author') stats = relationship('ItemStat', primaryjoin='Build.id == ItemStat.item_id', foreign_keys=[id], uselist=True) __repr__ = model_repr('label', 'target') def __init__(self, **kwargs): super(Build, self).__init__(**kwargs) if self.id is None: self.id = uuid.uuid4() if self.result is None: self.result = Result.unknown if self.status is None: self.status = Status.unknown if self.date_created is None: self.date_created = datetime.utcnow() if self.date_modified is None: self.date_modified = self.date_created if self.date_started and self.date_finished and not self.duration: self.duration = (self.date_finished - self.date_started).total_seconds() * 1000 if self.number is None and self.project: self.number = select([func.next_item_value(self.project.id.hex)])
class TableColumn(Model, BaseColumn): """ORM object for table columns, each table can have multiple columns""" __tablename__ = "table_columns" __table_args__ = (UniqueConstraint("table_id", "column_name"), ) table_id = Column(Integer, ForeignKey("tables.id")) table = relationship( "SqlaTable", backref=backref("columns", cascade="all, delete-orphan"), foreign_keys=[table_id], ) is_dttm = Column(Boolean, default=False) expression = Column(Text) python_date_format = Column(String(255)) export_fields = [ "table_id", "column_name", "verbose_name", "is_dttm", "is_active", "type", "groupby", "filterable", "expression", "description", "python_date_format", ] update_from_object_fields = [ s for s in export_fields if s not in ("table_id", ) ] export_parent = "table" @property def is_numeric(self) -> bool: db_engine_spec = self.table.database.db_engine_spec return db_engine_spec.is_db_column_type_match( self.type, utils.DbColumnType.NUMERIC) @property def is_string(self) -> bool: db_engine_spec = self.table.database.db_engine_spec return db_engine_spec.is_db_column_type_match( self.type, utils.DbColumnType.STRING) @property def is_temporal(self) -> bool: db_engine_spec = self.table.database.db_engine_spec return db_engine_spec.is_db_column_type_match( self.type, utils.DbColumnType.TEMPORAL) def get_sqla_col(self, label: Optional[str] = None) -> Column: label = label or self.column_name if self.expression: col = literal_column(self.expression) else: db_engine_spec = self.table.database.db_engine_spec type_ = db_engine_spec.get_sqla_column_type(self.type) col = column(self.column_name, type_=type_) col = self.table.make_sqla_column_compatible(col, label) return col @property def datasource(self) -> RelationshipProperty: return self.table def get_time_filter( self, start_dttm: DateTime, end_dttm: DateTime, time_range_endpoints: Optional[Tuple[utils.TimeRangeEndpoint, utils.TimeRangeEndpoint]], ) -> ColumnElement: col = self.get_sqla_col(label="__time") l = [] if start_dttm: l.append(col >= text( self.dttm_sql_literal(start_dttm, time_range_endpoints))) if end_dttm: if (time_range_endpoints and time_range_endpoints[1] == utils.TimeRangeEndpoint.EXCLUSIVE): l.append(col < text( self.dttm_sql_literal(end_dttm, time_range_endpoints))) else: l.append(col <= text(self.dttm_sql_literal(end_dttm, None))) return and_(*l) def get_timestamp_expression( self, time_grain: Optional[str]) -> Union[TimestampExpression, Label]: """ Return a SQLAlchemy Core element representation of self to be used in a query. :param time_grain: Optional time grain, e.g. P1Y :return: A TimeExpression object wrapped in a Label if supported by db """ label = utils.DTTM_ALIAS db = self.table.database pdf = self.python_date_format is_epoch = pdf in ("epoch_s", "epoch_ms") if not self.expression and not time_grain and not is_epoch: sqla_col = column(self.column_name, type_=DateTime) return self.table.make_sqla_column_compatible(sqla_col, label) if self.expression: col = literal_column(self.expression) else: col = column(self.column_name) time_expr = db.db_engine_spec.get_timestamp_expr( col, pdf, time_grain, self.type) return self.table.make_sqla_column_compatible(time_expr, label) @classmethod def import_obj(cls, i_column): def lookup_obj(lookup_column): return (db.session.query(TableColumn).filter( TableColumn.table_id == lookup_column.table_id, TableColumn.column_name == lookup_column.column_name, ).first()) return import_datasource.import_simple_obj(db.session, i_column, lookup_obj) def dttm_sql_literal( self, dttm: DateTime, time_range_endpoints: Optional[Tuple[utils.TimeRangeEndpoint, utils.TimeRangeEndpoint]], ) -> str: """Convert datetime object to a SQL expression string""" sql = (self.table.database.db_engine_spec.convert_dttm( self.type, dttm) if self.type else None) if sql: return sql tf = self.python_date_format # Fallback to the default format (if defined) only if the SIP-15 time range # endpoints, i.e., [start, end) are enabled. if not tf and time_range_endpoints == ( utils.TimeRangeEndpoint.INCLUSIVE, utils.TimeRangeEndpoint.EXCLUSIVE, ): tf = (self.table.database.get_extra().get( "python_date_format_by_column_name", {}).get(self.column_name)) if tf: if tf in ["epoch_ms", "epoch_s"]: seconds_since_epoch = int(dttm.timestamp()) if tf == "epoch_s": return str(seconds_since_epoch) return str(seconds_since_epoch * 1000) return f"'{dttm.strftime(tf)}'" # TODO(john-bodley): SIP-15 will explicitly require a type conversion. return f"""'{dttm.strftime("%Y-%m-%d %H:%M:%S.%f")}'"""
class Task(Base): """Class to store a task. Not to be used directly (import it from SQLAlchemyAll). """ __tablename__ = 'tasks' __table_args__ = ( UniqueConstraint('contest_id', 'num', name='cst_task_contest_id_num'), UniqueConstraint('contest_id', 'name', name='cst_task_contest_id_name'), CheckConstraint("token_initial <= token_max"), ) # Auto increment primary key. id = Column(Integer, primary_key=True) # Number of the task for sorting. num = Column(Integer, nullable=False) # Contest (id and object) owning the task. contest_id = Column(Integer, ForeignKey(Contest.id, onupdate="CASCADE", ondelete="CASCADE"), nullable=False, index=True) contest = relationship(Contest, backref=backref( 'tasks', collection_class=ordering_list('num'), order_by=[num], cascade="all, delete-orphan", passive_deletes=True)) # Short name and long human readable title of the task. name = Column(String, nullable=False) title = Column(String, nullable=False) # A JSON-encoded lists of strings: the language codes of the # statments that will be highlighted to all users for this task. primary_statements = Column(String, nullable=False) # Time and memory limits for every testcase. time_limit = Column(Float, nullable=True) memory_limit = Column(Integer, nullable=True) # Name of the TaskType child class suited for the task. task_type = Column(String, nullable=False) # Parameters for the task type class, JSON encoded. task_type_parameters = Column(String, nullable=False) # Name of the ScoreType child class suited for the task. score_type = Column(String, nullable=False) # Parameters for the scorer class, JSON encoded. score_parameters = Column(String, nullable=False) # Parameter to define the token behaviour. See Contest.py for # details. The only change is that these parameters influence the # contest in a task-per-task behaviour. To play a token on a given # task, a user must satisfy the condition of the contest and the # one of the task. token_initial = Column(Integer, CheckConstraint("token_initial >= 0"), nullable=True) token_max = Column(Integer, CheckConstraint("token_max > 0"), nullable=True) token_total = Column(Integer, CheckConstraint("token_total > 0"), nullable=True) token_min_interval = Column( Interval, CheckConstraint("token_min_interval >= '0 seconds'"), nullable=False) token_gen_time = Column(Interval, CheckConstraint("token_gen_time >= '0 seconds'"), nullable=False) token_gen_number = Column(Integer, CheckConstraint("token_gen_number >= 0"), nullable=False) # Maximum number of submissions or user_tests allowed for each user # on this task during the whole contest or None to not enforce # this limitation. max_submission_number = Column( Integer, CheckConstraint("max_submission_number > 0"), nullable=True) max_user_test_number = Column(Integer, CheckConstraint("max_user_test_number > 0"), nullable=True) # Minimum interval between two submissions or user_tests for this # task, or None to not enforce this limitation. min_submission_interval = Column( Interval, CheckConstraint("min_submission_interval > '0 seconds'"), nullable=True) min_user_test_interval = Column( Interval, CheckConstraint("min_user_test_interval > '0 seconds'"), nullable=True) # Follows the description of the fields automatically added by # SQLAlchemy. # submission_format (list of SubmissionFormatElement objects) # testcases (list of Testcase objects) # attachments (dict of Attachment objects indexed by filename) # managers (dict of Manager objects indexed by filename) # statements (dict of Statement objects indexed by language code) # submissions (list of Submission objects) # user_tests (list of UserTest objects) # This object (independent from SQLAlchemy) is the instance of the # ScoreType class with the given parameters, taking care of # building the scores of the submissions. scorer = None def __init__(self, name, title, statements, attachments, time_limit, memory_limit, primary_statements, task_type, task_type_parameters, submission_format, managers, score_type, score_parameters, testcases, token_initial=None, token_max=None, token_total=None, token_min_interval=timedelta(), token_gen_time=timedelta(), token_gen_number=0, max_submission_number=None, max_user_test_number=None, min_submission_interval=None, min_user_test_interval=None, contest=None, num=0): for filename, attachment in attachments.iteritems(): attachment.filename = filename for filename, manager in managers.iteritems(): manager.filename = filename for language, statement in statements.iteritems(): statement.language = language self.num = num self.name = name self.title = title self.statements = statements self.attachments = attachments self.time_limit = time_limit self.memory_limit = memory_limit self.primary_statements = primary_statements \ if primary_statements is not None else "[]" self.task_type = task_type self.task_type_parameters = task_type_parameters self.submission_format = submission_format self.managers = managers self.score_type = score_type self.score_parameters = score_parameters self.testcases = testcases self.token_initial = token_initial self.token_max = token_max self.token_total = token_total self.token_min_interval = token_min_interval self.token_gen_time = token_gen_time self.token_gen_number = token_gen_number self.max_submission_number = max_submission_number self.max_user_test_number = max_user_test_number self.min_submission_interval = min_submission_interval self.min_user_test_interval = min_user_test_interval self.contest = contest def export_to_dict(self): """Return object data as a dictionary. """ return { 'name': self.name, 'title': self.title, 'num': self.num, 'statements': [ statement.export_to_dict() for statement in self.statements.itervalues() ], 'attachments': [ attachment.export_to_dict() for attachment in self.attachments.itervalues() ], 'time_limit': self.time_limit, 'memory_limit': self.memory_limit, 'primary_statements': self.primary_statements, 'task_type': self.task_type, 'task_type_parameters': self.task_type_parameters, 'submission_format': [element.export_to_dict() for element in self.submission_format], 'managers': [ manager.export_to_dict() for manager in self.managers.itervalues() ], 'score_type': self.score_type, 'score_parameters': self.score_parameters, 'testcases': [testcase.export_to_dict() for testcase in self.testcases], 'token_initial': self.token_initial, 'token_max': self.token_max, 'token_total': self.token_total, 'token_min_interval': self.token_min_interval.total_seconds(), 'token_gen_time': self.token_gen_time.total_seconds(), 'token_gen_number': self.token_gen_number, 'max_submission_number': self.max_submission_number, 'max_user_test_number': self.max_user_test_number, 'min_submission_interval': self.min_submission_interval.total_seconds() if self.min_submission_interval is not None else None, 'min_user_test_interval': self.min_user_test_interval.total_seconds() if self.min_user_test_interval is not None else None, } @classmethod def import_from_dict(cls, data): """Build the object using data from a dictionary. """ data['attachments'] = [ Attachment.import_from_dict(attch_data) for attch_data in data['attachments'] ] data['attachments'] = dict([(attachment.filename, attachment) for attachment in data['attachments']]) data['submission_format'] = [ SubmissionFormatElement.import_from_dict(sfe_data) for sfe_data in data['submission_format'] ] data['managers'] = [ Manager.import_from_dict(manager_data) for manager_data in data['managers'] ] data['managers'] = dict([(manager.filename, manager) for manager in data['managers']]) data['testcases'] = [ Testcase.import_from_dict(testcase_data) for testcase_data in data['testcases'] ] data['statements'] = [ Statement.import_from_dict(statement_data) for statement_data in data['statements'] ] data['statements'] = dict([(statement.language, statement) for statement in data['statements']]) if 'token_min_interval' in data: data['token_min_interval'] = \ timedelta(seconds=data['token_min_interval']) if 'token_gen_time' in data: data['token_gen_time'] = timedelta(seconds=data['token_gen_time']) if 'min_submission_interval' in data and \ data['min_submission_interval'] is not None: data['min_submission_interval'] = \ timedelta(seconds=data['min_submission_interval']) if 'min_user_test_interval' in data and \ data['min_user_test_interval'] is not None: data['min_user_test_interval'] = \ timedelta(seconds=data['min_user_test_interval']) return cls(**data)
class Participant(db.Model): __tablename__ = "participants" __table_args__ = (UniqueConstraint( "session_token", name="participants_session_token_key"), ) id = Column(BigInteger, Sequence('participants_id_seq'), nullable=False, unique=True) username = Column(Text, nullable=False, primary_key=True) username_lower = Column(Text, nullable=False, unique=True) statement = Column(Text, default="", nullable=False) stripe_customer_id = Column(Text) last_bill_result = Column(Text) session_token = Column(Text) session_expires = Column(TIMESTAMP(timezone=True), default="now()") ctime = Column(TIMESTAMP(timezone=True), nullable=False, default="now()") claimed_time = Column(TIMESTAMP(timezone=True)) is_admin = Column(Boolean, nullable=False, default=False) balance = Column(Numeric(precision=35, scale=2), CheckConstraint("balance >= 0", name="min_balance"), default=0.0, nullable=False) pending = Column(Numeric(precision=35, scale=2), default=None) anonymous = Column(Boolean, default=False, nullable=False) goal = Column(Numeric(precision=35, scale=2), default=None) balanced_account_uri = Column(Text) last_ach_result = Column(Text) is_suspicious = Column(Boolean) type = Column(Enum('individual', 'group', 'open group', nullable=False)) ### Relations ### accounts_elsewhere = relationship("Elsewhere", backref="participant_orm", lazy="dynamic") exchanges = relationship("Exchange", backref="participant_orm") # TODO: Once tippee/tipper are renamed to tippee_id/tipper_idd, we can go # ahead and drop the foreign_keys & rename backrefs to tipper/tippee _tips_giving = relationship("Tip", backref="tipper_participant", foreign_keys="Tip.tipper", lazy="dynamic") _tips_receiving = relationship("Tip", backref="tippee_participant", foreign_keys="Tip.tippee", lazy="dynamic") transferer = relationship("Transfer", backref="transferer", foreign_keys="Transfer.tipper") transferee = relationship("Transfer", backref="transferee", foreign_keys="Transfer.tippee") def __eq__(self, other): return self.id == other.id def __ne__(self, other): return self.id != other.id # Class-specific exceptions class ProblemChangingUsername(Exception): pass class UsernameTooLong(ProblemChangingUsername): pass class UsernameContainsInvalidCharacters(ProblemChangingUsername): pass class UsernameIsRestricted(ProblemChangingUsername): pass class UsernameAlreadyTaken(ProblemChangingUsername): pass class UnknownPlatform(Exception): pass @property def IS_INDIVIDUAL(self): return self.type == 'individual' @property def IS_GROUP(self): return self.type == 'group' @property def IS_OPEN_GROUP(self): return self.type == 'open group' @property def tips_giving(self): return self._tips_giving.distinct("tips.tippee")\ .order_by("tips.tippee, tips.mtime DESC") @property def tips_receiving(self): return self._tips_receiving.distinct("tips.tipper")\ .order_by("tips.tipper, tips.mtime DESC") @property def accepts_tips(self): return (self.goal is None) or (self.goal >= 0) @property def valid_tips_receiving(self): ''' SELECT count(anon_1.amount) AS count_1 FROM ( SELECT DISTINCT ON (tips.tipper) tips.id AS id , tips.ctime AS ctime , tips.mtime AS mtime , tips.tipper AS tipper , tips.tippee AS tippee , tips.amount AS amount FROM tips JOIN participants ON tips.tipper = participants.username WHERE %(param_1)s = tips.tippee AND participants.is_suspicious IS NOT true AND participants.last_bill_result = %(last_bill_result_1)s ORDER BY tips.tipper, tips.mtime DESC ) AS anon_1 WHERE anon_1.amount > %(amount_1)s ''' return self.tips_receiving \ .join( Participant , Tip.tipper.op('=')(Participant.username) ) \ .filter( 'participants.is_suspicious IS NOT true' , Participant.last_bill_result == '' ) def resolve_unclaimed(self): if self.accounts_elsewhere: return self.accounts_elsewhere[0].resolve_unclaimed() else: return None def set_as_claimed(self, claimed_at=None): if claimed_at is None: claimed_at = datetime.datetime.now(pytz.utc) self.claimed_time = claimed_at db.session.add(self) db.session.commit() def change_username(self, desired_username): """Raise self.ProblemChangingUsername, or return None. We want to be pretty loose with usernames. Unicode is allowed--XXX aspen bug :(. So are spaces. Control characters aren't. We also limit to 32 characters in length. """ for i, c in enumerate(desired_username): if i == 32: raise self.UsernameTooLong # Request Entity Too Large (more or less) elif ord(c) < 128 and c not in ASCII_ALLOWED_IN_USERNAME: raise self.UsernameContainsInvalidCharacters # Yeah, no. elif c not in ASCII_ALLOWED_IN_USERNAME: # XXX Burned by an Aspen bug. :`-( # https://github.com/gittip/aspen/issues/102 raise self.UsernameContainsInvalidCharacters lowercased = desired_username.lower() if lowercased in gittip.RESTRICTED_USERNAMES: raise self.UsernameIsRestricted if desired_username != self.username: try: self.username = desired_username self.username_lower = lowercased db.session.add(self) db.session.commit() # Will raise sqlalchemy.exc.IntegrityError if the # desired_username is taken. except IntegrityError: db.session.rollback() raise self.UsernameAlreadyTaken def get_accounts_elsewhere(self): github_account = twitter_account = bitbucket_account = bountysource_account = None for account in self.accounts_elsewhere.all(): if account.platform == "github": github_account = account elif account.platform == "twitter": twitter_account = account elif account.platform == "bitbucket": bitbucket_account = account elif account.platform == "bountysource": bountysource_account = account else: raise self.UnknownPlatform(account.platform) return (github_account, twitter_account, bitbucket_account, bountysource_account) def get_img_src(self, size=128): """Return a value for <img src="..." />. Until we have our own profile pics, delegate. XXX Is this an attack vector? Can someone inject this value? Don't think so, but if you make it happen, let me know, eh? Thanks. :) https://www.gittip.com/security.txt """ typecheck(size, int) src = '/assets/%s/avatar-default.gif' % os.environ['__VERSION__'] github, twitter, bitbucket, bountysource = self.get_accounts_elsewhere( ) if github is not None: # GitHub -> Gravatar: http://en.gravatar.com/site/implement/images/ if 'gravatar_id' in github.user_info: gravatar_hash = github.user_info['gravatar_id'] src = "https://www.gravatar.com/avatar/%s.jpg?s=%s" src %= (gravatar_hash, size) elif twitter is not None: # https://dev.twitter.com/docs/api/1/get/users/profile_image/%3Ascreen_name if 'profile_image_url_https' in twitter.user_info: src = twitter.user_info['profile_image_url_https'] # For Twitter, we don't have good control over size. We don't # want the original, cause that can be huge. The next option is # 73px(?!). src = src.replace('_normal.', '_bigger.') return src def get_tip_to(self, tippee): tip = self.tips_giving.filter_by(tippee=tippee).first() if tip: amount = tip.amount else: amount = Decimal('0.00') return amount def get_dollars_receiving(self): return sum(tip.amount for tip in self.valid_tips_receiving) + Decimal('0.00') def get_number_of_backers(self): amount_column = self.valid_tips_receiving.subquery().columns.amount count = func.count(amount_column) nbackers = db.session.query(count).filter(amount_column > 0).one()[0] return nbackers def get_og_title(self): out = self.username receiving = self.get_dollars_receiving() giving = self.get_dollars_giving() if (giving > receiving) and not self.anonymous: out += " gives $%.2f/wk" % giving elif receiving > 0: out += " receives $%.2f/wk" % receiving else: out += " is" return out + " on Gittip" def get_age_in_seconds(self): out = -1 if self.claimed_time is not None: now = datetime.datetime.now(self.claimed_time.tzinfo) out = (now - self.claimed_time).total_seconds() return out def allowed_to_vote_on(self, group): """Given a Participant object, return a boolean. """ if self.ANON: return False if self.is_suspicious: return False if group == self: return True return self.username in group.get_voters() def get_voters(self): identifications = list( gittip.db.fetchall( """ SELECT member , weight , identified_by , mtime FROM current_identifications WHERE "group"=%s AND weight > 0 ORDER BY mtime ASC """, (self.username, ))) voters = set() vote_counts = defaultdict(int) class Break(Exception): pass while identifications: try: for i in range(len(identifications)): identification = identifications[i] identified_by = identification['identified_by'] member = identification['member'] def record(): voters.add(member) identifications.pop(i) # Group participant itself has chosen. if identified_by == self.username: record() break # Group member has voted. elif identified_by in voters: vote_counts[member] += 1 target = math.ceil(len(voters) * 0.667) if vote_counts[member] == target: record() break else: # This pass through didn't find any new voters. raise Break except Break: break voters = list(voters) voters.sort() return voters def compute_split(self): if not self.IS_OPEN_GROUP: return [{"username": self.username, "weight": "1.0"}] split = [] voters = self.get_voters() identifications = gittip.db.fetchall( """ SELECT member , weight , identified_by , mtime FROM current_identifications WHERE "group"=%s AND weight > 0 ORDER BY mtime ASC """, (self.username, )) splitmap = defaultdict(int) total = 0 for row in identifications: if row['identified_by'] not in voters: continue splitmap[row['member']] += row['weight'] total += row['weight'] total = Decimal(total) for username, weight in splitmap.items(): split.append({ "username": username, "weight": Decimal(weight) / total }) split.sort(key=lambda r: r['weight'], reverse=True) return voters, split # TODO: Move these queries into this class. def set_tip_to(self, tippee, amount): return OldParticipant(self.username).set_tip_to(tippee, amount) def get_dollars_giving(self): return OldParticipant(self.username).get_dollars_giving() def get_tip_distribution(self): return OldParticipant(self.username).get_tip_distribution() def get_giving_for_profile(self, db=None): return OldParticipant(self.username).get_giving_for_profile(db) def get_tips_and_total(self, for_payday=False, db=None): return OldParticipant(self.username).get_tips_and_total(for_payday, db) def take_over(self, account_elsewhere, have_confirmation=False): OldParticipant(self.username).take_over(account_elsewhere, have_confirmation)
class TableColumn(Model, BaseColumn): """ORM object for table columns, each table can have multiple columns""" __tablename__ = 'table_columns' __table_args__ = (UniqueConstraint('table_id', 'column_name'), ) table_id = Column(Integer, ForeignKey('tables.id')) table = relationship('SqlaTable', backref=backref('columns', cascade='all, delete-orphan'), foreign_keys=[table_id]) is_dttm = Column(Boolean, default=False) expression = Column(Text, default='') python_date_format = Column(String(255)) database_expression = Column(String(255)) export_fields = ( 'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active', 'type', 'groupby', 'count_distinct', 'sum', 'avg', 'max', 'min', 'filterable', 'expression', 'description', 'python_date_format', 'database_expression', ) export_parent = 'table' @property def sqla_col(self): name = self.column_name if not self.expression: col = column(self.column_name).label(name) else: col = literal_column(self.expression).label(name) return col def get_time_filter(self, start_dttm, end_dttm): col = self.sqla_col.label('__time') l = [] # noqa: E741 if start_dttm: l.append(col >= text(self.dttm_sql_literal(start_dttm))) if end_dttm: l.append(col <= text(self.dttm_sql_literal(end_dttm))) return and_(*l) def get_timestamp_expression(self, time_grain): """Getting the time component of the query""" expr = self.expression or self.column_name if not self.expression and not time_grain: return column(expr, type_=DateTime).label(DTTM_ALIAS) if time_grain: pdf = self.python_date_format if pdf in ('epoch_s', 'epoch_ms'): # if epoch, translate to DATE using db specific conf db_spec = self.table.database.db_engine_spec if pdf == 'epoch_s': expr = db_spec.epoch_to_dttm().format(col=expr) elif pdf == 'epoch_ms': expr = db_spec.epoch_ms_to_dttm().format(col=expr) grain = self.table.database.grains_dict().get(time_grain, '{col}') expr = grain.function.format(col=expr) return literal_column(expr, type_=DateTime).label(DTTM_ALIAS) @classmethod def import_obj(cls, i_column): def lookup_obj(lookup_column): return db.session.query(TableColumn).filter( TableColumn.table_id == lookup_column.table_id, TableColumn.column_name == lookup_column.column_name).first() return import_util.import_simple_obj(db.session, i_column, lookup_obj) def dttm_sql_literal(self, dttm): """Convert datetime object to a SQL expression string If database_expression is empty, the internal dttm will be parsed as the string with the pattern that the user inputted (python_date_format) If database_expression is not empty, the internal dttm will be parsed as the sql sentence for the database to convert """ tf = self.python_date_format if self.database_expression: return self.database_expression.format( dttm.strftime('%Y-%m-%d %H:%M:%S')) elif tf: if tf == 'epoch_s': return str((dttm - datetime(1970, 1, 1)).total_seconds()) elif tf == 'epoch_ms': return str( (dttm - datetime(1970, 1, 1)).total_seconds() * 1000.0) return "'{}'".format(dttm.strftime(tf)) else: s = self.table.database.db_engine_spec.convert_dttm( self.type or '', dttm) return s or "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S.%f'))
class SqlaTable(Model, BaseDatasource): """An ORM object for SqlAlchemy table references""" type = 'table' query_language = 'sql' metric_class = SqlMetric column_class = TableColumn __tablename__ = 'tables' __table_args__ = (UniqueConstraint('database_id', 'table_name'), ) table_name = Column(String(250)) main_dttm_col = Column(String(250)) database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False) fetch_values_predicate = Column(String(1000)) user_id = Column(Integer, ForeignKey('ab_user.id')) owner = relationship(sm.user_model, backref='tables', foreign_keys=[user_id]) database = relationship('Database', backref=backref('tables', cascade='all, delete-orphan'), foreign_keys=[database_id]) schema = Column(String(255)) sql = Column(Text) baselink = 'tablemodelview' export_fields = ('table_name', 'main_dttm_col', 'description', 'default_endpoint', 'database_id', 'offset', 'cache_timeout', 'schema', 'sql', 'params') export_parent = 'database' export_children = ['metrics', 'columns'] def __repr__(self): return self.name @property def connection(self): return str(self.database) @property def description_markeddown(self): return utils.markdown(self.description) @property def link(self): name = escape(self.name) return Markup( '<a href="{self.explore_url}">{name}</a>'.format(**locals())) @property def schema_perm(self): """Returns schema permission if present, database one otherwise.""" return utils.get_schema_perm(self.database, self.schema) def get_perm(self): return ('[{obj.database}].[{obj.table_name}]' '(id:{obj.id})').format(obj=self) @property def name(self): if not self.schema: return self.table_name return '{}.{}'.format(self.schema, self.table_name) @property def full_name(self): return utils.get_datasource_full_name(self.database, self.table_name, schema=self.schema) @property def dttm_cols(self): l = [c.column_name for c in self.columns if c.is_dttm] # noqa: E741 if self.main_dttm_col and self.main_dttm_col not in l: l.append(self.main_dttm_col) return l @property def num_cols(self): return [c.column_name for c in self.columns if c.is_num] @property def any_dttm_col(self): cols = self.dttm_cols if cols: return cols[0] @property def html(self): t = ((c.column_name, c.type) for c in self.columns) df = pd.DataFrame(t) df.columns = ['field', 'type'] return df.to_html( index=False, classes=('dataframe table table-striped table-bordered ' 'table-condensed')) @property def sql_url(self): return self.database.sql_url + '?table_name=' + str(self.table_name) @property def time_column_grains(self): return { 'time_columns': self.dttm_cols, 'time_grains': [grain.name for grain in self.database.grains()], } def get_col(self, col_name): columns = self.columns for col in columns: if col_name == col.column_name: return col @property def data(self): d = super(SqlaTable, self).data if self.type == 'table': grains = self.database.grains() or [] if grains: grains = [(g.name, g.name) for g in grains] d['granularity_sqla'] = utils.choicify(self.dttm_cols) d['time_grain_sqla'] = grains return d def values_for_column(self, column_name, limit=10000): """Runs query against sqla to retrieve some sample values for the given column. """ cols = {col.column_name: col for col in self.columns} target_col = cols[column_name] tp = self.get_template_processor() db_engine_spec = self.database.db_engine_spec qry = (select([target_col.sqla_col]).select_from( self.get_from_clause(tp, db_engine_spec)).distinct(column_name)) if limit: qry = qry.limit(limit) if self.fetch_values_predicate: tp = self.get_template_processor() qry = qry.where(tp.process_template(self.fetch_values_predicate)) engine = self.database.get_sqla_engine() sql = '{}'.format( qry.compile( engine, compile_kwargs={'literal_binds': True}, ), ) df = pd.read_sql_query(sql=sql, con=engine) return [row[0] for row in df.to_records(index=False)] def get_template_processor(self, **kwargs): return get_template_processor(table=self, database=self.database, **kwargs) def get_query_str(self, query_obj): engine = self.database.get_sqla_engine() qry = self.get_sqla_query(**query_obj) sql = six.text_type( qry.compile( engine, compile_kwargs={'literal_binds': True}, ), ) logging.info(sql) sql = sqlparse.format(sql, reindent=True) return sql def get_sqla_table(self): tbl = table(self.table_name) if self.schema: tbl.schema = self.schema return tbl def get_from_clause(self, template_processor=None, db_engine_spec=None): # Supporting arbitrary SQL statements in place of tables if self.sql: from_sql = self.sql if template_processor: from_sql = template_processor.process_template(from_sql) if db_engine_spec: from_sql = db_engine_spec.escape_sql(from_sql) return TextAsFrom(sa.text(from_sql), []).alias('expr_qry') return self.get_sqla_table() def get_sqla_query( # sqla self, groupby, metrics, granularity, from_dttm, to_dttm, filter=None, # noqa is_timeseries=True, timeseries_limit=15, timeseries_limit_metric=None, row_limit=None, inner_from_dttm=None, inner_to_dttm=None, orderby=None, extras=None, columns=None, form_data=None, order_desc=True): """Querying any sqla table from this common interface""" template_kwargs = { 'from_dttm': from_dttm, 'groupby': groupby, 'metrics': metrics, 'row_limit': row_limit, 'to_dttm': to_dttm, 'form_data': form_data, } template_processor = self.get_template_processor(**template_kwargs) db_engine_spec = self.database.db_engine_spec orderby = orderby or [] # For backward compatibility if granularity not in self.dttm_cols: granularity = self.main_dttm_col # Database spec supports join-free timeslot grouping time_groupby_inline = db_engine_spec.time_groupby_inline cols = {col.column_name: col for col in self.columns} metrics_dict = {m.metric_name: m for m in self.metrics} if not granularity and is_timeseries: raise Exception( _('Datetime column not provided as part table configuration ' 'and is required by this type of chart')) if not groupby and not metrics and not columns: raise Exception(_('Empty query?')) for m in metrics: if m not in metrics_dict: raise Exception(_("Metric '{}' is not valid".format(m))) metrics_exprs = [metrics_dict.get(m).sqla_col for m in metrics] if metrics_exprs: main_metric_expr = metrics_exprs[0] else: main_metric_expr = literal_column('COUNT(*)').label('ccount') select_exprs = [] groupby_exprs = [] if groupby: select_exprs = [] inner_select_exprs = [] inner_groupby_exprs = [] for s in groupby: col = cols[s] outer = col.sqla_col inner = col.sqla_col.label(col.column_name + '__') groupby_exprs.append(outer) select_exprs.append(outer) inner_groupby_exprs.append(inner) inner_select_exprs.append(inner) elif columns: for s in columns: select_exprs.append(cols[s].sqla_col) metrics_exprs = [] if granularity: dttm_col = cols[granularity] time_grain = extras.get('time_grain_sqla') time_filters = [] if is_timeseries: timestamp = dttm_col.get_timestamp_expression(time_grain) select_exprs += [timestamp] groupby_exprs += [timestamp] # Use main dttm column to support index with secondary dttm columns if db_engine_spec.time_secondary_columns and \ self.main_dttm_col in self.dttm_cols and \ self.main_dttm_col != dttm_col.column_name: time_filters.append(cols[self.main_dttm_col].get_time_filter( from_dttm, to_dttm)) time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm)) select_exprs += metrics_exprs qry = sa.select(select_exprs) tbl = self.get_from_clause(template_processor, db_engine_spec) if not columns: qry = qry.group_by(*groupby_exprs) where_clause_and = [] having_clause_and = [] for flt in filter: if not all([flt.get(s) for s in ['col', 'op', 'val']]): continue col = flt['col'] op = flt['op'] eq = flt['val'] col_obj = cols.get(col) if col_obj: if op in ('in', 'not in'): values = [] for v in eq: # For backwards compatibility and edge cases # where a column data type might have changed if isinstance(v, basestring): v = v.strip("'").strip('"') if col_obj.is_num: v = utils.string_to_num(v) # Removing empty strings and non numeric values # targeting numeric columns if v is not None: values.append(v) cond = col_obj.sqla_col.in_(values) if op == 'not in': cond = ~cond where_clause_and.append(cond) else: if col_obj.is_num: eq = utils.string_to_num(flt['val']) if op == '==': where_clause_and.append(col_obj.sqla_col == eq) elif op == '!=': where_clause_and.append(col_obj.sqla_col != eq) elif op == '>': where_clause_and.append(col_obj.sqla_col > eq) elif op == '<': where_clause_and.append(col_obj.sqla_col < eq) elif op == '>=': where_clause_and.append(col_obj.sqla_col >= eq) elif op == '<=': where_clause_and.append(col_obj.sqla_col <= eq) elif op == 'LIKE': where_clause_and.append(col_obj.sqla_col.like(eq)) if extras: where = extras.get('where') if where: where = template_processor.process_template(where) where_clause_and += [sa.text('({})'.format(where))] having = extras.get('having') if having: having = template_processor.process_template(having) having_clause_and += [sa.text('({})'.format(having))] if granularity: qry = qry.where(and_(*(time_filters + where_clause_and))) else: qry = qry.where(and_(*where_clause_and)) qry = qry.having(and_(*having_clause_and)) if not orderby and not columns: orderby = [(main_metric_expr, not order_desc)] for col, ascending in orderby: direction = asc if ascending else desc qry = qry.order_by(direction(col)) if row_limit: qry = qry.limit(row_limit) if is_timeseries and \ timeseries_limit and groupby and not time_groupby_inline: # some sql dialects require for order by expressions # to also be in the select clause -- others, e.g. vertica, # require a unique inner alias inner_main_metric_expr = main_metric_expr.label('mme_inner__') inner_select_exprs += [inner_main_metric_expr] subq = select(inner_select_exprs) subq = subq.select_from(tbl) inner_time_filter = dttm_col.get_time_filter( inner_from_dttm or from_dttm, inner_to_dttm or to_dttm, ) subq = subq.where(and_(*(where_clause_and + [inner_time_filter]))) subq = subq.group_by(*inner_groupby_exprs) ob = inner_main_metric_expr if timeseries_limit_metric: timeseries_limit_metric = metrics_dict.get( timeseries_limit_metric) ob = timeseries_limit_metric.sqla_col direction = desc if order_desc else asc subq = subq.order_by(direction(ob)) subq = subq.limit(timeseries_limit) on_clause = [] for i, gb in enumerate(groupby): on_clause.append(groupby_exprs[i] == column(gb + '__')) tbl = tbl.join(subq.alias(), and_(*on_clause)) return qry.select_from(tbl) def query(self, query_obj): qry_start_dttm = datetime.now() sql = self.get_query_str(query_obj) status = QueryStatus.SUCCESS error_message = None df = None try: df = self.database.get_df(sql, self.schema) except Exception as e: status = QueryStatus.FAILED logging.exception(e) error_message = ( self.database.db_engine_spec.extract_error_message(e)) return QueryResult(status=status, df=df, duration=datetime.now() - qry_start_dttm, query=sql, error_message=error_message) def get_sqla_table_object(self): return self.database.get_table(self.table_name, schema=self.schema) def fetch_metadata(self): """Fetches the metadata for the table and merges it in""" try: table = self.get_sqla_table_object() except Exception: raise Exception( _("Table [{}] doesn't seem to exist in the specified database, " "couldn't fetch column information").format(self.table_name)) M = SqlMetric # noqa metrics = [] any_date_col = None db_dialect = self.database.get_dialect() dbcols = (db.session.query(TableColumn).filter( TableColumn.table == self).filter( or_(TableColumn.column_name == col.name for col in table.columns))) dbcols = {dbcol.column_name: dbcol for dbcol in dbcols} for col in table.columns: try: datatype = col.type.compile(dialect=db_dialect).upper() except Exception as e: datatype = 'UNKNOWN' logging.error('Unrecognized data type in {}.{}'.format( table, col.name)) logging.exception(e) dbcol = dbcols.get(col.name, None) if not dbcol: dbcol = TableColumn(column_name=col.name, type=datatype) dbcol.groupby = dbcol.is_string dbcol.filterable = dbcol.is_string dbcol.sum = dbcol.is_num dbcol.avg = dbcol.is_num dbcol.is_dttm = dbcol.is_time self.columns.append(dbcol) if not any_date_col and dbcol.is_time: any_date_col = col.name quoted = str(col.compile(dialect=db_dialect)) if dbcol.sum: metrics.append( M( metric_name='sum__' + dbcol.column_name, verbose_name='sum__' + dbcol.column_name, metric_type='sum', expression='SUM({})'.format(quoted), )) if dbcol.avg: metrics.append( M( metric_name='avg__' + dbcol.column_name, verbose_name='avg__' + dbcol.column_name, metric_type='avg', expression='AVG({})'.format(quoted), )) if dbcol.max: metrics.append( M( metric_name='max__' + dbcol.column_name, verbose_name='max__' + dbcol.column_name, metric_type='max', expression='MAX({})'.format(quoted), )) if dbcol.min: metrics.append( M( metric_name='min__' + dbcol.column_name, verbose_name='min__' + dbcol.column_name, metric_type='min', expression='MIN({})'.format(quoted), )) if dbcol.count_distinct: metrics.append( M( metric_name='count_distinct__' + dbcol.column_name, verbose_name='count_distinct__' + dbcol.column_name, metric_type='count_distinct', expression='COUNT(DISTINCT {})'.format(quoted), )) dbcol.type = datatype metrics.append( M( metric_name='count', verbose_name='COUNT(*)', metric_type='count', expression='COUNT(*)', )) dbmetrics = db.session.query(M).filter(M.table_id == self.id).filter( or_(M.metric_name == metric.metric_name for metric in metrics)) dbmetrics = {metric.metric_name: metric for metric in dbmetrics} for metric in metrics: metric.table_id = self.id if not dbmetrics.get(metric.metric_name, None): db.session.add(metric) if not self.main_dttm_col: self.main_dttm_col = any_date_col db.session.merge(self) db.session.commit() @classmethod def import_obj(cls, i_datasource, import_time=None): """Imports the datasource from the object to the database. Metrics and columns and datasource will be overrided if exists. This function can be used to import/export dashboards between multiple superset instances. Audit metadata isn't copies over. """ def lookup_sqlatable(table): return db.session.query(SqlaTable).join(Database).filter( SqlaTable.table_name == table.table_name, SqlaTable.schema == table.schema, Database.id == table.database_id, ).first() def lookup_database(table): return db.session.query(Database).filter_by( database_name=table.params_dict['database_name']).one() return import_util.import_datasource(db.session, i_datasource, lookup_database, lookup_sqlatable, import_time) @classmethod def query_datasources_by_name(cls, session, database, datasource_name, schema=None): query = (session.query(cls).filter_by( database_id=database.id).filter_by(table_name=datasource_name)) if schema: query = query.filter_by(schema=schema) return query.all()
class Database(Model, AuditMixinNullable, ImportMixin): """An ORM object that stores Database related information""" __tablename__ = "dbs" type = "table" __table_args__ = (UniqueConstraint("database_name"),) id = Column(Integer, primary_key=True) verbose_name = Column(String(250), unique=True) # short unique name, used in permissions database_name = Column(String(250), unique=True, nullable=False) sqlalchemy_uri = Column(String(1024)) password = Column(EncryptedType(String(1024), config.get("SECRET_KEY"))) cache_timeout = Column(Integer) select_as_create_table_as = Column(Boolean, default=False) expose_in_sqllab = Column(Boolean, default=True) allow_run_async = Column(Boolean, default=False) allow_csv_upload = Column(Boolean, default=False) allow_ctas = Column(Boolean, default=False) allow_dml = Column(Boolean, default=False) force_ctas_schema = Column(String(250)) allow_multi_schema_metadata_fetch = Column(Boolean, default=False) extra = Column( Text, default=textwrap.dedent( """\ { "metadata_params": {}, "engine_params": {}, "metadata_cache_timeout": {}, "schemas_allowed_for_csv_upload": [] } """ ), ) perm = Column(String(1000)) impersonate_user = Column(Boolean, default=False) export_fields = ( "database_name", "sqlalchemy_uri", "cache_timeout", "expose_in_sqllab", "allow_run_async", "allow_ctas", "allow_csv_upload", "extra", ) export_children = ["tables"] def __repr__(self): return self.name @property def name(self): return self.verbose_name if self.verbose_name else self.database_name @property def allows_subquery(self): return self.db_engine_spec.allows_subqueries @property def allows_cost_estimate(self) -> bool: extra = self.get_extra() database_version = extra.get("version") cost_estimate_enabled = extra.get("cost_estimate_enabled") return ( self.db_engine_spec.get_allow_cost_estimate(database_version) and cost_estimate_enabled ) @property def data(self): return { "id": self.id, "name": self.database_name, "backend": self.backend, "allow_multi_schema_metadata_fetch": self.allow_multi_schema_metadata_fetch, "allows_subquery": self.allows_subquery, "allows_cost_estimate": self.allows_cost_estimate, } @property def unique_name(self): return self.database_name @property def url_object(self): return make_url(self.sqlalchemy_uri_decrypted) @property def backend(self): url = make_url(self.sqlalchemy_uri_decrypted) return url.get_backend_name() @property def metadata_cache_timeout(self): return self.get_extra().get("metadata_cache_timeout", {}) @property def schema_cache_enabled(self): return "schema_cache_timeout" in self.metadata_cache_timeout @property def schema_cache_timeout(self): return self.metadata_cache_timeout.get("schema_cache_timeout") @property def table_cache_enabled(self): return "table_cache_timeout" in self.metadata_cache_timeout @property def table_cache_timeout(self): return self.metadata_cache_timeout.get("table_cache_timeout") @property def default_schemas(self): return self.get_extra().get("default_schemas", []) @classmethod def get_password_masked_url_from_uri(cls, uri): url = make_url(uri) return cls.get_password_masked_url(url) @classmethod def get_password_masked_url(cls, url): url_copy = deepcopy(url) if url_copy.password is not None and url_copy.password != PASSWORD_MASK: url_copy.password = PASSWORD_MASK return url_copy def set_sqlalchemy_uri(self, uri): conn = sqla.engine.url.make_url(uri.strip()) if conn.password != PASSWORD_MASK and not custom_password_store: # do not over-write the password with the password mask self.password = conn.password conn.password = PASSWORD_MASK if conn.password else None self.sqlalchemy_uri = str(conn) # hides the password def get_effective_user(self, url, user_name=None): """ Get the effective user, especially during impersonation. :param url: SQL Alchemy URL object :param user_name: Default username :return: The effective username """ effective_username = None if self.impersonate_user: effective_username = url.username if user_name: effective_username = user_name elif ( hasattr(g, "user") and hasattr(g.user, "username") and g.user.username is not None ): effective_username = g.user.username return effective_username @utils.memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted", "extra")) def get_sqla_engine(self, schema=None, nullpool=True, user_name=None, source=None): extra = self.get_extra() url = make_url(self.sqlalchemy_uri_decrypted) url = self.db_engine_spec.adjust_database_uri(url, schema) effective_username = self.get_effective_user(url, user_name) # If using MySQL or Presto for example, will set url.username # If using Hive, will not do anything yet since that relies on a # configuration parameter instead. self.db_engine_spec.modify_url_for_impersonation( url, self.impersonate_user, effective_username ) masked_url = self.get_password_masked_url(url) logging.info("Database.get_sqla_engine(). Masked URL: {0}".format(masked_url)) params = extra.get("engine_params", {}) if nullpool: params["poolclass"] = NullPool # If using Hive, this will set hive.server2.proxy.user=$effective_username configuration = {} configuration.update( self.db_engine_spec.get_configuration_for_impersonation( str(url), self.impersonate_user, effective_username ) ) if configuration: d = params.get("connect_args", {}) d["configuration"] = configuration params["connect_args"] = d DB_CONNECTION_MUTATOR = config.get("DB_CONNECTION_MUTATOR") if DB_CONNECTION_MUTATOR: url, params = DB_CONNECTION_MUTATOR( url, params, effective_username, security_manager, source ) return create_engine(url, **params) def get_reserved_words(self): return self.get_dialect().preparer.reserved_words def get_quoter(self): return self.get_dialect().identifier_preparer.quote def get_df(self, sql, schema, mutator=None): sqls = [str(s).strip().strip(";") for s in sqlparse.parse(sql)] source_key = None if request and request.referrer: if "/superset/dashboard/" in request.referrer: source_key = "dashboard" elif "/superset/explore/" in request.referrer: source_key = "chart" engine = self.get_sqla_engine( schema=schema, source=utils.sources.get(source_key, None) ) username = utils.get_username() def needs_conversion(df_series): if df_series.empty: return False if isinstance(df_series[0], (list, dict)): return True return False def _log_query(sql): if log_query: log_query(engine.url, sql, schema, username, __name__, security_manager) with closing(engine.raw_connection()) as conn: with closing(conn.cursor()) as cursor: for sql in sqls[:-1]: _log_query(sql) self.db_engine_spec.execute(cursor, sql) cursor.fetchall() _log_query(sqls[-1]) self.db_engine_spec.execute(cursor, sqls[-1]) if cursor.description is not None: columns = [col_desc[0] for col_desc in cursor.description] else: columns = [] df = pd.DataFrame.from_records( data=list(cursor.fetchall()), columns=columns, coerce_float=True ) if mutator: df = mutator(df) for k, v in df.dtypes.items(): if v.type == numpy.object_ and needs_conversion(df[k]): df[k] = df[k].apply(utils.json_dumps_w_dates) return df def compile_sqla_query(self, qry, schema=None): engine = self.get_sqla_engine(schema=schema) sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) if engine.dialect.identifier_preparer._double_percents: sql = sql.replace("%%", "%") return sql def select_star( self, table_name, schema=None, limit=100, show_cols=False, indent=True, latest_partition=False, cols=None, ): """Generates a ``select *`` statement in the proper dialect""" eng = self.get_sqla_engine( schema=schema, source=utils.sources.get("sql_lab", None) ) return self.db_engine_spec.select_star( self, table_name, schema=schema, engine=eng, limit=limit, show_cols=show_cols, indent=indent, latest_partition=latest_partition, cols=cols, ) def apply_limit_to_sql(self, sql, limit=1000): return self.db_engine_spec.apply_limit_to_sql(sql, limit, self) def safe_sqlalchemy_uri(self): return self.sqlalchemy_uri @property def inspector(self): engine = self.get_sqla_engine() return sqla.inspect(engine) @cache_util.memoized_func( key=lambda *args, **kwargs: "db:{}:schema:None:table_list", attribute_in_key="id", ) def get_all_table_names_in_database( self, cache: bool = False, cache_timeout: bool = None, force=False ) -> List[utils.DatasourceName]: """Parameters need to be passed as keyword arguments.""" if not self.allow_multi_schema_metadata_fetch: return [] return self.db_engine_spec.get_all_datasource_names(self, "table") @cache_util.memoized_func( key=lambda *args, **kwargs: "db:{}:schema:None:view_list", attribute_in_key="id" ) def get_all_view_names_in_database( self, cache: bool = False, cache_timeout: bool = None, force: bool = False ) -> List[utils.DatasourceName]: """Parameters need to be passed as keyword arguments.""" if not self.allow_multi_schema_metadata_fetch: return [] return self.db_engine_spec.get_all_datasource_names(self, "view") @cache_util.memoized_func( key=lambda *args, **kwargs: "db:{{}}:schema:{}:table_list".format( kwargs.get("schema") ), attribute_in_key="id", ) def get_all_table_names_in_schema( self, schema: str, cache: bool = False, cache_timeout: int = None, force: bool = False, ): """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in cache_util.memoized_func decorator. :param schema: schema name :param cache: whether cache is enabled for the function :param cache_timeout: timeout in seconds for the cache :param force: whether to force refresh the cache :return: list of tables """ try: tables = self.db_engine_spec.get_table_names( database=self, inspector=self.inspector, schema=schema ) return [ utils.DatasourceName(table=table, schema=schema) for table in tables ] except Exception as e: logging.exception(e) @cache_util.memoized_func( key=lambda *args, **kwargs: "db:{{}}:schema:{}:view_list".format( kwargs.get("schema") ), attribute_in_key="id", ) def get_all_view_names_in_schema( self, schema: str, cache: bool = False, cache_timeout: int = None, force: bool = False, ): """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in cache_util.memoized_func decorator. :param schema: schema name :param cache: whether cache is enabled for the function :param cache_timeout: timeout in seconds for the cache :param force: whether to force refresh the cache :return: list of views """ try: views = self.db_engine_spec.get_view_names( database=self, inspector=self.inspector, schema=schema ) return [utils.DatasourceName(table=view, schema=schema) for view in views] except Exception as e: logging.exception(e) @cache_util.memoized_func( key=lambda *args, **kwargs: "db:{}:schema_list", attribute_in_key="id" ) def get_all_schema_names( self, cache: bool = False, cache_timeout: int = None, force: bool = False ) -> List[str]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in cache_util.memoized_func decorator. :param cache: whether cache is enabled for the function :param cache_timeout: timeout in seconds for the cache :param force: whether to force refresh the cache :return: schema list """ return self.db_engine_spec.get_schema_names(self.inspector) @property def db_engine_spec(self): return db_engine_specs.engines.get(self.backend, db_engine_specs.BaseEngineSpec) @classmethod def get_db_engine_spec_for_backend(cls, backend): return db_engine_specs.engines.get(backend, db_engine_specs.BaseEngineSpec) def grains(self): """Defines time granularity database-specific expressions. The idea here is to make it easy for users to change the time grain from a datetime (maybe the source grain is arbitrary timestamps, daily or 5 minutes increments) to another, "truncated" datetime. Since each database has slightly different but similar datetime functions, this allows a mapping between database engines and actual functions. """ return self.db_engine_spec.get_time_grains() def get_extra(self): extra = {} if self.extra: try: extra = json.loads(self.extra) except Exception as e: logging.error(e) raise e return extra def get_table(self, table_name, schema=None): extra = self.get_extra() meta = MetaData(**extra.get("metadata_params", {})) return Table( table_name, meta, schema=schema or None, autoload=True, autoload_with=self.get_sqla_engine(), ) def get_columns(self, table_name, schema=None): return self.db_engine_spec.get_columns(self.inspector, table_name, schema) def get_indexes(self, table_name, schema=None): return self.inspector.get_indexes(table_name, schema) def get_pk_constraint(self, table_name, schema=None): return self.inspector.get_pk_constraint(table_name, schema) def get_foreign_keys(self, table_name, schema=None): return self.inspector.get_foreign_keys(table_name, schema) def get_schema_access_for_csv_upload(self): return self.get_extra().get("schemas_allowed_for_csv_upload", []) @property def sqlalchemy_uri_decrypted(self): conn = sqla.engine.url.make_url(self.sqlalchemy_uri) if custom_password_store: conn.password = custom_password_store(conn) else: conn.password = self.password return str(conn) @property def sql_url(self): return "/superset/sql/{}/".format(self.id) def get_perm(self): return ("[{obj.database_name}].(id:{obj.id})").format(obj=self) def has_table(self, table): engine = self.get_sqla_engine() return engine.has_table(table.table_name, table.schema or None) def has_table_by_name(self, table_name, schema=None): engine = self.get_sqla_engine() return engine.has_table(table_name, schema) @utils.memoized def get_dialect(self): sqla_url = url.make_url(self.sqlalchemy_uri_decrypted) return sqla_url.get_dialect()()
return "/superset/explore/?form_data=%7B%22slice_id%22%3A%20{0}%7D".format( self.id ) sqla.event.listen(Slice, "before_insert", set_related_perm) sqla.event.listen(Slice, "before_update", set_related_perm) dashboard_slices = Table( "dashboard_slices", metadata, Column("id", Integer, primary_key=True), Column("dashboard_id", Integer, ForeignKey("dashboards.id")), Column("slice_id", Integer, ForeignKey("slices.id")), UniqueConstraint("dashboard_id", "slice_id"), ) dashboard_user = Table( "dashboard_user", metadata, Column("id", Integer, primary_key=True), Column("user_id", Integer, ForeignKey("ab_user.id")), Column("dashboard_id", Integer, ForeignKey("dashboards.id")), ) class Dashboard(Model, AuditMixinNullable, ImportMixin): """The dashboard object!"""
class Label(MailSyncBase, UpdatedAtMixin, DeletedAtMixin): """ Labels from the remote account backend (Gmail). """ # TOFIX this causes an import error due to circular dependencies # from inbox.models.account import Account # `use_alter` required here to avoid circular dependency w/Account account_id = Column(ForeignKey('account.id', use_alter=True, name='label_fk1', ondelete='CASCADE'), nullable=False) account = relationship( 'Account', backref=backref( 'labels', # Don't load labels if the account is deleted, # (the labels will be deleted by the foreign key delete casade). passive_deletes=True), load_on_pending=True) name = Column(CategoryNameString(), nullable=False) canonical_name = Column(String(MAX_INDEXABLE_LENGTH), nullable=False, default='') category_id = Column(ForeignKey(Category.id, ondelete='CASCADE')) category = relationship(Category, backref=backref('labels', cascade='all, delete-orphan')) @validates('name') def validate_name(self, key, name): sanitized_name = sanitize_name(name) if sanitized_name != name: log.warning("Truncating label name for account", account_id=self.account_id, name=name) return sanitized_name @classmethod def find_or_create(cls, session, account, name, role=None): q = session.query(cls).filter(cls.account_id == account.id) role = role or '' if role: q = q.filter(cls.canonical_name == role) else: q = q.filter(cls.name == name) obj = q.first() if obj is None: obj = cls(account=account, name=name, canonical_name=role) obj.category = Category.find_or_create( session, namespace_id=account.namespace.id, name=role, display_name=name, type_='label') session.add(obj) return obj __table_args__ = \ (UniqueConstraint('account_id', 'name', 'canonical_name'),)
class SqlaTable(Model, BaseDatasource): """An ORM object for SqlAlchemy table references""" type = "table" query_language = "sql" metric_class = SqlMetric column_class = TableColumn owner_class = security_manager.user_model __tablename__ = "tables" __table_args__ = (UniqueConstraint("database_id", "table_name"), ) table_name = Column(String(250), nullable=False) main_dttm_col = Column(String(250)) database_id = Column(Integer, ForeignKey("dbs.id"), nullable=False) fetch_values_predicate = Column(String(1000)) owners = relationship(owner_class, secondary=sqlatable_user, backref="tables") database = relationship( "Database", backref=backref("tables", cascade="all, delete-orphan"), foreign_keys=[database_id], ) schema = Column(String(255)) sql = Column(Text) is_sqllab_view = Column(Boolean, default=False) template_params = Column(Text) baselink = "tablemodelview" export_fields = [ "table_name", "main_dttm_col", "description", "default_endpoint", "database_id", "offset", "cache_timeout", "schema", "sql", "params", "template_params", "filter_select_enabled", "fetch_values_predicate", ] update_from_object_fields = [ f for f in export_fields if f not in ("table_name", "database_id") ] export_parent = "database" export_children = ["metrics", "columns"] sqla_aggregations = { "COUNT_DISTINCT": lambda column_name: sa.func.COUNT(sa.distinct(column_name)), "COUNT": sa.func.COUNT, "SUM": sa.func.SUM, "AVG": sa.func.AVG, "MIN": sa.func.MIN, "MAX": sa.func.MAX, } def make_sqla_column_compatible(self, sqla_col: Column, label: Optional[str] = None) -> Column: """Takes a sqlalchemy column object and adds label info if supported by engine. :param sqla_col: sqlalchemy column instance :param label: alias/label that column is expected to have :return: either a sql alchemy column or label instance if supported by engine """ label_expected = label or sqla_col.name db_engine_spec = self.database.db_engine_spec if db_engine_spec.allows_column_aliases: label = db_engine_spec.make_label_compatible(label_expected) sqla_col = sqla_col.label(label) sqla_col._df_label_expected = label_expected return sqla_col def __repr__(self): return self.name @property def changed_by_name(self) -> str: if not self.changed_by: return "" return str(self.changed_by) @property def changed_by_url(self) -> str: if not self.changed_by: return "" return f"/superset/profile/{self.changed_by.username}" @property def connection(self) -> str: return str(self.database) @property def description_markeddown(self) -> str: return utils.markdown(self.description) @property def datasource_name(self) -> str: return self.table_name @property def database_name(self) -> str: return self.database.name @classmethod def get_datasource_by_name( cls, session: Session, datasource_name: str, schema: Optional[str], database_name: str, ) -> Optional["SqlaTable"]: schema = schema or None query = (session.query(cls).join(Database).filter( cls.table_name == datasource_name).filter( Database.database_name == database_name)) # Handling schema being '' or None, which is easier to handle # in python than in the SQLA query in a multi-dialect way for tbl in query.all(): if schema == (tbl.schema or None): return tbl return None @property def link(self) -> Markup: name = escape(self.name) anchor = f'<a target="_blank" href="{self.explore_url}">{name}</a>' return Markup(anchor) def get_schema_perm(self) -> Optional[str]: """Returns schema permission if present, database one otherwise.""" return security_manager.get_schema_perm(self.database, self.schema) def get_perm(self) -> str: return ("[{obj.database}].[{obj.table_name}]" "(id:{obj.id})").format(obj=self) @property def name(self) -> str: # type: ignore if not self.schema: return self.table_name return "{}.{}".format(self.schema, self.table_name) @property def full_name(self) -> str: return utils.get_datasource_full_name(self.database, self.table_name, schema=self.schema) @property def dttm_cols(self) -> List: l = [c.column_name for c in self.columns if c.is_dttm] if self.main_dttm_col and self.main_dttm_col not in l: l.append(self.main_dttm_col) return l @property def num_cols(self) -> List: return [c.column_name for c in self.columns if c.is_numeric] @property def any_dttm_col(self) -> Optional[str]: cols = self.dttm_cols return cols[0] if cols else None @property def html(self) -> str: t = ((c.column_name, c.type) for c in self.columns) df = pd.DataFrame(t) df.columns = ["field", "type"] return df.to_html( index=False, classes=("dataframe table table-striped table-bordered " "table-condensed"), ) @property def sql_url(self) -> str: return self.database.sql_url + "?table_name=" + str(self.table_name) def external_metadata(self): cols = self.database.get_columns(self.table_name, schema=self.schema) for col in cols: try: col["type"] = str(col["type"]) except CompileError: col["type"] = "UNKNOWN" return cols @property def time_column_grains(self) -> Dict[str, Any]: return { "time_columns": self.dttm_cols, "time_grains": [grain.name for grain in self.database.grains()], } @property def select_star(self) -> str: # show_cols and latest_partition set to false to avoid # the expensive cost of inspecting the DB return self.database.select_star(self.table_name, schema=self.schema, show_cols=False, latest_partition=False) @property def data(self) -> Dict: d = super().data if self.type == "table": grains = self.database.grains() or [] if grains: grains = [(g.duration, g.name) for g in grains] d["granularity_sqla"] = utils.choicify(self.dttm_cols) d["time_grain_sqla"] = grains d["main_dttm_col"] = self.main_dttm_col d["fetch_values_predicate"] = self.fetch_values_predicate d["template_params"] = self.template_params return d def values_for_column(self, column_name: str, limit: int = 10000) -> List: """Runs query against sqla to retrieve some sample values for the given column. """ cols = {col.column_name: col for col in self.columns} target_col = cols[column_name] tp = self.get_template_processor() qry = (select([target_col.get_sqla_col() ]).select_from(self.get_from_clause(tp)).distinct()) if limit: qry = qry.limit(limit) if self.fetch_values_predicate: tp = self.get_template_processor() qry = qry.where(tp.process_template(self.fetch_values_predicate)) engine = self.database.get_sqla_engine() sql = "{}".format( qry.compile(engine, compile_kwargs={"literal_binds": True})) sql = self.mutate_query_from_config(sql) df = pd.read_sql_query(sql=sql, con=engine) return df[column_name].to_list() def mutate_query_from_config(self, sql: str) -> str: """Apply config's SQL_QUERY_MUTATOR Typically adds comments to the query with context""" SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"] if SQL_QUERY_MUTATOR: username = utils.get_username() sql = SQL_QUERY_MUTATOR(sql, username, security_manager, self.database) return sql def get_template_processor(self, **kwargs): return get_template_processor(table=self, database=self.database, **kwargs) def get_query_str_extended( self, query_obj: Dict[str, Any]) -> QueryStringExtended: sqlaq = self.get_sqla_query(**query_obj) sql = self.database.compile_sqla_query(sqlaq.sqla_query) logger.info(sql) sql = sqlparse.format(sql, reindent=True) sql = self.mutate_query_from_config(sql) return QueryStringExtended(labels_expected=sqlaq.labels_expected, sql=sql, prequeries=sqlaq.prequeries) def get_query_str(self, query_obj: Dict[str, Any]) -> str: query_str_ext = self.get_query_str_extended(query_obj) all_queries = query_str_ext.prequeries + [query_str_ext.sql] return ";\n\n".join(all_queries) + ";" def get_sqla_table(self): tbl = table(self.table_name) if self.schema: tbl.schema = self.schema return tbl def get_from_clause(self, template_processor=None): # Supporting arbitrary SQL statements in place of tables if self.sql: from_sql = self.sql if template_processor: from_sql = template_processor.process_template(from_sql) from_sql = sqlparse.format(from_sql, strip_comments=True) return TextAsFrom(sa.text(from_sql), []).alias("expr_qry") return self.get_sqla_table() def adhoc_metric_to_sqla(self, metric: Dict, cols: Dict) -> Optional[Column]: """ Turn an adhoc metric into a sqlalchemy column. :param dict metric: Adhoc metric definition :param dict cols: Columns for the current table :returns: The metric defined as a sqlalchemy column :rtype: sqlalchemy.sql.column """ expression_type = metric.get("expressionType") label = utils.get_metric_name(metric) if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]: column_name = metric["column"].get("column_name") table_column = cols.get(column_name) if table_column: sqla_column = table_column.get_sqla_col() else: sqla_column = column(column_name) sqla_metric = self.sqla_aggregations[metric["aggregate"]]( sqla_column) elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SQL"]: sqla_metric = literal_column(metric.get("sqlExpression")) else: return None return self.make_sqla_column_compatible(sqla_metric, label) def _get_sqla_row_level_filters(self, template_processor) -> List[str]: """ Return the appropriate row level security filters for this table and the current user. :param BaseTemplateProcessor template_processor: The template processor to apply to the filters. :returns: A list of SQL clauses to be ANDed together. :rtype: List[str] """ return [ text("({})".format(template_processor.process_template(f.clause))) for f in security_manager.get_rls_filters(self) ] def get_sqla_query( # sqla self, metrics, granularity, from_dttm, to_dttm, columns=None, groupby=None, filter=None, is_timeseries=True, timeseries_limit=15, timeseries_limit_metric=None, row_limit=None, inner_from_dttm=None, inner_to_dttm=None, orderby=None, extras=None, order_desc=True, ) -> SqlaQuery: """Querying any sqla table from this common interface""" template_kwargs = { "from_dttm": from_dttm, "groupby": groupby, "metrics": metrics, "row_limit": row_limit, "to_dttm": to_dttm, "filter": filter, "columns": {col.column_name: col for col in self.columns}, } is_sip_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE") template_kwargs.update(self.template_params_dict) extra_cache_keys: List[Any] = [] template_kwargs["extra_cache_keys"] = extra_cache_keys template_processor = self.get_template_processor(**template_kwargs) db_engine_spec = self.database.db_engine_spec prequeries: List[str] = [] orderby = orderby or [] # For backward compatibility if granularity not in self.dttm_cols: granularity = self.main_dttm_col # Database spec supports join-free timeslot grouping time_groupby_inline = db_engine_spec.time_groupby_inline cols: Dict[str, Column] = {col.column_name: col for col in self.columns} metrics_dict: Dict[str, SqlMetric] = { m.metric_name: m for m in self.metrics } if not granularity and is_timeseries: raise Exception( _("Datetime column not provided as part table configuration " "and is required by this type of chart")) if (not metrics and not columns and (is_sip_38 or (not is_sip_38 and not groupby))): raise Exception(_("Empty query?")) metrics_exprs: List[ColumnElement] = [] for m in metrics: if utils.is_adhoc_metric(m): metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols)) elif m in metrics_dict: metrics_exprs.append(metrics_dict[m].get_sqla_col()) else: raise Exception( _("Metric '%(metric)s' does not exist", metric=m)) if metrics_exprs: main_metric_expr = metrics_exprs[0] else: main_metric_expr, label = literal_column("COUNT(*)"), "ccount" main_metric_expr = self.make_sqla_column_compatible( main_metric_expr, label) select_exprs: List[Column] = [] groupby_exprs_sans_timestamp: OrderedDict = OrderedDict() if (is_sip_38 and metrics and columns) or (not is_sip_38 and groupby): # dedup columns while preserving order groupby = list(dict.fromkeys(columns if is_sip_38 else groupby)) select_exprs = [] for s in groupby: if s in cols: outer = cols[s].get_sqla_col() else: outer = literal_column(f"({s})") outer = self.make_sqla_column_compatible(outer, s) groupby_exprs_sans_timestamp[outer.name] = outer select_exprs.append(outer) elif columns: for s in columns: select_exprs.append( cols[s].get_sqla_col() if s in cols else self. make_sqla_column_compatible(literal_column(s))) metrics_exprs = [] time_range_endpoints = extras.get("time_range_endpoints") groupby_exprs_with_timestamp = OrderedDict( groupby_exprs_sans_timestamp.items()) if granularity: dttm_col = cols[granularity] time_grain = extras.get("time_grain_sqla") time_filters = [] if is_timeseries: timestamp = dttm_col.get_timestamp_expression(time_grain) select_exprs += [timestamp] groupby_exprs_with_timestamp[timestamp.name] = timestamp # Use main dttm column to support index with secondary dttm columns. if (db_engine_spec.time_secondary_columns and self.main_dttm_col in self.dttm_cols and self.main_dttm_col != dttm_col.column_name): time_filters.append(cols[self.main_dttm_col].get_time_filter( from_dttm, to_dttm, time_range_endpoints)) time_filters.append( dttm_col.get_time_filter(from_dttm, to_dttm, time_range_endpoints)) select_exprs += metrics_exprs labels_expected = [c._df_label_expected for c in select_exprs] select_exprs = db_engine_spec.make_select_compatible( groupby_exprs_with_timestamp.values(), select_exprs) qry = sa.select(select_exprs) tbl = self.get_from_clause(template_processor) if (is_sip_38 and metrics) or (not is_sip_38 and not columns): qry = qry.group_by(*groupby_exprs_with_timestamp.values()) where_clause_and = [] having_clause_and: List = [] for flt in filter: if not all([flt.get(s) for s in ["col", "op"]]): continue col = flt["col"] op = flt["op"].upper() col_obj = cols.get(col) if col_obj: is_list_target = op in ( utils.FilterOperator.IN.value, utils.FilterOperator.NOT_IN.value, ) eq = self.filter_values_handler( values=flt.get("val"), target_column_is_numeric=col_obj.is_numeric, is_list_target=is_list_target, ) if op in ( utils.FilterOperator.IN.value, utils.FilterOperator.NOT_IN.value, ): cond = col_obj.get_sqla_col().in_(eq) if isinstance(eq, str) and NULL_STRING in eq: cond = or_(cond, col_obj.get_sqla_col() is None) if op == utils.FilterOperator.NOT_IN.value: cond = ~cond where_clause_and.append(cond) else: if col_obj.is_numeric: eq = utils.cast_to_num(flt["val"]) if op == utils.FilterOperator.EQUALS.value: where_clause_and.append(col_obj.get_sqla_col() == eq) elif op == utils.FilterOperator.NOT_EQUALS.value: where_clause_and.append(col_obj.get_sqla_col() != eq) elif op == utils.FilterOperator.GREATER_THAN.value: where_clause_and.append(col_obj.get_sqla_col() > eq) elif op == utils.FilterOperator.LESS_THAN.value: where_clause_and.append(col_obj.get_sqla_col() < eq) elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS.value: where_clause_and.append(col_obj.get_sqla_col() >= eq) elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value: where_clause_and.append(col_obj.get_sqla_col() <= eq) elif op == utils.FilterOperator.LIKE.value: where_clause_and.append( col_obj.get_sqla_col().like(eq)) elif op == utils.FilterOperator.IS_NULL.value: where_clause_and.append(col_obj.get_sqla_col() == None) elif op == utils.FilterOperator.IS_NOT_NULL.value: where_clause_and.append(col_obj.get_sqla_col() != None) else: raise Exception( _("Invalid filter operation type: %(op)s", op=op)) where_clause_and += self._get_sqla_row_level_filters( template_processor) if extras: where = extras.get("where") if where: where = template_processor.process_template(where) where_clause_and += [sa.text("({})".format(where))] having = extras.get("having") if having: having = template_processor.process_template(having) having_clause_and += [sa.text("({})".format(having))] if granularity: qry = qry.where(and_(*(time_filters + where_clause_and))) else: qry = qry.where(and_(*where_clause_and)) qry = qry.having(and_(*having_clause_and)) if not orderby and ((is_sip_38 and metrics) or (not is_sip_38 and not columns)): orderby = [(main_metric_expr, not order_desc)] # To ensure correct handling of the ORDER BY labeling we need to reference the # metric instance if defined in the SELECT clause. metrics_exprs_by_label = {m._label: m for m in metrics_exprs} for col, ascending in orderby: direction = asc if ascending else desc if utils.is_adhoc_metric(col): col = self.adhoc_metric_to_sqla(col, cols) elif col in cols: col = cols[col].get_sqla_col() if isinstance(col, Label) and col._label in metrics_exprs_by_label: col = metrics_exprs_by_label[col._label] qry = qry.order_by(direction(col)) if row_limit: qry = qry.limit(row_limit) if (is_timeseries and timeseries_limit and not time_groupby_inline and ((is_sip_38 and columns) or (not is_sip_38 and groupby))): if self.database.db_engine_spec.allows_joins: # some sql dialects require for order by expressions # to also be in the select clause -- others, e.g. vertica, # require a unique inner alias inner_main_metric_expr = self.make_sqla_column_compatible( main_metric_expr, "mme_inner__") inner_groupby_exprs = [] inner_select_exprs = [] for gby_name, gby_obj in groupby_exprs_sans_timestamp.items(): inner = self.make_sqla_column_compatible( gby_obj, gby_name + "__") inner_groupby_exprs.append(inner) inner_select_exprs.append(inner) inner_select_exprs += [inner_main_metric_expr] subq = select(inner_select_exprs).select_from(tbl) inner_time_filter = dttm_col.get_time_filter( inner_from_dttm or from_dttm, inner_to_dttm or to_dttm, time_range_endpoints, ) subq = subq.where( and_(*(where_clause_and + [inner_time_filter]))) subq = subq.group_by(*inner_groupby_exprs) ob = inner_main_metric_expr if timeseries_limit_metric: ob = self._get_timeseries_orderby(timeseries_limit_metric, metrics_dict, cols) direction = desc if order_desc else asc subq = subq.order_by(direction(ob)) subq = subq.limit(timeseries_limit) on_clause = [] for gby_name, gby_obj in groupby_exprs_sans_timestamp.items(): # in this case the column name, not the alias, needs to be # conditionally mutated, as it refers to the column alias in # the inner query col_name = db_engine_spec.make_label_compatible(gby_name + "__") on_clause.append(gby_obj == column(col_name)) tbl = tbl.join(subq.alias(), and_(*on_clause)) else: if timeseries_limit_metric: orderby = [( self._get_timeseries_orderby(timeseries_limit_metric, metrics_dict, cols), False, )] # run prequery to get top groups prequery_obj = { "is_timeseries": False, "row_limit": timeseries_limit, "metrics": metrics, "granularity": granularity, "from_dttm": inner_from_dttm or from_dttm, "to_dttm": inner_to_dttm or to_dttm, "filter": filter, "orderby": orderby, "extras": extras, "columns": columns, "order_desc": True, } if not is_sip_38: prequery_obj["groupby"] = groupby result = self.query(prequery_obj) prequeries.append(result.query) dimensions = [ c for c in result.df.columns if c not in metrics and c in groupby_exprs_sans_timestamp ] top_groups = self._get_top_groups( result.df, dimensions, groupby_exprs_sans_timestamp) qry = qry.where(top_groups) return SqlaQuery( extra_cache_keys=extra_cache_keys, labels_expected=labels_expected, sqla_query=qry.select_from(tbl), prequeries=prequeries, ) def _get_timeseries_orderby(self, timeseries_limit_metric, metrics_dict, cols): if utils.is_adhoc_metric(timeseries_limit_metric): ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols) elif timeseries_limit_metric in metrics_dict: timeseries_limit_metric = metrics_dict.get(timeseries_limit_metric) ob = timeseries_limit_metric.get_sqla_col() else: raise Exception( _("Metric '%(metric)s' does not exist", metric=timeseries_limit_metric)) return ob def _get_top_groups(self, df: pd.DataFrame, dimensions: List, groupby_exprs: OrderedDict) -> ColumnElement: groups = [] for unused, row in df.iterrows(): group = [] for dimension in dimensions: group.append(groupby_exprs[dimension] == row[dimension]) groups.append(and_(*group)) return or_(*groups) def query(self, query_obj: Dict[str, Any]) -> QueryResult: qry_start_dttm = datetime.now() query_str_ext = self.get_query_str_extended(query_obj) sql = query_str_ext.sql status = utils.QueryStatus.SUCCESS error_message = None def mutator(df: pd.DataFrame) -> None: """ Some engines change the case or generate bespoke column names, either by default or due to lack of support for aliasing. This function ensures that the column names in the DataFrame correspond to what is expected by the viz components. :param df: Original DataFrame returned by the engine """ labels_expected = query_str_ext.labels_expected if df is not None and not df.empty: if len(df.columns) != len(labels_expected): raise Exception(f"For {sql}, df.columns: {df.columns}" f" differs from {labels_expected}") else: df.columns = labels_expected try: df = self.database.get_df(sql, self.schema, mutator) except Exception as ex: df = pd.DataFrame() status = utils.QueryStatus.FAILED logger.exception(f"Query {sql} on schema {self.schema} failed") db_engine_spec = self.database.db_engine_spec error_message = db_engine_spec.extract_error_message(ex) return QueryResult( status=status, df=df, duration=datetime.now() - qry_start_dttm, query=sql, error_message=error_message, ) def get_sqla_table_object(self) -> Table: return self.database.get_table(self.table_name, schema=self.schema) def fetch_metadata(self, commit=True) -> None: """Fetches the metadata for the table and merges it in""" try: table = self.get_sqla_table_object() except Exception as ex: logger.exception(ex) raise Exception( _("Table [{}] doesn't seem to exist in the specified database, " "couldn't fetch column information").format(self.table_name)) metrics = [] any_date_col = None db_engine_spec = self.database.db_engine_spec db_dialect = self.database.get_dialect() dbcols = (db.session.query(TableColumn).filter( TableColumn.table == self).filter( or_(TableColumn.column_name == col.name for col in table.columns))) dbcols = {dbcol.column_name: dbcol for dbcol in dbcols} for col in table.columns: try: datatype = db_engine_spec.column_datatype_to_string( col.type, db_dialect) except Exception as ex: datatype = "UNKNOWN" logger.error("Unrecognized data type in {}.{}".format( table, col.name)) logger.exception(ex) dbcol = dbcols.get(col.name, None) if not dbcol: dbcol = TableColumn(column_name=col.name, type=datatype, table=self) dbcol.sum = dbcol.is_numeric dbcol.avg = dbcol.is_numeric dbcol.is_dttm = dbcol.is_temporal db_engine_spec.alter_new_orm_column(dbcol) else: dbcol.type = datatype dbcol.groupby = True dbcol.filterable = True self.columns.append(dbcol) if not any_date_col and dbcol.is_temporal: any_date_col = col.name metrics.append( SqlMetric( metric_name="count", verbose_name="COUNT(*)", metric_type="count", expression="COUNT(*)", )) if not self.main_dttm_col: self.main_dttm_col = any_date_col self.add_missing_metrics(metrics) db.session.merge(self) if commit: db.session.commit() @classmethod def import_obj(cls, i_datasource, import_time=None) -> int: """Imports the datasource from the object to the database. Metrics and columns and datasource will be overrided if exists. This function can be used to import/export dashboards between multiple superset instances. Audit metadata isn't copies over. """ def lookup_sqlatable(table): return (db.session.query(SqlaTable).join(Database).filter( SqlaTable.table_name == table.table_name, SqlaTable.schema == table.schema, Database.id == table.database_id, ).first()) def lookup_database(table): try: return (db.session.query(Database).filter_by( database_name=table.params_dict["database_name"]).one()) except NoResultFound: raise DatabaseNotFound( _( "Database '%(name)s' is not found", name=table.params_dict["database_name"], )) return import_datasource.import_datasource(db.session, i_datasource, lookup_database, lookup_sqlatable, import_time) @classmethod def query_datasources_by_name(cls, session: Session, database: Database, datasource_name: str, schema=None) -> List["SqlaTable"]: query = (session.query(cls).filter_by( database_id=database.id).filter_by(table_name=datasource_name)) if schema: query = query.filter_by(schema=schema) return query.all() @staticmethod def default_query(qry) -> Query: return qry.filter_by(is_sqllab_view=False) def has_calls_to_cache_key_wrapper(self, query_obj: Dict[str, Any]) -> bool: """ Detects the presence of calls to `cache_key_wrapper` in items in query_obj that can be templated. If any are present, the query must be evaluated to extract additional keys for the cache key. This method is needed to avoid executing the template code unnecessarily, as it may contain expensive calls, e.g. to extract the latest partition of a database. :param query_obj: query object to analyze :return: True if at least one item calls `cache_key_wrapper`, otherwise False """ regex = re.compile(r"\{\{.*cache_key_wrapper\(.*\).*\}\}") templatable_statements: List[str] = [] if self.sql: templatable_statements.append(self.sql) if self.fetch_values_predicate: templatable_statements.append(self.fetch_values_predicate) extras = query_obj.get("extras", {}) if "where" in extras: templatable_statements.append(extras["where"]) if "having" in extras: templatable_statements.append(extras["having"]) for statement in templatable_statements: if regex.search(statement): return True return False def get_extra_cache_keys(self, query_obj: Dict[str, Any]) -> List[Hashable]: """ The cache key of a SqlaTable needs to consider any keys added by the parent class and any keys added via `cache_key_wrapper`. :param query_obj: query object to analyze :return: True if at least one item calls `cache_key_wrapper`, otherwise False """ extra_cache_keys = super().get_extra_cache_keys(query_obj) if self.has_calls_to_cache_key_wrapper(query_obj): sqla_query = self.get_sqla_query(**query_obj) extra_cache_keys += sqla_query.extra_cache_keys return extra_cache_keys
class Dataset(Base): """Class to store the information about a data set. """ __tablename__ = 'datasets' __table_args__ = ( UniqueConstraint('task_id', 'description'), # Useless, in theory, because 'id' is already unique. Yet, we # need this because it's a target of a foreign key. UniqueConstraint('id', 'task_id'), ) # Auto increment primary key. id = Column( Integer, primary_key=True) # Task (id and object) owning the dataset. task_id = Column( Integer, ForeignKey(Task.id, onupdate="CASCADE", ondelete="CASCADE"), nullable=False) task = relationship( Task, foreign_keys=[task_id], back_populates="datasets") # A human-readable text describing the dataset. description = Column( Unicode, CodenameConstraint("description"), nullable=False) # Whether this dataset will be automatically judged by ES and SS # "in background", together with the active dataset of each task. autojudge = Column( Boolean, nullable=False, default=False) # Time and memory limits for every testcase. time_limit = Column( Float, CheckConstraint("time_limit > 0"), nullable=True) memory_limit = Column( BigInteger, CheckConstraint("memory_limit > 0"), nullable=True) # Name of the TaskType child class suited for the task. task_type = Column( String, nullable=False) # Parameters for the task type class. task_type_parameters = Column( JSONB, nullable=False) # Name of the ScoreType child class suited for the task. score_type = Column( String, nullable=False) # Parameters for the score type class. score_type_parameters = Column( JSONB, nullable=False) # These one-to-many relationships are the reversed directions of # the ones defined in the "child" classes using foreign keys. managers = relationship( "Manager", collection_class=attribute_mapped_collection("filename"), cascade="all, delete-orphan", passive_deletes=True, back_populates="dataset") testcases = relationship( "Testcase", collection_class=attribute_mapped_collection("codename"), cascade="all, delete-orphan", passive_deletes=True, back_populates="dataset") @property def active(self): """Shorthand for detecting if the dataset is active. return (bool): True if this dataset is the active one for its task. """ return self is self.task.active_dataset @property def task_type_object(self): if not hasattr(self, "_cached_task_type_object") \ or self.task_type != self._cached_task_type \ or self.task_type_parameters \ != self._cached_task_type_parameters: # Import late to avoid a circular dependency. from cms.grading.tasktypes import get_task_type # This can raise. self._cached_task_type_object = get_task_type( self.task_type, self.task_type_parameters) # If an exception is raised these updates don't take place: # that way, next time this property is accessed, we get a # cache miss again and the same exception is raised again. self._cached_task_type = self.task_type self._cached_task_type_parameters = \ copy.deepcopy(self.task_type_parameters) return self._cached_task_type_object @property def score_type_object(self): public_testcases = {k: tc.public for k, tc in iteritems(self.testcases)} if not hasattr(self, "_cached_score_type_object") \ or self.score_type != self._cached_score_type \ or self.score_type_parameters \ != self._cached_score_type_parameters \ or public_testcases != self._cached_public_testcases: # Import late to avoid a circular dependency. from cms.grading.scoretypes import get_score_type # This can raise. self._cached_score_type_object = get_score_type( self.score_type, self.score_type_parameters, public_testcases) # If an exception is raised these updates don't take place: # that way, next time this property is accessed, we get a # cache miss again and the same exception is raised again. self._cached_score_type = self.score_type self._cached_score_type_parameters = \ copy.deepcopy(self.score_type_parameters) self._cached_public_testcases = public_testcases return self._cached_score_type_object def clone_from(self, old_dataset, clone_managers=True, clone_testcases=True, clone_results=False): """Overwrite the data with that in dataset. old_dataset (Dataset): original dataset to copy from. clone_managers (bool): copy dataset managers. clone_testcases (bool): copy dataset testcases. clone_results (bool): copy submission results (will also copy managers and testcases). """ new_testcases = dict() if clone_testcases or clone_results: for old_t in itervalues(old_dataset.testcases): new_t = old_t.clone() new_t.dataset = self new_testcases[new_t.codename] = new_t if clone_managers or clone_results: for old_m in itervalues(old_dataset.managers): new_m = old_m.clone() new_m.dataset = self # TODO: why is this needed? self.sa_session.flush() if clone_results: old_results = self.get_submission_results(old_dataset) for old_sr in old_results: # Create the submission result. new_sr = old_sr.clone() new_sr.submission = old_sr.submission new_sr.dataset = self # Create executables. for old_e in itervalues(old_sr.executables): new_e = old_e.clone() new_e.submission_result = new_sr # Create evaluations. for old_e in old_sr.evaluations: new_e = old_e.clone() new_e.submission_result = new_sr new_e.testcase = new_testcases[old_e.codename] self.sa_session.flush()
class Contact( MailSyncBase, HasRevisions, HasPublicID, HasEmailAddress, UpdatedAtMixin, DeletedAtMixin, ): """Data for a user's contact.""" API_OBJECT_NAME = "contact" namespace_id = Column(BigInteger, nullable=False, index=True) namespace = relationship( Namespace, primaryjoin="foreign(Contact.namespace_id) == remote(Namespace.id)", load_on_pending=True, ) # A server-provided unique ID. # NB: We specify the collation here so that the test DB gets setup correctly. uid = Column(String(64, collation="utf8mb4_bin"), nullable=False) # A constant, unique identifier for the remote backend this contact came # from. E.g., 'google', 'eas', 'inbox' provider_name = Column(String(64)) name = Column(Text) raw_data = Column(Text) # A score to use for ranking contact search results. This should be # precomputed to facilitate performant search. score = Column(Integer) # Flag to set if the contact is deleted in a remote backend. # (This is an unmapped attribute, i.e., it does not correspond to a # database column.) deleted = False __table_args__ = ( UniqueConstraint("uid", "namespace_id", "provider_name"), Index("idx_namespace_created", "namespace_id", "created_at"), Index("ix_contact_ns_uid_provider_name", "namespace_id", "uid", "provider_name"), ) @validates("raw_data") def validate_text_column_length(self, key, value): if value is None: return None return unicode_safe_truncate(value, MAX_TEXT_CHARS) @property def versioned_relationships(self): return ["phone_numbers"] def merge_from(self, new_contact): # This must be updated when new fields are added to the class. merge_attrs = ["name", "email_address", "raw_data"] for attr in merge_attrs: if getattr(self, attr) != getattr(new_contact, attr): setattr(self, attr, getattr(new_contact, attr))
class Paging(Base): __tablename__ = 'paging' __table_args__ = ( PrimaryKeyConstraint('id'), UniqueConstraint('number'), ) id = Column(Integer, nullable=False) tenant_uuid = Column(String(36), ForeignKey('tenant.uuid', ondelete='CASCADE'), nullable=False) number = Column(String(32)) name = Column(String(128)) duplex = Column(Integer, nullable=False, server_default='0') ignore = Column(Integer, nullable=False, server_default='0') record = Column(Integer, nullable=False, server_default='0') quiet = Column(Integer, nullable=False, server_default='0') timeout = Column(Integer, nullable=False, server_default='30') announcement_file = Column(String(64)) announcement_play = Column(Integer, nullable=False, server_default='0') announcement_caller = Column(Integer, nullable=False, server_default='0') commented = Column(Integer, nullable=False, server_default='0') description = Column(Text) paging_members = relationship( 'PagingUser', primaryjoin="""and_( PagingUser.pagingid == Paging.id, PagingUser.caller == 0 )""", cascade='all, delete-orphan', ) users_member = association_proxy( 'paging_members', 'user', creator=lambda _user: PagingUser(user=_user, caller=0), ) paging_callers = relationship( 'PagingUser', primaryjoin="""and_( PagingUser.pagingid == Paging.id, PagingUser.caller == 1 )""", cascade='all, delete-orphan') users_caller = association_proxy( 'paging_callers', 'user', creator=lambda _user: PagingUser(user=_user, caller=1), ) func_keys = relationship('FuncKeyDestPaging', cascade='all, delete-orphan') @hybrid_property def enabled(self): return self.commented == 0 @enabled.expression def enabled(cls): return not_(cast(cls.commented, Boolean)) @enabled.setter def enabled(self, value): self.commented = int(value is False) @hybrid_property def duplex_bool(self): return self.duplex == 1 @duplex_bool.expression def duplex_bool(cls): return cast(cls.duplex, Boolean) @duplex_bool.setter def duplex_bool(self, value): self.duplex = int(value) @hybrid_property def record_bool(self): return self.record == 1 @record_bool.expression def record_bool(cls): return cast(cls.record, Boolean) @record_bool.setter def record_bool(self, value): self.record = int(value) @hybrid_property def ignore_forward(self): return self.ignore == 1 @ignore_forward.expression def ignore_forward(cls): return cast(cls.ignore, Boolean) @ignore_forward.setter def ignore_forward(self, value): self.ignore = int(value) @hybrid_property def caller_notification(self): return self.quiet == 0 @caller_notification.expression def caller_notification(cls): return not_(cast(cls.quiet, Boolean)) @caller_notification.setter def caller_notification(self, value): self.quiet = int(value == 0) @hybrid_property def announce_caller(self): return self.announcement_caller == 0 @announce_caller.expression def announce_caller(cls): return not_(cast(cls.announcement_caller, Boolean)) @announce_caller.setter def announce_caller(self, value): self.announcement_caller = int(value == 0) @hybrid_property def announce_sound(self): return self.announcement_file @announce_sound.setter def announce_sound(self, value): self.announcement_play = int(value is not None) self.announcement_file = value
class Database(Model, AuditMixinNullable, ImportMixin): """An ORM object that stores Database related information""" __tablename__ = 'dbs' type = 'table' __table_args__ = (UniqueConstraint('database_name'),) id = Column(Integer, primary_key=True) verbose_name = Column(String(250), unique=True) # short unique name, used in permissions database_name = Column(String(250), unique=True) sqlalchemy_uri = Column(String(1024)) password = Column(EncryptedType(String(1024), config.get('SECRET_KEY'))) cache_timeout = Column(Integer) select_as_create_table_as = Column(Boolean, default=False) expose_in_sqllab = Column(Boolean, default=False) allow_run_sync = Column(Boolean, default=True) allow_run_async = Column(Boolean, default=False) allow_csv_upload = Column(Boolean, default=False) allow_ctas = Column(Boolean, default=False) allow_dml = Column(Boolean, default=False) force_ctas_schema = Column(String(250)) allow_multi_schema_metadata_fetch = Column(Boolean, default=True) extra = Column(Text, default=textwrap.dedent("""\ { "metadata_params": {}, "engine_params": {}, "metadata_cache_timeout": {}, "schemas_allowed_for_csv_upload": [] } """)) perm = Column(String(1000)) impersonate_user = Column(Boolean, default=False) export_fields = ('database_name', 'sqlalchemy_uri', 'cache_timeout', 'expose_in_sqllab', 'allow_run_sync', 'allow_run_async', 'allow_ctas', 'allow_csv_upload', 'extra') export_children = ['tables'] def __repr__(self): return self.verbose_name if self.verbose_name else self.database_name @property def name(self): return self.verbose_name if self.verbose_name else self.database_name @property def allows_subquery(self): return self.db_engine_spec.allows_subquery @property def data(self): return { 'id': self.id, 'name': self.database_name, 'backend': self.backend, 'allow_multi_schema_metadata_fetch': self.allow_multi_schema_metadata_fetch, 'allows_subquery': self.allows_subquery, } @property def unique_name(self): return self.database_name @property def url_object(self): return make_url(self.sqlalchemy_uri_decrypted) @property def backend(self): url = make_url(self.sqlalchemy_uri_decrypted) return url.get_backend_name() @classmethod def get_password_masked_url_from_uri(cls, uri): url = make_url(uri) return cls.get_password_masked_url(url) @classmethod def get_password_masked_url(cls, url): url_copy = deepcopy(url) if url_copy.password is not None and url_copy.password != PASSWORD_MASK: url_copy.password = PASSWORD_MASK return url_copy def set_sqlalchemy_uri(self, uri): conn = sqla.engine.url.make_url(uri.strip()) if conn.password != PASSWORD_MASK and not custom_password_store: # do not over-write the password with the password mask self.password = conn.password conn.password = PASSWORD_MASK if conn.password else None self.sqlalchemy_uri = str(conn) # hides the password def get_effective_user(self, url, user_name=None): """ Get the effective user, especially during impersonation. :param url: SQL Alchemy URL object :param user_name: Default username :return: The effective username """ effective_username = None if self.impersonate_user: effective_username = url.username if user_name: effective_username = user_name elif ( hasattr(g, 'user') and hasattr(g.user, 'username') and g.user.username is not None ): effective_username = g.user.username return effective_username @utils.memoized( watch=('impersonate_user', 'sqlalchemy_uri_decrypted', 'extra')) def get_sqla_engine(self, schema=None, nullpool=True, user_name=None): extra = self.get_extra() url = make_url(self.sqlalchemy_uri_decrypted) url = self.db_engine_spec.adjust_database_uri(url, schema) effective_username = self.get_effective_user(url, user_name) # If using MySQL or Presto for example, will set url.username # If using Hive, will not do anything yet since that relies on a # configuration parameter instead. self.db_engine_spec.modify_url_for_impersonation( url, self.impersonate_user, effective_username) masked_url = self.get_password_masked_url(url) logging.info('Database.get_sqla_engine(). Masked URL: {0}'.format(masked_url)) params = extra.get('engine_params', {}) if nullpool: params['poolclass'] = NullPool # If using Hive, this will set hive.server2.proxy.user=$effective_username configuration = {} configuration.update( self.db_engine_spec.get_configuration_for_impersonation( str(url), self.impersonate_user, effective_username)) if configuration: params['connect_args'] = {'configuration': configuration} DB_CONNECTION_MUTATOR = config.get('DB_CONNECTION_MUTATOR') if DB_CONNECTION_MUTATOR: url, params = DB_CONNECTION_MUTATOR( url, params, effective_username, security_manager) return create_engine(url, **params) def get_reserved_words(self): return self.get_dialect().preparer.reserved_words def get_quoter(self): return self.get_dialect().identifier_preparer.quote def get_df(self, sql, schema): sqls = [str(s).strip().strip(';') for s in sqlparse.parse(sql)] engine = self.get_sqla_engine(schema=schema) def needs_conversion(df_series): if df_series.empty: return False if isinstance(df_series[0], (list, dict)): return True return False with closing(engine.raw_connection()) as conn: with closing(conn.cursor()) as cursor: for sql in sqls[:-1]: self.db_engine_spec.execute(cursor, sql) cursor.fetchall() self.db_engine_spec.execute(cursor, sqls[-1]) if cursor.description is not None: columns = [col_desc[0] for col_desc in cursor.description] else: columns = [] df = pd.DataFrame.from_records( data=list(cursor.fetchall()), columns=columns, coerce_float=True, ) for k, v in df.dtypes.items(): if v.type == numpy.object_ and needs_conversion(df[k]): df[k] = df[k].apply(utils.json_dumps_w_dates) return df def compile_sqla_query(self, qry, schema=None): engine = self.get_sqla_engine(schema=schema) sql = str( qry.compile( engine, compile_kwargs={'literal_binds': True}, ), ) if engine.dialect.identifier_preparer._double_percents: sql = sql.replace('%%', '%') return sql def select_star( self, table_name, schema=None, limit=100, show_cols=False, indent=True, latest_partition=False, cols=None): """Generates a ``select *`` statement in the proper dialect""" eng = self.get_sqla_engine(schema=schema) return self.db_engine_spec.select_star( self, table_name, schema=schema, engine=eng, limit=limit, show_cols=show_cols, indent=indent, latest_partition=latest_partition, cols=cols) def apply_limit_to_sql(self, sql, limit=1000): return self.db_engine_spec.apply_limit_to_sql(sql, limit, self) def safe_sqlalchemy_uri(self): return self.sqlalchemy_uri @property def inspector(self): engine = self.get_sqla_engine() return sqla.inspect(engine) def all_table_names(self, schema=None, force=False): if not schema: if not self.allow_multi_schema_metadata_fetch: return [] tables_dict = self.db_engine_spec.fetch_result_sets( self, 'table', force=force) return tables_dict.get('', []) extra = self.get_extra() medatada_cache_timeout = extra.get('metadata_cache_timeout', {}) table_cache_timeout = medatada_cache_timeout.get('table_cache_timeout') enable_cache = 'table_cache_timeout' in medatada_cache_timeout return sorted(self.db_engine_spec.get_table_names( inspector=self.inspector, db_id=self.id, schema=schema, enable_cache=enable_cache, cache_timeout=table_cache_timeout, force=force)) def all_view_names(self, schema=None, force=False): if not schema: if not self.allow_multi_schema_metadata_fetch: return [] views_dict = self.db_engine_spec.fetch_result_sets( self, 'view', force=force) return views_dict.get('', []) views = [] try: extra = self.get_extra() medatada_cache_timeout = extra.get('metadata_cache_timeout', {}) table_cache_timeout = medatada_cache_timeout.get('table_cache_timeout') enable_cache = 'table_cache_timeout' in medatada_cache_timeout views = self.db_engine_spec.get_view_names( inspector=self.inspector, db_id=self.id, schema=schema, enable_cache=enable_cache, cache_timeout=table_cache_timeout, force=force) except Exception: pass return views def all_schema_names(self, force_refresh=False): extra = self.get_extra() medatada_cache_timeout = extra.get('metadata_cache_timeout', {}) schema_cache_timeout = medatada_cache_timeout.get('schema_cache_timeout') enable_cache = 'schema_cache_timeout' in medatada_cache_timeout return sorted(self.db_engine_spec.get_schema_names( inspector=self.inspector, enable_cache=enable_cache, cache_timeout=schema_cache_timeout, db_id=self.id, force=force_refresh)) @property def db_engine_spec(self): return db_engine_specs.engines.get( self.backend, db_engine_specs.BaseEngineSpec) @classmethod def get_db_engine_spec_for_backend(cls, backend): return db_engine_specs.engines.get(backend, db_engine_specs.BaseEngineSpec) def grains(self): """Defines time granularity database-specific expressions. The idea here is to make it easy for users to change the time grain form a datetime (maybe the source grain is arbitrary timestamps, daily or 5 minutes increments) to another, "truncated" datetime. Since each database has slightly different but similar datetime functions, this allows a mapping between database engines and actual functions. """ return self.db_engine_spec.get_time_grains() def grains_dict(self): """Allowing to lookup grain by either label or duration For backward compatibility""" d = {grain.duration: grain for grain in self.grains()} d.update({grain.label: grain for grain in self.grains()}) return d def get_extra(self): extra = {} if self.extra: try: extra = json.loads(self.extra) except Exception as e: logging.error(e) raise e return extra def get_table(self, table_name, schema=None): extra = self.get_extra() meta = MetaData(**extra.get('metadata_params', {})) return Table( table_name, meta, schema=schema or None, autoload=True, autoload_with=self.get_sqla_engine()) def get_columns(self, table_name, schema=None): return self.inspector.get_columns(table_name, schema) def get_indexes(self, table_name, schema=None): return self.inspector.get_indexes(table_name, schema) def get_pk_constraint(self, table_name, schema=None): return self.inspector.get_pk_constraint(table_name, schema) def get_foreign_keys(self, table_name, schema=None): return self.inspector.get_foreign_keys(table_name, schema) def get_schema_access_for_csv_upload(self): return self.get_extra().get('schemas_allowed_for_csv_upload', []) @property def sqlalchemy_uri_decrypted(self): conn = sqla.engine.url.make_url(self.sqlalchemy_uri) if custom_password_store: conn.password = custom_password_store(conn) else: conn.password = self.password return str(conn) @property def sql_url(self): return '/superset/sql/{}/'.format(self.id) def get_perm(self): return ( '[{obj.database_name}].(id:{obj.id})').format(obj=self) def has_table(self, table): engine = self.get_sqla_engine() return engine.has_table( table.table_name, table.schema or None) @utils.memoized def get_dialect(self): sqla_url = url.make_url(self.sqlalchemy_uri_decrypted) return sqla_url.get_dialect()()
class UserCustom(Base): __tablename__ = 'usercustom' __table_args__ = ( PrimaryKeyConstraint('id'), UniqueConstraint('interface', 'intfsuffix', 'category'), Index('usercustom__idx__category', 'category'), Index('usercustom__idx__context', 'context'), Index('usercustom__idx__name', 'name'), ) id = Column(Integer, nullable=False) tenant_uuid = Column(String(36), ForeignKey('tenant.uuid', ondelete='CASCADE'), nullable=False) name = Column(String(40)) context = Column(String(39)) interface = Column(String(128), nullable=False) intfsuffix = Column(String(32), nullable=False, server_default='') commented = Column(Integer, nullable=False, server_default='0') protocol = Column(enum.trunk_protocol, nullable=False, server_default='custom') category = Column(Enum('user', 'trunk', name='usercustom_category', metadata=Base.metadata), nullable=False) line = relationship('LineFeatures', uselist=False, viewonly=True) trunk = relationship('TrunkFeatures', uselist=False, viewonly=True) @hybrid_property def enabled(self): return not bool(self.commented) @enabled.expression def enabled(cls): return not_(cast(cls.commented, Boolean)) @enabled.setter def enabled(self, value): if value is None: self.commented = None else: self.commented = int(value is False) def endpoint_protocol(self): return 'custom' def same_protocol(self, protocol, protocolid): return protocol == 'custom' and self.id == int(protocolid) @hybrid_property def interface_suffix(self): if self.intfsuffix == '': return None return self.intfsuffix @interface_suffix.expression def interface_suffix(cls): return func.nullif(cls.intfsuffix, '') @interface_suffix.setter def interface_suffix(self, value): if value is None: self.intfsuffix = '' else: self.intfsuffix = value
class User(Base): """Class to store a 'user participating in a contest'. """ # TODO: we really need to split this as a user (as in: not paired # with a contest) and a participation. __tablename__ = 'users' __table_args__ = (UniqueConstraint('contest_id', 'username'), ) # Auto increment primary key. id = Column(Integer, primary_key=True) # Real name (human readable) of the user. first_name = Column(Unicode, nullable=False) last_name = Column(Unicode, nullable=False) # Username and password to log in the CWS. username = Column(Unicode, nullable=False) password = Column(Unicode, nullable=False, default=generate_random_password) # Email for any communications in case of remote contest. email = Column(Unicode, nullable=True) # User can log in CWS only from this ip. ip = Column(Unicode, nullable=True) # A hidden user is used only for debugging purpose. hidden = Column(Boolean, nullable=False, default=False) # Contest (id and object) to which the user is participating. contest_id = Column(Integer, ForeignKey(Contest.id, onupdate="CASCADE", ondelete="CASCADE"), nullable=False, index=True) contest = relationship(Contest, backref=backref("users", cascade="all, delete-orphan", passive_deletes=True)) # A JSON-encoded dictionary of lists of strings: statements["a"] # contains the language codes of the statements that will be # highlighted to this user for task "a". primary_statements = Column(String, nullable=False, default="{}") # Timezone for the user. All timestamps in CWS will be shown using # the timezone associated to the logged-in user or (if it's None # or an invalid string) the timezone associated to the contest or # (if it's None or an invalid string) the local timezone of the # server. This value has to be a string like "Europe/Rome", # "Australia/Sydney", "America/New_York", etc. timezone = Column(Unicode, nullable=True) # Starting time: for contests where every user has at most x hours # of the y > x hours totally available, this is the time the user # decided to start his/her time-frame. starting_time = Column(DateTime, nullable=True) # An extra amount of time allocated for this user. extra_time = Column(Interval, nullable=False, default=timedelta())
class SqlaTable(Model, BaseDatasource): """An ORM object for SqlAlchemy table references""" type = 'table' query_language = 'sql' metric_class = SqlMetric column_class = TableColumn __tablename__ = 'tables' __table_args__ = (UniqueConstraint('database_id', 'table_name'), ) table_name = Column(String(250)) main_dttm_col = Column(String(250)) database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False) fetch_values_predicate = Column(String(1000)) user_id = Column(Integer, ForeignKey('ab_user.id')) owner = relationship(security_manager.user_model, backref='tables', foreign_keys=[user_id]) database = relationship('Database', backref=backref('tables', cascade='all, delete-orphan'), foreign_keys=[database_id]) schema = Column(String(255)) sql = Column(Text) is_sqllab_view = Column(Boolean, default=False) template_params = Column(Text) baselink = 'tablemodelview' export_fields = ('table_name', 'main_dttm_col', 'description', 'default_endpoint', 'database_id', 'offset', 'cache_timeout', 'schema', 'sql', 'params', 'template_params', 'filter_select_enabled') update_from_object_fields = [ f for f in export_fields if f not in ('table_name', 'database_id') ] export_parent = 'database' export_children = ['metrics', 'columns'] sqla_aggregations = { 'COUNT_DISTINCT': lambda column_name: sa.func.COUNT(sa.distinct(column_name)), 'COUNT': sa.func.COUNT, 'SUM': sa.func.SUM, 'AVG': sa.func.AVG, 'MIN': sa.func.MIN, 'MAX': sa.func.MAX, } def __repr__(self): return self.name @property def connection(self): return str(self.database) @property def description_markeddown(self): return utils.markdown(self.description) @property def datasource_name(self): return self.table_name @property def database_name(self): return self.database.name @property def link(self): name = escape(self.name) anchor = '<a target="_blank" href="{self.explore_url}">{name}</a>' return Markup(anchor.format(**locals())) @property def schema_perm(self): """Returns schema permission if present, database one otherwise.""" return security_manager.get_schema_perm(self.database, self.schema) def get_perm(self): return ('[{obj.database}].[{obj.table_name}]' '(id:{obj.id})').format(obj=self) @property def name(self): if not self.schema: return self.table_name return '{}.{}'.format(self.schema, self.table_name) @property def full_name(self): return utils.get_datasource_full_name(self.database, self.table_name, schema=self.schema) @property def dttm_cols(self): l = [c.column_name for c in self.columns if c.is_dttm] # noqa: E741 if self.main_dttm_col and self.main_dttm_col not in l: l.append(self.main_dttm_col) return l @property def num_cols(self): return [c.column_name for c in self.columns if c.is_num] @property def any_dttm_col(self): cols = self.dttm_cols if cols: return cols[0] @property def html(self): t = ((c.column_name, c.type) for c in self.columns) df = pd.DataFrame(t) df.columns = ['field', 'type'] return df.to_html( index=False, classes=('dataframe table table-striped table-bordered ' 'table-condensed')) @property def sql_url(self): return self.database.sql_url + '?table_name=' + str(self.table_name) def external_metadata(self): cols = self.database.get_columns(self.table_name, schema=self.schema) for col in cols: col['type'] = '{}'.format(col['type']) return cols @property def time_column_grains(self): return { 'time_columns': self.dttm_cols, 'time_grains': [grain.name for grain in self.database.grains()], } @property def select_star(self): # show_cols and latest_partition set to false to avoid # the expensive cost of inspecting the DB return self.database.select_star(self.name, show_cols=False, latest_partition=False) def get_col(self, col_name): columns = self.columns for col in columns: if col_name == col.column_name: return col @property def data(self): d = super(SqlaTable, self).data if self.type == 'table': grains = self.database.grains() or [] if grains: grains = [(g.duration, g.name) for g in grains] d['granularity_sqla'] = utils.choicify(self.dttm_cols) d['time_grain_sqla'] = grains d['main_dttm_col'] = self.main_dttm_col return d def values_for_column(self, column_name, limit=10000): """Runs query against sqla to retrieve some sample values for the given column. """ cols = {col.column_name: col for col in self.columns} target_col = cols[column_name] tp = self.get_template_processor() qry = (select([target_col.get_sqla_col() ]).select_from(self.get_from_clause(tp)).distinct()) if limit: qry = qry.limit(limit) if self.fetch_values_predicate: tp = self.get_template_processor() qry = qry.where(tp.process_template(self.fetch_values_predicate)) engine = self.database.get_sqla_engine() sql = '{}'.format( qry.compile(engine, compile_kwargs={'literal_binds': True}), ) sql = self.mutate_query_from_config(sql) df = pd.read_sql_query(sql=sql, con=engine) return [row[0] for row in df.to_records(index=False)] def mutate_query_from_config(self, sql): """Apply config's SQL_QUERY_MUTATOR Typically adds comments to the query with context""" SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR') if SQL_QUERY_MUTATOR: username = utils.get_username() sql = SQL_QUERY_MUTATOR(sql, username, security_manager, self.database) return sql def get_template_processor(self, **kwargs): return get_template_processor(table=self, database=self.database, **kwargs) def get_query_str(self, query_obj): qry = self.get_sqla_query(**query_obj) sql = self.database.compile_sqla_query(qry) logging.info(sql) sql = sqlparse.format(sql, reindent=True) if query_obj['is_prequery']: query_obj['prequeries'].append(sql) sql = self.mutate_query_from_config(sql) return sql def get_sqla_table(self): tbl = table(self.table_name) if self.schema: tbl.schema = self.schema return tbl def get_from_clause(self, template_processor=None): # Supporting arbitrary SQL statements in place of tables if self.sql: from_sql = self.sql if template_processor: from_sql = template_processor.process_template(from_sql) from_sql = sqlparse.format(from_sql, strip_comments=True) return TextAsFrom(sa.text(from_sql), []).alias('expr_qry') return self.get_sqla_table() def adhoc_metric_to_sqla(self, metric, cols): """ Turn an adhoc metric into a sqlalchemy column. :param dict metric: Adhoc metric definition :param dict cols: Columns for the current table :returns: The metric defined as a sqlalchemy column :rtype: sqlalchemy.sql.column """ expression_type = metric.get('expressionType') db_engine_spec = self.database.db_engine_spec label = db_engine_spec.make_label_compatible(metric.get('label')) if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']: column_name = metric.get('column').get('column_name') sqla_column = column(column_name) table_column = cols.get(column_name) if table_column: sqla_column = table_column.get_sqla_col() sqla_metric = self.sqla_aggregations[metric.get('aggregate')]( sqla_column) sqla_metric = sqla_metric.label(label) return sqla_metric elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']: sqla_metric = literal_column(metric.get('sqlExpression')) sqla_metric = sqla_metric.label(label) return sqla_metric else: return None def get_sqla_query( # sqla self, groupby, metrics, granularity, from_dttm, to_dttm, filter=None, # noqa is_timeseries=True, timeseries_limit=15, timeseries_limit_metric=None, row_limit=None, inner_from_dttm=None, inner_to_dttm=None, orderby=None, extras=None, columns=None, order_desc=True, prequeries=None, is_prequery=False, ): """Querying any sqla table from this common interface""" template_kwargs = { 'from_dttm': from_dttm, 'groupby': groupby, 'metrics': metrics, 'row_limit': row_limit, 'to_dttm': to_dttm, 'filter': filter, 'columns': {col.column_name: col for col in self.columns}, } template_kwargs.update(self.template_params_dict) template_processor = self.get_template_processor(**template_kwargs) db_engine_spec = self.database.db_engine_spec orderby = orderby or [] # For backward compatibility if granularity not in self.dttm_cols: granularity = self.main_dttm_col # Database spec supports join-free timeslot grouping time_groupby_inline = db_engine_spec.time_groupby_inline cols = {col.column_name: col for col in self.columns} metrics_dict = {m.metric_name: m for m in self.metrics} if not granularity and is_timeseries: raise Exception( _('Datetime column not provided as part table configuration ' 'and is required by this type of chart')) if not groupby and not metrics and not columns: raise Exception(_('Empty query?')) metrics_exprs = [] for m in metrics: if utils.is_adhoc_metric(m): metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols)) elif m in metrics_dict: metrics_exprs.append(metrics_dict.get(m).get_sqla_col()) else: raise Exception(_("Metric '{}' is not valid".format(m))) if metrics_exprs: main_metric_expr = metrics_exprs[0] else: main_metric_expr = literal_column('COUNT(*)').label( db_engine_spec.make_label_compatible('count')) select_exprs = [] groupby_exprs = [] if groupby: select_exprs = [] inner_select_exprs = [] inner_groupby_exprs = [] for s in groupby: col = cols[s] outer = col.get_sqla_col() inner = col.get_sqla_col(col.column_name + '__') groupby_exprs.append(outer) select_exprs.append(outer) inner_groupby_exprs.append(inner) inner_select_exprs.append(inner) elif columns: for s in columns: select_exprs.append(cols[s].get_sqla_col()) metrics_exprs = [] if granularity: dttm_col = cols[granularity] time_grain = extras.get('time_grain_sqla') time_filters = [] if is_timeseries: timestamp = dttm_col.get_timestamp_expression(time_grain) select_exprs += [timestamp] groupby_exprs += [timestamp] # Use main dttm column to support index with secondary dttm columns if db_engine_spec.time_secondary_columns and \ self.main_dttm_col in self.dttm_cols and \ self.main_dttm_col != dttm_col.column_name: time_filters.append(cols[self.main_dttm_col].get_time_filter( from_dttm, to_dttm)) time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm)) select_exprs += metrics_exprs qry = sa.select(select_exprs) tbl = self.get_from_clause(template_processor) if not columns: qry = qry.group_by(*groupby_exprs) where_clause_and = [] having_clause_and = [] for flt in filter: if not all([flt.get(s) for s in ['col', 'op']]): continue col = flt['col'] op = flt['op'] col_obj = cols.get(col) if col_obj: is_list_target = op in ('in', 'not in') eq = self.filter_values_handler( flt.get('val'), target_column_is_numeric=col_obj.is_num, is_list_target=is_list_target) if op in ('in', 'not in'): cond = col_obj.get_sqla_col().in_(eq) if '<NULL>' in eq: cond = or_(cond, col_obj.get_sqla_col() == None) # noqa if op == 'not in': cond = ~cond where_clause_and.append(cond) else: if col_obj.is_num: eq = utils.string_to_num(flt['val']) if op == '==': where_clause_and.append(col_obj.get_sqla_col() == eq) elif op == '!=': where_clause_and.append(col_obj.get_sqla_col() != eq) elif op == '>': where_clause_and.append(col_obj.get_sqla_col() > eq) elif op == '<': where_clause_and.append(col_obj.get_sqla_col() < eq) elif op == '>=': where_clause_and.append(col_obj.get_sqla_col() >= eq) elif op == '<=': where_clause_and.append(col_obj.get_sqla_col() <= eq) elif op == 'LIKE': where_clause_and.append( col_obj.get_sqla_col().like(eq)) elif op == 'IS NULL': where_clause_and.append( col_obj.get_sqla_col() == None) # noqa elif op == 'IS NOT NULL': where_clause_and.append( col_obj.get_sqla_col() != None) # noqa if extras: where = extras.get('where') if where: where = template_processor.process_template(where) where_clause_and += [sa.text('({})'.format(where))] having = extras.get('having') if having: having = template_processor.process_template(having) having_clause_and += [sa.text('({})'.format(having))] if granularity: qry = qry.where(and_(*(time_filters + where_clause_and))) else: qry = qry.where(and_(*where_clause_and)) qry = qry.having(and_(*having_clause_and)) if not orderby and not columns: orderby = [(main_metric_expr, not order_desc)] for col, ascending in orderby: direction = asc if ascending else desc if utils.is_adhoc_metric(col): col = self.adhoc_metric_to_sqla(col, cols) qry = qry.order_by(direction(col)) if row_limit: qry = qry.limit(row_limit) if is_timeseries and \ timeseries_limit and groupby and not time_groupby_inline: if self.database.db_engine_spec.inner_joins: # some sql dialects require for order by expressions # to also be in the select clause -- others, e.g. vertica, # require a unique inner alias inner_main_metric_expr = main_metric_expr.label('mme_inner__') inner_select_exprs += [inner_main_metric_expr] subq = select(inner_select_exprs) subq = subq.select_from(tbl) inner_time_filter = dttm_col.get_time_filter( inner_from_dttm or from_dttm, inner_to_dttm or to_dttm, ) subq = subq.where( and_(*(where_clause_and + [inner_time_filter]))) subq = subq.group_by(*inner_groupby_exprs) ob = inner_main_metric_expr if timeseries_limit_metric: if utils.is_adhoc_metric(timeseries_limit_metric): ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols) elif timeseries_limit_metric in metrics_dict: timeseries_limit_metric = metrics_dict.get( timeseries_limit_metric, ) ob = timeseries_limit_metric.get_sqla_col() else: raise Exception(_( "Metric '{}' is not valid".format(m))) direction = desc if order_desc else asc subq = subq.order_by(direction(ob)) subq = subq.limit(timeseries_limit) on_clause = [] for i, gb in enumerate(groupby): on_clause.append(groupby_exprs[i] == column(gb + '__')) tbl = tbl.join(subq.alias(), and_(*on_clause)) else: # run subquery to get top groups subquery_obj = { 'prequeries': prequeries, 'is_prequery': True, 'is_timeseries': False, 'row_limit': timeseries_limit, 'groupby': groupby, 'metrics': metrics, 'granularity': granularity, 'from_dttm': inner_from_dttm or from_dttm, 'to_dttm': inner_to_dttm or to_dttm, 'filter': filter, 'orderby': orderby, 'extras': extras, 'columns': columns, 'order_desc': True, } result = self.query(subquery_obj) cols = {col.column_name: col for col in self.columns} dimensions = [ c for c in result.df.columns if c not in metrics and c in cols ] top_groups = self._get_top_groups(result.df, dimensions) qry = qry.where(top_groups) return qry.select_from(tbl) def _get_top_groups(self, df, dimensions): cols = {col.column_name: col for col in self.columns} groups = [] for unused, row in df.iterrows(): group = [] for dimension in dimensions: col_obj = cols.get(dimension) group.append(col_obj.get_sqla_col() == row[dimension]) groups.append(and_(*group)) return or_(*groups) def query(self, query_obj): qry_start_dttm = datetime.now() sql = self.get_query_str(query_obj) status = QueryStatus.SUCCESS error_message = None df = None try: df = self.database.get_df(sql, self.schema) except Exception as e: status = QueryStatus.FAILED logging.exception(e) error_message = ( self.database.db_engine_spec.extract_error_message(e)) # if this is a main query with prequeries, combine them together if not query_obj['is_prequery']: query_obj['prequeries'].append(sql) sql = ';\n\n'.join(query_obj['prequeries']) sql += ';' return QueryResult(status=status, df=df, duration=datetime.now() - qry_start_dttm, query=sql, error_message=error_message) def get_sqla_table_object(self): return self.database.get_table(self.table_name, schema=self.schema) def fetch_metadata(self): """Fetches the metadata for the table and merges it in""" try: table = self.get_sqla_table_object() except Exception as e: logging.exception(e) raise Exception( _("Table [{}] doesn't seem to exist in the specified database, " "couldn't fetch column information").format(self.table_name)) M = SqlMetric # noqa metrics = [] any_date_col = None db_dialect = self.database.get_dialect() dbcols = (db.session.query(TableColumn).filter( TableColumn.table == self).filter( or_(TableColumn.column_name == col.name for col in table.columns))) dbcols = {dbcol.column_name: dbcol for dbcol in dbcols} db_engine_spec = self.database.db_engine_spec for col in table.columns: try: datatype = col.type.compile(dialect=db_dialect).upper() except Exception as e: datatype = 'UNKNOWN' logging.error('Unrecognized data type in {}.{}'.format( table, col.name)) logging.exception(e) dbcol = dbcols.get(col.name, None) if not dbcol: dbcol = TableColumn(column_name=col.name, type=datatype) dbcol.groupby = dbcol.is_string dbcol.filterable = dbcol.is_string dbcol.sum = dbcol.is_num dbcol.avg = dbcol.is_num dbcol.is_dttm = dbcol.is_time else: dbcol.type = datatype self.columns.append(dbcol) if not any_date_col and dbcol.is_time: any_date_col = col.name metrics += dbcol.get_metrics().values() metrics.append( M( metric_name='count', verbose_name='COUNT(*)', metric_type='count', expression='COUNT(*)', )) if not self.main_dttm_col: self.main_dttm_col = any_date_col for metric in metrics: metric.metric_name = db_engine_spec.mutate_expression_label( metric.metric_name) self.add_missing_metrics(metrics) db.session.merge(self) db.session.commit() @classmethod def import_obj(cls, i_datasource, import_time=None): """Imports the datasource from the object to the database. Metrics and columns and datasource will be overrided if exists. This function can be used to import/export dashboards between multiple superset instances. Audit metadata isn't copies over. """ def lookup_sqlatable(table): return db.session.query(SqlaTable).join(Database).filter( SqlaTable.table_name == table.table_name, SqlaTable.schema == table.schema, Database.id == table.database_id, ).first() def lookup_database(table): return db.session.query(Database).filter_by( database_name=table.params_dict['database_name']).one() return import_util.import_datasource(db.session, i_datasource, lookup_database, lookup_sqlatable, import_time) @classmethod def query_datasources_by_name(cls, session, database, datasource_name, schema=None): query = (session.query(cls).filter_by( database_id=database.id).filter_by(table_name=datasource_name)) if schema: query = query.filter_by(schema=schema) return query.all() @staticmethod def default_query(qry): return qry.filter_by(is_sqllab_view=False)
class Attachment(__db.Base): __tablename__ = "attachment" idAttachment = Column(Integer, primary_key=True) title = Column(String(255)) uuid = Column(String(512)) idProject = Column(Integer, ForeignKey('project.idProject')) idOwner = Column(Integer, ForeignKey('student.idStudent')) advisorApproved = Column(Boolean) uploadDate = Column(DateTime) constraints = list() constraints.append(UniqueConstraint('idProject', 'idOwner')) if len(constraints) > 0: __table_args__ = tuple(constraints) def __init__(self, dictModel): if ("idAttachment" in dictModel) and (dictModel["idAttachment"] != None): self.idAttachment = dictModel["idAttachment"] if ("title" in dictModel) and (dictModel["title"] != None): self.title = dictModel["title"] if ("uuid" in dictModel) and (dictModel["uuid"] != None): self.uuid = dictModel["uuid"] if ("idProject" in dictModel) and (dictModel["idProject"] != None): self.idProject = dictModel["idProject"] if ("idOwner" in dictModel) and (dictModel["idOwner"] != None): self.idOwner = dictModel["idOwner"] if ("advisorApproved" in dictModel) and (dictModel["advisorApproved"] != None): self.advisorApproved = dictModel["advisorApproved"] if ("uploadDate" in dictModel) and (dictModel["uploadDate"] != None): self.uploadDate = dictModel["uploadDate"] def __repr__(self): return '<Attachment idAttachment={} title={} uuid={} idProject={} idOwner={} advisorApproved={} uploadDate={} >'.format( self.idAttachment, self.title, self.uuid, self.idProject, self.idOwner, self.advisorApproved, self.uploadDate, ) def json(self): return { "idAttachment": self.idAttachment, "title": self.title, "uuid": self.uuid, "idProject": self.idProject, "idOwner": self.idOwner, "advisorApproved": self.advisorApproved, "uploadDate": self.uploadDate, } def update(self, dictModel): if ("idAttachment" in dictModel) and (dictModel["idAttachment"] != None): self.idAttachment = dictModel["idAttachment"] if ("title" in dictModel) and (dictModel["title"] != None): self.title = dictModel["title"] if ("uuid" in dictModel) and (dictModel["uuid"] != None): self.uuid = dictModel["uuid"] if ("idProject" in dictModel) and (dictModel["idProject"] != None): self.idProject = dictModel["idProject"] if ("idOwner" in dictModel) and (dictModel["idOwner"] != None): self.idOwner = dictModel["idOwner"] if ("advisorApproved" in dictModel) and (dictModel["advisorApproved"] != None): self.advisorApproved = dictModel["advisorApproved"] if ("uploadDate" in dictModel) and (dictModel["uploadDate"] != None): self.uploadDate = dictModel["uploadDate"]
class SystemBaseline(db.Model): __tablename__ = "system_baselines" # do not allow two records in the same account to have the same display name __table_args__ = (UniqueConstraint("account", "display_name", name="_account_display_name_uc"), ) id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) account = db.Column(db.String(10), nullable=False) display_name = db.Column(db.String(200), nullable=False) created_on = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) modified_on = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) baseline_facts = db.Column(JSONB) mapped_systems = relationship( "SystemBaselineMappedSystem", cascade="all, delete, delete-orphan", ) @property def fact_count(self): return len(self.baseline_facts) @validates("baseline_facts") def validate_facts(self, key, value): validators.check_facts_length(value) validators.check_for_duplicate_names(value) return value def mapped_system_ids(self): mapped_system_ids = [] for mapped_system in self.mapped_systems: mapped_system_ids.append(str(mapped_system.system_id)) return mapped_system_ids def to_json(self, withhold_facts=False, withhold_system_ids=True): json_dict = {} json_dict["id"] = str(self.id) json_dict["account"] = self.account json_dict["display_name"] = self.display_name json_dict["fact_count"] = self.fact_count json_dict["created"] = self.created_on.isoformat() + "Z" json_dict["updated"] = self.modified_on.isoformat() + "Z" if not withhold_facts: json_dict["baseline_facts"] = self.baseline_facts if not withhold_system_ids: json_dict["system_ids"] = self.mapped_system_ids() return json_dict def add_mapped_system(self, system_id): new_mapped_system = SystemBaselineMappedSystem(system_id=system_id, account=self.account) self.mapped_systems.append(new_mapped_system) db.session.add(new_mapped_system) def remove_mapped_system(self, system_id): system_id_removed = False for mapped_system in self.mapped_systems: if str(mapped_system.system_id) == str(system_id): self.mapped_systems.remove(mapped_system) system_id_removed = True break if not system_id_removed: # do we want to raise exception here? raise ValueError( "Failed to remove system id %s from mapped systems - not in list" % system_id)
class UserTestResult(Base): """Class to store the execution results of a user_test. """ # Possible statuses of a user test result. COMPILING and # EVALUATING do not necessarily imply we are going to schedule # compilation and run for these user test results: for # example, they might be for datasets not scheduled for # evaluation, or they might have passed the maximum number of # tries. If a user test result does not exists for a pair # (user test, dataset), its status can be implicitly assumed to # be COMPILING. COMPILING = 1 COMPILATION_FAILED = 2 EVALUATING = 3 EVALUATED = 4 __tablename__ = 'user_test_results' __table_args__ = (UniqueConstraint('user_test_id', 'dataset_id'), ) # Primary key is (user_test_id, dataset_id). user_test_id = Column(Integer, ForeignKey(UserTest.id, onupdate="CASCADE", ondelete="CASCADE"), primary_key=True) user_test = relationship(UserTest, back_populates="results") dataset_id = Column(Integer, ForeignKey(Dataset.id, onupdate="CASCADE", ondelete="CASCADE"), primary_key=True) dataset = relationship(Dataset) # Now below follow the actual result fields. # Output file's digest for this test output = Column(Digest, nullable=True) # Compilation outcome (can be None = yet to compile, "ok" = # compilation successful and we can evaluate, "fail" = # compilation unsuccessful, throw it away). compilation_outcome = Column(String, nullable=True) # The output from the sandbox (to allow localization the first item # of the list is a format string, possibly containing some "%s", # that will be filled in using the remaining items of the list). compilation_text = Column(ARRAY(String), nullable=False, default=[]) # Number of attempts of compilation. compilation_tries = Column(Integer, nullable=False, default=0) # The compiler stdout and stderr. compilation_stdout = Column(Unicode, nullable=True) compilation_stderr = Column(Unicode, nullable=True) # Other information about the compilation. compilation_time = Column(Float, nullable=True) compilation_wall_clock_time = Column(Float, nullable=True) compilation_memory = Column(BigInteger, nullable=True) # Worker shard and sandbox where the compilation was performed. compilation_shard = Column(Integer, nullable=True) compilation_sandbox = Column(String, nullable=True) # Evaluation outcome (can be None = yet to evaluate, "ok" = # evaluation successful). evaluation_outcome = Column(String, nullable=True) # The output from the grader, usually "Correct", "Time limit", ... # (to allow localization the first item of the list is a format # string, possibly containing some "%s", that will be filled in # using the remaining items of the list). evaluation_text = Column(ARRAY(String), nullable=False, default=[]) # Number of attempts of evaluation. evaluation_tries = Column(Integer, nullable=False, default=0) # Other information about the execution. execution_time = Column(Float, nullable=True) execution_wall_clock_time = Column(Float, nullable=True) execution_memory = Column(BigInteger, nullable=True) # Worker shard and sandbox where the evaluation was performed. evaluation_shard = Column(Integer, nullable=True) evaluation_sandbox = Column(String, nullable=True) # These one-to-many relationships are the reversed directions of # the ones defined in the "child" classes using foreign keys. executables = relationship( "UserTestExecutable", collection_class=attribute_mapped_collection("filename"), cascade="all, delete-orphan", passive_deletes=True, back_populates="user_test_result") def get_status(self): """Return the status of this object. """ if not self.compiled(): return UserTestResult.COMPILING elif self.compilation_failed(): return UserTestResult.COMPILATION_FAILED elif not self.evaluated(): return UserTestResult.EVALUATING else: return UserTestResult.EVALUATED def compiled(self): """Return whether the user test result has been compiled. return (bool): True if compiled, False otherwise. """ return self.compilation_outcome is not None @staticmethod def filter_compiled(): """Return a filtering expression for compiled user test results. """ return UserTestResult.compilation_outcome.isnot(None) def compilation_failed(self): """Return whether the user test result did not compile. return (bool): True if the compilation failed (in the sense that there is a problem in the user's source), False if not yet compiled or compilation was successful. """ return self.compilation_outcome == "fail" @staticmethod def filter_compilation_failed(): """Return a filtering expression for user test results failing compilation. """ return UserTestResult.compilation_outcome == "fail" def compilation_succeeded(self): """Return whether the user test compiled. return (bool): True if the compilation succeeded (in the sense that an executable was created), False if not yet compiled or compilation was unsuccessful. """ return self.compilation_outcome == "ok" @staticmethod def filter_compilation_succeeded(): """Return a filtering expression for user test results failing compilation. """ return UserTestResult.compilation_outcome == "ok" def evaluated(self): """Return whether the user test result has been evaluated. return (bool): True if evaluated, False otherwise. """ return self.evaluation_outcome is not None @staticmethod def filter_evaluated(): """Return a filtering lambda for evaluated user test results. """ return UserTestResult.evaluation_outcome.isnot(None) def invalidate_compilation(self): """Blank all compilation and evaluation outcomes. """ self.invalidate_evaluation() self.compilation_outcome = None self.compilation_text = [] self.compilation_tries = 0 self.compilation_time = None self.compilation_wall_clock_time = None self.compilation_memory = None self.compilation_shard = None self.compilation_sandbox = None self.executables = {} def invalidate_evaluation(self): """Blank the evaluation outcome. """ self.evaluation_outcome = None self.evaluation_text = [] self.evaluation_tries = 0 self.execution_time = None self.execution_wall_clock_time = None self.execution_memory = None self.evaluation_shard = None self.evaluation_sandbox = None self.output = None def set_compilation_outcome(self, success): """Set the compilation outcome based on the success. success (bool): if the compilation was successful. """ self.compilation_outcome = "ok" if success else "fail" def set_evaluation_outcome(self): """Set the evaluation outcome (always ok now). """ self.evaluation_outcome = "ok"
class Permission(Base): """A class to hold permissions. Permissions in Stalker defines what one can do or do not. A Permission instance is composed by three attributes; access, action and class_name. Permissions for all the classes in SOM are generally created by Stalker when initializing the database. If you created any custom classes to extend SOM you are also responsible to create the Permissions for it by calling :meth:`stalker.db.register` and passing your class to it. See the :mod:`stalker.db` documentation for details. :param str access: An Enum value which can have the one of the values of ``Allow`` or ``Deny``. :param str action: An Enum value from the list ['Create', 'Read', 'Update', 'Delete', 'List']. Can not be None. The list can be changed from stalker.config.Config.default_actions. :param str class_name: The name of the class that this action is applied to. Can not be None or an empty string. Example: Let say that you want to create a Permission specifying a Group of Users are allowed to create Projects:: from stalker import db from stalker import db from stalker.models.auth import User, Group, Permission # first setup the db with the default database # # stalker.db.init() will create all the Actions possible with the # SOM classes automatically # # What is left to you is to create the permissions db.setup() user1 = User( name='Test User', login='******', password='******', email='*****@*****.**' ) user2 = User( name='Test User', login='******', password='******', email='*****@*****.**' ) group1 = Group(name='users') group1.users = [user1, user2] # get the permissions for the Project class project_permissions = Permission.query\ .filter(Permission.access='Allow')\ .filter(Permission.action='Create')\ .filter(Permission.class_name='Project')\ .first() # now we have the permission specifying the allowance of creating a # Project # to make group1 users able to create a Project we simply add this # Permission to the groups permission attribute group1.permissions.append(permission) # and persist this information in the database DBSession.add(group) DBSession.commit() """ __tablename__ = 'Permissions' __table_args__ = (UniqueConstraint('access', 'action', 'class_name'), { "extend_existing": True }) id = Column(Integer, primary_key=True) _access = Column('access', Enum('Allow', 'Deny', name='AccessNames')) _action = Column('action', Enum(*defaults.actions, name='ActionNames')) _class_name = Column('class_name', String(32)) def __init__(self, access, action, class_name): self._access = self._validate_access(access) self._action = self._validate_action(action) self._class_name = self._validate_class_name(class_name) def _validate_access(self, access): """validates the given access value """ from stalker import __string_types__ if not isinstance(access, __string_types__): raise TypeError( '%s.access should be an instance of str not %s' % (self.__class__.__name__, access.__class__.__name__)) if access not in ['Allow', 'Deny']: raise ValueError('%s.access should be "Allow" or "Deny" not %s' % (self.__class__.__name__, access)) return access def _access_getter(self): """returns the _access value """ return self._access access = synonym('_access', descriptor=property(_access_getter)) def _validate_class_name(self, class_name): """validates the given class_name value """ from stalker import __string_types__ if not isinstance(class_name, __string_types__): raise TypeError( '%s.class_name should be an instance of str not %s' % (self.__class__.__name__, class_name.__class__.__name__)) return class_name def _class_name_getter(self): """returns the _class_name attribute value """ return self._class_name class_name = synonym('_class_name', descriptor=property(_class_name_getter)) def _validate_action(self, action): """validates the given action value """ from stalker import __string_types__ if not isinstance(action, __string_types__): raise TypeError( '%s.action should be an instance of str not %s' % (self.__class__.__name__, action.__class__.__name__)) if action not in defaults.actions: raise ValueError( '%s.action should be one of the values of %s not %s' % (self.__class__.__name__, defaults.actions, action)) return action def _action_getter(self): """returns the _action value """ return self._action action = synonym('_action', descriptor=property(_action_getter)) def __eq__(self, other): """the equality of two Permissions """ return isinstance(other, Permission) \ and other.access == self.access \ and other.action == self.action \ and other.class_name == self.class_name
Column('description', String(length=255)), Column('tenant_id', String(length=64), nullable=True), Column('datastore_id', String(length=64), nullable=True), Column('datastore_version_id', String(length=64), nullable=True), Column('auto_apply', Boolean(), default=0, nullable=False), Column('visible', Boolean(), default=1, nullable=False), Column('live_update', Boolean(), default=0, nullable=False), Column('md5', String(length=32), nullable=False), Column('created', DateTime(), nullable=False), Column('updated', DateTime(), nullable=False), Column('deleted', Boolean(), default=0, nullable=False), Column('deleted_at', DateTime()), UniqueConstraint('type', 'tenant_id', 'datastore_id', 'datastore_version_id', 'name', 'deleted_at', name='UQ_type_tenant_datastore_datastore_version_name'), ) instance_modules = Table( 'instance_modules', meta, Column('id', String(length=64), primary_key=True, nullable=False), Column('instance_id', String(length=64), ForeignKey('instances.id', ondelete="CASCADE", onupdate="CASCADE"), nullable=False), Column('module_id', String(length=64),
class TestCase(db.Model): """ A single run of a single test, together with any captured output, retry-count and its return value. Every test that gets run ever has a row in this table. At the time this was written, it seems to have 400-500M rows (how is this still surviving?) NOTE: DO NOT MODIFY THIS TABLE! Running migration on this table has caused unavailability in the past. If you need to add a new column, consider doing that on a new table and linking it back to tests via the ID. """ __tablename__ = 'test' __table_args__ = ( UniqueConstraint('job_id', 'label_sha', name='unq_test_name'), Index('idx_test_step_id', 'step_id'), Index('idx_test_project_key', 'project_id', 'label_sha'), Index('idx_task_date_created', 'date_created'), Index('idx_test_project_key_date', 'project_id', 'label_sha', 'date_created'), ) id = Column(GUID, nullable=False, primary_key=True, default=uuid.uuid4) job_id = Column(GUID, ForeignKey('job.id', ondelete="CASCADE"), nullable=False) project_id = Column(GUID, ForeignKey('project.id', ondelete="CASCADE"), nullable=False) step_id = Column(GUID, ForeignKey('jobstep.id', ondelete="CASCADE")) name_sha = Column('label_sha', String(40), nullable=False) name = Column(Text, nullable=False) _package = Column('package', Text, nullable=True) result = Column(Enum(Result), default=Result.unknown, nullable=False) duration = Column(Integer, default=0) message = deferred(Column(Text)) date_created = Column(DateTime, default=datetime.utcnow, nullable=False) reruns = Column(Integer) # owner should be considered an unstructured string field. It may contain # email address ("Foo <*****@*****.**>", a username ("foo"), or something # else. This field is not used directly by Changes, so # providers + consumers on either side of Changes should be sure they know # what they're doing. owner = Column(Text) job = relationship('Job') step = relationship('JobStep') project = relationship('Project') __repr__ = model_repr('name', '_package', 'result') def __init__(self, **kwargs): super(TestCase, self).__init__(**kwargs) if self.id is None: self.id = uuid.uuid4() if self.result is None: self.result = Result.unknown if self.date_created is None: self.date_created = datetime.utcnow() @classmethod def calculate_name_sha(self, name): if name: return sha1(name).hexdigest() raise ValueError @property def sep(self): name = (self._package or self.name) # handle the case where it might begin with some special character if not re.match(r'^[a-zA-Z0-9]', name): return '/' elif '/' in name: return '/' return '.' def _get_package(self): if not self._package: try: package, _ = self.name.rsplit(self.sep, 1) except ValueError: package = None else: package = self._package return package def _set_package(self, value): self._package = value package = property(_get_package, _set_package) @property def short_name(self): name, package = self.name, self.package if package and name.startswith(package) and name != package: return name[len(package) + 1:] return name
class Category(MailSyncBase, HasRevisions, HasPublicID, UpdatedAtMixin, DeletedAtMixin): @property def API_OBJECT_NAME(self): return self.type_ # Override the default `deleted_at` column with one that is NOT NULL -- # Category.deleted_at is needed in a UniqueConstraint. # Set the default Category.deleted_at = EPOCH instead. deleted_at = Column(DateTime, index=True, nullable=False, default='1970-01-01 00:00:00') # Need `use_alter` here to avoid circular dependencies namespace_id = Column(ForeignKey('namespace.id', use_alter=True, name='category_fk1', ondelete='CASCADE'), nullable=False) namespace = relationship('Namespace', load_on_pending=True) # STOPSHIP(emfree): need to index properly for API filtering performance. name = Column(String(MAX_INDEXABLE_LENGTH), nullable=False, default='') display_name = Column(CategoryNameString(), nullable=False) type_ = Column(Enum('folder', 'label'), nullable=False, default='folder') @validates('display_name') def validate_display_name(self, key, display_name): sanitized_name = sanitize_name(display_name) if sanitized_name != display_name: log.warning("Truncating category display_name", type_=self.type_, original=display_name) return sanitized_name @classmethod def find_or_create(cls, session, namespace_id, name, display_name, type_): name = name or '' objects = session.query(cls).filter( cls.namespace_id == namespace_id, cls.display_name == display_name).all() if not objects: obj = cls(namespace_id=namespace_id, name=name, display_name=display_name, type_=type_, deleted_at=EPOCH) session.add(obj) elif len(objects) == 1: obj = objects[0] if not obj.name: # There is an existing category with this `display_name` and no # `name`, so update it's `name` as needed. # This is needed because the first time we sync generic IMAP # folders, they may initially have `name` == '' but later they may # get a `name`. At this point, it *is* the same folder so we # merely want to update its `name`, not create a new one. obj.name = name else: log.error('Duplicate category rows for namespace_id {}, ' 'name {}, display_name: {}'.format( namespace_id, name, display_name)) raise MultipleResultsFound( 'Duplicate category rows for namespace_id {}, name {}, ' 'display_name: {}'.format(namespace_id, name, display_name)) return obj @classmethod def create(cls, session, namespace_id, name, display_name, type_): name = name or '' obj = cls(namespace_id=namespace_id, name=name, display_name=display_name, type_=type_, deleted_at=EPOCH) session.add(obj) return obj @property def account(self): return self.namespace.account @property def type(self): return self.account.category_type @hybrid_property def lowercase_name(self): return self.display_name.lower() @lowercase_name.comparator def lowercase_name(cls): return CaseInsensitiveComparator(cls.display_name) @property def api_display_name(self): if self.namespace.account.provider == 'gmail': if self.display_name.startswith('[Gmail]/'): return self.display_name[8:] elif self.display_name.startswith('[Google Mail]/'): return self.display_name[14:] if self.namespace.account.provider not in ['gmail', 'eas']: return fs_folder_path( self.display_name, separator=self.namespace.account.folder_separator, prefix=self.namespace.account.folder_prefix) return self.display_name @property def is_deleted(self): return self.deleted_at > EPOCH __table_args__ = (UniqueConstraint('namespace_id', 'name', 'display_name', 'deleted_at'), UniqueConstraint('namespace_id', 'public_id'))
def revert_username_index(db): """ Revert the stuff we did in migration 22 above. There were a couple of problems with what we did: - There was never a need for this migration! The unique constraint had an implicit b-tree index, so it wasn't really needed. (This is my (Chris Webber's) fault for suggesting it needed to happen without knowing what's going on... my bad!) - On top of that, databases created after the models.py was changed weren't the same as those that had been run through migration 22 above. As such, we're setting things back to the way they were before, but as it turns out, that's tricky to do! """ metadata = MetaData(bind=db.bind) user_table = inspect_table(metadata, "core__users") indexes = dict( [(index.name, index) for index in user_table.indexes]) # index from unnecessary migration users_uploader_index = indexes.get(u'ix_core__users_uploader') # index created from models.py after (unique=True, index=True) # was set in models.py users_username_index = indexes.get(u'ix_core__users_username') if users_uploader_index is None and users_username_index is None: # We don't need to do anything. # The database isn't in a state where it needs fixing # # (ie, either went through the previous borked migration or # was initialized with a models.py where core__users was both # unique=True and index=True) return if db.bind.url.drivername == 'sqlite': # Again, sqlite has problems. So this is tricky. # Yes, this is correct to use User_vR1! Nothing has changed # between the *correct* version of this table and migration 18. User_vR1.__table__.create(db.bind) db.commit() new_user_table = inspect_table(metadata, 'rename__users') replace_table_hack(db, user_table, new_user_table) else: # If the db is not run using SQLite, we don't need to do crazy # table copying. # Remove whichever of the not-used indexes are in place if users_uploader_index is not None: users_uploader_index.drop() if users_username_index is not None: users_username_index.drop() # Given we're removing indexes then adding a unique constraint # which *we know might fail*, thus probably rolling back the # session, let's commit here. db.commit() try: # Add the unique constraint constraint = UniqueConstraint( 'username', table=user_table) constraint.create() except ProgrammingError: # constraint already exists, no need to add db.rollback() db.commit()
class JobPhase(db.Model): """A JobPhase is a grouping of one or more JobSteps performing the same basic task. The phases of a Job are intended to be executed sequentially, though that isn't necessarily enforced. One example of phase usage: a Job may have a test collection phase and a test execution phase, with a single JobStep collecting tests in the first phase and an arbitrary number of JobSteps executing shards of the collected tests in the second phase. By using two phases, the types of JobSteps can be tracked and managed independently. Though JobPhases are typically created to group newly created JobSteps, they can also be constructed retroactively once a JobStep has finished based on phased artifacts. This is convenient but a little confusing, and perhaps should be handled by another mechanism. """ # TODO(dcramer): add order column rather than implicity date_started ordering # TODO(dcramer): make duration a column __tablename__ = 'jobphase' __table_args__ = ( UniqueConstraint('job_id', 'label', name='unq_jobphase_key'), ) id = Column(GUID, nullable=False, primary_key=True, default=uuid.uuid4) job_id = Column(GUID, ForeignKey('job.id', ondelete="CASCADE"), nullable=False) project_id = Column(GUID, ForeignKey('project.id', ondelete="CASCADE"), nullable=False) label = Column(String(128), nullable=False) status = Column(Enum(Status), nullable=False, default=Status.unknown) result = Column(Enum(Result), nullable=False, default=Result.unknown) date_started = Column(DateTime) date_finished = Column(DateTime) date_created = Column(DateTime, default=datetime.utcnow) job = relationship('Job', backref=backref('phases', order_by='JobPhase.date_started')) project = relationship('Project') def __init__(self, **kwargs): super(JobPhase, self).__init__(**kwargs) if self.id is None: self.id = uuid.uuid4() if self.result is None: self.result = Result.unknown if self.status is None: self.status = Status.unknown if self.date_created is None: self.date_created = datetime.utcnow() @property def duration(self): """ Return the duration (in milliseconds) that this item was in-progress. """ if self.date_started and self.date_finished: duration = (self.date_finished - self.date_started).total_seconds() * 1000 else: duration = None return duration @property def current_steps(self): """ Return only steps from this phase that have not been replaced. """ # note that the self.steps property exists because of a backref in JobStep return [s for s in self.steps if s.replacement_id is None]