class Relation(Base): """ Container class for event relations """ event_id = Column('event_id', BigInteger, ForeignKey('events.event_id', onupdate='cascade', ondelete='cascade'), nullable=False, index=True) event = relationship("Event", uselist=False, primaryjoin='Event.identifier==Relation.event_id') rel_event_id = Column('rel_event_id', BigInteger, ForeignKey('events.event_id', onupdate='cascade', ondelete='cascade'), nullable=False, index=True) rel_event = relationship( "Event", uselist=False, primaryjoin='Event.identifier==Relation.rel_event_id') attribute_id = Column('attribute_id', BigInteger, ForeignKey('attributes.attribute_id', onupdate='cascade', ondelete='cascade'), nullable=False, index=True) attribute = relationship( "Attribute", uselist=False, primaryjoin='Attribute.identifier==Relation.attribute_id') rel_attribute_id = Column('rel_attribute_id', BigInteger, ForeignKey('attributes.attribute_id', onupdate='cascade', ondelete='cascade'), nullable=False, index=True) rel_attribute = relationship( 'Attribute', uselist=False, primaryjoin='Attribute.identifier==Relation.rel_attribute_id') UniqueConstraint('event_id', 'attribute_id', 'rel_event_id', 'rel_attribute_id') def to_dict(self, complete=True, inflated=False, event_permissions=None, user=None): return { 'event': self.event.to_dict(complete, inflated, event_permissions, user), 'rel_event': self.rel_event.to_dict(complete, inflated, event_permissions, user), 'attribute': self.attribute.to_dict(complete, inflated, event_permissions, user), 'rel_attribute': self.rel_attribute.to_dict(complete, inflated, event_permissions, user), } def validate(self): """ Returns true if the object is valid """ return True
class Database(Model, AuditMixinNullable, ImportMixin): """An ORM object that stores Database related information""" __tablename__ = "dbs" type = "table" __table_args__ = (UniqueConstraint("database_name"), ) id = Column(Integer, primary_key=True) verbose_name = Column(String(250), unique=True) # short unique name, used in permissions database_name = Column(String(250), unique=True) sqlalchemy_uri = Column(String(1024)) password = Column(EncryptedType(String(1024), config.get("SECRET_KEY"))) cache_timeout = Column(Integer) select_as_create_table_as = Column(Boolean, default=False) expose_in_sqllab = Column(Boolean, default=True) allow_run_async = Column(Boolean, default=False) allow_csv_upload = Column(Boolean, default=False) allow_ctas = Column(Boolean, default=False) allow_dml = Column(Boolean, default=False) force_ctas_schema = Column(String(250)) allow_multi_schema_metadata_fetch = Column(Boolean, default=False) extra = Column( Text, default=textwrap.dedent("""\ { "metadata_params": {}, "engine_params": {}, "metadata_cache_timeout": {}, "schemas_allowed_for_csv_upload": [] } """), ) perm = Column(String(1000)) impersonate_user = Column(Boolean, default=False) export_fields = ( "database_name", "sqlalchemy_uri", "cache_timeout", "expose_in_sqllab", "allow_run_async", "allow_ctas", "allow_csv_upload", "extra", ) export_children = ["tables"] def __repr__(self): return self.verbose_name if self.verbose_name else self.database_name @property def name(self): return self.verbose_name if self.verbose_name else self.database_name @property def allows_subquery(self): return self.db_engine_spec.allows_subquery @property def data(self): return { "id": self.id, "name": self.database_name, "backend": self.backend, "allow_multi_schema_metadata_fetch": self.allow_multi_schema_metadata_fetch, "allows_subquery": self.allows_subquery, } @property def unique_name(self): return self.database_name @property def url_object(self): return make_url(self.sqlalchemy_uri_decrypted) @property def backend(self): url = make_url(self.sqlalchemy_uri_decrypted) return url.get_backend_name() @property def metadata_cache_timeout(self): return self.get_extra().get("metadata_cache_timeout", {}) @property def schema_cache_enabled(self): return "schema_cache_timeout" in self.metadata_cache_timeout @property def schema_cache_timeout(self): return self.metadata_cache_timeout.get("schema_cache_timeout") @property def table_cache_enabled(self): return "table_cache_timeout" in self.metadata_cache_timeout @property def table_cache_timeout(self): return self.metadata_cache_timeout.get("table_cache_timeout") @property def default_schemas(self): return self.get_extra().get("default_schemas", []) @classmethod def get_password_masked_url_from_uri(cls, uri): url = make_url(uri) return cls.get_password_masked_url(url) @classmethod def get_password_masked_url(cls, url): url_copy = deepcopy(url) if url_copy.password is not None and url_copy.password != PASSWORD_MASK: url_copy.password = PASSWORD_MASK return url_copy def set_sqlalchemy_uri(self, uri): conn = sqla.engine.url.make_url(uri.strip()) if conn.password != PASSWORD_MASK and not custom_password_store: # do not over-write the password with the password mask self.password = conn.password conn.password = PASSWORD_MASK if conn.password else None self.sqlalchemy_uri = str(conn) # hides the password def get_effective_user(self, url, user_name=None): """ Get the effective user, especially during impersonation. :param url: SQL Alchemy URL object :param user_name: Default username :return: The effective username """ effective_username = None if self.impersonate_user: effective_username = url.username if user_name: effective_username = user_name elif (hasattr(g, "user") and hasattr(g.user, "username") and g.user.username is not None): effective_username = g.user.username return effective_username @utils.memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted", "extra")) def get_sqla_engine(self, schema=None, nullpool=True, user_name=None, source=None): extra = self.get_extra() url = make_url(self.sqlalchemy_uri_decrypted) url = self.db_engine_spec.adjust_database_uri(url, schema) effective_username = self.get_effective_user(url, user_name) # If using MySQL or Presto for example, will set url.username # If using Hive, will not do anything yet since that relies on a # configuration parameter instead. self.db_engine_spec.modify_url_for_impersonation( url, self.impersonate_user, effective_username) masked_url = self.get_password_masked_url(url) logging.info( "Database.get_sqla_engine(). Masked URL: {0}".format(masked_url)) params = extra.get("engine_params", {}) if nullpool: params["poolclass"] = NullPool # If using Hive, this will set hive.server2.proxy.user=$effective_username configuration = {} configuration.update( self.db_engine_spec.get_configuration_for_impersonation( str(url), self.impersonate_user, effective_username)) if configuration: d = params.get("connect_args", {}) d["configuration"] = configuration params["connect_args"] = d DB_CONNECTION_MUTATOR = config.get("DB_CONNECTION_MUTATOR") if DB_CONNECTION_MUTATOR: url, params = DB_CONNECTION_MUTATOR(url, params, effective_username, security_manager, source) return create_engine(url, **params) def get_reserved_words(self): return self.get_dialect().preparer.reserved_words def get_quoter(self): return self.get_dialect().identifier_preparer.quote def get_df(self, sql, schema, mutator=None): sqls = [str(s).strip().strip(";") for s in sqlparse.parse(sql)] source_key = None if request and request.referrer: if "/superset/dashboard/" in request.referrer: source_key = "dashboard" elif "/superset/explore/" in request.referrer: source_key = "chart" engine = self.get_sqla_engine(schema=schema, source=utils.sources.get( source_key, None)) username = utils.get_username() def needs_conversion(df_series): if df_series.empty: return False if isinstance(df_series[0], (list, dict)): return True return False def _log_query(sql): if log_query: log_query(engine.url, sql, schema, username, __name__, security_manager) with closing(engine.raw_connection()) as conn: with closing(conn.cursor()) as cursor: for sql in sqls[:-1]: _log_query(sql) self.db_engine_spec.execute(cursor, sql) cursor.fetchall() _log_query(sqls[-1]) self.db_engine_spec.execute(cursor, sqls[-1]) if cursor.description is not None: columns = [col_desc[0] for col_desc in cursor.description] else: columns = [] df = pd.DataFrame.from_records(data=list(cursor.fetchall()), columns=columns, coerce_float=True) if mutator: df = mutator(df) for k, v in df.dtypes.items(): if v.type == numpy.object_ and needs_conversion(df[k]): df[k] = df[k].apply(utils.json_dumps_w_dates) return df def compile_sqla_query(self, qry, schema=None): engine = self.get_sqla_engine(schema=schema) sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) if engine.dialect.identifier_preparer._double_percents: sql = sql.replace("%%", "%") return sql def select_star( self, table_name, schema=None, limit=100, show_cols=False, indent=True, latest_partition=False, cols=None, ): """Generates a ``select *`` statement in the proper dialect""" eng = self.get_sqla_engine(schema=schema, source=utils.sources.get("sql_lab", None)) return self.db_engine_spec.select_star( self, table_name, schema=schema, engine=eng, limit=limit, show_cols=show_cols, indent=indent, latest_partition=latest_partition, cols=cols, ) def apply_limit_to_sql(self, sql, limit=1000): return self.db_engine_spec.apply_limit_to_sql(sql, limit, self) def safe_sqlalchemy_uri(self): return self.sqlalchemy_uri @property def inspector(self): engine = self.get_sqla_engine() return sqla.inspect(engine) @cache_util.memoized_func( key=lambda *args, **kwargs: "db:{}:schema:None:table_list", attribute_in_key="id", ) def get_all_table_names_in_database( self, cache: bool = False, cache_timeout: bool = None, force=False) -> List[utils.DatasourceName]: """Parameters need to be passed as keyword arguments.""" if not self.allow_multi_schema_metadata_fetch: return [] return self.db_engine_spec.get_all_datasource_names(self, "table") @cache_util.memoized_func( key=lambda *args, **kwargs: "db:{}:schema:None:view_list", attribute_in_key="id") def get_all_view_names_in_database( self, cache: bool = False, cache_timeout: bool = None, force: bool = False) -> List[utils.DatasourceName]: """Parameters need to be passed as keyword arguments.""" if not self.allow_multi_schema_metadata_fetch: return [] return self.db_engine_spec.get_all_datasource_names(self, "view") @cache_util.memoized_func( key=lambda *args, **kwargs: "db:{{}}:schema:{}:table_list".format( kwargs.get("schema")), attribute_in_key="id", ) def get_all_table_names_in_schema( self, schema: str, cache: bool = False, cache_timeout: int = None, force: bool = False, ): """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in cache_util.memoized_func decorator. :param schema: schema name :param cache: whether cache is enabled for the function :param cache_timeout: timeout in seconds for the cache :param force: whether to force refresh the cache :return: list of tables """ try: tables = self.db_engine_spec.get_table_names( inspector=self.inspector, schema=schema) return [ utils.DatasourceName(table=table, schema=schema) for table in tables ] except Exception as e: logging.exception(e) @cache_util.memoized_func( key=lambda *args, **kwargs: "db:{{}}:schema:{}:view_list".format( kwargs.get("schema")), attribute_in_key="id", ) def get_all_view_names_in_schema( self, schema: str, cache: bool = False, cache_timeout: int = None, force: bool = False, ): """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in cache_util.memoized_func decorator. :param schema: schema name :param cache: whether cache is enabled for the function :param cache_timeout: timeout in seconds for the cache :param force: whether to force refresh the cache :return: list of views """ try: views = self.db_engine_spec.get_view_names( inspector=self.inspector, schema=schema) return [ utils.DatasourceName(table=view, schema=schema) for view in views ] except Exception as e: logging.exception(e) @cache_util.memoized_func(key=lambda *args, **kwargs: "db:{}:schema_list", attribute_in_key="id") def get_all_schema_names(self, cache: bool = False, cache_timeout: int = None, force: bool = False) -> List[str]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in cache_util.memoized_func decorator. :param cache: whether cache is enabled for the function :param cache_timeout: timeout in seconds for the cache :param force: whether to force refresh the cache :return: schema list """ return self.db_engine_spec.get_schema_names(self.inspector) @property def db_engine_spec(self): return db_engine_specs.engines.get(self.backend, db_engine_specs.BaseEngineSpec) @classmethod def get_db_engine_spec_for_backend(cls, backend): return db_engine_specs.engines.get(backend, db_engine_specs.BaseEngineSpec) def grains(self): """Defines time granularity database-specific expressions. The idea here is to make it easy for users to change the time grain from a datetime (maybe the source grain is arbitrary timestamps, daily or 5 minutes increments) to another, "truncated" datetime. Since each database has slightly different but similar datetime functions, this allows a mapping between database engines and actual functions. """ return self.db_engine_spec.get_time_grains() def get_extra(self): extra = {} if self.extra: try: extra = json.loads(self.extra) except Exception as e: logging.error(e) raise e return extra def get_table(self, table_name, schema=None): extra = self.get_extra() meta = MetaData(**extra.get("metadata_params", {})) return Table( table_name, meta, schema=schema or None, autoload=True, autoload_with=self.get_sqla_engine(), ) def get_columns(self, table_name, schema=None): return self.db_engine_spec.get_columns(self.inspector, table_name, schema) def get_indexes(self, table_name, schema=None): return self.inspector.get_indexes(table_name, schema) def get_pk_constraint(self, table_name, schema=None): return self.inspector.get_pk_constraint(table_name, schema) def get_foreign_keys(self, table_name, schema=None): return self.inspector.get_foreign_keys(table_name, schema) def get_schema_access_for_csv_upload(self): return self.get_extra().get("schemas_allowed_for_csv_upload", []) @property def sqlalchemy_uri_decrypted(self): conn = sqla.engine.url.make_url(self.sqlalchemy_uri) if custom_password_store: conn.password = custom_password_store(conn) else: conn.password = self.password return str(conn) @property def sql_url(self): return "/superset/sql/{}/".format(self.id) def get_perm(self): return ("[{obj.database_name}].(id:{obj.id})").format(obj=self) def has_table(self, table): engine = self.get_sqla_engine() return engine.has_table(table.table_name, table.schema or None) @utils.memoized def get_dialect(self): sqla_url = url.make_url(self.sqlalchemy_uri_decrypted) return sqla_url.get_dialect()()
class ImapUid(MailSyncBase): """ Maps UIDs to their IMAP folders and per-UID flag metadata. This table is used solely for bookkeeping by the IMAP mail sync backends. """ account_id = Column(ForeignKey(ImapAccount.id, ondelete='CASCADE'), nullable=False) account = relationship(ImapAccount) message_id = Column(Integer, ForeignKey(Message.id, ondelete='CASCADE'), nullable=False) message = relationship(Message, backref=backref('imapuids', passive_deletes=True)) msg_uid = Column(BigInteger, nullable=False, index=True) folder_id = Column(Integer, ForeignKey(Folder.id, ondelete='CASCADE'), nullable=False) # We almost always need the folder name too, so eager load by default. folder = relationship(Folder, lazy='joined', backref=backref('imapuids', passive_deletes=True)) # Flags # # Message has not completed composition (marked as a draft). is_draft = Column(Boolean, server_default=false(), nullable=False) # Message has been read is_seen = Column(Boolean, server_default=false(), nullable=False) # Message is "flagged" for urgent/special attention is_flagged = Column(Boolean, server_default=false(), nullable=False) # session is the first session to have been notified about this message is_recent = Column(Boolean, server_default=false(), nullable=False) # Message has been answered is_answered = Column(Boolean, server_default=false(), nullable=False) # things like: ['$Forwarded', 'nonjunk', 'Junk'] extra_flags = Column(LittleJSON, default=[], nullable=False) # labels (Gmail-specific) g_labels = Column(JSON, default=lambda: [], nullable=True) def update_flags_and_labels(self, new_flags, x_gm_labels=None): """Sets flag and g_labels values based on the new_flags and x_gm_labels parameters. Returns True if any values have changed compared to what we previously stored.""" changed = False new_flags = set(new_flags) col_for_flag = { u'\\Draft': 'is_draft', u'\\Seen': 'is_seen', u'\\Recent': 'is_recent', u'\\Answered': 'is_answered', u'\\Flagged': 'is_flagged', } for flag, col in col_for_flag.iteritems(): prior_flag_value = getattr(self, col) new_flag_value = flag in new_flags if prior_flag_value != new_flag_value: changed = True setattr(self, col, new_flag_value) new_flags.discard(flag) extra_flags = sorted(new_flags) if extra_flags != self.extra_flags: changed = True self.extra_flags = extra_flags if x_gm_labels is not None: new_labels = sorted(x_gm_labels) if new_labels != self.g_labels: changed = True self.g_labels = new_labels # Gmail doesn't use the \Draft flag. Go figure. if '\\Draft' in x_gm_labels: if not self.is_draft: changed = True self.is_draft = True return changed @property def namespace(self): return self.imapaccount.namespace __table_args__ = (UniqueConstraint( 'folder_id', 'msg_uid', 'account_id', ), )
class Publisher(Base): __tablename__ = 'publisher' __table_args__ = (UniqueConstraint('name'), ) id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String, nullable=False) is_active = sa.Column(sa.Boolean, default=False)
class Subject(Base): __tablename__ = 'subject' __table_args__ = (UniqueConstraint('name'), ) id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String, nullable=False)
class Database(Model, AuditMixinNullable, ImportMixin): """An ORM object that stores Database related information""" __tablename__ = 'dbs' type = 'table' __table_args__ = (UniqueConstraint('database_name'), ) id = Column(Integer, primary_key=True) verbose_name = Column(String(250), unique=True) # short unique name, used in permissions database_name = Column(String(250), unique=True) sqlalchemy_uri = Column(String(1024)) password = Column(EncryptedType(String(1024), config.get('SECRET_KEY'))) cache_timeout = Column(Integer) select_as_create_table_as = Column(Boolean, default=False) expose_in_sqllab = Column(Boolean, default=False) allow_run_sync = Column(Boolean, default=True) allow_run_async = Column(Boolean, default=False) allow_ctas = Column(Boolean, default=False) allow_dml = Column(Boolean, default=False) force_ctas_schema = Column(String(250)) extra = Column(Text, default=textwrap.dedent("""\ { "metadata_params": {}, "engine_params": {} } """)) perm = Column(String(1000)) impersonate_user = Column(Boolean, default=False) export_fields = ('database_name', 'sqlalchemy_uri', 'cache_timeout', 'expose_in_sqllab', 'allow_run_sync', 'allow_run_async', 'allow_ctas', 'extra') export_children = ['tables'] def __repr__(self): return self.verbose_name if self.verbose_name else self.database_name @property def name(self): return self.verbose_name if self.verbose_name else self.database_name @property def data(self): return { 'name': self.database_name, 'backend': self.backend, } @property def unique_name(self): return self.database_name @property def backend(self): url = make_url(self.sqlalchemy_uri_decrypted) return url.get_backend_name() @classmethod def get_password_masked_url_from_uri(cls, uri): url = make_url(uri) return cls.get_password_masked_url(url) @classmethod def get_password_masked_url(cls, url): url_copy = deepcopy(url) if url_copy.password is not None and url_copy.password != PASSWORD_MASK: url_copy.password = PASSWORD_MASK return url_copy def set_sqlalchemy_uri(self, uri): conn = sqla.engine.url.make_url(uri.strip()) if conn.password != PASSWORD_MASK and not custom_password_store: # do not over-write the password with the password mask self.password = conn.password conn.password = PASSWORD_MASK if conn.password else None self.sqlalchemy_uri = str(conn) # hides the password def get_effective_user(self, url, user_name=None): """ Get the effective user, especially during impersonation. :param url: SQL Alchemy URL object :param user_name: Default username :return: The effective username """ effective_username = None if self.impersonate_user: effective_username = url.username if user_name: effective_username = user_name elif (hasattr(g, 'user') and hasattr(g.user, 'username') and g.user.username is not None): effective_username = g.user.username return effective_username @utils.memoized(watch=('impersonate_user', 'sqlalchemy_uri_decrypted', 'extra')) def get_sqla_engine(self, schema=None, nullpool=True, user_name=None): extra = self.get_extra() url = make_url(self.sqlalchemy_uri_decrypted) url = self.db_engine_spec.adjust_database_uri(url, schema) effective_username = self.get_effective_user(url, user_name) # If using MySQL or Presto for example, will set url.username # If using Hive, will not do anything yet since that relies on a # configuration parameter instead. self.db_engine_spec.modify_url_for_impersonation( url, self.impersonate_user, effective_username) masked_url = self.get_password_masked_url(url) logging.info( 'Database.get_sqla_engine(). Masked URL: {0}'.format(masked_url)) params = extra.get('engine_params', {}) if nullpool: params['poolclass'] = NullPool # If using Hive, this will set hive.server2.proxy.user=$effective_username configuration = {} configuration.update( self.db_engine_spec.get_configuration_for_impersonation( str(url), self.impersonate_user, effective_username)) if configuration: params['connect_args'] = {'configuration': configuration} return create_engine(url, **params) def get_reserved_words(self): return self.get_dialect().preparer.reserved_words def get_quoter(self): return self.get_dialect().identifier_preparer.quote def get_df(self, sql, schema): sql = sql.strip().strip(';') eng = self.get_sqla_engine(schema=schema) df = pd.read_sql(sql, eng) def needs_conversion(df_series): if df_series.empty: return False if isinstance(df_series[0], (list, dict)): return True return False for k, v in df.dtypes.iteritems(): if v.type == numpy.object_ and needs_conversion(df[k]): df[k] = df[k].apply(utils.json_dumps_w_dates) return df def compile_sqla_query(self, qry, schema=None): eng = self.get_sqla_engine(schema=schema) compiled = qry.compile(eng, compile_kwargs={'literal_binds': True}) return '{}'.format(compiled) def select_star(self, table_name, schema=None, limit=100, show_cols=False, indent=True, latest_partition=True): """Generates a ``select *`` statement in the proper dialect""" return self.db_engine_spec.select_star( self, table_name, schema=schema, limit=limit, show_cols=show_cols, indent=indent, latest_partition=latest_partition) def wrap_sql_limit(self, sql, limit=1000): qry = (select('*').select_from( TextAsFrom(text(sql), ['*']).alias('inner_qry'), ).limit(limit)) return self.compile_sqla_query(qry) def safe_sqlalchemy_uri(self): return self.sqlalchemy_uri @property def inspector(self): engine = self.get_sqla_engine() return sqla.inspect(engine) def all_table_names(self, schema=None, force=False): if not schema: tables_dict = self.db_engine_spec.fetch_result_sets(self, 'table', force=force) return tables_dict.get('', []) return sorted( self.db_engine_spec.get_table_names(schema, self.inspector)) def all_view_names(self, schema=None, force=False): if not schema: views_dict = self.db_engine_spec.fetch_result_sets(self, 'view', force=force) return views_dict.get('', []) views = [] try: views = self.inspector.get_view_names(schema) except Exception: pass return views def all_schema_names(self): return sorted(self.db_engine_spec.get_schema_names(self.inspector)) @property def db_engine_spec(self): return db_engine_specs.engines.get(self.backend, db_engine_specs.BaseEngineSpec) @classmethod def get_db_engine_spec_for_backend(cls, backend): return db_engine_specs.engines.get(backend, db_engine_specs.BaseEngineSpec) def grains(self): """Defines time granularity database-specific expressions. The idea here is to make it easy for users to change the time grain form a datetime (maybe the source grain is arbitrary timestamps, daily or 5 minutes increments) to another, "truncated" datetime. Since each database has slightly different but similar datetime functions, this allows a mapping between database engines and actual functions. """ return self.db_engine_spec.time_grains def grains_dict(self): return {grain.name: grain for grain in self.grains()} def get_extra(self): extra = {} if self.extra: try: extra = json.loads(self.extra) except Exception as e: logging.error(e) return extra def get_table(self, table_name, schema=None): extra = self.get_extra() meta = MetaData(**extra.get('metadata_params', {})) return Table(table_name, meta, schema=schema or None, autoload=True, autoload_with=self.get_sqla_engine()) def get_columns(self, table_name, schema=None): return self.inspector.get_columns(table_name, schema) def get_indexes(self, table_name, schema=None): return self.inspector.get_indexes(table_name, schema) def get_pk_constraint(self, table_name, schema=None): return self.inspector.get_pk_constraint(table_name, schema) def get_foreign_keys(self, table_name, schema=None): return self.inspector.get_foreign_keys(table_name, schema) @property def sqlalchemy_uri_decrypted(self): conn = sqla.engine.url.make_url(self.sqlalchemy_uri) if custom_password_store: conn.password = custom_password_store(conn) else: conn.password = self.password return str(conn) @property def sql_url(self): return '/superset/sql/{}/'.format(self.id) def get_perm(self): return ('[{obj.database_name}].(id:{obj.id})').format(obj=self) def has_table(self, table): engine = self.get_sqla_engine() return engine.has_table(table.table_name, table.schema or None) @utils.memoized def get_dialect(self): sqla_url = url.make_url(self.sqlalchemy_uri_decrypted) return sqla_url.get_dialect()()
class TableColumn(Model, BaseColumn): """ORM object for table columns, each table can have multiple columns""" __tablename__ = 'table_columns' __table_args__ = (UniqueConstraint('table_id', 'column_name'), ) table_id = Column(Integer, ForeignKey('tables.id')) table = relationship('SqlaTable', backref=backref('columns', cascade='all, delete-orphan'), foreign_keys=[table_id]) is_dttm = Column(Boolean, default=False) expression = Column(Text, default='') python_date_format = Column(String(255)) database_expression = Column(String(255)) export_fields = ( 'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active', 'type', 'groupby', 'count_distinct', 'sum', 'avg', 'max', 'min', 'filterable', 'expression', 'description', 'python_date_format', 'database_expression', ) export_parent = 'table' @property def sqla_col(self): name = self.column_name if not self.expression: col = column(self.column_name).label(name) else: col = literal_column(self.expression).label(name) return col @property def datasource(self): return self.table def get_time_filter(self, start_dttm, end_dttm): col = self.sqla_col.label('__time') l = [] # noqa: E741 if start_dttm: l.append(col >= text(self.dttm_sql_literal(start_dttm))) if end_dttm: l.append(col <= text(self.dttm_sql_literal(end_dttm))) return and_(*l) def get_timestamp_expression(self, time_grain): """Getting the time component of the query""" expr = self.expression or self.column_name if not self.expression and not time_grain: return column(expr, type_=DateTime).label(DTTM_ALIAS) if time_grain: pdf = self.python_date_format if pdf in ('epoch_s', 'epoch_ms'): # if epoch, translate to DATE using db specific conf db_spec = self.table.database.db_engine_spec if pdf == 'epoch_s': expr = db_spec.epoch_to_dttm().format(col=expr) elif pdf == 'epoch_ms': expr = db_spec.epoch_ms_to_dttm().format(col=expr) grain = self.table.database.grains_dict().get(time_grain, '{col}') expr = grain.function.format(col=expr) return literal_column(expr, type_=DateTime).label(DTTM_ALIAS) @classmethod def import_obj(cls, i_column): def lookup_obj(lookup_column): return db.session.query(TableColumn).filter( TableColumn.table_id == lookup_column.table_id, TableColumn.column_name == lookup_column.column_name).first() return import_util.import_simple_obj(db.session, i_column, lookup_obj) def dttm_sql_literal(self, dttm): """Convert datetime object to a SQL expression string If database_expression is empty, the internal dttm will be parsed as the string with the pattern that the user inputted (python_date_format) If database_expression is not empty, the internal dttm will be parsed as the sql sentence for the database to convert """ tf = self.python_date_format if self.database_expression: return self.database_expression.format( dttm.strftime('%Y-%m-%d %H:%M:%S')) elif tf: if tf == 'epoch_s': return str((dttm - datetime(1970, 1, 1)).total_seconds()) elif tf == 'epoch_ms': return str( (dttm - datetime(1970, 1, 1)).total_seconds() * 1000.0) return "'{}'".format(dttm.strftime(tf)) else: s = self.table.database.db_engine_spec.convert_dttm( self.type or '', dttm) return s or "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S.%f')) def get_metrics(self): metrics = [] M = SqlMetric # noqa quoted = self.column_name if self.sum: metrics.append( M( metric_name='sum__' + self.column_name, metric_type='sum', expression='SUM({})'.format(quoted), )) if self.avg: metrics.append( M( metric_name='avg__' + self.column_name, metric_type='avg', expression='AVG({})'.format(quoted), )) if self.max: metrics.append( M( metric_name='max__' + self.column_name, metric_type='max', expression='MAX({})'.format(quoted), )) if self.min: metrics.append( M( metric_name='min__' + self.column_name, metric_type='min', expression='MIN({})'.format(quoted), )) if self.count_distinct: metrics.append( M( metric_name='count_distinct__' + self.column_name, metric_type='count_distinct', expression='COUNT(DISTINCT {})'.format(quoted), )) return {m.metric_name: m for m in metrics}
class TableColumn(Model, BaseColumn): """ORM object for table columns, each table can have multiple columns""" __tablename__ = "table_columns" __table_args__ = (UniqueConstraint("table_id", "column_name"), ) table_id = Column(Integer, ForeignKey("tables.id")) table = relationship( "SqlaTable", backref=backref("columns", cascade="all, delete-orphan"), foreign_keys=[table_id], ) is_dttm = Column(Boolean, default=False) expression = Column(Text) python_date_format = Column(String(255)) export_fields = [ "table_id", "column_name", "verbose_name", "is_dttm", "is_active", "type", "groupby", "filterable", "expression", "description", "python_date_format", ] update_from_object_fields = [ s for s in export_fields if s not in ("table_id", ) ] export_parent = "table" @property def is_numeric(self) -> bool: db_engine_spec = self.table.database.db_engine_spec return db_engine_spec.is_db_column_type_match( self.type, utils.DbColumnType.NUMERIC) @property def is_string(self) -> bool: db_engine_spec = self.table.database.db_engine_spec return db_engine_spec.is_db_column_type_match( self.type, utils.DbColumnType.STRING) @property def is_temporal(self) -> bool: db_engine_spec = self.table.database.db_engine_spec return db_engine_spec.is_db_column_type_match( self.type, utils.DbColumnType.TEMPORAL) def get_sqla_col(self, label: Optional[str] = None) -> Column: label = label or self.column_name if self.expression: col = literal_column(self.expression) else: db_engine_spec = self.table.database.db_engine_spec type_ = db_engine_spec.get_sqla_column_type(self.type) col = column(self.column_name, type_=type_) col = self.table.make_sqla_column_compatible(col, label) return col @property def datasource(self) -> RelationshipProperty: return self.table def get_time_filter( self, start_dttm: DateTime, end_dttm: DateTime, time_range_endpoints: Optional[Tuple[utils.TimeRangeEndpoint, utils.TimeRangeEndpoint]], ) -> ColumnElement: col = self.get_sqla_col(label="__time") l = [] if start_dttm: l.append(col >= text( self.dttm_sql_literal(start_dttm, time_range_endpoints))) if end_dttm: if (time_range_endpoints and time_range_endpoints[1] == utils.TimeRangeEndpoint.EXCLUSIVE): l.append(col < text( self.dttm_sql_literal(end_dttm, time_range_endpoints))) else: l.append(col <= text(self.dttm_sql_literal(end_dttm, None))) return and_(*l) def get_timestamp_expression( self, time_grain: Optional[str]) -> Union[TimestampExpression, Label]: """ Return a SQLAlchemy Core element representation of self to be used in a query. :param time_grain: Optional time grain, e.g. P1Y :return: A TimeExpression object wrapped in a Label if supported by db """ label = utils.DTTM_ALIAS db = self.table.database pdf = self.python_date_format is_epoch = pdf in ("epoch_s", "epoch_ms") if not self.expression and not time_grain and not is_epoch: sqla_col = column(self.column_name, type_=DateTime) return self.table.make_sqla_column_compatible(sqla_col, label) if self.expression: col = literal_column(self.expression) else: col = column(self.column_name) time_expr = db.db_engine_spec.get_timestamp_expr( col, pdf, time_grain, self.type) return self.table.make_sqla_column_compatible(time_expr, label) @classmethod def import_obj(cls, i_column: "TableColumn") -> "TableColumn": def lookup_obj(lookup_column: TableColumn) -> TableColumn: return (db.session.query(TableColumn).filter( TableColumn.table_id == lookup_column.table_id, TableColumn.column_name == lookup_column.column_name, ).first()) return import_datasource.import_simple_obj(db.session, i_column, lookup_obj) def dttm_sql_literal( self, dttm: DateTime, time_range_endpoints: Optional[Tuple[utils.TimeRangeEndpoint, utils.TimeRangeEndpoint]], ) -> str: """Convert datetime object to a SQL expression string""" sql = (self.table.database.db_engine_spec.convert_dttm( self.type, dttm) if self.type else None) if sql: return sql tf = self.python_date_format # Fallback to the default format (if defined) only if the SIP-15 time range # endpoints, i.e., [start, end) are enabled. if not tf and time_range_endpoints == ( utils.TimeRangeEndpoint.INCLUSIVE, utils.TimeRangeEndpoint.EXCLUSIVE, ): tf = (self.table.database.get_extra().get( "python_date_format_by_column_name", {}).get(self.column_name)) if tf: if tf in ["epoch_ms", "epoch_s"]: seconds_since_epoch = int(dttm.timestamp()) if tf == "epoch_s": return str(seconds_since_epoch) return str(seconds_since_epoch * 1000) return f"'{dttm.strftime(tf)}'" # TODO(john-bodley): SIP-15 will explicitly require a type conversion. return f"""'{dttm.strftime("%Y-%m-%d %H:%M:%S.%f")}'""" @property def data(self) -> Dict[str, Any]: attrs = ( "id", "column_name", "verbose_name", "description", "expression", "filterable", "groupby", "is_dttm", "type", "python_date_format", ) return {s: getattr(self, s) for s in attrs if hasattr(self, s)}
class Dataset(Base): """Class to store the information about a data set. """ __tablename__ = 'datasets' __table_args__ = ( UniqueConstraint('task_id', 'description'), # Useless, in theory, because 'id' is already unique. Yet, we # need this because it's a target of a foreign key. UniqueConstraint('id', 'task_id'), ) # Auto increment primary key. id = Column( Integer, primary_key=True) # Task (id and object) owning the dataset. task_id = Column( Integer, ForeignKey(Task.id, onupdate="CASCADE", ondelete="CASCADE"), nullable=False) task = relationship( Task, foreign_keys=[task_id], # XXX In SQLAlchemy 0.8 we could remove this: primaryjoin='Task.id == Dataset.task_id', backref=backref('datasets', cascade="all, delete-orphan", passive_deletes=True)) # A human-readable text describing the dataset. description = Column( Unicode, nullable=False) # Whether this dataset will be automatically judged by ES and SS # "in background", together with the active dataset of each task. autojudge = Column( Boolean, nullable=False, default=False) # Time and memory limits for every testcase. time_limit = Column( Float, nullable=True) memory_limit = Column( Integer, nullable=True) # Name of the TaskType child class suited for the task. task_type = Column( String, nullable=False) # Parameters for the task type class, JSON encoded. task_type_parameters = Column( String, nullable=False) # Name of the ScoreType child class suited for the task. score_type = Column( String, nullable=False) # Parameters for the score type class, JSON encoded. score_type_parameters = Column( String, nullable=False)
class Task(Base): __tablename__ = 'task' """ A job that gets executed. Has a unique set of params within its Stage. """ # FIXME causes a problem with mysql? __table_args__ = (UniqueConstraint('stage_id', 'uid', name='_uc1'),) drm_options = {} id = Column(Integer, primary_key=True) uid = Column(String(255), index=True) mem_req = Column(Integer) core_req = Column(Integer) cpu_req = synonym('core_req') time_req = Column(Integer) NOOP = Column(Boolean, nullable=False) params = Column(MutableDict.as_mutable(JSONEncodedDict), nullable=False) stage_id = Column(ForeignKey('stage.id', ondelete="CASCADE"), nullable=False, index=True) log_dir = Column(String(255)) # output_dir = Column(String(255)) _status = Column(Enum_ColumnType(TaskStatus, length=255), default=TaskStatus.no_attempt, nullable=False) successful = Column(Boolean, nullable=False) started_on = Column(DateTime) # FIXME this should probably be deleted. Too hard to determine. submitted_on = Column(DateTime) finished_on = Column(DateTime) attempt = Column(Integer, nullable=False) must_succeed = Column(Boolean, nullable=False) drm = Column(String(255)) # FIXME consider making job_class a proper field next time the schema changes # job_class = Column(String(255)) queue = Column(String(255)) max_attempts = Column(Integer) parents = relationship("Task", secondary=TaskEdge.__table__, primaryjoin=id == TaskEdge.parent_id, secondaryjoin=id == TaskEdge.child_id, backref="children", passive_deletes=True, cascade="save-update, merge, delete", ) # input_map = Column(MutableDict.as_mutable(JSONEncodedDict), nullable=False) # output_map = Column(MutableDict.as_mutable(JSONEncodedDict), nullable=False) @property def input_map(self): d = dict() for key, val in self.params.items(): if key.startswith('in_'): d[key] = val return d @property def output_map(self): d = dict() for key, val in self.params.items(): if key.startswith('out_'): d[key] = val return d @property def input_files(self): return list(self.input_map.values()) @property def output_files(self): return list(self.output_map.values()) # command = Column(Text) drm_native_specification = Column(String(255)) drm_jobID = Column(String(255)) profile_fields = ['wall_time', 'cpu_time', 'percent_cpu', 'user_time', 'system_time', 'io_read_count', 'io_write_count', 'io_read_kb', 'io_write_kb', 'ctx_switch_voluntary', 'ctx_switch_involuntary', 'avg_rss_mem_kb', 'max_rss_mem_kb', 'avg_vms_mem_kb', 'max_vms_mem_kb', 'avg_num_threads', 'max_num_threads', 'avg_num_fds', 'max_num_fds', 'exit_status'] exclude_from_dict = profile_fields + ['command', 'info', 'input_files', 'output_files'] exit_status = Column(Integer) percent_cpu = Column(Integer) wall_time = Column(Integer) cpu_time = Column(Integer) user_time = Column(Integer) system_time = Column(Integer) avg_rss_mem_kb = Column(Integer) max_rss_mem_kb = Column(Integer) avg_vms_mem_kb = Column(Integer) max_vms_mem_kb = Column(Integer) io_read_count = Column(Integer) io_write_count = Column(Integer) io_wait = Column(Integer) io_read_kb = Column(Integer) io_write_kb = Column(Integer) ctx_switch_voluntary = Column(Integer) ctx_switch_involuntary = Column(Integer) avg_num_threads = Column(Integer) max_num_threads = Column(Integer) avg_num_fds = Column(Integer) max_num_fds = Column(Integer) extra = Column(MutableDict.as_mutable(JSONEncodedDict), nullable=False) @declared_attr def status(cls): def get_status(self): return self._status def set_status(self, value): if self._status != value: self._status = value signal_task_status_change.send(self) return synonym('_status', descriptor=property(get_status, set_status)) @property def workflow(self): return self.stage.workflow @property def log(self): return self.workflow.log @property def finished(self): return self.status in {TaskStatus.successful, TaskStatus.killed, TaskStatus.failed} _cache_profile = None output_profile_path = logplus('profile.json') output_command_script_path = logplus('command.bash') output_stderr_path = logplus('stderr.txt') output_stdout_path = logplus('stdout.txt') @property def stdout_text(self): return readfile(self.output_stdout_path) @property def stderr_text(self): r = readfile(self.output_stderr_path) if r == 'file does not exist': if self.drm == 'lsf' and self.drm_jobID: r += '\n\nbpeek %s output:\n\n' % self.drm_jobID try: r += codecs.decode(sp.check_output('bpeek %s' % self.drm_jobID, shell=True), 'utf-8') except Exception as e: r += str(e) return r @property def command_script_text(self): # return self.command return readfile(self.output_command_script_path).strip() or self.command def descendants(self, include_self=False): """ :return: (list) all stages that descend from this stage in the stage_graph """ x = nx.descendants(self.workflow.task_graph(), self) if include_self: return sorted({self}.union(x), key=lambda task: task.stage.number) else: return x @property def label(self): """Label used for the taskgraph image""" params = '' if len(self.params) == 0 else "\\n {0}".format( "\\n".join(["{0}: {1}".format(k, v) for k, v in self.params.items()])) return "[%s] %s%s" % (self.id, self.stage.name, params) def args_as_query_string(self): import urllib return urllib.urlencode(self.params) def delete(self, descendants=False): if descendants: tasks_to_delete = self.descendants(include_self=True) self.log.debug('Deleting %s and %s of its descendants' % (self, len(tasks_to_delete) - 1)) for t in tasks_to_delete: self.session.delete(t) else: self.log.debug('Deleting %s' % self) self.session.delete(self) self.session.commit() @property def url(self): return url_for('cosmos.task', ex_name=self.workflow.name, stage_name=self.stage.name, task_id=self.id) @property def params_pretty(self): return '%s' % ', '.join( '%s=%s' % (k, "'%s'" % v if isinstance(v, basestring) else v) for k, v in self.params.items()) @property def params_pformat(self): return pprint.pformat(self.params, indent=2, width=1) def __repr__(self): return "<Task[%s] %s(uid='%s')>" % (self.id or 'id_%s' % id(self), self.stage.name if self.stage else '', self.uid ) def __str__(self): return self.__repr__() # FIXME consider making job_class a proper field next time the schema changes def __init__(self, **kwargs): self.job_class = kwargs.pop('job_class', None) _declarative_constructor(self, **kwargs) @reconstructor def init_on_load(self): self.job_class = None
class Database(Model, AuditMixinNullable, ImportMixin): # pylint: disable=too-many-public-methods """An ORM object that stores Database related information""" __tablename__ = "dbs" type = "table" __table_args__ = (UniqueConstraint("database_name"), ) id = Column(Integer, primary_key=True) verbose_name = Column(String(250), unique=True) # short unique name, used in permissions database_name = Column(String(250), unique=True, nullable=False) sqlalchemy_uri = Column(String(1024), nullable=False) password = Column(EncryptedType(String(1024), config["SECRET_KEY"])) cache_timeout = Column(Integer) select_as_create_table_as = Column(Boolean, default=False) expose_in_sqllab = Column(Boolean, default=True) allow_run_async = Column(Boolean, default=False) allow_csv_upload = Column(Boolean, default=False) allow_ctas = Column(Boolean, default=False) allow_cvas = Column(Boolean, default=False) allow_dml = Column(Boolean, default=False) force_ctas_schema = Column(String(250)) allow_multi_schema_metadata_fetch = Column( # pylint: disable=invalid-name Boolean, default=False) extra = Column( Text, default=textwrap.dedent("""\ { "metadata_params": {}, "engine_params": {}, "metadata_cache_timeout": {}, "schemas_allowed_for_csv_upload": [] } """), ) encrypted_extra = Column(EncryptedType(Text, config["SECRET_KEY"]), nullable=True) impersonate_user = Column(Boolean, default=False) server_cert = Column(EncryptedType(Text, config["SECRET_KEY"]), nullable=True) export_fields = [ "database_name", "sqlalchemy_uri", "cache_timeout", "expose_in_sqllab", "allow_run_async", "allow_ctas", "allow_cvas", "allow_csv_upload", "extra", ] export_children = ["tables"] def __repr__(self) -> str: return self.name @property def name(self) -> str: return self.verbose_name if self.verbose_name else self.database_name @property def allows_subquery(self) -> bool: return self.db_engine_spec.allows_subqueries @property def function_names(self) -> List[str]: try: return self.db_engine_spec.get_function_names(self) except Exception as ex: # pylint: disable=broad-except # function_names property is used in bulk APIs and should not hard crash # more info in: https://github.com/apache/incubator-superset/issues/9678 logger.error( "Failed to fetch database function names with error: %s", str(ex)) return [] @property def allows_cost_estimate(self) -> bool: extra = self.get_extra() database_version = extra.get("version") cost_estimate_enabled: bool = extra.get( "cost_estimate_enabled") # type: ignore return (self.db_engine_spec.get_allow_cost_estimate(database_version) and cost_estimate_enabled) @property def allows_virtual_table_explore(self) -> bool: extra = self.get_extra() return bool(extra.get("allows_virtual_table_explore", True)) @property def explore_database_id(self) -> int: return self.get_extra().get("explore_database_id", self.id) @property def data(self) -> Dict[str, Any]: return { "id": self.id, "name": self.database_name, "backend": self.backend, "allow_multi_schema_metadata_fetch": self.allow_multi_schema_metadata_fetch, "allows_subquery": self.allows_subquery, "allows_cost_estimate": self.allows_cost_estimate, "allows_virtual_table_explore": self.allows_virtual_table_explore, "explore_database_id": self.explore_database_id, } @property def unique_name(self) -> str: return self.database_name @property def url_object(self) -> URL: return make_url(self.sqlalchemy_uri_decrypted) @property def backend(self) -> str: sqlalchemy_url = make_url(self.sqlalchemy_uri_decrypted) return sqlalchemy_url.get_backend_name() # pylint: disable=no-member @property def metadata_cache_timeout(self) -> Dict[str, Any]: return self.get_extra().get("metadata_cache_timeout", {}) @property def schema_cache_enabled(self) -> bool: return "schema_cache_timeout" in self.metadata_cache_timeout @property def schema_cache_timeout(self) -> Optional[int]: return self.metadata_cache_timeout.get("schema_cache_timeout") @property def table_cache_enabled(self) -> bool: return "table_cache_timeout" in self.metadata_cache_timeout @property def table_cache_timeout(self) -> Optional[int]: return self.metadata_cache_timeout.get("table_cache_timeout") @property def default_schemas(self) -> List[str]: return self.get_extra().get("default_schemas", []) @property def connect_args(self) -> Dict[str, Any]: return self.get_extra().get("engine_params", {}).get("connect_args", {}) @classmethod def get_password_masked_url_from_uri( # pylint: disable=invalid-name cls, uri: str) -> URL: sqlalchemy_url = make_url(uri) return cls.get_password_masked_url(sqlalchemy_url) @classmethod def get_password_masked_url( cls, url: URL # pylint: disable=redefined-outer-name ) -> URL: url_copy = deepcopy(url) if url_copy.password is not None: url_copy.password = PASSWORD_MASK return url_copy def set_sqlalchemy_uri(self, uri: str) -> None: conn = sqla.engine.url.make_url(uri.strip()) if conn.password != PASSWORD_MASK and not custom_password_store: # do not over-write the password with the password mask self.password = conn.password conn.password = PASSWORD_MASK if conn.password else None self.sqlalchemy_uri = str(conn) # hides the password def get_effective_user( self, url: URL, # pylint: disable=redefined-outer-name user_name: Optional[str] = None, ) -> Optional[str]: """ Get the effective user, especially during impersonation. :param url: SQL Alchemy URL object :param user_name: Default username :return: The effective username """ effective_username = None if self.impersonate_user: effective_username = url.username if user_name: effective_username = user_name elif (hasattr(g, "user") and hasattr(g.user, "username") and g.user.username is not None): effective_username = g.user.username return effective_username @utils.memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted", "extra")) def get_sqla_engine( self, schema: Optional[str] = None, nullpool: bool = True, user_name: Optional[str] = None, source: Optional[utils.QuerySource] = None, ) -> Engine: extra = self.get_extra() sqlalchemy_url = make_url(self.sqlalchemy_uri_decrypted) self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema) effective_username = self.get_effective_user(sqlalchemy_url, user_name) # If using MySQL or Presto for example, will set url.username # If using Hive, will not do anything yet since that relies on a # configuration parameter instead. self.db_engine_spec.modify_url_for_impersonation( sqlalchemy_url, self.impersonate_user, effective_username) masked_url = self.get_password_masked_url(sqlalchemy_url) logger.debug("Database.get_sqla_engine(). Masked URL: %s", str(masked_url)) params = extra.get("engine_params", {}) if nullpool: params["poolclass"] = NullPool connect_args = params.get("connect_args", {}) configuration = connect_args.get("configuration", {}) # If using Hive, this will set hive.server2.proxy.user=$effective_username configuration.update( self.db_engine_spec.get_configuration_for_impersonation( str(sqlalchemy_url), self.impersonate_user, effective_username)) if configuration: connect_args["configuration"] = configuration if connect_args: params["connect_args"] = connect_args params.update(self.get_encrypted_extra()) if DB_CONNECTION_MUTATOR: if not source and request and request.referrer: if "/superset/dashboard/" in request.referrer: source = utils.QuerySource.DASHBOARD elif "/superset/explore/" in request.referrer: source = utils.QuerySource.CHART elif "/superset/sqllab/" in request.referrer: source = utils.QuerySource.SQL_LAB sqlalchemy_url, params = DB_CONNECTION_MUTATOR( sqlalchemy_url, params, effective_username, security_manager, source) return create_engine(sqlalchemy_url, **params) def get_reserved_words(self) -> Set[str]: return self.get_dialect().preparer.reserved_words def get_quoter(self) -> Callable[[str, Any], str]: return self.get_dialect().identifier_preparer.quote def get_df( # pylint: disable=too-many-locals self, sql: str, schema: Optional[str] = None, mutator: Optional[Callable[[pd.DataFrame], None]] = None, ) -> pd.DataFrame: sqls = [str(s).strip(" ;") for s in sqlparse.parse(sql)] engine = self.get_sqla_engine(schema=schema) username = utils.get_username() def needs_conversion(df_series: pd.Series) -> bool: return not df_series.empty and isinstance(df_series[0], (list, dict)) def _log_query(sql: str) -> None: if log_query: log_query(engine.url, sql, schema, username, __name__, security_manager) with closing(engine.raw_connection()) as conn: with closing(conn.cursor()) as cursor: for sql_ in sqls[:-1]: _log_query(sql_) self.db_engine_spec.execute(cursor, sql_) cursor.fetchall() _log_query(sqls[-1]) self.db_engine_spec.execute(cursor, sqls[-1]) data = self.db_engine_spec.fetch_data(cursor) result_set = SupersetResultSet(data, cursor.description, self.db_engine_spec) df = result_set.to_pandas_df() if mutator: mutator(df) for k, v in df.dtypes.items(): if v.type == numpy.object_ and needs_conversion(df[k]): df[k] = df[k].apply(utils.json_dumps_w_dates) return df def compile_sqla_query(self, qry: Select, schema: Optional[str] = None) -> str: engine = self.get_sqla_engine(schema=schema) sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) if (engine.dialect.identifier_preparer._double_percents # pylint: disable=protected-access ): sql = sql.replace("%%", "%") return sql def select_star( # pylint: disable=too-many-arguments self, table_name: str, schema: Optional[str] = None, limit: int = 100, show_cols: bool = False, indent: bool = True, latest_partition: bool = False, cols: Optional[List[Dict[str, Any]]] = None, ) -> str: """Generates a ``select *`` statement in the proper dialect""" eng = self.get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB) return self.db_engine_spec.select_star( self, table_name, schema=schema, engine=eng, limit=limit, show_cols=show_cols, indent=indent, latest_partition=latest_partition, cols=cols, ) def apply_limit_to_sql(self, sql: str, limit: int = 1000) -> str: return self.db_engine_spec.apply_limit_to_sql(sql, limit, self) def safe_sqlalchemy_uri(self) -> str: return self.sqlalchemy_uri @property def inspector(self) -> Inspector: engine = self.get_sqla_engine() return sqla.inspect(engine) @cache_util.memoized_func( key=lambda *args, **kwargs: "db:{}:schema:None:table_list", attribute_in_key="id", ) def get_all_table_names_in_database( self, cache: bool = False, cache_timeout: Optional[bool] = None, force: bool = False, ) -> List[utils.DatasourceName]: """Parameters need to be passed as keyword arguments.""" if not self.allow_multi_schema_metadata_fetch: return [] return self.db_engine_spec.get_all_datasource_names(self, "table") @cache_util.memoized_func( key=lambda *args, **kwargs: "db:{}:schema:None:view_list", attribute_in_key="id") def get_all_view_names_in_database( self, cache: bool = False, cache_timeout: Optional[bool] = None, force: bool = False, ) -> List[utils.DatasourceName]: """Parameters need to be passed as keyword arguments.""" if not self.allow_multi_schema_metadata_fetch: return [] return self.db_engine_spec.get_all_datasource_names(self, "view") @cache_util.memoized_func( key=lambda *args, **kwargs: f"db:{{}}:schema:{kwargs.get('schema')}:table_list", # type: ignore attribute_in_key="id", ) def get_all_table_names_in_schema( self, schema: str, cache: bool = False, cache_timeout: Optional[int] = None, force: bool = False, ) -> List[utils.DatasourceName]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in cache_util.memoized_func decorator. :param schema: schema name :param cache: whether cache is enabled for the function :param cache_timeout: timeout in seconds for the cache :param force: whether to force refresh the cache :return: list of tables """ try: tables = self.db_engine_spec.get_table_names( database=self, inspector=self.inspector, schema=schema) return [ utils.DatasourceName(table=table, schema=schema) for table in tables ] except Exception as ex: # pylint: disable=broad-except logger.exception(ex) @cache_util.memoized_func( key=lambda *args, **kwargs: f"db:{{}}:schema:{kwargs.get('schema')}:view_list", # type: ignore attribute_in_key="id", ) def get_all_view_names_in_schema( self, schema: str, cache: bool = False, cache_timeout: Optional[int] = None, force: bool = False, ) -> List[utils.DatasourceName]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in cache_util.memoized_func decorator. :param schema: schema name :param cache: whether cache is enabled for the function :param cache_timeout: timeout in seconds for the cache :param force: whether to force refresh the cache :return: list of views """ try: views = self.db_engine_spec.get_view_names( database=self, inspector=self.inspector, schema=schema) return [ utils.DatasourceName(table=view, schema=schema) for view in views ] except Exception as ex: # pylint: disable=broad-except logger.exception(ex) @cache_util.memoized_func(key=lambda *args, **kwargs: "db:{}:schema_list", attribute_in_key="id") def get_all_schema_names( self, cache: bool = False, cache_timeout: Optional[int] = None, force: bool = False, ) -> List[str]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in cache_util.memoized_func decorator. :param cache: whether cache is enabled for the function :param cache_timeout: timeout in seconds for the cache :param force: whether to force refresh the cache :return: schema list """ return self.db_engine_spec.get_schema_names(self.inspector) @property def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]: return db_engine_specs.engines.get(self.backend, db_engine_specs.BaseEngineSpec) @classmethod def get_db_engine_spec_for_backend( cls, backend: str) -> Type[db_engine_specs.BaseEngineSpec]: return db_engine_specs.engines.get(backend, db_engine_specs.BaseEngineSpec) def grains(self) -> Tuple[TimeGrain, ...]: """Defines time granularity database-specific expressions. The idea here is to make it easy for users to change the time grain from a datetime (maybe the source grain is arbitrary timestamps, daily or 5 minutes increments) to another, "truncated" datetime. Since each database has slightly different but similar datetime functions, this allows a mapping between database engines and actual functions. """ return self.db_engine_spec.get_time_grains() def get_extra(self) -> Dict[str, Any]: return self.db_engine_spec.get_extra_params(self) def get_encrypted_extra(self) -> Dict[str, Any]: encrypted_extra = {} if self.encrypted_extra: try: encrypted_extra = json.loads(self.encrypted_extra) except json.JSONDecodeError as ex: logger.error(ex) raise ex return encrypted_extra def get_table(self, table_name: str, schema: Optional[str] = None) -> Table: extra = self.get_extra() meta = MetaData(**extra.get("metadata_params", {})) return Table( table_name, meta, schema=schema or None, autoload=True, autoload_with=self.get_sqla_engine(), ) def get_columns(self, table_name: str, schema: Optional[str] = None) -> List[Dict[str, Any]]: return self.db_engine_spec.get_columns(self.inspector, table_name, schema) def get_indexes(self, table_name: str, schema: Optional[str] = None) -> List[Dict[str, Any]]: return self.inspector.get_indexes(table_name, schema) def get_pk_constraint(self, table_name: str, schema: Optional[str] = None) -> Dict[str, Any]: return self.inspector.get_pk_constraint(table_name, schema) def get_foreign_keys(self, table_name: str, schema: Optional[str] = None) -> List[Dict[str, Any]]: return self.inspector.get_foreign_keys(table_name, schema) def get_schema_access_for_csv_upload( # pylint: disable=invalid-name self, ) -> List[str]: allowed_databases = self.get_extra().get( "schemas_allowed_for_csv_upload", []) if hasattr(g, "user"): extra_allowed_databases = config["ALLOWED_USER_CSV_SCHEMA_FUNC"]( self, g.user) allowed_databases += extra_allowed_databases return sorted(set(allowed_databases)) @property def sqlalchemy_uri_decrypted(self) -> str: conn = sqla.engine.url.make_url(self.sqlalchemy_uri) if custom_password_store: conn.password = custom_password_store(conn) else: conn.password = self.password return str(conn) @property def sql_url(self) -> str: return f"/superset/sql/{self.id}/" @hybrid_property def perm(self) -> str: return f"[{self.database_name}].(id:{self.id})" @perm.expression # type: ignore def perm(cls) -> str: # pylint: disable=no-self-argument return ("[" + cls.database_name + "].(id:" + expression.cast(cls.id, String) + ")") def get_perm(self) -> str: return self.perm # type: ignore def has_table(self, table: Table) -> bool: engine = self.get_sqla_engine() return engine.has_table(table.table_name, table.schema or None) def has_table_by_name(self, table_name: str, schema: Optional[str] = None) -> bool: engine = self.get_sqla_engine() return engine.has_table(table_name, schema) @utils.memoized def get_dialect(self) -> Dialect: sqla_url = url.make_url(self.sqlalchemy_uri_decrypted) return sqla_url.get_dialect()() # pylint: disable=no-member
) from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import func from databases import Database DATABASE_URL = os.getenv("DATABASE_URL") # SQLAlchemy engine = create_engine(DATABASE_URL) metadata = MetaData() persons = Table( "persons", metadata, Column("record", Integer, primary_key=True), Column("person_id", UUID(as_uuid=False), nullable=False), Column("first_name", String(50), nullable=False), Column("middle_name", String(50)), Column("last_name", String(50), nullable=False), Column("email", String(50), nullable=False), Column("age", Integer, nullable=False), Column("version", Integer, default=1, nullable=False), Column("is_latest", Boolean, default=True, nullable=False), Column("created_date", DateTime, default=func.now(), nullable=False), UniqueConstraint("person_id", "version", name="id_version_uc"), ) # databases query builder database = Database(DATABASE_URL)
class Folder(MailSyncBase): """ Folders and labels from the remote account backend (IMAP/Exchange). """ # `use_alter` required here to avoid circular dependency w/Account account_id = Column(Integer, ForeignKey('account.id', use_alter=True, name='folder_fk1', ondelete='CASCADE'), nullable=False) # TOFIX this causes an import error due to circular dependencies # from inbox.models.account import Account account = relationship( 'Account', backref=backref('folders', cascade='delete', primaryjoin='and_(' 'Folder.account_id == Account.id, ' 'Folder.deleted_at.is_(None))'), primaryjoin='and_(Folder.account_id==Account.id, ' 'Account.deleted_at==None)') # Explicitly set collation to be case insensitive. This is mysql's default # but never trust defaults! This allows us to store the original casing to # not confuse users when displaying it, but still only allow a single # folder with any specific name, canonicalized to lowercase. name = Column(String(MAX_FOLDER_NAME_LENGTH, collation='utf8mb4_general_ci'), nullable=True) canonical_name = Column(String(MAX_FOLDER_NAME_LENGTH), nullable=True) __table_args__ = (UniqueConstraint('account_id', 'name', 'canonical_name'),) @property def lowercase_name(self): if self.name is None: return None return self.name.lower() @property def namespace(self): return self.account.namespace @classmethod def create(cls, account, name, session, canonical_name=None): if name is not None and len(name) > MAX_FOLDER_NAME_LENGTH: log.warning("Truncating long folder name for account {}; " "original name was '{}'" .format(account.id, name)) name = name[:MAX_FOLDER_NAME_LENGTH] obj = cls(account=account, name=name, canonical_name=canonical_name) session.add(obj) return obj @classmethod def find_or_create(cls, session, account, name, canonical_name=None): try: if name is not None and len(name) > MAX_FOLDER_NAME_LENGTH: name = name[:MAX_FOLDER_NAME_LENGTH] q = session.query(cls).filter_by(account_id=account.id) if name is not None: q = q.filter(func.lower(Folder.name) == func.lower(name)) if canonical_name is not None: q = q.filter_by(canonical_name=canonical_name) obj = q.one() except NoResultFound: obj = cls.create(account, name, session, canonical_name) except MultipleResultsFound: log.info("Duplicate folder rows for folder {} for account {}" .format(name, account.id)) raise return obj def get_associated_tag(self, db_session): if self.canonical_name is not None: try: return db_session.query(Tag). \ filter(Tag.namespace_id == self.namespace.id, Tag.public_id == self.canonical_name).one() except NoResultFound: # Explicitly set the namespace_id instead of the namespace # attribute to avoid autoflush-induced IntegrityErrors where # the namespace_id is null on flush. tag = Tag(namespace_id=self.account.namespace.id, name=self.canonical_name, public_id=self.canonical_name) db_session.add(tag) return tag else: provider_prefix = self.account.provider tag_name = '-'.join((provider_prefix, self.name.lower()))[:MAX_INDEXABLE_LENGTH] try: return db_session.query(Tag). \ filter(Tag.namespace_id == self.namespace.id, Tag.name == tag_name).one() except NoResultFound: # Explicitly set the namespace_id instead of the namespace # attribute to avoid autoflush-induced IntegrityErrors where # the namespace_id is null on flush. tag = Tag(namespace_id=self.account.namespace.id, name=tag_name) db_session.add(tag) return tag
from ...db.meta import metadata, table_opts jcmt_allocation = Table( 'jcmt_allocation', metadata, Column('id', Integer, primary_key=True), Column('proposal_id', None, ForeignKey('proposal.id', onupdate='RESTRICT', ondelete='RESTRICT'), nullable=False), Column('instrument', Integer, nullable=False), Column('ancillary', Integer, nullable=False), Column('weather', Integer, nullable=False), Column('time', Float, nullable=False), UniqueConstraint('proposal_id', 'instrument', 'ancillary', 'weather'), **table_opts) jcmt_available = Table( 'jcmt_available', metadata, Column('id', Integer, primary_key=True), Column('call_id', None, ForeignKey('call.id', onupdate='RESTRICT', ondelete='RESTRICT'), nullable=False), Column('weather', Integer, nullable=False), Column('time', Float, nullable=False), UniqueConstraint('call_id', 'weather'), **table_opts) jcmt_options = Table(
class SqlaTable(Model, BaseDatasource): """An ORM object for SqlAlchemy table references""" type = 'table' query_language = 'sql' metric_class = SqlMetric column_class = TableColumn owner_class = security_manager.user_model __tablename__ = 'tables' __table_args__ = (UniqueConstraint('database_id', 'table_name'), ) table_name = Column(String(250)) main_dttm_col = Column(String(250)) database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False) fetch_values_predicate = Column(String(1000)) owners = relationship(owner_class, secondary=sqlatable_user, backref='tables') database = relationship('Database', backref=backref('tables', cascade='all, delete-orphan'), foreign_keys=[database_id]) schema = Column(String(255)) sql = Column(Text) is_sqllab_view = Column(Boolean, default=False) template_params = Column(Text) #add hive_partitions column to store partition info hive_partitions = Column(Text) baselink = 'tablemodelview' export_fields = ( 'table_name', 'main_dttm_col', 'description', 'default_endpoint', 'database_id', 'offset', 'cache_timeout', 'schema', 'sql', 'params', 'template_params', 'filter_select_enabled', 'fetch_values_predicate', 'hive_partitions', ) update_from_object_fields = [ f for f in export_fields if f not in ('table_name', 'database_id') ] export_parent = 'database' export_children = ['metrics', 'columns'] sqla_aggregations = { 'COUNT_DISTINCT': lambda column_name: sa.func.COUNT(sa.distinct(column_name)), 'COUNT': sa.func.COUNT, 'SUM': sa.func.SUM, 'AVG': sa.func.AVG, 'MIN': sa.func.MIN, 'MAX': sa.func.MAX, } def get_label(self, label): """Conditionally mutate a label to conform to db engine requirements and store mapping from mutated label to original label :param label: original label :return: Either a string or sqlalchemy.sql.elements.quoted_name if required by db engine """ db_engine_spec = self.database.db_engine_spec sqla_label = db_engine_spec.make_label_compatible(label) mutated_label = str(sqla_label) if label != mutated_label: self.mutated_labels[mutated_label] = label return sqla_label def __repr__(self): return self.name @property def connection(self): return str(self.database) @property def description_markeddown(self): return utils.markdown(self.description) @property def datasource_name(self): return self.table_name @property def database_name(self): return self.database.name @property def link(self): name = escape(self.name) anchor = f'<a target="_blank" href="{self.explore_url}">{name}</a>' return Markup(anchor) @property def schema_perm(self): """Returns schema permission if present, database one otherwise.""" return security_manager.get_schema_perm(self.database, self.schema) def get_perm(self): return ('[{obj.database}].[{obj.table_name}]' '(id:{obj.id})').format(obj=self) @property def name(self): if not self.schema: return self.table_name return '{}.{}'.format(self.schema, self.table_name) @property def full_name(self): return utils.get_datasource_full_name(self.database, self.table_name, schema=self.schema) @property def dttm_cols(self): l = [c.column_name for c in self.columns if c.is_dttm] # noqa: E741 if self.main_dttm_col and self.main_dttm_col not in l: l.append(self.main_dttm_col) return l @property def num_cols(self): return [c.column_name for c in self.columns if c.is_num] @property def any_dttm_col(self): cols = self.dttm_cols if cols: return cols[0] @property def html(self): t = ((c.column_name, c.type) for c in self.columns) df = pd.DataFrame(t) df.columns = ['field', 'type'] return df.to_html( index=False, classes=('dataframe table table-striped table-bordered ' 'table-condensed')) @property def sql_url(self): return self.database.sql_url + '?table_name=' + str(self.table_name) def external_metadata(self): cols = self.database.get_columns(self.table_name, schema=self.schema) for col in cols: try: col['type'] = str(col['type']) except CompileError: col['type'] = 'UNKNOWN' return cols @property def time_column_grains(self): return { 'time_columns': self.dttm_cols, 'time_grains': [grain.name for grain in self.database.grains()], } @property def select_star(self): # show_cols and latest_partition set to false to avoid # the expensive cost of inspecting the DB return self.database.select_star(self.name, show_cols=False, latest_partition=False) def get_col(self, col_name): columns = self.columns for col in columns: if col_name == col.column_name: return col @property def data(self): d = super(SqlaTable, self).data if self.type == 'table': grains = self.database.grains() or [] if grains: grains = [(g.duration, g.name) for g in grains] d['granularity_sqla'] = utils.choicify(self.dttm_cols) d['time_grain_sqla'] = grains d['main_dttm_col'] = self.main_dttm_col d['fetch_values_predicate'] = self.fetch_values_predicate return d def values_for_column(self, column_name, limit=10000): """Runs query against sqla to retrieve some sample values for the given column. """ cols = {col.column_name: col for col in self.columns} target_col = cols[column_name] tp = self.get_template_processor() qry = (select([target_col.get_sqla_col() ]).select_from(self.get_from_clause(tp)).distinct()) if limit: qry = qry.limit(limit) if self.fetch_values_predicate: tp = self.get_template_processor() qry = qry.where(tp.process_template(self.fetch_values_predicate)) engine = self.database.get_sqla_engine() sql = '{}'.format( qry.compile(engine, compile_kwargs={'literal_binds': True}), ) sql = self.mutate_query_from_config(sql) df = pd.read_sql_query(sql=sql, con=engine) return [row[0] for row in df.to_records(index=False)] def mutate_query_from_config(self, sql): """Apply config's SQL_QUERY_MUTATOR Typically adds comments to the query with context""" SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR') if SQL_QUERY_MUTATOR: username = utils.get_username() sql = SQL_QUERY_MUTATOR(sql, username, security_manager, self.database) return sql def get_template_processor(self, **kwargs): return get_template_processor(table=self, database=self.database, **kwargs) def get_query_str_extended(self, query_obj): sqlaq = self.get_sqla_query(**query_obj) sql = self.database.compile_sqla_query(sqlaq.sqla_query) logging.info(sql) sql = sqlparse.format(sql, reindent=True) if query_obj['is_prequery']: query_obj['prequeries'].append(sql) sql = self.mutate_query_from_config(sql) """Apply HIVE_QUERY_GENERATOR """ HIVE_QUERY_GENERATOR = config.get('HIVE_QUERY_GENERATOR') if HIVE_QUERY_GENERATOR: sql = HIVE_QUERY_GENERATOR(sql, query_obj, self.database, self.datasource_name) return QueryStringExtended(labels_expected=sqlaq.labels_expected, sql=sql) def get_query_str(self, query_obj): return self.get_query_str_extended(query_obj).sql def get_sqla_table(self): tbl = table(self.table_name) if self.schema: tbl.schema = self.schema return tbl def get_from_clause(self, template_processor=None): # Supporting arbitrary SQL statements in place of tables if self.sql: from_sql = self.sql if template_processor: from_sql = template_processor.process_template(from_sql) from_sql = sqlparse.format(from_sql, strip_comments=True) return TextAsFrom(sa.text(from_sql), []).alias('expr_qry') return self.get_sqla_table() def adhoc_metric_to_sqla(self, metric, cols): """ Turn an adhoc metric into a sqlalchemy column. :param dict metric: Adhoc metric definition :param dict cols: Columns for the current table :returns: The metric defined as a sqlalchemy column :rtype: sqlalchemy.sql.column """ expression_type = metric.get('expressionType') label = utils.get_metric_name(metric) label = self.get_label(label) if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']: column_name = metric.get('column').get('column_name') table_column = cols.get(column_name) if table_column: sqla_column = table_column.get_sqla_col() else: sqla_column = column(column_name) sqla_metric = self.sqla_aggregations[metric.get('aggregate')]( sqla_column) sqla_metric = sqla_metric.label(label) return sqla_metric elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']: sqla_metric = literal_column(metric.get('sqlExpression')) sqla_metric = sqla_metric.label(label) return sqla_metric else: return None def get_sqla_query( # sqla self, groupby, metrics, granularity, from_dttm, to_dttm, filter=None, # noqa is_timeseries=True, timeseries_limit=15, timeseries_limit_metric=None, row_limit=None, inner_from_dttm=None, inner_to_dttm=None, orderby=None, extras=None, columns=None, order_desc=True, prequeries=None, is_prequery=False, ): """Querying any sqla table from this common interface""" template_kwargs = { 'from_dttm': from_dttm, 'groupby': groupby, 'metrics': metrics, 'row_limit': row_limit, 'to_dttm': to_dttm, 'filter': filter, 'columns': {col.column_name: col for col in self.columns}, } template_kwargs.update(self.template_params_dict) template_processor = self.get_template_processor(**template_kwargs) db_engine_spec = self.database.db_engine_spec # Initialize empty cache to store mutated labels self.mutated_labels = {} orderby = orderby or [] # For backward compatibility if granularity not in self.dttm_cols: granularity = self.main_dttm_col # Database spec supports join-free timeslot grouping time_groupby_inline = db_engine_spec.time_groupby_inline cols = {col.column_name: col for col in self.columns} metrics_dict = {m.metric_name: m for m in self.metrics} if not granularity and is_timeseries: raise Exception( _('Datetime column not provided as part table configuration ' 'and is required by this type of chart')) if not groupby and not metrics and not columns: raise Exception(_('Empty query?')) metrics_exprs = [] for m in metrics: if utils.is_adhoc_metric(m): metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols)) elif m in metrics_dict: metrics_exprs.append(metrics_dict.get(m).get_sqla_col()) else: raise Exception(_("Metric '{}' is not valid".format(m))) if metrics_exprs: main_metric_expr = metrics_exprs[0] else: label = self.get_label('ccount') main_metric_expr = literal_column('COUNT(*)').label(label) select_exprs = [] groupby_exprs_sans_timestamp = OrderedDict() if groupby: select_exprs = [] for s in groupby: if s in cols: outer = cols[s].get_sqla_col() else: outer = literal_column(f'({s})').label(self.get_label(s)) groupby_exprs_sans_timestamp[outer.name] = outer select_exprs.append(outer) elif columns: for s in columns: select_exprs.append(cols[s].get_sqla_col() if s in cols else literal_column(s)) metrics_exprs = [] groupby_exprs_with_timestamp = OrderedDict( groupby_exprs_sans_timestamp.items()) if granularity: dttm_col = cols[granularity] time_grain = extras.get('time_grain_sqla') time_filters = [] if is_timeseries: timestamp = dttm_col.get_timestamp_expression(time_grain) select_exprs += [timestamp] groupby_exprs_with_timestamp[timestamp.name] = timestamp # Use main dttm column to support index with secondary dttm columns if db_engine_spec.time_secondary_columns and \ self.main_dttm_col in self.dttm_cols and \ self.main_dttm_col != dttm_col.column_name: time_filters.append(cols[self.main_dttm_col].get_time_filter( from_dttm, to_dttm)) time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm)) select_exprs += metrics_exprs labels_expected = [str(c.name) for c in select_exprs] select_exprs = db_engine_spec.make_select_compatible( groupby_exprs_with_timestamp.values(), select_exprs) qry = sa.select(select_exprs) tbl = self.get_from_clause(template_processor) if not columns: qry = qry.group_by(*groupby_exprs_with_timestamp.values()) where_clause_and = [] having_clause_and = [] for flt in filter: if not all([flt.get(s) for s in ['col', 'op']]): continue col = flt['col'] op = flt['op'] col_obj = cols.get(col) if col_obj: is_list_target = op in ('in', 'not in') eq = self.filter_values_handler( flt.get('val'), target_column_is_numeric=col_obj.is_num, is_list_target=is_list_target) if op in ('in', 'not in'): cond = col_obj.get_sqla_col().in_(eq) if '<NULL>' in eq: cond = or_(cond, col_obj.get_sqla_col() == None) # noqa if op == 'not in': cond = ~cond where_clause_and.append(cond) else: if col_obj.is_num: eq = utils.string_to_num(flt['val']) if op == '==': where_clause_and.append(col_obj.get_sqla_col() == eq) elif op == '!=': where_clause_and.append(col_obj.get_sqla_col() != eq) elif op == '>': where_clause_and.append(col_obj.get_sqla_col() > eq) elif op == '<': where_clause_and.append(col_obj.get_sqla_col() < eq) elif op == '>=': where_clause_and.append(col_obj.get_sqla_col() >= eq) elif op == '<=': where_clause_and.append(col_obj.get_sqla_col() <= eq) elif op == 'LIKE': where_clause_and.append( col_obj.get_sqla_col().like(eq)) elif op == 'IS NULL': where_clause_and.append( col_obj.get_sqla_col() == None) # noqa elif op == 'IS NOT NULL': where_clause_and.append( col_obj.get_sqla_col() != None) # noqa if extras: where = extras.get('where') if where: where = template_processor.process_template(where) where_clause_and += [sa.text('({})'.format(where))] having = extras.get('having') if having: having = template_processor.process_template(having) having_clause_and += [sa.text('({})'.format(having))] if granularity: qry = qry.where(and_(*(time_filters + where_clause_and))) else: qry = qry.where(and_(*where_clause_and)) qry = qry.having(and_(*having_clause_and)) if not orderby and not columns: orderby = [(main_metric_expr, not order_desc)] for col, ascending in orderby: direction = asc if ascending else desc if utils.is_adhoc_metric(col): col = self.adhoc_metric_to_sqla(col, cols) qry = qry.order_by(direction(col)) if row_limit: qry = qry.limit(row_limit) if is_timeseries and \ timeseries_limit and groupby and not time_groupby_inline: if self.database.db_engine_spec.inner_joins: # some sql dialects require for order by expressions # to also be in the select clause -- others, e.g. vertica, # require a unique inner alias label = self.get_label('mme_inner__') inner_main_metric_expr = main_metric_expr.label(label) inner_groupby_exprs = [] inner_select_exprs = [] for gby_name, gby_obj in groupby_exprs_sans_timestamp.items(): inner = gby_obj.label(gby_name + '__') inner_groupby_exprs.append(inner) inner_select_exprs.append(inner) inner_select_exprs += [inner_main_metric_expr] subq = select(inner_select_exprs).select_from(tbl) inner_time_filter = dttm_col.get_time_filter( inner_from_dttm or from_dttm, inner_to_dttm or to_dttm, ) subq = subq.where( and_(*(where_clause_and + [inner_time_filter]))) subq = subq.group_by(*inner_groupby_exprs) ob = inner_main_metric_expr if timeseries_limit_metric: if utils.is_adhoc_metric(timeseries_limit_metric): ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols) elif timeseries_limit_metric in metrics_dict: timeseries_limit_metric = metrics_dict.get( timeseries_limit_metric, ) ob = timeseries_limit_metric.get_sqla_col() else: raise Exception(_( "Metric '{}' is not valid".format(m))) direction = desc if order_desc else asc subq = subq.order_by(direction(ob)) subq = subq.limit(timeseries_limit) on_clause = [] for gby_name, gby_obj in groupby_exprs_sans_timestamp.items(): # in this case the column name, not the alias, needs to be # conditionally mutated, as it refers to the column alias in # the inner query col_name = self.get_label(gby_name + '__') on_clause.append(gby_obj == column(col_name)) tbl = tbl.join(subq.alias(), and_(*on_clause)) else: # run subquery to get top groups subquery_obj = { 'prequeries': prequeries, 'is_prequery': True, 'is_timeseries': False, 'row_limit': timeseries_limit, 'groupby': groupby, 'metrics': metrics, 'granularity': granularity, 'from_dttm': inner_from_dttm or from_dttm, 'to_dttm': inner_to_dttm or to_dttm, 'filter': filter, 'orderby': orderby, 'extras': extras, 'columns': columns, 'order_desc': True, } result = self.query(subquery_obj) dimensions = [ c for c in result.df.columns if c not in metrics and c in groupby_exprs_sans_timestamp ] top_groups = self._get_top_groups( result.df, dimensions, groupby_exprs_sans_timestamp) qry = qry.where(top_groups) return SqlaQuery(sqla_query=qry.select_from(tbl), labels_expected=labels_expected) def _get_top_groups(self, df, dimensions, groupby_exprs): groups = [] for unused, row in df.iterrows(): group = [] for dimension in dimensions: group.append(groupby_exprs[dimension] == row[dimension]) groups.append(and_(*group)) return or_(*groups) def query(self, query_obj): qry_start_dttm = datetime.now() query_str_ext = self.get_query_str_extended(query_obj) sql = query_str_ext.sql status = utils.QueryStatus.SUCCESS error_message = None df = None logging.debug( '[PERFORMANCE CHECK] SQL Query formation time {0} '.format( datetime.now() - qry_start_dttm)) db_engine_spec = self.database.db_engine_spec try: df = self.database.get_df(sql, self.schema) if self.mutated_labels: df = df.rename(index=str, columns=self.mutated_labels) db_engine_spec.mutate_df_columns(df, sql, query_str_ext.labels_expected) except Exception as e: status = utils.QueryStatus.FAILED logging.exception(f'Query {sql} on schema {self.schema} failed') error_message = db_engine_spec.extract_error_message(e) # if this is a main query with prequeries, combine them together if not query_obj['is_prequery']: query_obj['prequeries'].append(sql) sql = ';\n\n'.join(query_obj['prequeries']) sql += ';' return QueryResult(status=status, df=df, duration=datetime.now() - qry_start_dttm, query=sql, error_message=error_message) def get_sqla_table_object(self): return self.database.get_table(self.table_name, schema=self.schema) def fetch_metadata(self): """Fetches the metadata for the table and merges it in""" try: has_kerberos_ticket() table = self.get_sqla_table_object() except Exception as e: logging.exception(e) raise Exception( _("Table [{}] doesn't seem to exist in the specified database, " "couldn't fetch column information").format(self.table_name)) M = SqlMetric # noqa metrics = [] any_date_col = None db_dialect = self.database.get_dialect() dbcols = (db.session.query(TableColumn).filter( TableColumn.table == self).filter( or_(TableColumn.column_name == col.name for col in table.columns))) dbcols = {dbcol.column_name: dbcol for dbcol in dbcols} for col in table.columns: try: datatype = col.type.compile(dialect=db_dialect).upper() except Exception as e: datatype = 'UNKNOWN' logging.error('Unrecognized data type in {}.{}'.format( table, col.name)) logging.exception(e) dbcol = dbcols.get(col.name, None) if not dbcol: dbcol = TableColumn(column_name=col.name, type=datatype) dbcol.sum = dbcol.is_num dbcol.avg = dbcol.is_num dbcol.is_dttm = dbcol.is_time else: dbcol.type = datatype dbcol.groupby = True dbcol.filterable = True self.columns.append(dbcol) if not any_date_col and dbcol.is_time: any_date_col = col.name metrics.append( M( metric_name='count', verbose_name='COUNT(*)', metric_type='count', expression='COUNT(*)', )) if not self.main_dttm_col: self.main_dttm_col = any_date_col self.add_missing_metrics(metrics) db.session.merge(self) db.session.commit() @classmethod def import_obj(cls, i_datasource, import_time=None): """Imports the datasource from the object to the database. Metrics and columns and datasource will be overrided if exists. This function can be used to import/export dashboards between multiple superset instances. Audit metadata isn't copies over. """ def lookup_sqlatable(table): return db.session.query(SqlaTable).join(Database).filter( SqlaTable.table_name == table.table_name, SqlaTable.schema == table.schema, Database.id == table.database_id, ).first() def lookup_database(table): return db.session.query(Database).filter_by( database_name=table.params_dict['database_name']).one() return import_datasource.import_datasource(db.session, i_datasource, lookup_database, lookup_sqlatable, import_time) @classmethod def query_datasources_by_name(cls, session, database, datasource_name, schema=None): query = (session.query(cls).filter_by( database_id=database.id).filter_by(table_name=datasource_name)) if schema: query = query.filter_by(schema=schema) return query.all() @staticmethod def default_query(qry): return qry.filter_by(is_sqllab_view=False)
class Task(Base): """Class to store a task. """ __tablename__ = 'tasks' __table_args__ = ( UniqueConstraint('contest_id', 'num'), UniqueConstraint('contest_id', 'name'), ForeignKeyConstraint( ("id", "active_dataset_id"), ("datasets.task_id", "datasets.id"), onupdate="SET NULL", ondelete="SET NULL", # Use an ALTER query to set this foreign key after # both tables have been CREATEd, to avoid circular # dependencies. use_alter=True, name="fk_active_dataset_id" ), CheckConstraint("token_initial <= token_max"), ) # Auto increment primary key. id = Column( Integer, primary_key=True, # Needed to enable autoincrement on integer primary keys that # are referenced by a foreign key defined on this table. autoincrement='ignore_fk') # Number of the task for sorting. num = Column( Integer, nullable=False) # Contest (id and object) owning the task. contest_id = Column( Integer, ForeignKey(Contest.id, onupdate="CASCADE", ondelete="CASCADE"), nullable=False, index=True) contest = relationship( Contest, backref=backref('tasks', collection_class=ordering_list('num'), order_by=[num], cascade="all, delete-orphan", passive_deletes=True)) # Short name and long human readable title of the task. name = Column( Unicode, nullable=False) title = Column( Unicode, nullable=False) # A JSON-encoded lists of strings: the language codes of the # statements that will be highlighted to all users for this task. primary_statements = Column( String, nullable=False, default="[]") # Parameter to define the token behaviour. See Contest.py for # details. The only change is that these parameters influence the # contest in a task-per-task behaviour. To play a token on a given # task, a user must satisfy the condition of the contest and the # one of the task. token_initial = Column( Integer, CheckConstraint("token_initial >= 0"), nullable=True) token_max = Column( Integer, CheckConstraint("token_max > 0"), nullable=True) token_total = Column( Integer, CheckConstraint("token_total > 0"), nullable=True) token_min_interval = Column( Interval, CheckConstraint("token_min_interval >= '0 seconds'"), nullable=False, default=timedelta()) token_gen_time = Column( Interval, CheckConstraint("token_gen_time >= '0 seconds'"), nullable=False, default=timedelta()) token_gen_number = Column( Integer, CheckConstraint("token_gen_number >= 0"), nullable=False, default=0) # Maximum number of submissions or user_tests allowed for each user # on this task during the whole contest or None to not enforce # this limitation. max_submission_number = Column( Integer, CheckConstraint("max_submission_number > 0"), nullable=True) max_user_test_number = Column( Integer, CheckConstraint("max_user_test_number > 0"), nullable=True) # Minimum interval between two submissions or user_tests for this # task, or None to not enforce this limitation. min_submission_interval = Column( Interval, CheckConstraint("min_submission_interval > '0 seconds'"), nullable=True) min_user_test_interval = Column( Interval, CheckConstraint("min_user_test_interval > '0 seconds'"), nullable=True) # The scores for this task will be rounded to this number of # decimal places. score_precision = Column( Integer, CheckConstraint("score_precision >= 0"), nullable=False, default=0) # Active Dataset (id and object) currently being used for scoring. # The ForeignKeyConstraint for this column is set at table-level. active_dataset_id = Column( Integer, nullable=True) active_dataset = relationship( 'Dataset', foreign_keys=[active_dataset_id], # XXX In SQLAlchemy 0.8 we could remove this: primaryjoin='Task.active_dataset_id == Dataset.id', # Use an UPDATE query *after* an INSERT query (and *before* a # DELETE query) to set (and unset) the column associated to # this relationship. post_update=True)
class PriceHistory(db.Model): __table_args__ = (UniqueConstraint('ticker_id', 'date', name='_ticker__date'), ) id = db.Column(db.Integer, primary_key=True) ticker_id = db.Column(db.Integer, db.ForeignKey('ticker.id'), nullable=False) date = db.Column(db.Date, nullable=False, server_default=func.current_date()) open = db.Column(db.Float(decimal_return_scale=2), nullable=False) close = db.Column(db.Float(decimal_return_scale=2), nullable=False) high = db.Column(db.Float(decimal_return_scale=2), nullable=False) low = db.Column(db.Float(decimal_return_scale=2), nullable=False) volume = db.Column(db.Integer, nullable=False) @classmethod def get_or_create(cls, ticker=None, date=None, **kwargs): instance = cls.query.filter_by(ticker_id=ticker.id, date=date).first() if instance: return instance, False else: kwargs['volume'] = kwargs['volume'].replace(',', '') instance = cls(ticker=ticker, date=date, **kwargs) db.session.add(instance) return instance, True def update(self, commit=False, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) if commit: db.session.commit() @classmethod def get_analytics(cls, ticker_name, date_from, date_to): PricesA = aliased(cls) PricesB = aliased(cls) result = db.session.query( label('open', func.abs(PricesA.open - PricesB.open)), label('close', func.abs(PricesA.close - PricesB.close)), label('low', func.abs(PricesA.low - PricesB.low)), label('high', func.abs(PricesA.close - PricesB.high))).join(Ticker, ).join( PricesB, and_(PricesB.ticker_id == Ticker.id)).filter( Ticker.name == ticker_name, PricesA.date == date_from, PricesB.date == date_to) return result @classmethod def get_delta(cls, ticker_name, type_price, value): sql = text( DELTA_SELECT.format(ticker_name=ticker_name, type_price=type_price, value_delta=value)) result = db.engine.execute(sql) return result.fetchall() def __repr__(self): return '<{} = {}, {}'.format(self.ticker.name, self.volume, self.date)
class ImapFolderSyncStatus(MailSyncBase, HasRunState, UpdatedAtMixin, DeletedAtMixin): """ Per-folder status state saving for IMAP folders. """ account_id = Column(ForeignKey(ImapAccount.id, ondelete='CASCADE'), nullable=False) account = relationship(ImapAccount, backref=backref('foldersyncstatuses', passive_deletes=True)) folder_id = Column(ForeignKey('folder.id', ondelete='CASCADE'), nullable=False) # We almost always need the folder name too, so eager load by default. folder = relationship('Folder', lazy='joined', backref=backref('imapsyncstatus', uselist=False, passive_deletes=True)) # see state machine in mailsync/backends/imap/imap.py state = Column(Enum('initial', 'initial uidinvalid', 'poll', 'poll uidinvalid', 'finish'), server_default='initial', nullable=False) # stats on messages downloaded etc. _metrics = Column(MutableDict.as_mutable(JSON), default={}, nullable=True) @property def metrics(self): status = dict(name=self.folder.name, state=self.state) status.update(self._metrics or {}) return status def start_sync(self): self._metrics = dict(run_state='running', sync_start_time=datetime.utcnow()) def stop_sync(self): self._metrics['run_state'] = 'stopped' self._metrics['sync_end_time'] = datetime.utcnow() @property def is_killed(self): return self._metrics.get('run_state') == 'killed' def update_metrics(self, metrics): sync_status_metrics = [ 'remote_uid_count', 'delete_uid_count', 'update_uid_count', 'download_uid_count', 'uid_checked_timestamp', 'num_downloaded_since_timestamp', 'queue_checked_at', 'percent' ] assert isinstance(metrics, dict) for k in metrics.iterkeys(): assert k in sync_status_metrics, k if self._metrics is not None: self._metrics.update(metrics) else: self._metrics = metrics @property def sync_enabled(self): # sync is enabled if the folder's run bit is set, and the account's # run bit is set. (this saves us needing to reproduce account-state # transition logic on the folder level, and gives us a comparison bit # against folder heartbeats.) return self.sync_should_run and self.account.sync_should_run __table_args__ = (UniqueConstraint('account_id', 'folder_id'), )
class SqlaTable(Model, BaseDatasource): """An ORM object for SqlAlchemy table references""" type = 'table' query_language = 'sql' metric_class = SqlMetric column_class = TableColumn __tablename__ = 'tables' __table_args__ = (UniqueConstraint('database_id', 'table_name'), ) table_name = Column(String(250)) main_dttm_col = Column(String(250)) database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False) fetch_values_predicate = Column(String(1000)) user_id = Column(Integer, ForeignKey('ab_user.id')) owner = relationship(sm.user_model, backref='tables', foreign_keys=[user_id]) database = relationship('Database', backref=backref('tables', cascade='all, delete-orphan'), foreign_keys=[database_id]) schema = Column(String(255)) sql = Column(Text) baselink = 'tablemodelview' export_fields = ('table_name', 'main_dttm_col', 'description', 'default_endpoint', 'database_id', 'offset', 'cache_timeout', 'schema', 'sql', 'params') export_parent = 'database' export_children = ['metrics', 'columns'] def __repr__(self): return self.name @property def connection(self): return str(self.database) @property def description_markeddown(self): return utils.markdown(self.description) @property def link(self): name = escape(self.name) return Markup( '<a href="{self.explore_url}">{name}</a>'.format(**locals())) @property def schema_perm(self): """Returns schema permission if present, database one otherwise.""" return utils.get_schema_perm(self.database, self.schema) def get_perm(self): return ('[{obj.database}].[{obj.table_name}]' '(id:{obj.id})').format(obj=self) @property def name(self): if not self.schema: return self.table_name return '{}.{}'.format(self.schema, self.table_name) @property def full_name(self): return utils.get_datasource_full_name(self.database, self.table_name, schema=self.schema) @property def dttm_cols(self): l = [c.column_name for c in self.columns if c.is_dttm] # noqa: E741 if self.main_dttm_col and self.main_dttm_col not in l: l.append(self.main_dttm_col) return l @property def num_cols(self): return [c.column_name for c in self.columns if c.is_num] @property def any_dttm_col(self): cols = self.dttm_cols if cols: return cols[0] @property def html(self): t = ((c.column_name, c.type) for c in self.columns) df = pd.DataFrame(t) df.columns = ['field', 'type'] return df.to_html( index=False, classes=('dataframe table table-striped table-bordered ' 'table-condensed')) @property def sql_url(self): return self.database.sql_url + '?table_name=' + str(self.table_name) @property def time_column_grains(self): return { 'time_columns': self.dttm_cols, 'time_grains': [grain.name for grain in self.database.grains()], } def get_col(self, col_name): columns = self.columns for col in columns: if col_name == col.column_name: return col @property def data(self): d = super(SqlaTable, self).data if self.type == 'table': grains = self.database.grains() or [] if grains: grains = [(g.name, g.name) for g in grains] d['granularity_sqla'] = utils.choicify(self.dttm_cols) d['time_grain_sqla'] = grains return d def values_for_column(self, column_name, limit=10000): """Runs query against sqla to retrieve some sample values for the given column. """ cols = {col.column_name: col for col in self.columns} target_col = cols[column_name] tp = self.get_template_processor() db_engine_spec = self.database.db_engine_spec qry = (select([target_col.sqla_col]).select_from( self.get_from_clause(tp, db_engine_spec)).distinct(column_name)) if limit: qry = qry.limit(limit) if self.fetch_values_predicate: tp = self.get_template_processor() qry = qry.where(tp.process_template(self.fetch_values_predicate)) engine = self.database.get_sqla_engine() sql = '{}'.format( qry.compile(engine, compile_kwargs={'literal_binds': True}), ) df = pd.read_sql_query(sql=sql, con=engine) return [row[0] for row in df.to_records(index=False)] def get_template_processor(self, **kwargs): return get_template_processor(table=self, database=self.database, **kwargs) def get_query_str(self, query_obj): engine = self.database.get_sqla_engine() qry = self.get_sqla_query(**query_obj) sql = six.text_type( qry.compile( engine, compile_kwargs={'literal_binds': True}, ), ) logging.info(sql) sql = sqlparse.format(sql, reindent=True) if query_obj['is_prequery']: query_obj['prequeries'].append(sql) return sql def get_sqla_table(self): tbl = table(self.table_name) if self.schema: tbl.schema = self.schema return tbl def get_from_clause(self, template_processor=None, db_engine_spec=None): # Supporting arbitrary SQL statements in place of tables if self.sql: from_sql = self.sql if template_processor: from_sql = template_processor.process_template(from_sql) from_sql = sqlparse.format(from_sql, strip_comments=True) return TextAsFrom(sa.text(from_sql), []).alias('expr_qry') return self.get_sqla_table() def get_sqla_query( # sqla self, groupby, metrics, granularity, from_dttm, to_dttm, filter=None, # noqa is_timeseries=True, timeseries_limit=15, timeseries_limit_metric=None, row_limit=None, inner_from_dttm=None, inner_to_dttm=None, orderby=None, extras=None, columns=None, order_desc=True, prequeries=None, is_prequery=False, ): """Querying any sqla table from this common interface""" template_kwargs = { 'from_dttm': from_dttm, 'groupby': groupby, 'metrics': metrics, 'row_limit': row_limit, 'to_dttm': to_dttm, } template_processor = self.get_template_processor(**template_kwargs) db_engine_spec = self.database.db_engine_spec orderby = orderby or [] # For backward compatibility if granularity not in self.dttm_cols: granularity = self.main_dttm_col # Database spec supports join-free timeslot grouping time_groupby_inline = db_engine_spec.time_groupby_inline cols = {col.column_name: col for col in self.columns} metrics_dict = {m.metric_name: m for m in self.metrics} if not granularity and is_timeseries: raise Exception( _('Datetime column not provided as part table configuration ' 'and is required by this type of chart')) if not groupby and not metrics and not columns: raise Exception(_('Empty query?')) for m in metrics: if m not in metrics_dict: raise Exception(_("Metric '{}' is not valid".format(m))) metrics_exprs = [metrics_dict.get(m).sqla_col for m in metrics] if metrics_exprs: main_metric_expr = metrics_exprs[0] else: main_metric_expr = literal_column('COUNT(*)').label('ccount') select_exprs = [] groupby_exprs = [] if groupby: select_exprs = [] inner_select_exprs = [] inner_groupby_exprs = [] for s in groupby: col = cols[s] outer = col.sqla_col inner = col.sqla_col.label(col.column_name + '__') groupby_exprs.append(outer) select_exprs.append(outer) inner_groupby_exprs.append(inner) inner_select_exprs.append(inner) elif columns: for s in columns: select_exprs.append(cols[s].sqla_col) metrics_exprs = [] if granularity: dttm_col = cols[granularity] time_grain = extras.get('time_grain_sqla') time_filters = [] if is_timeseries: timestamp = dttm_col.get_timestamp_expression(time_grain) select_exprs += [timestamp] groupby_exprs += [timestamp] # Use main dttm column to support index with secondary dttm columns if db_engine_spec.time_secondary_columns and \ self.main_dttm_col in self.dttm_cols and \ self.main_dttm_col != dttm_col.column_name: time_filters.append(cols[self.main_dttm_col].get_time_filter( from_dttm, to_dttm)) time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm)) select_exprs += metrics_exprs qry = sa.select(select_exprs) tbl = self.get_from_clause(template_processor, db_engine_spec) if not columns: qry = qry.group_by(*groupby_exprs) where_clause_and = [] having_clause_and = [] for flt in filter: if not all([flt.get(s) for s in ['col', 'op', 'val']]): continue col = flt['col'] op = flt['op'] eq = flt['val'] col_obj = cols.get(col) if col_obj: if op in ('in', 'not in'): values = [] for v in eq: # For backwards compatibility and edge cases # where a column data type might have changed if isinstance(v, basestring): v = v.strip("'").strip('"') if col_obj.is_num: v = utils.string_to_num(v) # Removing empty strings and non numeric values # targeting numeric columns if v is not None: values.append(v) cond = col_obj.sqla_col.in_(values) if op == 'not in': cond = ~cond where_clause_and.append(cond) else: if col_obj.is_num: eq = utils.string_to_num(flt['val']) if op == '==': where_clause_and.append(col_obj.sqla_col == eq) elif op == '!=': where_clause_and.append(col_obj.sqla_col != eq) elif op == '>': where_clause_and.append(col_obj.sqla_col > eq) elif op == '<': where_clause_and.append(col_obj.sqla_col < eq) elif op == '>=': where_clause_and.append(col_obj.sqla_col >= eq) elif op == '<=': where_clause_and.append(col_obj.sqla_col <= eq) elif op == 'LIKE': where_clause_and.append(col_obj.sqla_col.like(eq)) if extras: where = extras.get('where') if where: where = template_processor.process_template(where) where_clause_and += [sa.text('({})'.format(where))] having = extras.get('having') if having: having = template_processor.process_template(having) having_clause_and += [sa.text('({})'.format(having))] if granularity: qry = qry.where(and_(*(time_filters + where_clause_and))) else: qry = qry.where(and_(*where_clause_and)) qry = qry.having(and_(*having_clause_and)) if not orderby and not columns: orderby = [(main_metric_expr, not order_desc)] for col, ascending in orderby: direction = asc if ascending else desc qry = qry.order_by(direction(col)) if row_limit: qry = qry.limit(row_limit) if is_timeseries and \ timeseries_limit and groupby and not time_groupby_inline: if self.database.db_engine_spec.inner_joins: # some sql dialects require for order by expressions # to also be in the select clause -- others, e.g. vertica, # require a unique inner alias inner_main_metric_expr = main_metric_expr.label('mme_inner__') inner_select_exprs += [inner_main_metric_expr] subq = select(inner_select_exprs) subq = subq.select_from(tbl) inner_time_filter = dttm_col.get_time_filter( inner_from_dttm or from_dttm, inner_to_dttm or to_dttm, ) subq = subq.where( and_(*(where_clause_and + [inner_time_filter]))) subq = subq.group_by(*inner_groupby_exprs) ob = inner_main_metric_expr if timeseries_limit_metric: timeseries_limit_metric = metrics_dict.get( timeseries_limit_metric) ob = timeseries_limit_metric.sqla_col direction = desc if order_desc else asc subq = subq.order_by(direction(ob)) subq = subq.limit(timeseries_limit) on_clause = [] for i, gb in enumerate(groupby): on_clause.append(groupby_exprs[i] == column(gb + '__')) tbl = tbl.join(subq.alias(), and_(*on_clause)) else: # run subquery to get top groups subquery_obj = { 'prequeries': prequeries, 'is_prequery': True, 'is_timeseries': False, 'row_limit': timeseries_limit, 'groupby': groupby, 'metrics': metrics, 'granularity': granularity, 'from_dttm': inner_from_dttm or from_dttm, 'to_dttm': inner_to_dttm or to_dttm, 'filter': filter, 'orderby': orderby, 'extras': extras, 'columns': columns, 'order_desc': True, } result = self.query(subquery_obj) dimensions = [c for c in result.df.columns if c not in metrics] top_groups = self._get_top_groups(result.df, dimensions) qry = qry.where(top_groups) return qry.select_from(tbl) def _get_top_groups(self, df, dimensions): cols = {col.column_name: col for col in self.columns} groups = [] for unused, row in df.iterrows(): group = [] for dimension in dimensions: col_obj = cols.get(dimension) group.append(col_obj.sqla_col == row[dimension]) groups.append(and_(*group)) return or_(*groups) def query(self, query_obj): qry_start_dttm = datetime.now() sql = self.get_query_str(query_obj) status = QueryStatus.SUCCESS error_message = None df = None try: df = self.database.get_df(sql, self.schema) except Exception as e: status = QueryStatus.FAILED logging.exception(e) error_message = ( self.database.db_engine_spec.extract_error_message(e)) # if this is a main query with prequeries, combine them together if not query_obj['is_prequery']: query_obj['prequeries'].append(sql) sql = ';\n\n'.join(query_obj['prequeries']) sql += ';' return QueryResult(status=status, df=df, duration=datetime.now() - qry_start_dttm, query=sql, error_message=error_message) def get_sqla_table_object(self): return self.database.get_table(self.table_name, schema=self.schema) def fetch_metadata(self): """Fetches the metadata for the table and merges it in""" try: table = self.get_sqla_table_object() except Exception: raise Exception( _("Table [{}] doesn't seem to exist in the specified database, " "couldn't fetch column information").format(self.table_name)) M = SqlMetric # noqa metrics = [] any_date_col = None db_dialect = self.database.get_dialect() dbcols = (db.session.query(TableColumn).filter( TableColumn.table == self).filter( or_(TableColumn.column_name == col.name for col in table.columns))) dbcols = {dbcol.column_name: dbcol for dbcol in dbcols} for col in table.columns: try: datatype = col.type.compile(dialect=db_dialect).upper() except Exception as e: datatype = 'UNKNOWN' logging.error('Unrecognized data type in {}.{}'.format( table, col.name)) logging.exception(e) dbcol = dbcols.get(col.name, None) if not dbcol: dbcol = TableColumn(column_name=col.name, type=datatype) dbcol.groupby = dbcol.is_string dbcol.filterable = dbcol.is_string dbcol.sum = dbcol.is_num dbcol.avg = dbcol.is_num dbcol.is_dttm = dbcol.is_time else: dbcol.type = datatype self.columns.append(dbcol) if not any_date_col and dbcol.is_time: any_date_col = col.name metrics += dbcol.get_metrics().values() metrics.append( M( metric_name='count', verbose_name='COUNT(*)', metric_type='count', expression='COUNT(*)', )) if not self.main_dttm_col: self.main_dttm_col = any_date_col self.add_missing_metrics(metrics) db.session.merge(self) db.session.commit() @classmethod def import_obj(cls, i_datasource, import_time=None): """Imports the datasource from the object to the database. Metrics and columns and datasource will be overrided if exists. This function can be used to import/export dashboards between multiple superset instances. Audit metadata isn't copies over. """ def lookup_sqlatable(table): return db.session.query(SqlaTable).join(Database).filter( SqlaTable.table_name == table.table_name, SqlaTable.schema == table.schema, Database.id == table.database_id, ).first() def lookup_database(table): return db.session.query(Database).filter_by( database_name=table.params_dict['database_name']).one() return import_util.import_datasource(db.session, i_datasource, lookup_database, lookup_sqlatable, import_time) @classmethod def query_datasources_by_name(cls, session, database, datasource_name, schema=None): query = (session.query(cls).filter_by( database_id=database.id).filter_by(table_name=datasource_name)) if schema: query = query.filter_by(schema=schema) return query.all()
class ImapUid(MailSyncBase, UpdatedAtMixin, DeletedAtMixin): """ Maps UIDs to their IMAP folders and per-UID flag metadata. This table is used solely for bookkeeping by the IMAP mail sync backends. """ account_id = Column(ForeignKey(ImapAccount.id, ondelete='CASCADE'), nullable=False) account = relationship(ImapAccount) message_id = Column(ForeignKey(Message.id, ondelete='CASCADE'), nullable=False) message = relationship(Message, backref=backref('imapuids', passive_deletes=True)) msg_uid = Column(BigInteger, nullable=False, index=True) folder_id = Column(ForeignKey(Folder.id, ondelete='CASCADE'), nullable=False) # We almost always need the folder name too, so eager load by default. folder = relationship(Folder, lazy='joined', backref=backref('imapuids', passive_deletes=True)) labels = association_proxy('labelitems', 'label', creator=lambda label: LabelItem(label=label)) # Flags # # Message has not completed composition (marked as a draft). is_draft = Column(Boolean, server_default=false(), nullable=False) # Message has been read is_seen = Column(Boolean, server_default=false(), nullable=False) # Message is "flagged" for urgent/special attention is_flagged = Column(Boolean, server_default=false(), nullable=False) # session is the first session to have been notified about this message is_recent = Column(Boolean, server_default=false(), nullable=False) # Message has been answered is_answered = Column(Boolean, server_default=false(), nullable=False) # things like: ['$Forwarded', 'nonjunk', 'Junk'] extra_flags = Column(LittleJSON, default=[], nullable=False) # labels (Gmail-specific) # TO BE DEPRECATED g_labels = Column(JSON, default=lambda: [], nullable=True) def update_flags(self, new_flags): """ Sets flag and g_labels values based on the new_flags and x_gm_labels parameters. Returns True if any values have changed compared to what we previously stored. """ changed = False new_flags = set(new_flags) col_for_flag = { u'\\Draft': 'is_draft', u'\\Seen': 'is_seen', u'\\Recent': 'is_recent', u'\\Answered': 'is_answered', u'\\Flagged': 'is_flagged', } for flag, col in col_for_flag.iteritems(): prior_flag_value = getattr(self, col) new_flag_value = flag in new_flags if prior_flag_value != new_flag_value: changed = True setattr(self, col, new_flag_value) new_flags.discard(flag) extra_flags = sorted(new_flags) if extra_flags != self.extra_flags: changed = True # Sadly, there's a limit of 255 chars for this # column. while len(json.dumps(extra_flags)) > 255: extra_flags.pop() self.extra_flags = extra_flags return changed def update_labels(self, new_labels): # TODO(emfree): This is all mad complicated. Simplify if possible? # Gmail IMAP doesn't use the normal IMAP \\Draft flag. Silly Gmail # IMAP. self.is_draft = '\\Draft' in new_labels self.is_starred = '\\Starred' in new_labels category_map = { '\\Inbox': 'inbox', '\\Important': 'important', '\\Sent': 'sent', '\\Trash': 'trash', '\\Spam': 'spam', '\\All': 'all' } remote_labels = set() for label in new_labels: if label in ('\\Draft', '\\Starred'): continue elif label in category_map: remote_labels.add((category_map[label], category_map[label])) else: remote_labels.add((label, None)) local_labels = {(l.name, l.canonical_name): l for l in self.labels} remove = set(local_labels) - remote_labels add = remote_labels - set(local_labels) with object_session(self).no_autoflush: for key in remove: self.labels.remove(local_labels[key]) for name, canonical_name in add: label = Label.find_or_create(object_session(self), self.account, name, canonical_name) self.labels.add(label) @property def namespace(self): return self.imapaccount.namespace @property def categories(self): categories = set([l.category for l in self.labels]) categories.add(self.folder.category) return categories __table_args__ = (UniqueConstraint( 'folder_id', 'msg_uid', 'account_id', ), )
class Tag(MailSyncBase, HasRevisions): """Tags represent extra data associated with threads. A note about the schema. The 'public_id' of a tag is immutable. For reserved tags such as the inbox or starred tag, the public_id is a fixed human-readable string. For other tags, the public_id is an autogenerated uid similar to a normal public id, but stored as a string for compatibility. The name of a tag is allowed to be mutable, to allow for the eventuality that users wish to change the name of user-created labels, or that we someday expose localized names ('DAS INBOX'), or that we somehow manage to sync renamed gmail labels, etc. """ API_OBJECT_NAME = 'tag' namespace = relationship( Namespace, backref=backref( 'tags', collection_class=attribute_mapped_collection('public_id')), load_on_pending=True) namespace_id = Column(Integer, ForeignKey( 'namespace.id', ondelete='CASCADE'), nullable=False) public_id = Column(String(MAX_INDEXABLE_LENGTH), nullable=False, default=generate_public_id) name = Column(String(MAX_INDEXABLE_LENGTH), nullable=False) user_created = Column(Boolean, server_default=false(), nullable=False) CANONICAL_TAG_NAMES = ['inbox', 'archive', 'drafts', 'sending', 'sent', 'spam', 'starred', 'trash', 'unread', 'unseen', 'attachment'] RESERVED_TAG_NAMES = ['all', 'archive', 'drafts', 'send', 'replied', 'file', 'attachment', 'unseen'] # Tags that are allowed to be both added and removed via the API. USER_MUTABLE_TAGS = ['unread', 'starred', 'spam', 'trash', 'inbox', 'archive'] @property def user_removable(self): # The 'unseen' tag can only be removed. return (self.user_created or self.public_id in self.USER_MUTABLE_TAGS or self.public_id == 'unseen') @property def user_addable(self): return (self.user_created or self.public_id in self.USER_MUTABLE_TAGS) @property def readonly(self): return not (self.user_removable or self.user_addable) @classmethod def name_available(cls, name, namespace_id, db_session): name = name.lower() if name in cls.RESERVED_TAG_NAMES or name in cls.CANONICAL_TAG_NAMES: return False if (name,) in db_session.query(Tag.name). \ filter(Tag.namespace_id == namespace_id).all(): return False return True __table_args__ = (UniqueConstraint('namespace_id', 'name'), UniqueConstraint('namespace_id', 'public_id'))
from trove.db.sqlalchemy.migrate_repo.schema import create_tables from trove.db.sqlalchemy.migrate_repo.schema import DateTime from trove.db.sqlalchemy.migrate_repo.schema import drop_tables from trove.db.sqlalchemy.migrate_repo.schema import Integer from trove.db.sqlalchemy.migrate_repo.schema import String from trove.db.sqlalchemy.migrate_repo.schema import Table meta = MetaData() quotas = Table('quotas', meta, Column('id', String(36), primary_key=True, nullable=False), Column('created', DateTime()), Column('updated', DateTime()), Column('tenant_id', String(36)), Column('resource', String(length=255), nullable=False), Column('hard_limit', Integer()), UniqueConstraint('tenant_id', 'resource')) quota_usages = Table( 'quota_usages', meta, Column('id', String(36), primary_key=True, nullable=False), Column('created', DateTime()), Column('updated', DateTime()), Column('tenant_id', String(36)), Column('in_use', Integer(), default=0), Column('reserved', Integer(), default=0), Column('resource', String(length=255), nullable=False), UniqueConstraint('tenant_id', 'resource')) reservations = Table( 'reservations', meta, Column('created', DateTime()), Column('updated', DateTime()), Column('id', String(36), primary_key=True, nullable=False), Column('usage_id', String(36)), Column('delta', Integer(), nullable=False),
class Shelf(Base): __tablename__ = 'shelf' __table_args__ = (UniqueConstraint('shelf'), ) id = sa.Column(sa.Integer, primary_key=True) shelf = sa.Column(sa.String, nullable=False)
class TableColumn(Model, BaseColumn): """ORM object for table columns, each table can have multiple columns""" __tablename__ = "table_columns" __table_args__ = (UniqueConstraint("table_id", "column_name"), ) table_id = Column(Integer, ForeignKey("tables.id")) table = relationship( "SqlaTable", backref=backref("columns", cascade="all, delete-orphan"), foreign_keys=[table_id], ) is_dttm = Column(Boolean, default=False) expression = Column(Text) python_date_format = Column(String(255)) export_fields = [ "table_id", "column_name", "verbose_name", "is_dttm", "is_active", "type", "groupby", "filterable", "expression", "description", "python_date_format", ] update_from_object_fields = [ s for s in export_fields if s not in ("table_id", ) ] export_parent = "table" def get_sqla_col(self, label: Optional[str] = None) -> Column: label = label or self.column_name if not self.expression: db_engine_spec = self.table.database.db_engine_spec type_ = db_engine_spec.get_sqla_column_type(self.type) col = column(self.column_name, type_=type_) else: col = literal_column(self.expression) col = self.table.make_sqla_column_compatible(col, label) return col @property def datasource(self) -> RelationshipProperty: return self.table def get_time_filter(self, start_dttm: DateTime, end_dttm: DateTime) -> ColumnElement: col = self.get_sqla_col(label="__time") l = [] if start_dttm: l.append(col >= text(self.dttm_sql_literal(start_dttm))) if end_dttm: l.append(col <= text(self.dttm_sql_literal(end_dttm))) return and_(*l) def get_timestamp_expression( self, time_grain: Optional[str]) -> Union[TimestampExpression, Label]: """ Return a SQLAlchemy Core element representation of self to be used in a query. :param time_grain: Optional time grain, e.g. P1Y :return: A TimeExpression object wrapped in a Label if supported by db """ label = utils.DTTM_ALIAS db = self.table.database pdf = self.python_date_format is_epoch = pdf in ("epoch_s", "epoch_ms") if not self.expression and not time_grain and not is_epoch: sqla_col = column(self.column_name, type_=DateTime) return self.table.make_sqla_column_compatible(sqla_col, label) if self.expression: col = literal_column(self.expression) else: col = column(self.column_name) time_expr = db.db_engine_spec.get_timestamp_expr(col, pdf, time_grain) return self.table.make_sqla_column_compatible(time_expr, label) @classmethod def import_obj(cls, i_column): def lookup_obj(lookup_column): return (db.session.query(TableColumn).filter( TableColumn.table_id == lookup_column.table_id, TableColumn.column_name == lookup_column.column_name, ).first()) return import_datasource.import_simple_obj(db.session, i_column, lookup_obj) def dttm_sql_literal(self, dttm: DateTime) -> str: """Convert datetime object to a SQL expression string""" tf = self.python_date_format if tf: seconds_since_epoch = int(dttm.timestamp()) if tf == "epoch_s": return str(seconds_since_epoch) elif tf == "epoch_ms": return str(seconds_since_epoch * 1000) return "'{}'".format(dttm.strftime(tf)) else: s = self.table.database.db_engine_spec.convert_dttm( self.type or "", dttm) # TODO(john-bodley): SIP-15 will explicitly require a type conversion. return s or "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S.%f"))
class Language(Base): __tablename__ = 'language' __table_args__ = (UniqueConstraint('code'), ) id = sa.Column(sa.Integer, primary_key=True) code = sa.Column(sa.String, nullable=False)
class SqlaTable(Model, BaseDatasource): """An ORM object for SqlAlchemy table references""" type = "table" query_language = "sql" metric_class = SqlMetric column_class = TableColumn owner_class = security_manager.user_model __tablename__ = "tables" __table_args__ = (UniqueConstraint("database_id", "table_name"), ) table_name = Column(String(250), nullable=False) main_dttm_col = Column(String(250)) database_id = Column(Integer, ForeignKey("dbs.id"), nullable=False) fetch_values_predicate = Column(String(1000)) owners = relationship(owner_class, secondary=sqlatable_user, backref="tables") database = relationship( "Database", backref=backref("tables", cascade="all, delete-orphan"), foreign_keys=[database_id], ) schema = Column(String(255)) sql = Column(Text) is_sqllab_view = Column(Boolean, default=False) template_params = Column(Text) baselink = "tablemodelview" export_fields = [ "table_name", "main_dttm_col", "description", "default_endpoint", "database_id", "offset", "cache_timeout", "schema", "sql", "params", "template_params", "filter_select_enabled", "fetch_values_predicate", ] update_from_object_fields = [ f for f in export_fields if f not in ("table_name", "database_id") ] export_parent = "database" export_children = ["metrics", "columns"] sqla_aggregations = { "COUNT_DISTINCT": lambda column_name: sa.func.COUNT(sa.distinct(column_name)), "COUNT": sa.func.COUNT, "SUM": sa.func.SUM, "AVG": sa.func.AVG, "MIN": sa.func.MIN, "MAX": sa.func.MAX, } def make_sqla_column_compatible(self, sqla_col: Column, label: Optional[str] = None) -> Column: """Takes a sql alchemy column object and adds label info if supported by engine. :param sqla_col: sql alchemy column instance :param label: alias/label that column is expected to have :return: either a sql alchemy column or label instance if supported by engine """ label_expected = label or sqla_col.name db_engine_spec = self.database.db_engine_spec if db_engine_spec.allows_column_aliases: label = db_engine_spec.make_label_compatible(label_expected) sqla_col = sqla_col.label(label) sqla_col._df_label_expected = label_expected return sqla_col def __repr__(self): return self.name @property def connection(self) -> str: return str(self.database) @property def description_markeddown(self) -> str: return utils.markdown(self.description) @property def datasource_name(self) -> str: return self.table_name @property def database_name(self) -> str: return self.database.name @classmethod def get_datasource_by_name( cls, session: Session, datasource_name: str, schema: Optional[str], database_name: str, ) -> Optional["SqlaTable"]: schema = schema or None query = (session.query(cls).join(Database).filter( cls.table_name == datasource_name).filter( Database.database_name == database_name)) # Handling schema being '' or None, which is easier to handle # in python than in the SQLA query in a multi-dialect way for tbl in query.all(): if schema == (tbl.schema or None): return tbl return None @property def link(self) -> Markup: name = escape(self.name) anchor = f'<a target="_blank" href="{self.explore_url}">{name}</a>' return Markup(anchor) @property def schema_perm(self) -> Optional[str]: """Returns schema permission if present, database one otherwise.""" return security_manager.get_schema_perm(self.database, self.schema) def get_perm(self) -> str: return ("[{obj.database}].[{obj.table_name}]" "(id:{obj.id})").format(obj=self) @property def name(self) -> str: # type: ignore if not self.schema: return self.table_name return "{}.{}".format(self.schema, self.table_name) @property def full_name(self) -> str: return utils.get_datasource_full_name(self.database, self.table_name, schema=self.schema) @property def dttm_cols(self) -> List: l = [c.column_name for c in self.columns if c.is_dttm] if self.main_dttm_col and self.main_dttm_col not in l: l.append(self.main_dttm_col) return l @property def num_cols(self) -> List: return [c.column_name for c in self.columns if c.is_num] @property def any_dttm_col(self) -> Optional[str]: cols = self.dttm_cols return cols[0] if cols else None @property def html(self) -> str: t = ((c.column_name, c.type) for c in self.columns) df = pd.DataFrame(t) df.columns = ["field", "type"] return df.to_html( index=False, classes=("dataframe table table-striped table-bordered " "table-condensed"), ) @property def sql_url(self) -> str: return self.database.sql_url + "?table_name=" + str(self.table_name) def external_metadata(self): cols = self.database.get_columns(self.table_name, schema=self.schema) for col in cols: try: col["type"] = str(col["type"]) except CompileError: col["type"] = "UNKNOWN" return cols @property def time_column_grains(self) -> Dict[str, Any]: return { "time_columns": self.dttm_cols, "time_grains": [grain.name for grain in self.database.grains()], } @property def select_star(self) -> str: # show_cols and latest_partition set to false to avoid # the expensive cost of inspecting the DB return self.database.select_star(self.table_name, schema=self.schema, show_cols=False, latest_partition=False) def get_col(self, col_name: str) -> Optional[Column]: columns = self.columns for col in columns: if col_name == col.column_name: return col return None @property def data(self) -> Dict: d = super(SqlaTable, self).data if self.type == "table": grains = self.database.grains() or [] if grains: grains = [(g.duration, g.name) for g in grains] d["granularity_sqla"] = utils.choicify(self.dttm_cols) d["time_grain_sqla"] = grains d["main_dttm_col"] = self.main_dttm_col d["fetch_values_predicate"] = self.fetch_values_predicate d["template_params"] = self.template_params return d def values_for_column(self, column_name: str, limit: int = 10000) -> List: """Runs query against sqla to retrieve some sample values for the given column. """ cols = {col.column_name: col for col in self.columns} target_col = cols[column_name] tp = self.get_template_processor() qry = (select([target_col.get_sqla_col() ]).select_from(self.get_from_clause(tp)).distinct()) if limit: qry = qry.limit(limit) if self.fetch_values_predicate: tp = self.get_template_processor() qry = qry.where(tp.process_template(self.fetch_values_predicate)) engine = self.database.get_sqla_engine() sql = "{}".format( qry.compile(engine, compile_kwargs={"literal_binds": True})) sql = self.mutate_query_from_config(sql) df = pd.read_sql_query(sql=sql, con=engine) return df[column_name].to_list() def mutate_query_from_config(self, sql: str) -> str: """Apply config's SQL_QUERY_MUTATOR Typically adds comments to the query with context""" SQL_QUERY_MUTATOR = config.get("SQL_QUERY_MUTATOR") if SQL_QUERY_MUTATOR: username = utils.get_username() sql = SQL_QUERY_MUTATOR(sql, username, security_manager, self.database) return sql def get_template_processor(self, **kwargs): return get_template_processor(table=self, database=self.database, **kwargs) def get_query_str_extended(self, query_obj: Dict) -> QueryStringExtended: sqlaq = self.get_sqla_query(**query_obj) sql = self.database.compile_sqla_query(sqlaq.sqla_query) logging.info(sql) sql = sqlparse.format(sql, reindent=True) sql = self.mutate_query_from_config(sql) return QueryStringExtended(labels_expected=sqlaq.labels_expected, sql=sql, prequeries=sqlaq.prequeries) def get_query_str(self, query_obj: Dict) -> str: query_str_ext = self.get_query_str_extended(query_obj) all_queries = query_str_ext.prequeries + [query_str_ext.sql] return ";\n\n".join(all_queries) + ";" def get_sqla_table(self): tbl = table(self.table_name) if self.schema: tbl.schema = self.schema return tbl def get_from_clause(self, template_processor=None): # Supporting arbitrary SQL statements in place of tables if self.sql: from_sql = self.sql if template_processor: from_sql = template_processor.process_template(from_sql) from_sql = sqlparse.format(from_sql, strip_comments=True) return TextAsFrom(sa.text(from_sql), []).alias("expr_qry") return self.get_sqla_table() def adhoc_metric_to_sqla(self, metric: Dict, cols: Dict) -> Optional[Column]: """ Turn an adhoc metric into a sqlalchemy column. :param dict metric: Adhoc metric definition :param dict cols: Columns for the current table :returns: The metric defined as a sqlalchemy column :rtype: sqlalchemy.sql.column """ expression_type = metric.get("expressionType") label = utils.get_metric_name(metric) if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]: column_name = metric["column"].get("column_name") table_column = cols.get(column_name) if table_column: sqla_column = table_column.get_sqla_col() else: sqla_column = column(column_name) sqla_metric = self.sqla_aggregations[metric["aggregate"]]( sqla_column) elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SQL"]: sqla_metric = literal_column(metric.get("sqlExpression")) else: return None return self.make_sqla_column_compatible(sqla_metric, label) def get_sqla_query( # sqla self, groupby, metrics, granularity, from_dttm, to_dttm, filter=None, is_timeseries=True, timeseries_limit=15, timeseries_limit_metric=None, row_limit=None, inner_from_dttm=None, inner_to_dttm=None, orderby=None, extras=None, columns=None, order_desc=True, ) -> SqlaQuery: """Querying any sqla table from this common interface""" template_kwargs = { "from_dttm": from_dttm, "groupby": groupby, "metrics": metrics, "row_limit": row_limit, "to_dttm": to_dttm, "filter": filter, "columns": {col.column_name: col for col in self.columns}, } template_kwargs.update(self.template_params_dict) extra_cache_keys: List[Any] = [] template_kwargs["extra_cache_keys"] = extra_cache_keys template_processor = self.get_template_processor(**template_kwargs) db_engine_spec = self.database.db_engine_spec prequeries: List[str] = [] orderby = orderby or [] # For backward compatibility if granularity not in self.dttm_cols: granularity = self.main_dttm_col # Database spec supports join-free timeslot grouping time_groupby_inline = db_engine_spec.time_groupby_inline cols: Dict[str, Column] = {col.column_name: col for col in self.columns} metrics_dict: Dict[str, SqlMetric] = { m.metric_name: m for m in self.metrics } if not granularity and is_timeseries: raise Exception( _("Datetime column not provided as part table configuration " "and is required by this type of chart")) if not groupby and not metrics and not columns: raise Exception(_("Empty query?")) metrics_exprs = [] for m in metrics: if utils.is_adhoc_metric(m): metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols)) elif m in metrics_dict: metrics_exprs.append(metrics_dict[m].get_sqla_col()) else: raise Exception( _("Metric '%(metric)s' does not exist", metric=m)) if metrics_exprs: main_metric_expr = metrics_exprs[0] else: main_metric_expr, label = literal_column("COUNT(*)"), "ccount" main_metric_expr = self.make_sqla_column_compatible( main_metric_expr, label) select_exprs: List[Column] = [] groupby_exprs_sans_timestamp: OrderedDict = OrderedDict() if groupby: select_exprs = [] for s in groupby: if s in cols: outer = cols[s].get_sqla_col() else: outer = literal_column(f"({s})") outer = self.make_sqla_column_compatible(outer, s) groupby_exprs_sans_timestamp[outer.name] = outer select_exprs.append(outer) elif columns: for s in columns: select_exprs.append( cols[s].get_sqla_col() if s in cols else self. make_sqla_column_compatible(literal_column(s))) metrics_exprs = [] groupby_exprs_with_timestamp = OrderedDict( groupby_exprs_sans_timestamp.items()) if granularity: dttm_col = cols[granularity] time_grain = extras.get("time_grain_sqla") time_filters = [] if is_timeseries: timestamp = dttm_col.get_timestamp_expression(time_grain) select_exprs += [timestamp] groupby_exprs_with_timestamp[timestamp.name] = timestamp # Use main dttm column to support index with secondary dttm columns if (db_engine_spec.time_secondary_columns and self.main_dttm_col in self.dttm_cols and self.main_dttm_col != dttm_col.column_name): time_filters.append(cols[self.main_dttm_col].get_time_filter( from_dttm, to_dttm)) time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm)) select_exprs += metrics_exprs labels_expected = [c._df_label_expected for c in select_exprs] select_exprs = db_engine_spec.make_select_compatible( groupby_exprs_with_timestamp.values(), select_exprs) qry = sa.select(select_exprs) tbl = self.get_from_clause(template_processor) if not columns: qry = qry.group_by(*groupby_exprs_with_timestamp.values()) where_clause_and = [] having_clause_and: List = [] for flt in filter: if not all([flt.get(s) for s in ["col", "op"]]): continue col = flt["col"] op = flt["op"] col_obj = cols.get(col) if col_obj: is_list_target = op in ("in", "not in") eq = self.filter_values_handler( flt.get("val"), target_column_is_numeric=col_obj.is_num, is_list_target=is_list_target, ) if op in ("in", "not in"): cond = col_obj.get_sqla_col().in_(eq) if "<NULL>" in eq: cond = or_(cond, col_obj.get_sqla_col() == None) if op == "not in": cond = ~cond where_clause_and.append(cond) else: if col_obj.is_num: eq = utils.string_to_num(flt["val"]) if op == "==": where_clause_and.append(col_obj.get_sqla_col() == eq) elif op == "!=": where_clause_and.append(col_obj.get_sqla_col() != eq) elif op == ">": where_clause_and.append(col_obj.get_sqla_col() > eq) elif op == "<": where_clause_and.append(col_obj.get_sqla_col() < eq) elif op == ">=": where_clause_and.append(col_obj.get_sqla_col() >= eq) elif op == "<=": where_clause_and.append(col_obj.get_sqla_col() <= eq) elif op == "LIKE": where_clause_and.append( col_obj.get_sqla_col().like(eq)) elif op == "IS NULL": where_clause_and.append(col_obj.get_sqla_col() == None) elif op == "IS NOT NULL": where_clause_and.append(col_obj.get_sqla_col() != None) if extras: where = extras.get("where") if where: where = template_processor.process_template(where) where_clause_and += [sa.text("({})".format(where))] having = extras.get("having") if having: having = template_processor.process_template(having) having_clause_and += [sa.text("({})".format(having))] if granularity: qry = qry.where(and_(*(time_filters + where_clause_and))) else: qry = qry.where(and_(*where_clause_and)) qry = qry.having(and_(*having_clause_and)) if not orderby and not columns: orderby = [(main_metric_expr, not order_desc)] for col, ascending in orderby: direction = asc if ascending else desc if utils.is_adhoc_metric(col): col = self.adhoc_metric_to_sqla(col, cols) elif col in cols: col = cols[col].get_sqla_col() qry = qry.order_by(direction(col)) if row_limit: qry = qry.limit(row_limit) if is_timeseries and timeseries_limit and groupby and not time_groupby_inline: if self.database.db_engine_spec.allows_joins: # some sql dialects require for order by expressions # to also be in the select clause -- others, e.g. vertica, # require a unique inner alias inner_main_metric_expr = self.make_sqla_column_compatible( main_metric_expr, "mme_inner__") inner_groupby_exprs = [] inner_select_exprs = [] for gby_name, gby_obj in groupby_exprs_sans_timestamp.items(): inner = self.make_sqla_column_compatible( gby_obj, gby_name + "__") inner_groupby_exprs.append(inner) inner_select_exprs.append(inner) inner_select_exprs += [inner_main_metric_expr] subq = select(inner_select_exprs).select_from(tbl) inner_time_filter = dttm_col.get_time_filter( inner_from_dttm or from_dttm, inner_to_dttm or to_dttm) subq = subq.where( and_(*(where_clause_and + [inner_time_filter]))) subq = subq.group_by(*inner_groupby_exprs) ob = inner_main_metric_expr if timeseries_limit_metric: ob = self._get_timeseries_orderby(timeseries_limit_metric, metrics_dict, cols) direction = desc if order_desc else asc subq = subq.order_by(direction(ob)) subq = subq.limit(timeseries_limit) on_clause = [] for gby_name, gby_obj in groupby_exprs_sans_timestamp.items(): # in this case the column name, not the alias, needs to be # conditionally mutated, as it refers to the column alias in # the inner query col_name = db_engine_spec.make_label_compatible(gby_name + "__") on_clause.append(gby_obj == column(col_name)) tbl = tbl.join(subq.alias(), and_(*on_clause)) else: if timeseries_limit_metric: orderby = [( self._get_timeseries_orderby(timeseries_limit_metric, metrics_dict, cols), False, )] # run prequery to get top groups prequery_obj = { "is_timeseries": False, "row_limit": timeseries_limit, "groupby": groupby, "metrics": metrics, "granularity": granularity, "from_dttm": inner_from_dttm or from_dttm, "to_dttm": inner_to_dttm or to_dttm, "filter": filter, "orderby": orderby, "extras": extras, "columns": columns, "order_desc": True, } result = self.query(prequery_obj) prequeries.append(result.query) dimensions = [ c for c in result.df.columns if c not in metrics and c in groupby_exprs_sans_timestamp ] top_groups = self._get_top_groups( result.df, dimensions, groupby_exprs_sans_timestamp) qry = qry.where(top_groups) return SqlaQuery( extra_cache_keys=extra_cache_keys, labels_expected=labels_expected, sqla_query=qry.select_from(tbl), prequeries=prequeries, ) def _get_timeseries_orderby(self, timeseries_limit_metric, metrics_dict, cols): if utils.is_adhoc_metric(timeseries_limit_metric): ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols) elif timeseries_limit_metric in metrics_dict: timeseries_limit_metric = metrics_dict.get(timeseries_limit_metric) ob = timeseries_limit_metric.get_sqla_col() else: raise Exception( _("Metric '%(metric)s' does not exist", metric=timeseries_limit_metric)) return ob def _get_top_groups(self, df: pd.DataFrame, dimensions: List, groupby_exprs: OrderedDict) -> ColumnElement: groups = [] for unused, row in df.iterrows(): group = [] for dimension in dimensions: group.append(groupby_exprs[dimension] == row[dimension]) groups.append(and_(*group)) return or_(*groups) def query(self, query_obj: Dict) -> QueryResult: qry_start_dttm = datetime.now() query_str_ext = self.get_query_str_extended(query_obj) sql = query_str_ext.sql status = utils.QueryStatus.SUCCESS error_message = None def mutator(df): labels_expected = query_str_ext.labels_expected if df is not None and not df.empty: if len(df.columns) != len(labels_expected): raise Exception(f"For {sql}, df.columns: {df.columns}" f" differs from {labels_expected}") else: df.columns = labels_expected return df try: df = self.database.get_df(sql, self.schema, mutator) except Exception as e: df = None status = utils.QueryStatus.FAILED logging.exception(f"Query {sql} on schema {self.schema} failed") db_engine_spec = self.database.db_engine_spec error_message = db_engine_spec.extract_error_message(e) return QueryResult( status=status, df=df, duration=datetime.now() - qry_start_dttm, query=sql, error_message=error_message, ) def get_sqla_table_object(self) -> Table: return self.database.get_table(self.table_name, schema=self.schema) def fetch_metadata(self) -> None: """Fetches the metadata for the table and merges it in""" try: table = self.get_sqla_table_object() except Exception as e: logging.exception(e) raise Exception( _("Table [{}] doesn't seem to exist in the specified database, " "couldn't fetch column information").format(self.table_name)) M = SqlMetric metrics = [] any_date_col = None db_engine_spec = self.database.db_engine_spec db_dialect = self.database.get_dialect() dbcols = (db.session.query(TableColumn).filter( TableColumn.table == self).filter( or_(TableColumn.column_name == col.name for col in table.columns))) dbcols = {dbcol.column_name: dbcol for dbcol in dbcols} for col in table.columns: try: datatype = db_engine_spec.column_datatype_to_string( col.type, db_dialect) except Exception as e: datatype = "UNKNOWN" logging.error("Unrecognized data type in {}.{}".format( table, col.name)) logging.exception(e) dbcol = dbcols.get(col.name, None) if not dbcol: dbcol = TableColumn(column_name=col.name, type=datatype) dbcol.sum = dbcol.is_num dbcol.avg = dbcol.is_num dbcol.is_dttm = dbcol.is_time db_engine_spec.alter_new_orm_column(dbcol) else: dbcol.type = datatype dbcol.groupby = True dbcol.filterable = True self.columns.append(dbcol) if not any_date_col and dbcol.is_time: any_date_col = col.name metrics.append( M( metric_name="count", verbose_name="COUNT(*)", metric_type="count", expression="COUNT(*)", )) if not self.main_dttm_col: self.main_dttm_col = any_date_col self.add_missing_metrics(metrics) db.session.merge(self) db.session.commit() @classmethod def import_obj(cls, i_datasource, import_time=None) -> int: """Imports the datasource from the object to the database. Metrics and columns and datasource will be overrided if exists. This function can be used to import/export dashboards between multiple superset instances. Audit metadata isn't copies over. """ def lookup_sqlatable(table): return (db.session.query(SqlaTable).join(Database).filter( SqlaTable.table_name == table.table_name, SqlaTable.schema == table.schema, Database.id == table.database_id, ).first()) def lookup_database(table): try: return (db.session.query(Database).filter_by( database_name=table.params_dict["database_name"]).one()) except NoResultFound: raise DatabaseNotFound( _( "Database '%(name)s' is not found", name=table.params_dict["database_name"], )) return import_datasource.import_datasource(db.session, i_datasource, lookup_database, lookup_sqlatable, import_time) @classmethod def query_datasources_by_name(cls, session: Session, database: Database, datasource_name: str, schema=None) -> List["SqlaTable"]: query = (session.query(cls).filter_by( database_id=database.id).filter_by(table_name=datasource_name)) if schema: query = query.filter_by(schema=schema) return query.all() @staticmethod def default_query(qry) -> Query: return qry.filter_by(is_sqllab_view=False) def has_extra_cache_keys(self, query_obj: Dict) -> bool: """ Detects the presence of calls to cache_key_wrapper in items in query_obj that can be templated. :param query_obj: query object to analyze :return: True if at least one item calls cache_key_wrapper, otherwise False """ regex = re.compile(r"\{\{.*cache_key_wrapper\(.*\).*\}\}") templatable_statements: List[str] = [] if self.sql: templatable_statements.append(self.sql) if self.fetch_values_predicate: templatable_statements.append(self.fetch_values_predicate) extras = query_obj.get("extras", {}) if "where" in extras: templatable_statements.append(extras["where"]) if "having" in extras: templatable_statements.append(extras["having"]) for statement in templatable_statements: if regex.search(statement): return True return False def get_extra_cache_keys(self, query_obj: Dict) -> List[Any]: if self.has_extra_cache_keys(query_obj): sqla_query = self.get_sqla_query(**query_obj) extra_cache_keys = sqla_query.extra_cache_keys return extra_cache_keys return []
class ImapFolderSyncStatus(MailSyncBase): """ Per-folder status state saving for IMAP folders. """ account_id = Column(ForeignKey(ImapAccount.id, ondelete='CASCADE'), nullable=False) account = relationship(ImapAccount, backref=backref('foldersyncstatuses')) folder_id = Column(Integer, ForeignKey('folder.id', ondelete='CASCADE'), nullable=False) # We almost always need the folder name too, so eager load by default. folder = relationship('Folder', lazy='joined', backref=backref('imapsyncstatus', passive_deletes=True)) # see state machine in mailsync/backends/imap/imap.py state = Column(Enum('initial', 'initial uidinvalid', 'poll', 'poll uidinvalid', 'finish'), server_default='initial', nullable=False) # stats on messages downloaded etc. _metrics = Column(MutableDict.as_mutable(JSON), default={}, nullable=True) @property def metrics(self): status = dict(name=self.folder.name, state=self.state) status.update(self._metrics or {}) return status def start_sync(self): self._metrics = dict(run_state='running', sync_start_time=datetime.utcnow()) def stop_sync(self): self._metrics['run_state'] = 'stopped' self._metrics['sync_end_time'] = datetime.utcnow() def kill_sync(self, error=None): self._metrics['run_state'] = 'killed' self._metrics['sync_end_time'] = datetime.utcnow() self._metrics['sync_error'] = error @property def is_killed(self): return self._metrics.get('run_state') == 'killed' def update_metrics(self, metrics): sync_status_metrics = [ 'remote_uid_count', 'delete_uid_count', 'update_uid_count', 'download_uid_count', 'uid_checked_timestamp', 'num_downloaded_since_timestamp', 'queue_checked_at', 'percent' ] assert isinstance(metrics, dict) for k in metrics.iterkeys(): assert k in sync_status_metrics, k if self._metrics is not None: self._metrics.update(metrics) else: self._metrics = metrics __table_args__ = (UniqueConstraint('account_id', 'folder_id'), )
class TableColumn(Model, BaseColumn): """ORM object for table columns, each table can have multiple columns""" __tablename__ = 'table_columns' __table_args__ = (UniqueConstraint('table_id', 'column_name'), ) table_id = Column(Integer, ForeignKey('tables.id')) table = relationship('SqlaTable', backref=backref('columns', cascade='all, delete-orphan'), foreign_keys=[table_id]) is_dttm = Column(Boolean, default=False) expression = Column(Text, default='') python_date_format = Column(String(255)) database_expression = Column(String(255)) export_fields = ( 'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active', 'type', 'groupby', 'filterable', 'expression', 'description', 'python_date_format', 'database_expression', ) update_from_object_fields = [ s for s in export_fields if s not in ('table_id', ) ] export_parent = 'table' def get_sqla_col(self, label=None): label = label if label else self.column_name label = self.table.get_label(label) if not self.expression: db_engine_spec = self.table.database.db_engine_spec type_ = db_engine_spec.get_sqla_column_type(self.type) col = column(self.column_name, type_=type_).label(label) else: col = literal_column(self.expression).label(label) return col @property def datasource(self): return self.table def get_time_filter(self, start_dttm, end_dttm): is_epoch_in_utc = config.get('IS_EPOCH_S_TRULY_UTC', False) col = self.get_sqla_col(label='__time') l = [] # noqa: E741 if start_dttm: l.append( col >= text(self.dttm_sql_literal(start_dttm, is_epoch_in_utc)) ) if end_dttm: l.append( col <= text(self.dttm_sql_literal(end_dttm, is_epoch_in_utc))) return and_(*l) def get_timestamp_expression(self, time_grain): """Getting the time component of the query""" label = self.table.get_label(utils.DTTM_ALIAS) db = self.table.database pdf = self.python_date_format is_epoch = pdf in ('epoch_s', 'epoch_ms') if not self.expression and not time_grain and not is_epoch: return column(self.column_name, type_=DateTime).label(label) grain = None if time_grain: grain = db.grains_dict().get(time_grain) if not grain: raise NotImplementedError( f'No grain spec for {time_grain} for database {db.database_name}' ) expr = db.db_engine_spec.get_time_expr( self.expression or self.column_name, pdf, time_grain, grain) return literal_column(expr, type_=DateTime).label(label) @classmethod def import_obj(cls, i_column): def lookup_obj(lookup_column): return db.session.query(TableColumn).filter( TableColumn.table_id == lookup_column.table_id, TableColumn.column_name == lookup_column.column_name).first() return import_datasource.import_simple_obj(db.session, i_column, lookup_obj) def dttm_sql_literal(self, dttm, is_epoch_in_utc): """Convert datetime object to a SQL expression string If database_expression is empty, the internal dttm will be parsed as the string with the pattern that the user inputted (python_date_format) If database_expression is not empty, the internal dttm will be parsed as the sql sentence for the database to convert """ tf = self.python_date_format if self.database_expression: return self.database_expression.format( dttm.strftime('%Y-%m-%d %H:%M:%S')) elif tf: if is_epoch_in_utc: seconds_since_epoch = dttm.timestamp() else: seconds_since_epoch = (dttm - datetime(1970, 1, 1)).total_seconds() seconds_since_epoch = int(seconds_since_epoch) if tf == 'epoch_s': return str(seconds_since_epoch) elif tf == 'epoch_ms': return str(seconds_since_epoch * 1000) return "'{}'".format(dttm.strftime(tf)) else: s = self.table.database.db_engine_spec.convert_dttm( self.type or '', dttm) return s or "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S.%f'))
class IBMIKEPolicy(db.Model): ID_KEY = "id" NAME_KEY = "name" DH_GROUP_KEY = "dh_group" ENCRYPTION_ALGORITHM_KEY = "encryption_algorithm" AUTHENTICATION_ALGORITHM_KEY = "authentication_algorithm" KEY_LIFETIME_KEY = "key_lifetime" IKE_VERSION_KEY = "ike_version" RESOURCE_GROUP_KEY = "resource_group" STATUS_KEY = "status" REGION_KEY = "region" CLOUD_ID_KEY = "cloud_id" MESSAGE_KEY = "message" __tablename__ = "ibm_ike_policy" id = Column(String(32), primary_key=True) resource_id = Column(String(64)) name = Column(String(255), nullable=False) region = Column(String(255), nullable=False) status = Column(String(50)) ike_version = Column(Integer, default=2) key_lifetime = Column(Integer, default=28800) authentication_algorithm = Column(Enum("md5", "sha1", "sha256"), default="sha1") encryption_algorithm = Column(Enum("triple_des", "aes128", "aes256"), default="aes128") dh_group = Column(Integer, default=2) cloud_id = Column(String(32), ForeignKey("ibm_clouds.id")) resource_group_id = Column(String(32), ForeignKey("ibm_resource_groups.id")) vpn_connections = relationship("IBMVpnConnection", backref="ibm_ike_policy", lazy="dynamic") __table_args__ = (UniqueConstraint( name, region, cloud_id, name="uix_ibm_ike_policy_name_region_cloud_id"), ) def __init__( self, name, region, key_lifetime=None, status=CREATION_PENDING, ike_version=None, authentication_algorithm=None, encryption_algorithm=None, dh_group=None, resource_id=None, cloud_id=None, ): self.id = uuid.uuid4().hex self.name = name self.region = region self.status = status self.ike_version = ike_version self.authentication_algorithm = authentication_algorithm self.encryption_algorithm = encryption_algorithm self.key_lifetime = key_lifetime self.dh_group = dh_group self.resource_id = resource_id self.cloud_id = cloud_id def make_copy(self): obj = IBMIKEPolicy( name=self.name, region=self.region, key_lifetime=self.key_lifetime, status=self.status, ike_version=self.ike_version, authentication_algorithm=self.authentication_algorithm, encryption_algorithm=self.encryption_algorithm, dh_group=self.dh_group, resource_id=self.resource_id, cloud_id=self.cloud_id, ) if self.ibm_resource_group: obj.ibm_resource_group = self.ibm_resource_group.make_copy() return obj def get_existing_from_db(self): return (db.session.query( self.__class__).filter_by(name=self.name, cloud_id=self.cloud_id, region=self.region).first()) def params_eq(self, other): if not type(self) == type(other): return False if not ((self.name == other.name) and (self.region == other.region) and (self.resource_id == other.resource_id) and (self.ike_version == other.ike_version) and (self.key_lifetime == other.key_lifetime) and (self.authentication_algorithm == other.authentication_algorithm) and (self.encryption_algorithm == other.encryption_algorithm) and (self.dh_group == other.dh_group) and (self.status == other.status)): return False if (self.ibm_resource_group and not other.ibm_resource_group) or ( not self.ibm_resource_group and other.ibm_resource_group): return False if self.ibm_resource_group and other.ibm_resource_group: if not self.ibm_resource_group.params_eq(other.ibm_resource_group): return False return True def add_update_db(self): existing = self.get_existing_from_db() if not existing: ibm_resource_group = self.ibm_resource_group self.ibm_resource_group = None db.session.add(self) db.session.commit() if ibm_resource_group: self.ibm_resource_group = ibm_resource_group.add_update_db() db.session.commit() return self if not self.params_eq(existing): existing.name = self.name existing.region = self.region existing.status = self.status existing.resource_id = self.resource_id existing.ike_version = self.ike_version existing.key_lifetime = self.key_lifetime existing.authentication_algorithm = self.authentication_algorithm existing.encryption_algorithm = self.encryption_algorithm existing.dh_group = self.dh_group db.session.commit() ibm_resource_group = self.ibm_resource_group self.ibm_resource_group = None if ibm_resource_group: existing.ibm_resource_group = ibm_resource_group.add_update_db( ) else: existing.ibm_resource_group = None db.session.commit() return existing def to_json(self): return { self.ID_KEY: self.id, self.NAME_KEY: self.name, self.REGION_KEY: self.region, self.AUTHENTICATION_ALGORITHM_KEY: self.authentication_algorithm, self.ENCRYPTION_ALGORITHM_KEY: self.encryption_algorithm, self.KEY_LIFETIME_KEY: self.key_lifetime, self.IKE_VERSION_KEY: self.ike_version, self.DH_GROUP_KEY: self.dh_group, self.STATUS_KEY: self.status, self.RESOURCE_GROUP_KEY: self.ibm_resource_group.to_json() if self.ibm_resource_group else "", self.CLOUD_ID_KEY: self.cloud_id, } def to_json_body(self): return { "name": self.name, "authentication_algorithm": self.authentication_algorithm, "encryption_algorithm": self.encryption_algorithm, "key_lifetime": self.key_lifetime, "ike_version": self.ike_version, "dh_group": self.dh_group, "resource_group": { "id": self.ibm_resource_group.resource_id if self.ibm_resource_group else "" }, } def to_report_json(self): return { self.ID_KEY: self.id, self.NAME_KEY: self.name, self.STATUS_KEY: PENDING, self.MESSAGE_KEY: "" } @classmethod def from_ibm_json_body(cls, region, json_body): # TODO: Verify Schema ibm_ike_policy = IBMIKEPolicy( name=json_body["name"], region=region, key_lifetime=json_body["key_lifetime"], status="CREATED", ike_version=json_body["ike_version"], authentication_algorithm=json_body["authentication_algorithm"], encryption_algorithm=json_body["encryption_algorithm"], dh_group=json_body["dh_group"], resource_id=json_body["id"], ) return ibm_ike_policy
class UserTestResult(Base): """Class to store the execution results of a user_test. """ __tablename__ = 'user_test_results' __table_args__ = (UniqueConstraint('user_test_id', 'dataset_id'), ) # Primary key is (user_test_id, dataset_id). user_test_id = Column(Integer, ForeignKey(UserTest.id, onupdate="CASCADE", ondelete="CASCADE"), primary_key=True) user_test = relationship(UserTest, backref=backref("results", cascade="all, delete-orphan", passive_deletes=True)) dataset_id = Column(Integer, ForeignKey(Dataset.id, onupdate="CASCADE", ondelete="CASCADE"), primary_key=True) dataset = relationship(Dataset) # Now below follow the actual result fields. # Output file's digest for this test output = Column(String, nullable=True) # Compilation outcome (can be None = yet to compile, "ok" = # compilation successful and we can evaluate, "fail" = # compilation unsuccessful, throw it away). compilation_outcome = Column(String, nullable=True) # String containing output from the sandbox. compilation_text = Column(String, nullable=True) # Number of attempts of compilation. compilation_tries = Column(Integer, nullable=False, default=0) # The compiler stdout and stderr. compilation_stdout = Column(Unicode, nullable=True) compilation_stderr = Column(Unicode, nullable=True) # Other information about the compilation. compilation_time = Column(Float, nullable=True) compilation_wall_clock_time = Column(Float, nullable=True) compilation_memory = Column(Integer, nullable=True) # Worker shard and sandbox where the compilation was performed. compilation_shard = Column(Integer, nullable=True) compilation_sandbox = Column(String, nullable=True) # Evaluation outcome (can be None = yet to evaluate, "ok" = # evaluation successful). evaluation_outcome = Column(String, nullable=True) evaluation_text = Column(String, nullable=True) # Number of attempts of evaluation. evaluation_tries = Column(Integer, nullable=False, default=0) # Other information about the execution. execution_time = Column(Float, nullable=True) execution_wall_clock_time = Column(Float, nullable=True) execution_memory = Column(Integer, nullable=True) # Worker shard and sandbox where the evaluation was performed. evaluation_shard = Column(Integer, nullable=True) evaluation_sandbox = Column(String, nullable=True) # Follows the description of the fields automatically added by # SQLAlchemy. # executables (dict of UserTestExecutable objects indexed by filename) def compiled(self): """Return whether the user test result has been compiled. return (bool): True if compiled, False otherwise. """ return self.compilation_outcome is not None def compilation_failed(self): """Return whether the user test result did not compile. return (bool): True if the compilation failed (in the sense that there is a problem in the user's source), False if not yet compiled or compilation was successful. """ return self.compilation_outcome == "fail" def compilation_succeeded(self): """Return whether the user test compiled. return (bool): True if the compilation succeeded (in the sense that an executable was created), False if not yet compiled or compilation was unsuccessful. """ return self.compilation_outcome == "ok" def evaluated(self): """Return whether the user test result has been evaluated. return (bool): True if evaluated, False otherwise. """ return self.evaluation_outcome is not None def invalidate_compilation(self): """Blank all compilation and evaluation outcomes. """ self.invalidate_evaluation() self.compilation_outcome = None self.compilation_text = None self.compilation_tries = 0 self.compilation_time = None self.compilation_wall_clock_time = None self.compilation_memory = None self.compilation_shard = None self.compilation_sandbox = None self.executables = {} def invalidate_evaluation(self): """Blank the evaluation outcome. """ self.evaluation_outcome = None self.evaluation_text = None self.evaluation_tries = 0 self.execution_time = None self.execution_wall_clock_time = None self.execution_memory = None self.evaluation_shard = None self.evaluation_sandbox = None self.output = None def set_compilation_outcome(self, success): """Set the compilation outcome based on the success. success (bool): if the compilation was successful. """ self.compilation_outcome = "ok" if success else "fail" def set_evaluation_outcome(self): """Set the evaluation outcome (always ok now). """ self.evaluation_outcome = "ok"