Ejemplo n.º 1
0
class Relation(Base):
    """
  Container class for event relations
  """
    event_id = Column('event_id',
                      BigInteger,
                      ForeignKey('events.event_id',
                                 onupdate='cascade',
                                 ondelete='cascade'),
                      nullable=False,
                      index=True)
    event = relationship("Event",
                         uselist=False,
                         primaryjoin='Event.identifier==Relation.event_id')
    rel_event_id = Column('rel_event_id',
                          BigInteger,
                          ForeignKey('events.event_id',
                                     onupdate='cascade',
                                     ondelete='cascade'),
                          nullable=False,
                          index=True)
    rel_event = relationship(
        "Event",
        uselist=False,
        primaryjoin='Event.identifier==Relation.rel_event_id')
    attribute_id = Column('attribute_id',
                          BigInteger,
                          ForeignKey('attributes.attribute_id',
                                     onupdate='cascade',
                                     ondelete='cascade'),
                          nullable=False,
                          index=True)
    attribute = relationship(
        "Attribute",
        uselist=False,
        primaryjoin='Attribute.identifier==Relation.attribute_id')
    rel_attribute_id = Column('rel_attribute_id',
                              BigInteger,
                              ForeignKey('attributes.attribute_id',
                                         onupdate='cascade',
                                         ondelete='cascade'),
                              nullable=False,
                              index=True)
    rel_attribute = relationship(
        'Attribute',
        uselist=False,
        primaryjoin='Attribute.identifier==Relation.rel_attribute_id')

    UniqueConstraint('event_id', 'attribute_id', 'rel_event_id',
                     'rel_attribute_id')

    def to_dict(self,
                complete=True,
                inflated=False,
                event_permissions=None,
                user=None):
        return {
            'event':
            self.event.to_dict(complete, inflated, event_permissions, user),
            'rel_event':
            self.rel_event.to_dict(complete, inflated, event_permissions,
                                   user),
            'attribute':
            self.attribute.to_dict(complete, inflated, event_permissions,
                                   user),
            'rel_attribute':
            self.rel_attribute.to_dict(complete, inflated, event_permissions,
                                       user),
        }

    def validate(self):
        """
    Returns true if the object is valid
    """
        return True
Ejemplo n.º 2
0
class Database(Model, AuditMixinNullable, ImportMixin):
    """An ORM object that stores Database related information"""

    __tablename__ = "dbs"
    type = "table"
    __table_args__ = (UniqueConstraint("database_name"), )

    id = Column(Integer, primary_key=True)
    verbose_name = Column(String(250), unique=True)
    # short unique name, used in permissions
    database_name = Column(String(250), unique=True)
    sqlalchemy_uri = Column(String(1024))
    password = Column(EncryptedType(String(1024), config.get("SECRET_KEY")))
    cache_timeout = Column(Integer)
    select_as_create_table_as = Column(Boolean, default=False)
    expose_in_sqllab = Column(Boolean, default=True)
    allow_run_async = Column(Boolean, default=False)
    allow_csv_upload = Column(Boolean, default=False)
    allow_ctas = Column(Boolean, default=False)
    allow_dml = Column(Boolean, default=False)
    force_ctas_schema = Column(String(250))
    allow_multi_schema_metadata_fetch = Column(Boolean, default=False)
    extra = Column(
        Text,
        default=textwrap.dedent("""\
    {
        "metadata_params": {},
        "engine_params": {},
        "metadata_cache_timeout": {},
        "schemas_allowed_for_csv_upload": []
    }
    """),
    )
    perm = Column(String(1000))
    impersonate_user = Column(Boolean, default=False)
    export_fields = (
        "database_name",
        "sqlalchemy_uri",
        "cache_timeout",
        "expose_in_sqllab",
        "allow_run_async",
        "allow_ctas",
        "allow_csv_upload",
        "extra",
    )
    export_children = ["tables"]

    def __repr__(self):
        return self.verbose_name if self.verbose_name else self.database_name

    @property
    def name(self):
        return self.verbose_name if self.verbose_name else self.database_name

    @property
    def allows_subquery(self):
        return self.db_engine_spec.allows_subquery

    @property
    def data(self):
        return {
            "id": self.id,
            "name": self.database_name,
            "backend": self.backend,
            "allow_multi_schema_metadata_fetch":
            self.allow_multi_schema_metadata_fetch,
            "allows_subquery": self.allows_subquery,
        }

    @property
    def unique_name(self):
        return self.database_name

    @property
    def url_object(self):
        return make_url(self.sqlalchemy_uri_decrypted)

    @property
    def backend(self):
        url = make_url(self.sqlalchemy_uri_decrypted)
        return url.get_backend_name()

    @property
    def metadata_cache_timeout(self):
        return self.get_extra().get("metadata_cache_timeout", {})

    @property
    def schema_cache_enabled(self):
        return "schema_cache_timeout" in self.metadata_cache_timeout

    @property
    def schema_cache_timeout(self):
        return self.metadata_cache_timeout.get("schema_cache_timeout")

    @property
    def table_cache_enabled(self):
        return "table_cache_timeout" in self.metadata_cache_timeout

    @property
    def table_cache_timeout(self):
        return self.metadata_cache_timeout.get("table_cache_timeout")

    @property
    def default_schemas(self):
        return self.get_extra().get("default_schemas", [])

    @classmethod
    def get_password_masked_url_from_uri(cls, uri):
        url = make_url(uri)
        return cls.get_password_masked_url(url)

    @classmethod
    def get_password_masked_url(cls, url):
        url_copy = deepcopy(url)
        if url_copy.password is not None and url_copy.password != PASSWORD_MASK:
            url_copy.password = PASSWORD_MASK
        return url_copy

    def set_sqlalchemy_uri(self, uri):
        conn = sqla.engine.url.make_url(uri.strip())
        if conn.password != PASSWORD_MASK and not custom_password_store:
            # do not over-write the password with the password mask
            self.password = conn.password
        conn.password = PASSWORD_MASK if conn.password else None
        self.sqlalchemy_uri = str(conn)  # hides the password

    def get_effective_user(self, url, user_name=None):
        """
        Get the effective user, especially during impersonation.
        :param url: SQL Alchemy URL object
        :param user_name: Default username
        :return: The effective username
        """
        effective_username = None
        if self.impersonate_user:
            effective_username = url.username
            if user_name:
                effective_username = user_name
            elif (hasattr(g, "user") and hasattr(g.user, "username")
                  and g.user.username is not None):
                effective_username = g.user.username
        return effective_username

    @utils.memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted",
                           "extra"))
    def get_sqla_engine(self,
                        schema=None,
                        nullpool=True,
                        user_name=None,
                        source=None):
        extra = self.get_extra()
        url = make_url(self.sqlalchemy_uri_decrypted)
        url = self.db_engine_spec.adjust_database_uri(url, schema)
        effective_username = self.get_effective_user(url, user_name)
        # If using MySQL or Presto for example, will set url.username
        # If using Hive, will not do anything yet since that relies on a
        # configuration parameter instead.
        self.db_engine_spec.modify_url_for_impersonation(
            url, self.impersonate_user, effective_username)

        masked_url = self.get_password_masked_url(url)
        logging.info(
            "Database.get_sqla_engine(). Masked URL: {0}".format(masked_url))

        params = extra.get("engine_params", {})
        if nullpool:
            params["poolclass"] = NullPool

        # If using Hive, this will set hive.server2.proxy.user=$effective_username
        configuration = {}
        configuration.update(
            self.db_engine_spec.get_configuration_for_impersonation(
                str(url), self.impersonate_user, effective_username))
        if configuration:
            d = params.get("connect_args", {})
            d["configuration"] = configuration
            params["connect_args"] = d

        DB_CONNECTION_MUTATOR = config.get("DB_CONNECTION_MUTATOR")
        if DB_CONNECTION_MUTATOR:
            url, params = DB_CONNECTION_MUTATOR(url, params,
                                                effective_username,
                                                security_manager, source)
        return create_engine(url, **params)

    def get_reserved_words(self):
        return self.get_dialect().preparer.reserved_words

    def get_quoter(self):
        return self.get_dialect().identifier_preparer.quote

    def get_df(self, sql, schema, mutator=None):
        sqls = [str(s).strip().strip(";") for s in sqlparse.parse(sql)]
        source_key = None
        if request and request.referrer:
            if "/superset/dashboard/" in request.referrer:
                source_key = "dashboard"
            elif "/superset/explore/" in request.referrer:
                source_key = "chart"
        engine = self.get_sqla_engine(schema=schema,
                                      source=utils.sources.get(
                                          source_key, None))
        username = utils.get_username()

        def needs_conversion(df_series):
            if df_series.empty:
                return False
            if isinstance(df_series[0], (list, dict)):
                return True
            return False

        def _log_query(sql):
            if log_query:
                log_query(engine.url, sql, schema, username, __name__,
                          security_manager)

        with closing(engine.raw_connection()) as conn:
            with closing(conn.cursor()) as cursor:
                for sql in sqls[:-1]:
                    _log_query(sql)
                    self.db_engine_spec.execute(cursor, sql)
                    cursor.fetchall()

                _log_query(sqls[-1])
                self.db_engine_spec.execute(cursor, sqls[-1])

                if cursor.description is not None:
                    columns = [col_desc[0] for col_desc in cursor.description]
                else:
                    columns = []

                df = pd.DataFrame.from_records(data=list(cursor.fetchall()),
                                               columns=columns,
                                               coerce_float=True)

                if mutator:
                    df = mutator(df)

                for k, v in df.dtypes.items():
                    if v.type == numpy.object_ and needs_conversion(df[k]):
                        df[k] = df[k].apply(utils.json_dumps_w_dates)
                return df

    def compile_sqla_query(self, qry, schema=None):
        engine = self.get_sqla_engine(schema=schema)

        sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))

        if engine.dialect.identifier_preparer._double_percents:
            sql = sql.replace("%%", "%")

        return sql

    def select_star(
        self,
        table_name,
        schema=None,
        limit=100,
        show_cols=False,
        indent=True,
        latest_partition=False,
        cols=None,
    ):
        """Generates a ``select *`` statement in the proper dialect"""
        eng = self.get_sqla_engine(schema=schema,
                                   source=utils.sources.get("sql_lab", None))
        return self.db_engine_spec.select_star(
            self,
            table_name,
            schema=schema,
            engine=eng,
            limit=limit,
            show_cols=show_cols,
            indent=indent,
            latest_partition=latest_partition,
            cols=cols,
        )

    def apply_limit_to_sql(self, sql, limit=1000):
        return self.db_engine_spec.apply_limit_to_sql(sql, limit, self)

    def safe_sqlalchemy_uri(self):
        return self.sqlalchemy_uri

    @property
    def inspector(self):
        engine = self.get_sqla_engine()
        return sqla.inspect(engine)

    @cache_util.memoized_func(
        key=lambda *args, **kwargs: "db:{}:schema:None:table_list",
        attribute_in_key="id",
    )
    def get_all_table_names_in_database(
            self,
            cache: bool = False,
            cache_timeout: bool = None,
            force=False) -> List[utils.DatasourceName]:
        """Parameters need to be passed as keyword arguments."""
        if not self.allow_multi_schema_metadata_fetch:
            return []
        return self.db_engine_spec.get_all_datasource_names(self, "table")

    @cache_util.memoized_func(
        key=lambda *args, **kwargs: "db:{}:schema:None:view_list",
        attribute_in_key="id")
    def get_all_view_names_in_database(
            self,
            cache: bool = False,
            cache_timeout: bool = None,
            force: bool = False) -> List[utils.DatasourceName]:
        """Parameters need to be passed as keyword arguments."""
        if not self.allow_multi_schema_metadata_fetch:
            return []
        return self.db_engine_spec.get_all_datasource_names(self, "view")

    @cache_util.memoized_func(
        key=lambda *args, **kwargs: "db:{{}}:schema:{}:table_list".format(
            kwargs.get("schema")),
        attribute_in_key="id",
    )
    def get_all_table_names_in_schema(
        self,
        schema: str,
        cache: bool = False,
        cache_timeout: int = None,
        force: bool = False,
    ):
        """Parameters need to be passed as keyword arguments.

        For unused parameters, they are referenced in
        cache_util.memoized_func decorator.

        :param schema: schema name
        :param cache: whether cache is enabled for the function
        :param cache_timeout: timeout in seconds for the cache
        :param force: whether to force refresh the cache
        :return: list of tables
        """
        try:
            tables = self.db_engine_spec.get_table_names(
                inspector=self.inspector, schema=schema)
            return [
                utils.DatasourceName(table=table, schema=schema)
                for table in tables
            ]
        except Exception as e:
            logging.exception(e)

    @cache_util.memoized_func(
        key=lambda *args, **kwargs: "db:{{}}:schema:{}:view_list".format(
            kwargs.get("schema")),
        attribute_in_key="id",
    )
    def get_all_view_names_in_schema(
        self,
        schema: str,
        cache: bool = False,
        cache_timeout: int = None,
        force: bool = False,
    ):
        """Parameters need to be passed as keyword arguments.

        For unused parameters, they are referenced in
        cache_util.memoized_func decorator.

        :param schema: schema name
        :param cache: whether cache is enabled for the function
        :param cache_timeout: timeout in seconds for the cache
        :param force: whether to force refresh the cache
        :return: list of views
        """
        try:
            views = self.db_engine_spec.get_view_names(
                inspector=self.inspector, schema=schema)
            return [
                utils.DatasourceName(table=view, schema=schema)
                for view in views
            ]
        except Exception as e:
            logging.exception(e)

    @cache_util.memoized_func(key=lambda *args, **kwargs: "db:{}:schema_list",
                              attribute_in_key="id")
    def get_all_schema_names(self,
                             cache: bool = False,
                             cache_timeout: int = None,
                             force: bool = False) -> List[str]:
        """Parameters need to be passed as keyword arguments.

        For unused parameters, they are referenced in
        cache_util.memoized_func decorator.

        :param cache: whether cache is enabled for the function
        :param cache_timeout: timeout in seconds for the cache
        :param force: whether to force refresh the cache
        :return: schema list
        """
        return self.db_engine_spec.get_schema_names(self.inspector)

    @property
    def db_engine_spec(self):
        return db_engine_specs.engines.get(self.backend,
                                           db_engine_specs.BaseEngineSpec)

    @classmethod
    def get_db_engine_spec_for_backend(cls, backend):
        return db_engine_specs.engines.get(backend,
                                           db_engine_specs.BaseEngineSpec)

    def grains(self):
        """Defines time granularity database-specific expressions.

        The idea here is to make it easy for users to change the time grain
        from a datetime (maybe the source grain is arbitrary timestamps, daily
        or 5 minutes increments) to another, "truncated" datetime. Since
        each database has slightly different but similar datetime functions,
        this allows a mapping between database engines and actual functions.
        """
        return self.db_engine_spec.get_time_grains()

    def get_extra(self):
        extra = {}
        if self.extra:
            try:
                extra = json.loads(self.extra)
            except Exception as e:
                logging.error(e)
                raise e
        return extra

    def get_table(self, table_name, schema=None):
        extra = self.get_extra()
        meta = MetaData(**extra.get("metadata_params", {}))
        return Table(
            table_name,
            meta,
            schema=schema or None,
            autoload=True,
            autoload_with=self.get_sqla_engine(),
        )

    def get_columns(self, table_name, schema=None):
        return self.db_engine_spec.get_columns(self.inspector, table_name,
                                               schema)

    def get_indexes(self, table_name, schema=None):
        return self.inspector.get_indexes(table_name, schema)

    def get_pk_constraint(self, table_name, schema=None):
        return self.inspector.get_pk_constraint(table_name, schema)

    def get_foreign_keys(self, table_name, schema=None):
        return self.inspector.get_foreign_keys(table_name, schema)

    def get_schema_access_for_csv_upload(self):
        return self.get_extra().get("schemas_allowed_for_csv_upload", [])

    @property
    def sqlalchemy_uri_decrypted(self):
        conn = sqla.engine.url.make_url(self.sqlalchemy_uri)
        if custom_password_store:
            conn.password = custom_password_store(conn)
        else:
            conn.password = self.password
        return str(conn)

    @property
    def sql_url(self):
        return "/superset/sql/{}/".format(self.id)

    def get_perm(self):
        return ("[{obj.database_name}].(id:{obj.id})").format(obj=self)

    def has_table(self, table):
        engine = self.get_sqla_engine()
        return engine.has_table(table.table_name, table.schema or None)

    @utils.memoized
    def get_dialect(self):
        sqla_url = url.make_url(self.sqlalchemy_uri_decrypted)
        return sqla_url.get_dialect()()
Ejemplo n.º 3
0
class ImapUid(MailSyncBase):
    """ Maps UIDs to their IMAP folders and per-UID flag metadata.

    This table is used solely for bookkeeping by the IMAP mail sync backends.
    """
    account_id = Column(ForeignKey(ImapAccount.id, ondelete='CASCADE'),
                        nullable=False)
    account = relationship(ImapAccount)

    message_id = Column(Integer,
                        ForeignKey(Message.id, ondelete='CASCADE'),
                        nullable=False)
    message = relationship(Message,
                           backref=backref('imapuids', passive_deletes=True))
    msg_uid = Column(BigInteger, nullable=False, index=True)

    folder_id = Column(Integer,
                       ForeignKey(Folder.id, ondelete='CASCADE'),
                       nullable=False)
    # We almost always need the folder name too, so eager load by default.
    folder = relationship(Folder,
                          lazy='joined',
                          backref=backref('imapuids', passive_deletes=True))

    # Flags #
    # Message has not completed composition (marked as a draft).
    is_draft = Column(Boolean, server_default=false(), nullable=False)
    # Message has been read
    is_seen = Column(Boolean, server_default=false(), nullable=False)
    # Message is "flagged" for urgent/special attention
    is_flagged = Column(Boolean, server_default=false(), nullable=False)
    # session is the first session to have been notified about this message
    is_recent = Column(Boolean, server_default=false(), nullable=False)
    # Message has been answered
    is_answered = Column(Boolean, server_default=false(), nullable=False)
    # things like: ['$Forwarded', 'nonjunk', 'Junk']
    extra_flags = Column(LittleJSON, default=[], nullable=False)
    # labels (Gmail-specific)
    g_labels = Column(JSON, default=lambda: [], nullable=True)

    def update_flags_and_labels(self, new_flags, x_gm_labels=None):
        """Sets flag and g_labels values based on the new_flags and x_gm_labels
        parameters. Returns True if any values have changed compared to what we
        previously stored."""
        changed = False
        new_flags = set(new_flags)
        col_for_flag = {
            u'\\Draft': 'is_draft',
            u'\\Seen': 'is_seen',
            u'\\Recent': 'is_recent',
            u'\\Answered': 'is_answered',
            u'\\Flagged': 'is_flagged',
        }
        for flag, col in col_for_flag.iteritems():
            prior_flag_value = getattr(self, col)
            new_flag_value = flag in new_flags
            if prior_flag_value != new_flag_value:
                changed = True
                setattr(self, col, new_flag_value)
            new_flags.discard(flag)
        extra_flags = sorted(new_flags)
        if extra_flags != self.extra_flags:
            changed = True
        self.extra_flags = extra_flags

        if x_gm_labels is not None:
            new_labels = sorted(x_gm_labels)
            if new_labels != self.g_labels:
                changed = True
            self.g_labels = new_labels

            # Gmail doesn't use the \Draft flag. Go figure.
            if '\\Draft' in x_gm_labels:
                if not self.is_draft:
                    changed = True
                self.is_draft = True
        return changed

    @property
    def namespace(self):
        return self.imapaccount.namespace

    __table_args__ = (UniqueConstraint(
        'folder_id',
        'msg_uid',
        'account_id',
    ), )
Ejemplo n.º 4
0
class Publisher(Base):
    __tablename__ = 'publisher'
    __table_args__ = (UniqueConstraint('name'), )
    id = sa.Column(sa.Integer, primary_key=True)
    name = sa.Column(sa.String, nullable=False)
    is_active = sa.Column(sa.Boolean, default=False)
Ejemplo n.º 5
0
class Subject(Base):
    __tablename__ = 'subject'
    __table_args__ = (UniqueConstraint('name'), )
    id = sa.Column(sa.Integer, primary_key=True)
    name = sa.Column(sa.String, nullable=False)
Ejemplo n.º 6
0
class Database(Model, AuditMixinNullable, ImportMixin):
    """An ORM object that stores Database related information"""

    __tablename__ = 'dbs'
    type = 'table'
    __table_args__ = (UniqueConstraint('database_name'), )

    id = Column(Integer, primary_key=True)
    verbose_name = Column(String(250), unique=True)
    # short unique name, used in permissions
    database_name = Column(String(250), unique=True)
    sqlalchemy_uri = Column(String(1024))
    password = Column(EncryptedType(String(1024), config.get('SECRET_KEY')))
    cache_timeout = Column(Integer)
    select_as_create_table_as = Column(Boolean, default=False)
    expose_in_sqllab = Column(Boolean, default=False)
    allow_run_sync = Column(Boolean, default=True)
    allow_run_async = Column(Boolean, default=False)
    allow_ctas = Column(Boolean, default=False)
    allow_dml = Column(Boolean, default=False)
    force_ctas_schema = Column(String(250))
    extra = Column(Text,
                   default=textwrap.dedent("""\
    {
        "metadata_params": {},
        "engine_params": {}
    }
    """))
    perm = Column(String(1000))

    impersonate_user = Column(Boolean, default=False)
    export_fields = ('database_name', 'sqlalchemy_uri', 'cache_timeout',
                     'expose_in_sqllab', 'allow_run_sync', 'allow_run_async',
                     'allow_ctas', 'extra')
    export_children = ['tables']

    def __repr__(self):
        return self.verbose_name if self.verbose_name else self.database_name

    @property
    def name(self):
        return self.verbose_name if self.verbose_name else self.database_name

    @property
    def data(self):
        return {
            'name': self.database_name,
            'backend': self.backend,
        }

    @property
    def unique_name(self):
        return self.database_name

    @property
    def backend(self):
        url = make_url(self.sqlalchemy_uri_decrypted)
        return url.get_backend_name()

    @classmethod
    def get_password_masked_url_from_uri(cls, uri):
        url = make_url(uri)
        return cls.get_password_masked_url(url)

    @classmethod
    def get_password_masked_url(cls, url):
        url_copy = deepcopy(url)
        if url_copy.password is not None and url_copy.password != PASSWORD_MASK:
            url_copy.password = PASSWORD_MASK
        return url_copy

    def set_sqlalchemy_uri(self, uri):
        conn = sqla.engine.url.make_url(uri.strip())
        if conn.password != PASSWORD_MASK and not custom_password_store:
            # do not over-write the password with the password mask
            self.password = conn.password
        conn.password = PASSWORD_MASK if conn.password else None
        self.sqlalchemy_uri = str(conn)  # hides the password

    def get_effective_user(self, url, user_name=None):
        """
        Get the effective user, especially during impersonation.
        :param url: SQL Alchemy URL object
        :param user_name: Default username
        :return: The effective username
        """
        effective_username = None
        if self.impersonate_user:
            effective_username = url.username
            if user_name:
                effective_username = user_name
            elif (hasattr(g, 'user') and hasattr(g.user, 'username')
                  and g.user.username is not None):
                effective_username = g.user.username
        return effective_username

    @utils.memoized(watch=('impersonate_user', 'sqlalchemy_uri_decrypted',
                           'extra'))
    def get_sqla_engine(self, schema=None, nullpool=True, user_name=None):
        extra = self.get_extra()
        url = make_url(self.sqlalchemy_uri_decrypted)
        url = self.db_engine_spec.adjust_database_uri(url, schema)
        effective_username = self.get_effective_user(url, user_name)
        # If using MySQL or Presto for example, will set url.username
        # If using Hive, will not do anything yet since that relies on a
        # configuration parameter instead.
        self.db_engine_spec.modify_url_for_impersonation(
            url, self.impersonate_user, effective_username)

        masked_url = self.get_password_masked_url(url)
        logging.info(
            'Database.get_sqla_engine(). Masked URL: {0}'.format(masked_url))

        params = extra.get('engine_params', {})
        if nullpool:
            params['poolclass'] = NullPool

        # If using Hive, this will set hive.server2.proxy.user=$effective_username
        configuration = {}
        configuration.update(
            self.db_engine_spec.get_configuration_for_impersonation(
                str(url), self.impersonate_user, effective_username))
        if configuration:
            params['connect_args'] = {'configuration': configuration}

        return create_engine(url, **params)

    def get_reserved_words(self):
        return self.get_dialect().preparer.reserved_words

    def get_quoter(self):
        return self.get_dialect().identifier_preparer.quote

    def get_df(self, sql, schema):
        sql = sql.strip().strip(';')
        eng = self.get_sqla_engine(schema=schema)
        df = pd.read_sql(sql, eng)

        def needs_conversion(df_series):
            if df_series.empty:
                return False
            if isinstance(df_series[0], (list, dict)):
                return True
            return False

        for k, v in df.dtypes.iteritems():
            if v.type == numpy.object_ and needs_conversion(df[k]):
                df[k] = df[k].apply(utils.json_dumps_w_dates)
        return df

    def compile_sqla_query(self, qry, schema=None):
        eng = self.get_sqla_engine(schema=schema)
        compiled = qry.compile(eng, compile_kwargs={'literal_binds': True})
        return '{}'.format(compiled)

    def select_star(self,
                    table_name,
                    schema=None,
                    limit=100,
                    show_cols=False,
                    indent=True,
                    latest_partition=True):
        """Generates a ``select *`` statement in the proper dialect"""
        return self.db_engine_spec.select_star(
            self,
            table_name,
            schema=schema,
            limit=limit,
            show_cols=show_cols,
            indent=indent,
            latest_partition=latest_partition)

    def wrap_sql_limit(self, sql, limit=1000):
        qry = (select('*').select_from(
            TextAsFrom(text(sql), ['*']).alias('inner_qry'), ).limit(limit))
        return self.compile_sqla_query(qry)

    def safe_sqlalchemy_uri(self):
        return self.sqlalchemy_uri

    @property
    def inspector(self):
        engine = self.get_sqla_engine()
        return sqla.inspect(engine)

    def all_table_names(self, schema=None, force=False):
        if not schema:
            tables_dict = self.db_engine_spec.fetch_result_sets(self,
                                                                'table',
                                                                force=force)
            return tables_dict.get('', [])
        return sorted(
            self.db_engine_spec.get_table_names(schema, self.inspector))

    def all_view_names(self, schema=None, force=False):
        if not schema:
            views_dict = self.db_engine_spec.fetch_result_sets(self,
                                                               'view',
                                                               force=force)
            return views_dict.get('', [])
        views = []
        try:
            views = self.inspector.get_view_names(schema)
        except Exception:
            pass
        return views

    def all_schema_names(self):
        return sorted(self.db_engine_spec.get_schema_names(self.inspector))

    @property
    def db_engine_spec(self):
        return db_engine_specs.engines.get(self.backend,
                                           db_engine_specs.BaseEngineSpec)

    @classmethod
    def get_db_engine_spec_for_backend(cls, backend):
        return db_engine_specs.engines.get(backend,
                                           db_engine_specs.BaseEngineSpec)

    def grains(self):
        """Defines time granularity database-specific expressions.

        The idea here is to make it easy for users to change the time grain
        form a datetime (maybe the source grain is arbitrary timestamps, daily
        or 5 minutes increments) to another, "truncated" datetime. Since
        each database has slightly different but similar datetime functions,
        this allows a mapping between database engines and actual functions.
        """
        return self.db_engine_spec.time_grains

    def grains_dict(self):
        return {grain.name: grain for grain in self.grains()}

    def get_extra(self):
        extra = {}
        if self.extra:
            try:
                extra = json.loads(self.extra)
            except Exception as e:
                logging.error(e)
        return extra

    def get_table(self, table_name, schema=None):
        extra = self.get_extra()
        meta = MetaData(**extra.get('metadata_params', {}))
        return Table(table_name,
                     meta,
                     schema=schema or None,
                     autoload=True,
                     autoload_with=self.get_sqla_engine())

    def get_columns(self, table_name, schema=None):
        return self.inspector.get_columns(table_name, schema)

    def get_indexes(self, table_name, schema=None):
        return self.inspector.get_indexes(table_name, schema)

    def get_pk_constraint(self, table_name, schema=None):
        return self.inspector.get_pk_constraint(table_name, schema)

    def get_foreign_keys(self, table_name, schema=None):
        return self.inspector.get_foreign_keys(table_name, schema)

    @property
    def sqlalchemy_uri_decrypted(self):
        conn = sqla.engine.url.make_url(self.sqlalchemy_uri)
        if custom_password_store:
            conn.password = custom_password_store(conn)
        else:
            conn.password = self.password
        return str(conn)

    @property
    def sql_url(self):
        return '/superset/sql/{}/'.format(self.id)

    def get_perm(self):
        return ('[{obj.database_name}].(id:{obj.id})').format(obj=self)

    def has_table(self, table):
        engine = self.get_sqla_engine()
        return engine.has_table(table.table_name, table.schema or None)

    @utils.memoized
    def get_dialect(self):
        sqla_url = url.make_url(self.sqlalchemy_uri_decrypted)
        return sqla_url.get_dialect()()
Ejemplo n.º 7
0
class TableColumn(Model, BaseColumn):
    """ORM object for table columns, each table can have multiple columns"""

    __tablename__ = 'table_columns'
    __table_args__ = (UniqueConstraint('table_id', 'column_name'), )
    table_id = Column(Integer, ForeignKey('tables.id'))
    table = relationship('SqlaTable',
                         backref=backref('columns',
                                         cascade='all, delete-orphan'),
                         foreign_keys=[table_id])
    is_dttm = Column(Boolean, default=False)
    expression = Column(Text, default='')
    python_date_format = Column(String(255))
    database_expression = Column(String(255))

    export_fields = (
        'table_id',
        'column_name',
        'verbose_name',
        'is_dttm',
        'is_active',
        'type',
        'groupby',
        'count_distinct',
        'sum',
        'avg',
        'max',
        'min',
        'filterable',
        'expression',
        'description',
        'python_date_format',
        'database_expression',
    )
    export_parent = 'table'

    @property
    def sqla_col(self):
        name = self.column_name
        if not self.expression:
            col = column(self.column_name).label(name)
        else:
            col = literal_column(self.expression).label(name)
        return col

    @property
    def datasource(self):
        return self.table

    def get_time_filter(self, start_dttm, end_dttm):
        col = self.sqla_col.label('__time')
        l = []  # noqa: E741
        if start_dttm:
            l.append(col >= text(self.dttm_sql_literal(start_dttm)))
        if end_dttm:
            l.append(col <= text(self.dttm_sql_literal(end_dttm)))
        return and_(*l)

    def get_timestamp_expression(self, time_grain):
        """Getting the time component of the query"""
        expr = self.expression or self.column_name
        if not self.expression and not time_grain:
            return column(expr, type_=DateTime).label(DTTM_ALIAS)
        if time_grain:
            pdf = self.python_date_format
            if pdf in ('epoch_s', 'epoch_ms'):
                # if epoch, translate to DATE using db specific conf
                db_spec = self.table.database.db_engine_spec
                if pdf == 'epoch_s':
                    expr = db_spec.epoch_to_dttm().format(col=expr)
                elif pdf == 'epoch_ms':
                    expr = db_spec.epoch_ms_to_dttm().format(col=expr)
            grain = self.table.database.grains_dict().get(time_grain, '{col}')
            expr = grain.function.format(col=expr)
        return literal_column(expr, type_=DateTime).label(DTTM_ALIAS)

    @classmethod
    def import_obj(cls, i_column):
        def lookup_obj(lookup_column):
            return db.session.query(TableColumn).filter(
                TableColumn.table_id == lookup_column.table_id,
                TableColumn.column_name == lookup_column.column_name).first()

        return import_util.import_simple_obj(db.session, i_column, lookup_obj)

    def dttm_sql_literal(self, dttm):
        """Convert datetime object to a SQL expression string

        If database_expression is empty, the internal dttm
        will be parsed as the string with the pattern that
        the user inputted (python_date_format)
        If database_expression is not empty, the internal dttm
        will be parsed as the sql sentence for the database to convert
        """
        tf = self.python_date_format
        if self.database_expression:
            return self.database_expression.format(
                dttm.strftime('%Y-%m-%d %H:%M:%S'))
        elif tf:
            if tf == 'epoch_s':
                return str((dttm - datetime(1970, 1, 1)).total_seconds())
            elif tf == 'epoch_ms':
                return str(
                    (dttm - datetime(1970, 1, 1)).total_seconds() * 1000.0)
            return "'{}'".format(dttm.strftime(tf))
        else:
            s = self.table.database.db_engine_spec.convert_dttm(
                self.type or '', dttm)
            return s or "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S.%f'))

    def get_metrics(self):
        metrics = []
        M = SqlMetric  # noqa
        quoted = self.column_name
        if self.sum:
            metrics.append(
                M(
                    metric_name='sum__' + self.column_name,
                    metric_type='sum',
                    expression='SUM({})'.format(quoted),
                ))
        if self.avg:
            metrics.append(
                M(
                    metric_name='avg__' + self.column_name,
                    metric_type='avg',
                    expression='AVG({})'.format(quoted),
                ))
        if self.max:
            metrics.append(
                M(
                    metric_name='max__' + self.column_name,
                    metric_type='max',
                    expression='MAX({})'.format(quoted),
                ))
        if self.min:
            metrics.append(
                M(
                    metric_name='min__' + self.column_name,
                    metric_type='min',
                    expression='MIN({})'.format(quoted),
                ))
        if self.count_distinct:
            metrics.append(
                M(
                    metric_name='count_distinct__' + self.column_name,
                    metric_type='count_distinct',
                    expression='COUNT(DISTINCT {})'.format(quoted),
                ))
        return {m.metric_name: m for m in metrics}
Ejemplo n.º 8
0
class TableColumn(Model, BaseColumn):
    """ORM object for table columns, each table can have multiple columns"""

    __tablename__ = "table_columns"
    __table_args__ = (UniqueConstraint("table_id", "column_name"), )
    table_id = Column(Integer, ForeignKey("tables.id"))
    table = relationship(
        "SqlaTable",
        backref=backref("columns", cascade="all, delete-orphan"),
        foreign_keys=[table_id],
    )
    is_dttm = Column(Boolean, default=False)
    expression = Column(Text)
    python_date_format = Column(String(255))

    export_fields = [
        "table_id",
        "column_name",
        "verbose_name",
        "is_dttm",
        "is_active",
        "type",
        "groupby",
        "filterable",
        "expression",
        "description",
        "python_date_format",
    ]

    update_from_object_fields = [
        s for s in export_fields if s not in ("table_id", )
    ]
    export_parent = "table"

    @property
    def is_numeric(self) -> bool:
        db_engine_spec = self.table.database.db_engine_spec
        return db_engine_spec.is_db_column_type_match(
            self.type, utils.DbColumnType.NUMERIC)

    @property
    def is_string(self) -> bool:
        db_engine_spec = self.table.database.db_engine_spec
        return db_engine_spec.is_db_column_type_match(
            self.type, utils.DbColumnType.STRING)

    @property
    def is_temporal(self) -> bool:
        db_engine_spec = self.table.database.db_engine_spec
        return db_engine_spec.is_db_column_type_match(
            self.type, utils.DbColumnType.TEMPORAL)

    def get_sqla_col(self, label: Optional[str] = None) -> Column:
        label = label or self.column_name
        if self.expression:
            col = literal_column(self.expression)
        else:
            db_engine_spec = self.table.database.db_engine_spec
            type_ = db_engine_spec.get_sqla_column_type(self.type)
            col = column(self.column_name, type_=type_)
        col = self.table.make_sqla_column_compatible(col, label)
        return col

    @property
    def datasource(self) -> RelationshipProperty:
        return self.table

    def get_time_filter(
        self,
        start_dttm: DateTime,
        end_dttm: DateTime,
        time_range_endpoints: Optional[Tuple[utils.TimeRangeEndpoint,
                                             utils.TimeRangeEndpoint]],
    ) -> ColumnElement:
        col = self.get_sqla_col(label="__time")
        l = []
        if start_dttm:
            l.append(col >= text(
                self.dttm_sql_literal(start_dttm, time_range_endpoints)))
        if end_dttm:
            if (time_range_endpoints and time_range_endpoints[1]
                    == utils.TimeRangeEndpoint.EXCLUSIVE):
                l.append(col < text(
                    self.dttm_sql_literal(end_dttm, time_range_endpoints)))
            else:
                l.append(col <= text(self.dttm_sql_literal(end_dttm, None)))
        return and_(*l)

    def get_timestamp_expression(
            self,
            time_grain: Optional[str]) -> Union[TimestampExpression, Label]:
        """
        Return a SQLAlchemy Core element representation of self to be used in a query.

        :param time_grain: Optional time grain, e.g. P1Y
        :return: A TimeExpression object wrapped in a Label if supported by db
        """
        label = utils.DTTM_ALIAS

        db = self.table.database
        pdf = self.python_date_format
        is_epoch = pdf in ("epoch_s", "epoch_ms")
        if not self.expression and not time_grain and not is_epoch:
            sqla_col = column(self.column_name, type_=DateTime)
            return self.table.make_sqla_column_compatible(sqla_col, label)
        if self.expression:
            col = literal_column(self.expression)
        else:
            col = column(self.column_name)
        time_expr = db.db_engine_spec.get_timestamp_expr(
            col, pdf, time_grain, self.type)
        return self.table.make_sqla_column_compatible(time_expr, label)

    @classmethod
    def import_obj(cls, i_column: "TableColumn") -> "TableColumn":
        def lookup_obj(lookup_column: TableColumn) -> TableColumn:
            return (db.session.query(TableColumn).filter(
                TableColumn.table_id == lookup_column.table_id,
                TableColumn.column_name == lookup_column.column_name,
            ).first())

        return import_datasource.import_simple_obj(db.session, i_column,
                                                   lookup_obj)

    def dttm_sql_literal(
        self,
        dttm: DateTime,
        time_range_endpoints: Optional[Tuple[utils.TimeRangeEndpoint,
                                             utils.TimeRangeEndpoint]],
    ) -> str:
        """Convert datetime object to a SQL expression string"""
        sql = (self.table.database.db_engine_spec.convert_dttm(
            self.type, dttm) if self.type else None)

        if sql:
            return sql

        tf = self.python_date_format

        # Fallback to the default format (if defined) only if the SIP-15 time range
        # endpoints, i.e., [start, end) are enabled.
        if not tf and time_range_endpoints == (
                utils.TimeRangeEndpoint.INCLUSIVE,
                utils.TimeRangeEndpoint.EXCLUSIVE,
        ):
            tf = (self.table.database.get_extra().get(
                "python_date_format_by_column_name", {}).get(self.column_name))

        if tf:
            if tf in ["epoch_ms", "epoch_s"]:
                seconds_since_epoch = int(dttm.timestamp())
                if tf == "epoch_s":
                    return str(seconds_since_epoch)
                return str(seconds_since_epoch * 1000)
            return f"'{dttm.strftime(tf)}'"

        # TODO(john-bodley): SIP-15 will explicitly require a type conversion.
        return f"""'{dttm.strftime("%Y-%m-%d %H:%M:%S.%f")}'"""

    @property
    def data(self) -> Dict[str, Any]:
        attrs = (
            "id",
            "column_name",
            "verbose_name",
            "description",
            "expression",
            "filterable",
            "groupby",
            "is_dttm",
            "type",
            "python_date_format",
        )
        return {s: getattr(self, s) for s in attrs if hasattr(self, s)}
Ejemplo n.º 9
0
class Dataset(Base):
    """Class to store the information about a data set.

    """
    __tablename__ = 'datasets'
    __table_args__ = (
        UniqueConstraint('task_id', 'description'),
        # Useless, in theory, because 'id' is already unique. Yet, we
        # need this because it's a target of a foreign key.
        UniqueConstraint('id', 'task_id'),
    )

    # Auto increment primary key.
    id = Column(
        Integer,
        primary_key=True)

    # Task (id and object) owning the dataset.
    task_id = Column(
        Integer,
        ForeignKey(Task.id,
                   onupdate="CASCADE", ondelete="CASCADE"),
        nullable=False)
    task = relationship(
        Task,
        foreign_keys=[task_id],
        # XXX In SQLAlchemy 0.8 we could remove this:
        primaryjoin='Task.id == Dataset.task_id',
        backref=backref('datasets',
                        cascade="all, delete-orphan",
                        passive_deletes=True))

    # A human-readable text describing the dataset.
    description = Column(
        Unicode,
        nullable=False)

    # Whether this dataset will be automatically judged by ES and SS
    # "in background", together with the active dataset of each task.
    autojudge = Column(
        Boolean,
        nullable=False,
        default=False)

    # Time and memory limits for every testcase.
    time_limit = Column(
        Float,
        nullable=True)
    memory_limit = Column(
        Integer,
        nullable=True)

    # Name of the TaskType child class suited for the task.
    task_type = Column(
        String,
        nullable=False)

    # Parameters for the task type class, JSON encoded.
    task_type_parameters = Column(
        String,
        nullable=False)

    # Name of the ScoreType child class suited for the task.
    score_type = Column(
        String,
        nullable=False)

    # Parameters for the score type class, JSON encoded.
    score_type_parameters = Column(
        String,
        nullable=False)
Ejemplo n.º 10
0
class Task(Base):
    __tablename__ = 'task'
    """
    A job that gets executed.  Has a unique set of params within its Stage.
    """
    # FIXME causes a problem with mysql?
    __table_args__ = (UniqueConstraint('stage_id', 'uid', name='_uc1'),)

    drm_options = {}

    id = Column(Integer, primary_key=True)
    uid = Column(String(255), index=True)

    mem_req = Column(Integer)
    core_req = Column(Integer)
    cpu_req = synonym('core_req')
    time_req = Column(Integer)
    NOOP = Column(Boolean, nullable=False)
    params = Column(MutableDict.as_mutable(JSONEncodedDict), nullable=False)
    stage_id = Column(ForeignKey('stage.id', ondelete="CASCADE"), nullable=False, index=True)
    log_dir = Column(String(255))
    # output_dir = Column(String(255))
    _status = Column(Enum_ColumnType(TaskStatus, length=255), default=TaskStatus.no_attempt, nullable=False)
    successful = Column(Boolean, nullable=False)
    started_on = Column(DateTime)  # FIXME this should probably be deleted.  Too hard to determine.
    submitted_on = Column(DateTime)
    finished_on = Column(DateTime)
    attempt = Column(Integer, nullable=False)
    must_succeed = Column(Boolean, nullable=False)
    drm = Column(String(255))
    # FIXME consider making job_class a proper field next time the schema changes
    # job_class = Column(String(255))
    queue = Column(String(255))
    max_attempts = Column(Integer)
    parents = relationship("Task",
                           secondary=TaskEdge.__table__,
                           primaryjoin=id == TaskEdge.parent_id,
                           secondaryjoin=id == TaskEdge.child_id,
                           backref="children",
                           passive_deletes=True,
                           cascade="save-update, merge, delete",
                           )

    # input_map = Column(MutableDict.as_mutable(JSONEncodedDict), nullable=False)
    # output_map = Column(MutableDict.as_mutable(JSONEncodedDict), nullable=False)

    @property
    def input_map(self):
        d = dict()
        for key, val in self.params.items():
            if key.startswith('in_'):
                d[key] = val
        return d

    @property
    def output_map(self):
        d = dict()
        for key, val in self.params.items():
            if key.startswith('out_'):
                d[key] = val
        return d

    @property
    def input_files(self):
        return list(self.input_map.values())

    @property
    def output_files(self):
        return list(self.output_map.values())

    # command = Column(Text)

    drm_native_specification = Column(String(255))
    drm_jobID = Column(String(255))

    profile_fields = ['wall_time', 'cpu_time', 'percent_cpu', 'user_time', 'system_time', 'io_read_count',
                      'io_write_count', 'io_read_kb', 'io_write_kb',
                      'ctx_switch_voluntary', 'ctx_switch_involuntary', 'avg_rss_mem_kb', 'max_rss_mem_kb',
                      'avg_vms_mem_kb', 'max_vms_mem_kb', 'avg_num_threads',
                      'max_num_threads',
                      'avg_num_fds', 'max_num_fds', 'exit_status']
    exclude_from_dict = profile_fields + ['command', 'info', 'input_files', 'output_files']

    exit_status = Column(Integer)

    percent_cpu = Column(Integer)
    wall_time = Column(Integer)

    cpu_time = Column(Integer)
    user_time = Column(Integer)
    system_time = Column(Integer)

    avg_rss_mem_kb = Column(Integer)
    max_rss_mem_kb = Column(Integer)
    avg_vms_mem_kb = Column(Integer)
    max_vms_mem_kb = Column(Integer)

    io_read_count = Column(Integer)
    io_write_count = Column(Integer)
    io_wait = Column(Integer)
    io_read_kb = Column(Integer)
    io_write_kb = Column(Integer)

    ctx_switch_voluntary = Column(Integer)
    ctx_switch_involuntary = Column(Integer)

    avg_num_threads = Column(Integer)
    max_num_threads = Column(Integer)

    avg_num_fds = Column(Integer)
    max_num_fds = Column(Integer)

    extra = Column(MutableDict.as_mutable(JSONEncodedDict), nullable=False)

    @declared_attr
    def status(cls):
        def get_status(self):
            return self._status

        def set_status(self, value):
            if self._status != value:
                self._status = value
                signal_task_status_change.send(self)

        return synonym('_status', descriptor=property(get_status, set_status))

    @property
    def workflow(self):
        return self.stage.workflow

    @property
    def log(self):
        return self.workflow.log

    @property
    def finished(self):
        return self.status in {TaskStatus.successful, TaskStatus.killed, TaskStatus.failed}

    _cache_profile = None

    output_profile_path = logplus('profile.json')
    output_command_script_path = logplus('command.bash')
    output_stderr_path = logplus('stderr.txt')
    output_stdout_path = logplus('stdout.txt')

    @property
    def stdout_text(self):
        return readfile(self.output_stdout_path)

    @property
    def stderr_text(self):
        r = readfile(self.output_stderr_path)
        if r == 'file does not exist':
            if self.drm == 'lsf' and self.drm_jobID:
                r += '\n\nbpeek %s output:\n\n' % self.drm_jobID
                try:
                    r += codecs.decode(sp.check_output('bpeek %s' % self.drm_jobID, shell=True), 'utf-8')
                except Exception as e:
                    r += str(e)
        return r

    @property
    def command_script_text(self):
        # return self.command
        return readfile(self.output_command_script_path).strip() or self.command

    def descendants(self, include_self=False):
        """
        :return: (list) all stages that descend from this stage in the stage_graph
        """
        x = nx.descendants(self.workflow.task_graph(), self)
        if include_self:
            return sorted({self}.union(x), key=lambda task: task.stage.number)
        else:
            return x

    @property
    def label(self):
        """Label used for the taskgraph image"""
        params = '' if len(self.params) == 0 else "\\n {0}".format(
            "\\n".join(["{0}: {1}".format(k, v) for k, v in self.params.items()]))

        return "[%s] %s%s" % (self.id, self.stage.name, params)

    def args_as_query_string(self):
        import urllib

        return urllib.urlencode(self.params)

    def delete(self, descendants=False):
        if descendants:
            tasks_to_delete = self.descendants(include_self=True)
            self.log.debug('Deleting %s and %s of its descendants' % (self, len(tasks_to_delete) - 1))
            for t in tasks_to_delete:
                self.session.delete(t)
        else:
            self.log.debug('Deleting %s' % self)
            self.session.delete(self)

        self.session.commit()

    @property
    def url(self):
        return url_for('cosmos.task', ex_name=self.workflow.name, stage_name=self.stage.name, task_id=self.id)

    @property
    def params_pretty(self):
        return '%s' % ', '.join(
            '%s=%s' % (k, "'%s'" % v if isinstance(v, basestring) else v) for k, v in self.params.items())

    @property
    def params_pformat(self):
        return pprint.pformat(self.params, indent=2, width=1)

    def __repr__(self):
        return "<Task[%s] %s(uid='%s')>" % (self.id or 'id_%s' % id(self),
                                            self.stage.name if self.stage else '',
                                            self.uid
                                            )

    def __str__(self):
        return self.__repr__()

    # FIXME consider making job_class a proper field next time the schema changes
    def __init__(self, **kwargs):
        self.job_class = kwargs.pop('job_class', None)
        _declarative_constructor(self, **kwargs)

    @reconstructor
    def init_on_load(self):
        self.job_class = None
Ejemplo n.º 11
0
class Database(Model, AuditMixinNullable, ImportMixin):  # pylint: disable=too-many-public-methods
    """An ORM object that stores Database related information"""

    __tablename__ = "dbs"
    type = "table"
    __table_args__ = (UniqueConstraint("database_name"), )

    id = Column(Integer, primary_key=True)
    verbose_name = Column(String(250), unique=True)
    # short unique name, used in permissions
    database_name = Column(String(250), unique=True, nullable=False)
    sqlalchemy_uri = Column(String(1024), nullable=False)
    password = Column(EncryptedType(String(1024), config["SECRET_KEY"]))
    cache_timeout = Column(Integer)
    select_as_create_table_as = Column(Boolean, default=False)
    expose_in_sqllab = Column(Boolean, default=True)
    allow_run_async = Column(Boolean, default=False)
    allow_csv_upload = Column(Boolean, default=False)
    allow_ctas = Column(Boolean, default=False)
    allow_cvas = Column(Boolean, default=False)
    allow_dml = Column(Boolean, default=False)
    force_ctas_schema = Column(String(250))
    allow_multi_schema_metadata_fetch = Column(  # pylint: disable=invalid-name
        Boolean, default=False)
    extra = Column(
        Text,
        default=textwrap.dedent("""\
    {
        "metadata_params": {},
        "engine_params": {},
        "metadata_cache_timeout": {},
        "schemas_allowed_for_csv_upload": []
    }
    """),
    )
    encrypted_extra = Column(EncryptedType(Text, config["SECRET_KEY"]),
                             nullable=True)
    impersonate_user = Column(Boolean, default=False)
    server_cert = Column(EncryptedType(Text, config["SECRET_KEY"]),
                         nullable=True)
    export_fields = [
        "database_name",
        "sqlalchemy_uri",
        "cache_timeout",
        "expose_in_sqllab",
        "allow_run_async",
        "allow_ctas",
        "allow_cvas",
        "allow_csv_upload",
        "extra",
    ]
    export_children = ["tables"]

    def __repr__(self) -> str:
        return self.name

    @property
    def name(self) -> str:
        return self.verbose_name if self.verbose_name else self.database_name

    @property
    def allows_subquery(self) -> bool:
        return self.db_engine_spec.allows_subqueries

    @property
    def function_names(self) -> List[str]:
        try:
            return self.db_engine_spec.get_function_names(self)
        except Exception as ex:  # pylint: disable=broad-except
            # function_names property is used in bulk APIs and should not hard crash
            # more info in: https://github.com/apache/incubator-superset/issues/9678
            logger.error(
                "Failed to fetch database function names with error: %s",
                str(ex))
        return []

    @property
    def allows_cost_estimate(self) -> bool:
        extra = self.get_extra()

        database_version = extra.get("version")
        cost_estimate_enabled: bool = extra.get(
            "cost_estimate_enabled")  # type: ignore

        return (self.db_engine_spec.get_allow_cost_estimate(database_version)
                and cost_estimate_enabled)

    @property
    def allows_virtual_table_explore(self) -> bool:
        extra = self.get_extra()

        return bool(extra.get("allows_virtual_table_explore", True))

    @property
    def explore_database_id(self) -> int:
        return self.get_extra().get("explore_database_id", self.id)

    @property
    def data(self) -> Dict[str, Any]:
        return {
            "id": self.id,
            "name": self.database_name,
            "backend": self.backend,
            "allow_multi_schema_metadata_fetch":
            self.allow_multi_schema_metadata_fetch,
            "allows_subquery": self.allows_subquery,
            "allows_cost_estimate": self.allows_cost_estimate,
            "allows_virtual_table_explore": self.allows_virtual_table_explore,
            "explore_database_id": self.explore_database_id,
        }

    @property
    def unique_name(self) -> str:
        return self.database_name

    @property
    def url_object(self) -> URL:
        return make_url(self.sqlalchemy_uri_decrypted)

    @property
    def backend(self) -> str:
        sqlalchemy_url = make_url(self.sqlalchemy_uri_decrypted)
        return sqlalchemy_url.get_backend_name()  # pylint: disable=no-member

    @property
    def metadata_cache_timeout(self) -> Dict[str, Any]:
        return self.get_extra().get("metadata_cache_timeout", {})

    @property
    def schema_cache_enabled(self) -> bool:
        return "schema_cache_timeout" in self.metadata_cache_timeout

    @property
    def schema_cache_timeout(self) -> Optional[int]:
        return self.metadata_cache_timeout.get("schema_cache_timeout")

    @property
    def table_cache_enabled(self) -> bool:
        return "table_cache_timeout" in self.metadata_cache_timeout

    @property
    def table_cache_timeout(self) -> Optional[int]:
        return self.metadata_cache_timeout.get("table_cache_timeout")

    @property
    def default_schemas(self) -> List[str]:
        return self.get_extra().get("default_schemas", [])

    @property
    def connect_args(self) -> Dict[str, Any]:
        return self.get_extra().get("engine_params",
                                    {}).get("connect_args", {})

    @classmethod
    def get_password_masked_url_from_uri(  # pylint: disable=invalid-name
            cls, uri: str) -> URL:
        sqlalchemy_url = make_url(uri)
        return cls.get_password_masked_url(sqlalchemy_url)

    @classmethod
    def get_password_masked_url(
            cls,
            url: URL  # pylint: disable=redefined-outer-name
    ) -> URL:
        url_copy = deepcopy(url)
        if url_copy.password is not None:
            url_copy.password = PASSWORD_MASK
        return url_copy

    def set_sqlalchemy_uri(self, uri: str) -> None:
        conn = sqla.engine.url.make_url(uri.strip())
        if conn.password != PASSWORD_MASK and not custom_password_store:
            # do not over-write the password with the password mask
            self.password = conn.password
        conn.password = PASSWORD_MASK if conn.password else None
        self.sqlalchemy_uri = str(conn)  # hides the password

    def get_effective_user(
        self,
        url: URL,  # pylint: disable=redefined-outer-name
        user_name: Optional[str] = None,
    ) -> Optional[str]:
        """
        Get the effective user, especially during impersonation.
        :param url: SQL Alchemy URL object
        :param user_name: Default username
        :return: The effective username
        """
        effective_username = None
        if self.impersonate_user:
            effective_username = url.username
            if user_name:
                effective_username = user_name
            elif (hasattr(g, "user") and hasattr(g.user, "username")
                  and g.user.username is not None):
                effective_username = g.user.username
        return effective_username

    @utils.memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted",
                           "extra"))
    def get_sqla_engine(
        self,
        schema: Optional[str] = None,
        nullpool: bool = True,
        user_name: Optional[str] = None,
        source: Optional[utils.QuerySource] = None,
    ) -> Engine:
        extra = self.get_extra()
        sqlalchemy_url = make_url(self.sqlalchemy_uri_decrypted)
        self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
        effective_username = self.get_effective_user(sqlalchemy_url, user_name)
        # If using MySQL or Presto for example, will set url.username
        # If using Hive, will not do anything yet since that relies on a
        # configuration parameter instead.
        self.db_engine_spec.modify_url_for_impersonation(
            sqlalchemy_url, self.impersonate_user, effective_username)

        masked_url = self.get_password_masked_url(sqlalchemy_url)
        logger.debug("Database.get_sqla_engine(). Masked URL: %s",
                     str(masked_url))

        params = extra.get("engine_params", {})
        if nullpool:
            params["poolclass"] = NullPool

        connect_args = params.get("connect_args", {})
        configuration = connect_args.get("configuration", {})

        # If using Hive, this will set hive.server2.proxy.user=$effective_username
        configuration.update(
            self.db_engine_spec.get_configuration_for_impersonation(
                str(sqlalchemy_url), self.impersonate_user,
                effective_username))
        if configuration:
            connect_args["configuration"] = configuration
        if connect_args:
            params["connect_args"] = connect_args

        params.update(self.get_encrypted_extra())

        if DB_CONNECTION_MUTATOR:
            if not source and request and request.referrer:
                if "/superset/dashboard/" in request.referrer:
                    source = utils.QuerySource.DASHBOARD
                elif "/superset/explore/" in request.referrer:
                    source = utils.QuerySource.CHART
                elif "/superset/sqllab/" in request.referrer:
                    source = utils.QuerySource.SQL_LAB

            sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
                sqlalchemy_url, params, effective_username, security_manager,
                source)

        return create_engine(sqlalchemy_url, **params)

    def get_reserved_words(self) -> Set[str]:
        return self.get_dialect().preparer.reserved_words

    def get_quoter(self) -> Callable[[str, Any], str]:
        return self.get_dialect().identifier_preparer.quote

    def get_df(  # pylint: disable=too-many-locals
        self,
        sql: str,
        schema: Optional[str] = None,
        mutator: Optional[Callable[[pd.DataFrame], None]] = None,
    ) -> pd.DataFrame:
        sqls = [str(s).strip(" ;") for s in sqlparse.parse(sql)]

        engine = self.get_sqla_engine(schema=schema)
        username = utils.get_username()

        def needs_conversion(df_series: pd.Series) -> bool:
            return not df_series.empty and isinstance(df_series[0],
                                                      (list, dict))

        def _log_query(sql: str) -> None:
            if log_query:
                log_query(engine.url, sql, schema, username, __name__,
                          security_manager)

        with closing(engine.raw_connection()) as conn:
            with closing(conn.cursor()) as cursor:
                for sql_ in sqls[:-1]:
                    _log_query(sql_)
                    self.db_engine_spec.execute(cursor, sql_)
                    cursor.fetchall()

                _log_query(sqls[-1])
                self.db_engine_spec.execute(cursor, sqls[-1])

                data = self.db_engine_spec.fetch_data(cursor)
                result_set = SupersetResultSet(data, cursor.description,
                                               self.db_engine_spec)
                df = result_set.to_pandas_df()
                if mutator:
                    mutator(df)

                for k, v in df.dtypes.items():
                    if v.type == numpy.object_ and needs_conversion(df[k]):
                        df[k] = df[k].apply(utils.json_dumps_w_dates)

                return df

    def compile_sqla_query(self,
                           qry: Select,
                           schema: Optional[str] = None) -> str:
        engine = self.get_sqla_engine(schema=schema)

        sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))

        if (engine.dialect.identifier_preparer._double_percents  # pylint: disable=protected-access
            ):
            sql = sql.replace("%%", "%")

        return sql

    def select_star(  # pylint: disable=too-many-arguments
        self,
        table_name: str,
        schema: Optional[str] = None,
        limit: int = 100,
        show_cols: bool = False,
        indent: bool = True,
        latest_partition: bool = False,
        cols: Optional[List[Dict[str, Any]]] = None,
    ) -> str:
        """Generates a ``select *`` statement in the proper dialect"""
        eng = self.get_sqla_engine(schema=schema,
                                   source=utils.QuerySource.SQL_LAB)
        return self.db_engine_spec.select_star(
            self,
            table_name,
            schema=schema,
            engine=eng,
            limit=limit,
            show_cols=show_cols,
            indent=indent,
            latest_partition=latest_partition,
            cols=cols,
        )

    def apply_limit_to_sql(self, sql: str, limit: int = 1000) -> str:
        return self.db_engine_spec.apply_limit_to_sql(sql, limit, self)

    def safe_sqlalchemy_uri(self) -> str:
        return self.sqlalchemy_uri

    @property
    def inspector(self) -> Inspector:
        engine = self.get_sqla_engine()
        return sqla.inspect(engine)

    @cache_util.memoized_func(
        key=lambda *args, **kwargs: "db:{}:schema:None:table_list",
        attribute_in_key="id",
    )
    def get_all_table_names_in_database(
        self,
        cache: bool = False,
        cache_timeout: Optional[bool] = None,
        force: bool = False,
    ) -> List[utils.DatasourceName]:
        """Parameters need to be passed as keyword arguments."""
        if not self.allow_multi_schema_metadata_fetch:
            return []
        return self.db_engine_spec.get_all_datasource_names(self, "table")

    @cache_util.memoized_func(
        key=lambda *args, **kwargs: "db:{}:schema:None:view_list",
        attribute_in_key="id")
    def get_all_view_names_in_database(
        self,
        cache: bool = False,
        cache_timeout: Optional[bool] = None,
        force: bool = False,
    ) -> List[utils.DatasourceName]:
        """Parameters need to be passed as keyword arguments."""
        if not self.allow_multi_schema_metadata_fetch:
            return []
        return self.db_engine_spec.get_all_datasource_names(self, "view")

    @cache_util.memoized_func(
        key=lambda *args, **kwargs:
        f"db:{{}}:schema:{kwargs.get('schema')}:table_list",  # type: ignore
        attribute_in_key="id",
    )
    def get_all_table_names_in_schema(
        self,
        schema: str,
        cache: bool = False,
        cache_timeout: Optional[int] = None,
        force: bool = False,
    ) -> List[utils.DatasourceName]:
        """Parameters need to be passed as keyword arguments.

        For unused parameters, they are referenced in
        cache_util.memoized_func decorator.

        :param schema: schema name
        :param cache: whether cache is enabled for the function
        :param cache_timeout: timeout in seconds for the cache
        :param force: whether to force refresh the cache
        :return: list of tables
        """
        try:
            tables = self.db_engine_spec.get_table_names(
                database=self, inspector=self.inspector, schema=schema)
            return [
                utils.DatasourceName(table=table, schema=schema)
                for table in tables
            ]
        except Exception as ex:  # pylint: disable=broad-except
            logger.exception(ex)

    @cache_util.memoized_func(
        key=lambda *args, **kwargs:
        f"db:{{}}:schema:{kwargs.get('schema')}:view_list",  # type: ignore
        attribute_in_key="id",
    )
    def get_all_view_names_in_schema(
        self,
        schema: str,
        cache: bool = False,
        cache_timeout: Optional[int] = None,
        force: bool = False,
    ) -> List[utils.DatasourceName]:
        """Parameters need to be passed as keyword arguments.

        For unused parameters, they are referenced in
        cache_util.memoized_func decorator.

        :param schema: schema name
        :param cache: whether cache is enabled for the function
        :param cache_timeout: timeout in seconds for the cache
        :param force: whether to force refresh the cache
        :return: list of views
        """
        try:
            views = self.db_engine_spec.get_view_names(
                database=self, inspector=self.inspector, schema=schema)
            return [
                utils.DatasourceName(table=view, schema=schema)
                for view in views
            ]
        except Exception as ex:  # pylint: disable=broad-except
            logger.exception(ex)

    @cache_util.memoized_func(key=lambda *args, **kwargs: "db:{}:schema_list",
                              attribute_in_key="id")
    def get_all_schema_names(
        self,
        cache: bool = False,
        cache_timeout: Optional[int] = None,
        force: bool = False,
    ) -> List[str]:
        """Parameters need to be passed as keyword arguments.

        For unused parameters, they are referenced in
        cache_util.memoized_func decorator.

        :param cache: whether cache is enabled for the function
        :param cache_timeout: timeout in seconds for the cache
        :param force: whether to force refresh the cache
        :return: schema list
        """
        return self.db_engine_spec.get_schema_names(self.inspector)

    @property
    def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]:
        return db_engine_specs.engines.get(self.backend,
                                           db_engine_specs.BaseEngineSpec)

    @classmethod
    def get_db_engine_spec_for_backend(
            cls, backend: str) -> Type[db_engine_specs.BaseEngineSpec]:
        return db_engine_specs.engines.get(backend,
                                           db_engine_specs.BaseEngineSpec)

    def grains(self) -> Tuple[TimeGrain, ...]:
        """Defines time granularity database-specific expressions.

        The idea here is to make it easy for users to change the time grain
        from a datetime (maybe the source grain is arbitrary timestamps, daily
        or 5 minutes increments) to another, "truncated" datetime. Since
        each database has slightly different but similar datetime functions,
        this allows a mapping between database engines and actual functions.
        """
        return self.db_engine_spec.get_time_grains()

    def get_extra(self) -> Dict[str, Any]:
        return self.db_engine_spec.get_extra_params(self)

    def get_encrypted_extra(self) -> Dict[str, Any]:
        encrypted_extra = {}
        if self.encrypted_extra:
            try:
                encrypted_extra = json.loads(self.encrypted_extra)
            except json.JSONDecodeError as ex:
                logger.error(ex)
                raise ex
        return encrypted_extra

    def get_table(self,
                  table_name: str,
                  schema: Optional[str] = None) -> Table:
        extra = self.get_extra()
        meta = MetaData(**extra.get("metadata_params", {}))
        return Table(
            table_name,
            meta,
            schema=schema or None,
            autoload=True,
            autoload_with=self.get_sqla_engine(),
        )

    def get_columns(self,
                    table_name: str,
                    schema: Optional[str] = None) -> List[Dict[str, Any]]:
        return self.db_engine_spec.get_columns(self.inspector, table_name,
                                               schema)

    def get_indexes(self,
                    table_name: str,
                    schema: Optional[str] = None) -> List[Dict[str, Any]]:
        return self.inspector.get_indexes(table_name, schema)

    def get_pk_constraint(self,
                          table_name: str,
                          schema: Optional[str] = None) -> Dict[str, Any]:
        return self.inspector.get_pk_constraint(table_name, schema)

    def get_foreign_keys(self,
                         table_name: str,
                         schema: Optional[str] = None) -> List[Dict[str, Any]]:
        return self.inspector.get_foreign_keys(table_name, schema)

    def get_schema_access_for_csv_upload(  # pylint: disable=invalid-name
            self, ) -> List[str]:
        allowed_databases = self.get_extra().get(
            "schemas_allowed_for_csv_upload", [])
        if hasattr(g, "user"):
            extra_allowed_databases = config["ALLOWED_USER_CSV_SCHEMA_FUNC"](
                self, g.user)
            allowed_databases += extra_allowed_databases
        return sorted(set(allowed_databases))

    @property
    def sqlalchemy_uri_decrypted(self) -> str:
        conn = sqla.engine.url.make_url(self.sqlalchemy_uri)
        if custom_password_store:
            conn.password = custom_password_store(conn)
        else:
            conn.password = self.password
        return str(conn)

    @property
    def sql_url(self) -> str:
        return f"/superset/sql/{self.id}/"

    @hybrid_property
    def perm(self) -> str:
        return f"[{self.database_name}].(id:{self.id})"

    @perm.expression  # type: ignore
    def perm(cls) -> str:  # pylint: disable=no-self-argument
        return ("[" + cls.database_name + "].(id:" +
                expression.cast(cls.id, String) + ")")

    def get_perm(self) -> str:
        return self.perm  # type: ignore

    def has_table(self, table: Table) -> bool:
        engine = self.get_sqla_engine()
        return engine.has_table(table.table_name, table.schema or None)

    def has_table_by_name(self,
                          table_name: str,
                          schema: Optional[str] = None) -> bool:
        engine = self.get_sqla_engine()
        return engine.has_table(table_name, schema)

    @utils.memoized
    def get_dialect(self) -> Dialect:
        sqla_url = url.make_url(self.sqlalchemy_uri_decrypted)
        return sqla_url.get_dialect()()  # pylint: disable=no-member
Ejemplo n.º 12
0
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import func

from databases import Database

DATABASE_URL = os.getenv("DATABASE_URL")

# SQLAlchemy
engine = create_engine(DATABASE_URL)
metadata = MetaData()
persons = Table(
    "persons",
    metadata,
    Column("record", Integer, primary_key=True),
    Column("person_id", UUID(as_uuid=False), nullable=False),
    Column("first_name", String(50), nullable=False),
    Column("middle_name", String(50)),
    Column("last_name", String(50), nullable=False),
    Column("email", String(50), nullable=False),
    Column("age", Integer, nullable=False),
    Column("version", Integer, default=1, nullable=False),
    Column("is_latest", Boolean, default=True, nullable=False),
    Column("created_date", DateTime, default=func.now(), nullable=False),
    UniqueConstraint("person_id", "version", name="id_version_uc"),
)

# databases query builder
database = Database(DATABASE_URL)
Ejemplo n.º 13
0
class Folder(MailSyncBase):
    """ Folders and labels from the remote account backend (IMAP/Exchange). """
    # `use_alter` required here to avoid circular dependency w/Account
    account_id = Column(Integer,
                        ForeignKey('account.id', use_alter=True,
                                   name='folder_fk1',
                                   ondelete='CASCADE'), nullable=False)

    # TOFIX this causes an import error due to circular dependencies
    # from inbox.models.account import Account
    account = relationship(
        'Account', backref=backref('folders',
                                   cascade='delete',
                                   primaryjoin='and_('
                                   'Folder.account_id == Account.id, '
                                   'Folder.deleted_at.is_(None))'),
        primaryjoin='and_(Folder.account_id==Account.id, '
        'Account.deleted_at==None)')

    # Explicitly set collation to be case insensitive. This is mysql's default
    # but never trust defaults! This allows us to store the original casing to
    # not confuse users when displaying it, but still only allow a single
    # folder with any specific name, canonicalized to lowercase.
    name = Column(String(MAX_FOLDER_NAME_LENGTH,
                         collation='utf8mb4_general_ci'), nullable=True)
    canonical_name = Column(String(MAX_FOLDER_NAME_LENGTH), nullable=True)

    __table_args__ = (UniqueConstraint('account_id', 'name',
                                       'canonical_name'),)

    @property
    def lowercase_name(self):
        if self.name is None:
            return None
        return self.name.lower()

    @property
    def namespace(self):
        return self.account.namespace

    @classmethod
    def create(cls, account, name, session, canonical_name=None):
        if name is not None and len(name) > MAX_FOLDER_NAME_LENGTH:
            log.warning("Truncating long folder name for account {}; "
                        "original name was '{}'" .format(account.id, name))
            name = name[:MAX_FOLDER_NAME_LENGTH]
        obj = cls(account=account, name=name,
                  canonical_name=canonical_name)
        session.add(obj)
        return obj

    @classmethod
    def find_or_create(cls, session, account, name, canonical_name=None):
        try:
            if name is not None and len(name) > MAX_FOLDER_NAME_LENGTH:
                name = name[:MAX_FOLDER_NAME_LENGTH]
            q = session.query(cls).filter_by(account_id=account.id)
            if name is not None:
                q = q.filter(func.lower(Folder.name) == func.lower(name))
            if canonical_name is not None:
                q = q.filter_by(canonical_name=canonical_name)
            obj = q.one()
        except NoResultFound:
            obj = cls.create(account, name, session, canonical_name)
        except MultipleResultsFound:
            log.info("Duplicate folder rows for folder {} for account {}"
                     .format(name, account.id))
            raise
        return obj

    def get_associated_tag(self, db_session):
        if self.canonical_name is not None:
            try:
                return db_session.query(Tag). \
                    filter(Tag.namespace_id == self.namespace.id,
                           Tag.public_id == self.canonical_name).one()
            except NoResultFound:
                # Explicitly set the namespace_id instead of the namespace
                # attribute to avoid autoflush-induced IntegrityErrors where
                # the namespace_id is null on flush.
                tag = Tag(namespace_id=self.account.namespace.id,
                          name=self.canonical_name,
                          public_id=self.canonical_name)
                db_session.add(tag)
                return tag

        else:
            provider_prefix = self.account.provider
            tag_name = '-'.join((provider_prefix,
                                 self.name.lower()))[:MAX_INDEXABLE_LENGTH]
            try:
                return db_session.query(Tag). \
                    filter(Tag.namespace_id == self.namespace.id,
                           Tag.name == tag_name).one()
            except NoResultFound:
                # Explicitly set the namespace_id instead of the namespace
                # attribute to avoid autoflush-induced IntegrityErrors where
                # the namespace_id is null on flush.
                tag = Tag(namespace_id=self.account.namespace.id,
                          name=tag_name)
                db_session.add(tag)
                return tag
Ejemplo n.º 14
0
from ...db.meta import metadata, table_opts


jcmt_allocation = Table(
    'jcmt_allocation',
    metadata,
    Column('id', Integer, primary_key=True),
    Column('proposal_id', None,
           ForeignKey('proposal.id', onupdate='RESTRICT', ondelete='RESTRICT'),
           nullable=False),
    Column('instrument', Integer, nullable=False),
    Column('ancillary', Integer, nullable=False),
    Column('weather', Integer, nullable=False),
    Column('time', Float, nullable=False),
    UniqueConstraint('proposal_id', 'instrument', 'ancillary', 'weather'),
    **table_opts)

jcmt_available = Table(
    'jcmt_available',
    metadata,
    Column('id', Integer, primary_key=True),
    Column('call_id', None,
           ForeignKey('call.id', onupdate='RESTRICT', ondelete='RESTRICT'),
           nullable=False),
    Column('weather', Integer, nullable=False),
    Column('time', Float, nullable=False),
    UniqueConstraint('call_id', 'weather'),
    **table_opts)

jcmt_options = Table(
Ejemplo n.º 15
0
class SqlaTable(Model, BaseDatasource):
    """An ORM object for SqlAlchemy table references"""

    type = 'table'
    query_language = 'sql'
    metric_class = SqlMetric
    column_class = TableColumn
    owner_class = security_manager.user_model

    __tablename__ = 'tables'
    __table_args__ = (UniqueConstraint('database_id', 'table_name'), )

    table_name = Column(String(250))
    main_dttm_col = Column(String(250))
    database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
    fetch_values_predicate = Column(String(1000))
    owners = relationship(owner_class,
                          secondary=sqlatable_user,
                          backref='tables')
    database = relationship('Database',
                            backref=backref('tables',
                                            cascade='all, delete-orphan'),
                            foreign_keys=[database_id])
    schema = Column(String(255))
    sql = Column(Text)
    is_sqllab_view = Column(Boolean, default=False)
    template_params = Column(Text)
    #add hive_partitions column to store partition info
    hive_partitions = Column(Text)

    baselink = 'tablemodelview'

    export_fields = (
        'table_name',
        'main_dttm_col',
        'description',
        'default_endpoint',
        'database_id',
        'offset',
        'cache_timeout',
        'schema',
        'sql',
        'params',
        'template_params',
        'filter_select_enabled',
        'fetch_values_predicate',
        'hive_partitions',
    )
    update_from_object_fields = [
        f for f in export_fields if f not in ('table_name', 'database_id')
    ]
    export_parent = 'database'
    export_children = ['metrics', 'columns']

    sqla_aggregations = {
        'COUNT_DISTINCT':
        lambda column_name: sa.func.COUNT(sa.distinct(column_name)),
        'COUNT':
        sa.func.COUNT,
        'SUM':
        sa.func.SUM,
        'AVG':
        sa.func.AVG,
        'MIN':
        sa.func.MIN,
        'MAX':
        sa.func.MAX,
    }

    def get_label(self, label):
        """Conditionally mutate a label to conform to db engine requirements
        and store mapping from mutated label to original label

        :param label: original label
        :return: Either a string or sqlalchemy.sql.elements.quoted_name if required
        by db engine
        """
        db_engine_spec = self.database.db_engine_spec
        sqla_label = db_engine_spec.make_label_compatible(label)
        mutated_label = str(sqla_label)
        if label != mutated_label:
            self.mutated_labels[mutated_label] = label
        return sqla_label

    def __repr__(self):
        return self.name

    @property
    def connection(self):
        return str(self.database)

    @property
    def description_markeddown(self):
        return utils.markdown(self.description)

    @property
    def datasource_name(self):
        return self.table_name

    @property
    def database_name(self):
        return self.database.name

    @property
    def link(self):
        name = escape(self.name)
        anchor = f'<a target="_blank" href="{self.explore_url}">{name}</a>'
        return Markup(anchor)

    @property
    def schema_perm(self):
        """Returns schema permission if present, database one otherwise."""
        return security_manager.get_schema_perm(self.database, self.schema)

    def get_perm(self):
        return ('[{obj.database}].[{obj.table_name}]'
                '(id:{obj.id})').format(obj=self)

    @property
    def name(self):
        if not self.schema:
            return self.table_name
        return '{}.{}'.format(self.schema, self.table_name)

    @property
    def full_name(self):
        return utils.get_datasource_full_name(self.database,
                                              self.table_name,
                                              schema=self.schema)

    @property
    def dttm_cols(self):
        l = [c.column_name for c in self.columns if c.is_dttm]  # noqa: E741
        if self.main_dttm_col and self.main_dttm_col not in l:
            l.append(self.main_dttm_col)
        return l

    @property
    def num_cols(self):
        return [c.column_name for c in self.columns if c.is_num]

    @property
    def any_dttm_col(self):
        cols = self.dttm_cols
        if cols:
            return cols[0]

    @property
    def html(self):
        t = ((c.column_name, c.type) for c in self.columns)
        df = pd.DataFrame(t)
        df.columns = ['field', 'type']
        return df.to_html(
            index=False,
            classes=('dataframe table table-striped table-bordered '
                     'table-condensed'))

    @property
    def sql_url(self):
        return self.database.sql_url + '?table_name=' + str(self.table_name)

    def external_metadata(self):
        cols = self.database.get_columns(self.table_name, schema=self.schema)
        for col in cols:
            try:
                col['type'] = str(col['type'])
            except CompileError:
                col['type'] = 'UNKNOWN'
        return cols

    @property
    def time_column_grains(self):
        return {
            'time_columns': self.dttm_cols,
            'time_grains': [grain.name for grain in self.database.grains()],
        }

    @property
    def select_star(self):
        # show_cols and latest_partition set to false to avoid
        # the expensive cost of inspecting the DB
        return self.database.select_star(self.name,
                                         show_cols=False,
                                         latest_partition=False)

    def get_col(self, col_name):
        columns = self.columns
        for col in columns:
            if col_name == col.column_name:
                return col

    @property
    def data(self):
        d = super(SqlaTable, self).data
        if self.type == 'table':
            grains = self.database.grains() or []
            if grains:
                grains = [(g.duration, g.name) for g in grains]
            d['granularity_sqla'] = utils.choicify(self.dttm_cols)
            d['time_grain_sqla'] = grains
            d['main_dttm_col'] = self.main_dttm_col
            d['fetch_values_predicate'] = self.fetch_values_predicate
        return d

    def values_for_column(self, column_name, limit=10000):
        """Runs query against sqla to retrieve some
        sample values for the given column.
        """
        cols = {col.column_name: col for col in self.columns}
        target_col = cols[column_name]
        tp = self.get_template_processor()

        qry = (select([target_col.get_sqla_col()
                       ]).select_from(self.get_from_clause(tp)).distinct())
        if limit:
            qry = qry.limit(limit)

        if self.fetch_values_predicate:
            tp = self.get_template_processor()
            qry = qry.where(tp.process_template(self.fetch_values_predicate))

        engine = self.database.get_sqla_engine()
        sql = '{}'.format(
            qry.compile(engine, compile_kwargs={'literal_binds': True}), )
        sql = self.mutate_query_from_config(sql)

        df = pd.read_sql_query(sql=sql, con=engine)
        return [row[0] for row in df.to_records(index=False)]

    def mutate_query_from_config(self, sql):
        """Apply config's SQL_QUERY_MUTATOR

        Typically adds comments to the query with context"""
        SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR')
        if SQL_QUERY_MUTATOR:
            username = utils.get_username()
            sql = SQL_QUERY_MUTATOR(sql, username, security_manager,
                                    self.database)
        return sql

    def get_template_processor(self, **kwargs):
        return get_template_processor(table=self,
                                      database=self.database,
                                      **kwargs)

    def get_query_str_extended(self, query_obj):
        sqlaq = self.get_sqla_query(**query_obj)
        sql = self.database.compile_sqla_query(sqlaq.sqla_query)
        logging.info(sql)
        sql = sqlparse.format(sql, reindent=True)
        if query_obj['is_prequery']:
            query_obj['prequeries'].append(sql)
        sql = self.mutate_query_from_config(sql)
        """Apply HIVE_QUERY_GENERATOR """
        HIVE_QUERY_GENERATOR = config.get('HIVE_QUERY_GENERATOR')
        if HIVE_QUERY_GENERATOR:
            sql = HIVE_QUERY_GENERATOR(sql, query_obj, self.database,
                                       self.datasource_name)

        return QueryStringExtended(labels_expected=sqlaq.labels_expected,
                                   sql=sql)

    def get_query_str(self, query_obj):
        return self.get_query_str_extended(query_obj).sql

    def get_sqla_table(self):
        tbl = table(self.table_name)
        if self.schema:
            tbl.schema = self.schema
        return tbl

    def get_from_clause(self, template_processor=None):
        # Supporting arbitrary SQL statements in place of tables
        if self.sql:
            from_sql = self.sql
            if template_processor:
                from_sql = template_processor.process_template(from_sql)
            from_sql = sqlparse.format(from_sql, strip_comments=True)
            return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
        return self.get_sqla_table()

    def adhoc_metric_to_sqla(self, metric, cols):
        """
        Turn an adhoc metric into a sqlalchemy column.

        :param dict metric: Adhoc metric definition
        :param dict cols: Columns for the current table
        :returns: The metric defined as a sqlalchemy column
        :rtype: sqlalchemy.sql.column
        """
        expression_type = metric.get('expressionType')
        label = utils.get_metric_name(metric)
        label = self.get_label(label)

        if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
            column_name = metric.get('column').get('column_name')
            table_column = cols.get(column_name)
            if table_column:
                sqla_column = table_column.get_sqla_col()
            else:
                sqla_column = column(column_name)
            sqla_metric = self.sqla_aggregations[metric.get('aggregate')](
                sqla_column)
            sqla_metric = sqla_metric.label(label)
            return sqla_metric
        elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
            sqla_metric = literal_column(metric.get('sqlExpression'))
            sqla_metric = sqla_metric.label(label)
            return sqla_metric
        else:
            return None

    def get_sqla_query(  # sqla
        self,
        groupby,
        metrics,
        granularity,
        from_dttm,
        to_dttm,
        filter=None,  # noqa
        is_timeseries=True,
        timeseries_limit=15,
        timeseries_limit_metric=None,
        row_limit=None,
        inner_from_dttm=None,
        inner_to_dttm=None,
        orderby=None,
        extras=None,
        columns=None,
        order_desc=True,
        prequeries=None,
        is_prequery=False,
    ):
        """Querying any sqla table from this common interface"""
        template_kwargs = {
            'from_dttm': from_dttm,
            'groupby': groupby,
            'metrics': metrics,
            'row_limit': row_limit,
            'to_dttm': to_dttm,
            'filter': filter,
            'columns': {col.column_name: col
                        for col in self.columns},
        }
        template_kwargs.update(self.template_params_dict)
        template_processor = self.get_template_processor(**template_kwargs)
        db_engine_spec = self.database.db_engine_spec

        # Initialize empty cache to store mutated labels
        self.mutated_labels = {}

        orderby = orderby or []

        # For backward compatibility
        if granularity not in self.dttm_cols:
            granularity = self.main_dttm_col

        # Database spec supports join-free timeslot grouping
        time_groupby_inline = db_engine_spec.time_groupby_inline

        cols = {col.column_name: col for col in self.columns}
        metrics_dict = {m.metric_name: m for m in self.metrics}

        if not granularity and is_timeseries:
            raise Exception(
                _('Datetime column not provided as part table configuration '
                  'and is required by this type of chart'))
        if not groupby and not metrics and not columns:
            raise Exception(_('Empty query?'))
        metrics_exprs = []
        for m in metrics:
            if utils.is_adhoc_metric(m):
                metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
            elif m in metrics_dict:
                metrics_exprs.append(metrics_dict.get(m).get_sqla_col())
            else:
                raise Exception(_("Metric '{}' is not valid".format(m)))
        if metrics_exprs:
            main_metric_expr = metrics_exprs[0]
        else:
            label = self.get_label('ccount')
            main_metric_expr = literal_column('COUNT(*)').label(label)

        select_exprs = []
        groupby_exprs_sans_timestamp = OrderedDict()

        if groupby:
            select_exprs = []
            for s in groupby:
                if s in cols:
                    outer = cols[s].get_sqla_col()
                else:
                    outer = literal_column(f'({s})').label(self.get_label(s))

                groupby_exprs_sans_timestamp[outer.name] = outer
                select_exprs.append(outer)
        elif columns:
            for s in columns:
                select_exprs.append(cols[s].get_sqla_col() if s in
                                    cols else literal_column(s))
            metrics_exprs = []

        groupby_exprs_with_timestamp = OrderedDict(
            groupby_exprs_sans_timestamp.items())
        if granularity:
            dttm_col = cols[granularity]
            time_grain = extras.get('time_grain_sqla')
            time_filters = []

            if is_timeseries:
                timestamp = dttm_col.get_timestamp_expression(time_grain)
                select_exprs += [timestamp]
                groupby_exprs_with_timestamp[timestamp.name] = timestamp

            # Use main dttm column to support index with secondary dttm columns
            if db_engine_spec.time_secondary_columns and \
                    self.main_dttm_col in self.dttm_cols and \
                    self.main_dttm_col != dttm_col.column_name:
                time_filters.append(cols[self.main_dttm_col].get_time_filter(
                    from_dttm, to_dttm))
            time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm))

        select_exprs += metrics_exprs

        labels_expected = [str(c.name) for c in select_exprs]

        select_exprs = db_engine_spec.make_select_compatible(
            groupby_exprs_with_timestamp.values(), select_exprs)
        qry = sa.select(select_exprs)

        tbl = self.get_from_clause(template_processor)

        if not columns:
            qry = qry.group_by(*groupby_exprs_with_timestamp.values())

        where_clause_and = []
        having_clause_and = []
        for flt in filter:
            if not all([flt.get(s) for s in ['col', 'op']]):
                continue
            col = flt['col']
            op = flt['op']
            col_obj = cols.get(col)
            if col_obj:
                is_list_target = op in ('in', 'not in')
                eq = self.filter_values_handler(
                    flt.get('val'),
                    target_column_is_numeric=col_obj.is_num,
                    is_list_target=is_list_target)
                if op in ('in', 'not in'):
                    cond = col_obj.get_sqla_col().in_(eq)
                    if '<NULL>' in eq:
                        cond = or_(cond,
                                   col_obj.get_sqla_col() == None)  # noqa
                    if op == 'not in':
                        cond = ~cond
                    where_clause_and.append(cond)
                else:
                    if col_obj.is_num:
                        eq = utils.string_to_num(flt['val'])
                    if op == '==':
                        where_clause_and.append(col_obj.get_sqla_col() == eq)
                    elif op == '!=':
                        where_clause_and.append(col_obj.get_sqla_col() != eq)
                    elif op == '>':
                        where_clause_and.append(col_obj.get_sqla_col() > eq)
                    elif op == '<':
                        where_clause_and.append(col_obj.get_sqla_col() < eq)
                    elif op == '>=':
                        where_clause_and.append(col_obj.get_sqla_col() >= eq)
                    elif op == '<=':
                        where_clause_and.append(col_obj.get_sqla_col() <= eq)
                    elif op == 'LIKE':
                        where_clause_and.append(
                            col_obj.get_sqla_col().like(eq))
                    elif op == 'IS NULL':
                        where_clause_and.append(
                            col_obj.get_sqla_col() == None)  # noqa
                    elif op == 'IS NOT NULL':
                        where_clause_and.append(
                            col_obj.get_sqla_col() != None)  # noqa
        if extras:
            where = extras.get('where')
            if where:
                where = template_processor.process_template(where)
                where_clause_and += [sa.text('({})'.format(where))]
            having = extras.get('having')
            if having:
                having = template_processor.process_template(having)
                having_clause_and += [sa.text('({})'.format(having))]
        if granularity:
            qry = qry.where(and_(*(time_filters + where_clause_and)))
        else:
            qry = qry.where(and_(*where_clause_and))
        qry = qry.having(and_(*having_clause_and))

        if not orderby and not columns:
            orderby = [(main_metric_expr, not order_desc)]

        for col, ascending in orderby:
            direction = asc if ascending else desc
            if utils.is_adhoc_metric(col):
                col = self.adhoc_metric_to_sqla(col, cols)
            qry = qry.order_by(direction(col))

        if row_limit:
            qry = qry.limit(row_limit)

        if is_timeseries and \
                timeseries_limit and groupby and not time_groupby_inline:
            if self.database.db_engine_spec.inner_joins:
                # some sql dialects require for order by expressions
                # to also be in the select clause -- others, e.g. vertica,
                # require a unique inner alias
                label = self.get_label('mme_inner__')
                inner_main_metric_expr = main_metric_expr.label(label)
                inner_groupby_exprs = []
                inner_select_exprs = []
                for gby_name, gby_obj in groupby_exprs_sans_timestamp.items():
                    inner = gby_obj.label(gby_name + '__')
                    inner_groupby_exprs.append(inner)
                    inner_select_exprs.append(inner)

                inner_select_exprs += [inner_main_metric_expr]
                subq = select(inner_select_exprs).select_from(tbl)
                inner_time_filter = dttm_col.get_time_filter(
                    inner_from_dttm or from_dttm,
                    inner_to_dttm or to_dttm,
                )
                subq = subq.where(
                    and_(*(where_clause_and + [inner_time_filter])))
                subq = subq.group_by(*inner_groupby_exprs)

                ob = inner_main_metric_expr
                if timeseries_limit_metric:
                    if utils.is_adhoc_metric(timeseries_limit_metric):
                        ob = self.adhoc_metric_to_sqla(timeseries_limit_metric,
                                                       cols)
                    elif timeseries_limit_metric in metrics_dict:
                        timeseries_limit_metric = metrics_dict.get(
                            timeseries_limit_metric, )
                        ob = timeseries_limit_metric.get_sqla_col()
                    else:
                        raise Exception(_(
                            "Metric '{}' is not valid".format(m)))
                direction = desc if order_desc else asc
                subq = subq.order_by(direction(ob))
                subq = subq.limit(timeseries_limit)

                on_clause = []
                for gby_name, gby_obj in groupby_exprs_sans_timestamp.items():
                    # in this case the column name, not the alias, needs to be
                    # conditionally mutated, as it refers to the column alias in
                    # the inner query
                    col_name = self.get_label(gby_name + '__')
                    on_clause.append(gby_obj == column(col_name))

                tbl = tbl.join(subq.alias(), and_(*on_clause))
            else:
                # run subquery to get top groups
                subquery_obj = {
                    'prequeries': prequeries,
                    'is_prequery': True,
                    'is_timeseries': False,
                    'row_limit': timeseries_limit,
                    'groupby': groupby,
                    'metrics': metrics,
                    'granularity': granularity,
                    'from_dttm': inner_from_dttm or from_dttm,
                    'to_dttm': inner_to_dttm or to_dttm,
                    'filter': filter,
                    'orderby': orderby,
                    'extras': extras,
                    'columns': columns,
                    'order_desc': True,
                }
                result = self.query(subquery_obj)
                dimensions = [
                    c for c in result.df.columns
                    if c not in metrics and c in groupby_exprs_sans_timestamp
                ]
                top_groups = self._get_top_groups(
                    result.df, dimensions, groupby_exprs_sans_timestamp)
                qry = qry.where(top_groups)

        return SqlaQuery(sqla_query=qry.select_from(tbl),
                         labels_expected=labels_expected)

    def _get_top_groups(self, df, dimensions, groupby_exprs):
        groups = []
        for unused, row in df.iterrows():
            group = []
            for dimension in dimensions:
                group.append(groupby_exprs[dimension] == row[dimension])
            groups.append(and_(*group))

        return or_(*groups)

    def query(self, query_obj):
        qry_start_dttm = datetime.now()
        query_str_ext = self.get_query_str_extended(query_obj)
        sql = query_str_ext.sql
        status = utils.QueryStatus.SUCCESS
        error_message = None
        df = None
        logging.debug(
            '[PERFORMANCE CHECK] SQL Query formation time {0} '.format(
                datetime.now() - qry_start_dttm))
        db_engine_spec = self.database.db_engine_spec
        try:
            df = self.database.get_df(sql, self.schema)
            if self.mutated_labels:
                df = df.rename(index=str, columns=self.mutated_labels)
            db_engine_spec.mutate_df_columns(df, sql,
                                             query_str_ext.labels_expected)
        except Exception as e:
            status = utils.QueryStatus.FAILED
            logging.exception(f'Query {sql} on schema {self.schema} failed')
            error_message = db_engine_spec.extract_error_message(e)

        # if this is a main query with prequeries, combine them together
        if not query_obj['is_prequery']:
            query_obj['prequeries'].append(sql)
            sql = ';\n\n'.join(query_obj['prequeries'])
        sql += ';'

        return QueryResult(status=status,
                           df=df,
                           duration=datetime.now() - qry_start_dttm,
                           query=sql,
                           error_message=error_message)

    def get_sqla_table_object(self):
        return self.database.get_table(self.table_name, schema=self.schema)

    def fetch_metadata(self):
        """Fetches the metadata for the table and merges it in"""
        try:
            has_kerberos_ticket()
            table = self.get_sqla_table_object()
        except Exception as e:
            logging.exception(e)
            raise Exception(
                _("Table [{}] doesn't seem to exist in the specified database, "
                  "couldn't fetch column information").format(self.table_name))

        M = SqlMetric  # noqa
        metrics = []
        any_date_col = None
        db_dialect = self.database.get_dialect()
        dbcols = (db.session.query(TableColumn).filter(
            TableColumn.table == self).filter(
                or_(TableColumn.column_name == col.name
                    for col in table.columns)))
        dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}

        for col in table.columns:
            try:
                datatype = col.type.compile(dialect=db_dialect).upper()
            except Exception as e:
                datatype = 'UNKNOWN'
                logging.error('Unrecognized data type in {}.{}'.format(
                    table, col.name))
                logging.exception(e)
            dbcol = dbcols.get(col.name, None)
            if not dbcol:
                dbcol = TableColumn(column_name=col.name, type=datatype)
                dbcol.sum = dbcol.is_num
                dbcol.avg = dbcol.is_num
                dbcol.is_dttm = dbcol.is_time
            else:
                dbcol.type = datatype
            dbcol.groupby = True
            dbcol.filterable = True
            self.columns.append(dbcol)
            if not any_date_col and dbcol.is_time:
                any_date_col = col.name

        metrics.append(
            M(
                metric_name='count',
                verbose_name='COUNT(*)',
                metric_type='count',
                expression='COUNT(*)',
            ))
        if not self.main_dttm_col:
            self.main_dttm_col = any_date_col
        self.add_missing_metrics(metrics)
        db.session.merge(self)
        db.session.commit()

    @classmethod
    def import_obj(cls, i_datasource, import_time=None):
        """Imports the datasource from the object to the database.

         Metrics and columns and datasource will be overrided if exists.
         This function can be used to import/export dashboards between multiple
         superset instances. Audit metadata isn't copies over.
        """
        def lookup_sqlatable(table):
            return db.session.query(SqlaTable).join(Database).filter(
                SqlaTable.table_name == table.table_name,
                SqlaTable.schema == table.schema,
                Database.id == table.database_id,
            ).first()

        def lookup_database(table):
            return db.session.query(Database).filter_by(
                database_name=table.params_dict['database_name']).one()

        return import_datasource.import_datasource(db.session, i_datasource,
                                                   lookup_database,
                                                   lookup_sqlatable,
                                                   import_time)

    @classmethod
    def query_datasources_by_name(cls,
                                  session,
                                  database,
                                  datasource_name,
                                  schema=None):
        query = (session.query(cls).filter_by(
            database_id=database.id).filter_by(table_name=datasource_name))
        if schema:
            query = query.filter_by(schema=schema)
        return query.all()

    @staticmethod
    def default_query(qry):
        return qry.filter_by(is_sqllab_view=False)
Ejemplo n.º 16
0
class Task(Base):
    """Class to store a task.

    """
    __tablename__ = 'tasks'
    __table_args__ = (
        UniqueConstraint('contest_id', 'num'),
        UniqueConstraint('contest_id', 'name'),
        ForeignKeyConstraint(
            ("id", "active_dataset_id"),
            ("datasets.task_id", "datasets.id"),
            onupdate="SET NULL", ondelete="SET NULL",
            # Use an ALTER query to set this foreign key after
            # both tables have been CREATEd, to avoid circular
            # dependencies.
            use_alter=True,
            name="fk_active_dataset_id"
        ),
        CheckConstraint("token_initial <= token_max"),
    )

    # Auto increment primary key.
    id = Column(
        Integer,
        primary_key=True,
        # Needed to enable autoincrement on integer primary keys that
        # are referenced by a foreign key defined on this table.
        autoincrement='ignore_fk')

    # Number of the task for sorting.
    num = Column(
        Integer,
        nullable=False)

    # Contest (id and object) owning the task.
    contest_id = Column(
        Integer,
        ForeignKey(Contest.id,
                   onupdate="CASCADE", ondelete="CASCADE"),
        nullable=False,
        index=True)
    contest = relationship(
        Contest,
        backref=backref('tasks',
                        collection_class=ordering_list('num'),
                        order_by=[num],
                        cascade="all, delete-orphan",
                        passive_deletes=True))

    # Short name and long human readable title of the task.
    name = Column(
        Unicode,
        nullable=False)
    title = Column(
        Unicode,
        nullable=False)

    # A JSON-encoded lists of strings: the language codes of the
    # statements that will be highlighted to all users for this task.
    primary_statements = Column(
        String,
        nullable=False,
        default="[]")

    # Parameter to define the token behaviour. See Contest.py for
    # details. The only change is that these parameters influence the
    # contest in a task-per-task behaviour. To play a token on a given
    # task, a user must satisfy the condition of the contest and the
    # one of the task.
    token_initial = Column(
        Integer,
        CheckConstraint("token_initial >= 0"),
        nullable=True)
    token_max = Column(
        Integer,
        CheckConstraint("token_max > 0"),
        nullable=True)
    token_total = Column(
        Integer,
        CheckConstraint("token_total > 0"),
        nullable=True)
    token_min_interval = Column(
        Interval,
        CheckConstraint("token_min_interval >= '0 seconds'"),
        nullable=False,
        default=timedelta())
    token_gen_time = Column(
        Interval,
        CheckConstraint("token_gen_time >= '0 seconds'"),
        nullable=False,
        default=timedelta())
    token_gen_number = Column(
        Integer,
        CheckConstraint("token_gen_number >= 0"),
        nullable=False,
        default=0)

    # Maximum number of submissions or user_tests allowed for each user
    # on this task during the whole contest or None to not enforce
    # this limitation.
    max_submission_number = Column(
        Integer,
        CheckConstraint("max_submission_number > 0"),
        nullable=True)
    max_user_test_number = Column(
        Integer,
        CheckConstraint("max_user_test_number > 0"),
        nullable=True)

    # Minimum interval between two submissions or user_tests for this
    # task, or None to not enforce this limitation.
    min_submission_interval = Column(
        Interval,
        CheckConstraint("min_submission_interval > '0 seconds'"),
        nullable=True)
    min_user_test_interval = Column(
        Interval,
        CheckConstraint("min_user_test_interval > '0 seconds'"),
        nullable=True)

    # The scores for this task will be rounded to this number of
    # decimal places.
    score_precision = Column(
        Integer,
        CheckConstraint("score_precision >= 0"),
        nullable=False,
        default=0)

    # Active Dataset (id and object) currently being used for scoring.
    # The ForeignKeyConstraint for this column is set at table-level.
    active_dataset_id = Column(
        Integer,
        nullable=True)
    active_dataset = relationship(
        'Dataset',
        foreign_keys=[active_dataset_id],
        # XXX In SQLAlchemy 0.8 we could remove this:
        primaryjoin='Task.active_dataset_id == Dataset.id',
        # Use an UPDATE query *after* an INSERT query (and *before* a
        # DELETE query) to set (and unset) the column associated to
        # this relationship.
        post_update=True)
Ejemplo n.º 17
0
class PriceHistory(db.Model):

    __table_args__ = (UniqueConstraint('ticker_id',
                                       'date',
                                       name='_ticker__date'), )

    id = db.Column(db.Integer, primary_key=True)
    ticker_id = db.Column(db.Integer,
                          db.ForeignKey('ticker.id'),
                          nullable=False)
    date = db.Column(db.Date,
                     nullable=False,
                     server_default=func.current_date())
    open = db.Column(db.Float(decimal_return_scale=2), nullable=False)
    close = db.Column(db.Float(decimal_return_scale=2), nullable=False)
    high = db.Column(db.Float(decimal_return_scale=2), nullable=False)
    low = db.Column(db.Float(decimal_return_scale=2), nullable=False)
    volume = db.Column(db.Integer, nullable=False)

    @classmethod
    def get_or_create(cls, ticker=None, date=None, **kwargs):
        instance = cls.query.filter_by(ticker_id=ticker.id, date=date).first()
        if instance:
            return instance, False
        else:
            kwargs['volume'] = kwargs['volume'].replace(',', '')
            instance = cls(ticker=ticker, date=date, **kwargs)
            db.session.add(instance)
            return instance, True

    def update(self, commit=False, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

        if commit:
            db.session.commit()

    @classmethod
    def get_analytics(cls, ticker_name, date_from, date_to):
        PricesA = aliased(cls)
        PricesB = aliased(cls)

        result = db.session.query(
            label('open', func.abs(PricesA.open - PricesB.open)),
            label('close', func.abs(PricesA.close - PricesB.close)),
            label('low', func.abs(PricesA.low - PricesB.low)),
            label('high',
                  func.abs(PricesA.close - PricesB.high))).join(Ticker, ).join(
                      PricesB, and_(PricesB.ticker_id == Ticker.id)).filter(
                          Ticker.name == ticker_name,
                          PricesA.date == date_from, PricesB.date == date_to)
        return result

    @classmethod
    def get_delta(cls, ticker_name, type_price, value):
        sql = text(
            DELTA_SELECT.format(ticker_name=ticker_name,
                                type_price=type_price,
                                value_delta=value))
        result = db.engine.execute(sql)
        return result.fetchall()

    def __repr__(self):
        return '<{} = {}, {}'.format(self.ticker.name, self.volume, self.date)
Ejemplo n.º 18
0
class ImapFolderSyncStatus(MailSyncBase, HasRunState, UpdatedAtMixin,
                           DeletedAtMixin):
    """ Per-folder status state saving for IMAP folders. """
    account_id = Column(ForeignKey(ImapAccount.id, ondelete='CASCADE'),
                        nullable=False)
    account = relationship(ImapAccount,
                           backref=backref('foldersyncstatuses',
                                           passive_deletes=True))

    folder_id = Column(ForeignKey('folder.id', ondelete='CASCADE'),
                       nullable=False)
    # We almost always need the folder name too, so eager load by default.
    folder = relationship('Folder',
                          lazy='joined',
                          backref=backref('imapsyncstatus',
                                          uselist=False,
                                          passive_deletes=True))

    # see state machine in mailsync/backends/imap/imap.py
    state = Column(Enum('initial', 'initial uidinvalid', 'poll',
                        'poll uidinvalid', 'finish'),
                   server_default='initial',
                   nullable=False)

    # stats on messages downloaded etc.
    _metrics = Column(MutableDict.as_mutable(JSON), default={}, nullable=True)

    @property
    def metrics(self):
        status = dict(name=self.folder.name, state=self.state)
        status.update(self._metrics or {})

        return status

    def start_sync(self):
        self._metrics = dict(run_state='running',
                             sync_start_time=datetime.utcnow())

    def stop_sync(self):
        self._metrics['run_state'] = 'stopped'
        self._metrics['sync_end_time'] = datetime.utcnow()

    @property
    def is_killed(self):
        return self._metrics.get('run_state') == 'killed'

    def update_metrics(self, metrics):
        sync_status_metrics = [
            'remote_uid_count', 'delete_uid_count', 'update_uid_count',
            'download_uid_count', 'uid_checked_timestamp',
            'num_downloaded_since_timestamp', 'queue_checked_at', 'percent'
        ]

        assert isinstance(metrics, dict)
        for k in metrics.iterkeys():
            assert k in sync_status_metrics, k

        if self._metrics is not None:
            self._metrics.update(metrics)
        else:
            self._metrics = metrics

    @property
    def sync_enabled(self):
        # sync is enabled if the folder's run bit is set, and the account's
        # run bit is set. (this saves us needing to reproduce account-state
        # transition logic on the folder level, and gives us a comparison bit
        # against folder heartbeats.)
        return self.sync_should_run and self.account.sync_should_run

    __table_args__ = (UniqueConstraint('account_id', 'folder_id'), )
Ejemplo n.º 19
0
class SqlaTable(Model, BaseDatasource):
    """An ORM object for SqlAlchemy table references"""

    type = 'table'
    query_language = 'sql'
    metric_class = SqlMetric
    column_class = TableColumn

    __tablename__ = 'tables'
    __table_args__ = (UniqueConstraint('database_id', 'table_name'), )

    table_name = Column(String(250))
    main_dttm_col = Column(String(250))
    database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
    fetch_values_predicate = Column(String(1000))
    user_id = Column(Integer, ForeignKey('ab_user.id'))
    owner = relationship(sm.user_model,
                         backref='tables',
                         foreign_keys=[user_id])
    database = relationship('Database',
                            backref=backref('tables',
                                            cascade='all, delete-orphan'),
                            foreign_keys=[database_id])
    schema = Column(String(255))
    sql = Column(Text)

    baselink = 'tablemodelview'

    export_fields = ('table_name', 'main_dttm_col', 'description',
                     'default_endpoint', 'database_id', 'offset',
                     'cache_timeout', 'schema', 'sql', 'params')
    export_parent = 'database'
    export_children = ['metrics', 'columns']

    def __repr__(self):
        return self.name

    @property
    def connection(self):
        return str(self.database)

    @property
    def description_markeddown(self):
        return utils.markdown(self.description)

    @property
    def link(self):
        name = escape(self.name)
        return Markup(
            '<a href="{self.explore_url}">{name}</a>'.format(**locals()))

    @property
    def schema_perm(self):
        """Returns schema permission if present, database one otherwise."""
        return utils.get_schema_perm(self.database, self.schema)

    def get_perm(self):
        return ('[{obj.database}].[{obj.table_name}]'
                '(id:{obj.id})').format(obj=self)

    @property
    def name(self):
        if not self.schema:
            return self.table_name
        return '{}.{}'.format(self.schema, self.table_name)

    @property
    def full_name(self):
        return utils.get_datasource_full_name(self.database,
                                              self.table_name,
                                              schema=self.schema)

    @property
    def dttm_cols(self):
        l = [c.column_name for c in self.columns if c.is_dttm]  # noqa: E741
        if self.main_dttm_col and self.main_dttm_col not in l:
            l.append(self.main_dttm_col)
        return l

    @property
    def num_cols(self):
        return [c.column_name for c in self.columns if c.is_num]

    @property
    def any_dttm_col(self):
        cols = self.dttm_cols
        if cols:
            return cols[0]

    @property
    def html(self):
        t = ((c.column_name, c.type) for c in self.columns)
        df = pd.DataFrame(t)
        df.columns = ['field', 'type']
        return df.to_html(
            index=False,
            classes=('dataframe table table-striped table-bordered '
                     'table-condensed'))

    @property
    def sql_url(self):
        return self.database.sql_url + '?table_name=' + str(self.table_name)

    @property
    def time_column_grains(self):
        return {
            'time_columns': self.dttm_cols,
            'time_grains': [grain.name for grain in self.database.grains()],
        }

    def get_col(self, col_name):
        columns = self.columns
        for col in columns:
            if col_name == col.column_name:
                return col

    @property
    def data(self):
        d = super(SqlaTable, self).data
        if self.type == 'table':
            grains = self.database.grains() or []
            if grains:
                grains = [(g.name, g.name) for g in grains]
            d['granularity_sqla'] = utils.choicify(self.dttm_cols)
            d['time_grain_sqla'] = grains
        return d

    def values_for_column(self, column_name, limit=10000):
        """Runs query against sqla to retrieve some
        sample values for the given column.
        """
        cols = {col.column_name: col for col in self.columns}
        target_col = cols[column_name]
        tp = self.get_template_processor()
        db_engine_spec = self.database.db_engine_spec

        qry = (select([target_col.sqla_col]).select_from(
            self.get_from_clause(tp, db_engine_spec)).distinct(column_name))
        if limit:
            qry = qry.limit(limit)

        if self.fetch_values_predicate:
            tp = self.get_template_processor()
            qry = qry.where(tp.process_template(self.fetch_values_predicate))

        engine = self.database.get_sqla_engine()
        sql = '{}'.format(
            qry.compile(engine, compile_kwargs={'literal_binds': True}), )

        df = pd.read_sql_query(sql=sql, con=engine)
        return [row[0] for row in df.to_records(index=False)]

    def get_template_processor(self, **kwargs):
        return get_template_processor(table=self,
                                      database=self.database,
                                      **kwargs)

    def get_query_str(self, query_obj):
        engine = self.database.get_sqla_engine()
        qry = self.get_sqla_query(**query_obj)
        sql = six.text_type(
            qry.compile(
                engine,
                compile_kwargs={'literal_binds': True},
            ), )
        logging.info(sql)
        sql = sqlparse.format(sql, reindent=True)
        if query_obj['is_prequery']:
            query_obj['prequeries'].append(sql)
        return sql

    def get_sqla_table(self):
        tbl = table(self.table_name)
        if self.schema:
            tbl.schema = self.schema
        return tbl

    def get_from_clause(self, template_processor=None, db_engine_spec=None):
        # Supporting arbitrary SQL statements in place of tables
        if self.sql:
            from_sql = self.sql
            if template_processor:
                from_sql = template_processor.process_template(from_sql)
            from_sql = sqlparse.format(from_sql, strip_comments=True)
            return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
        return self.get_sqla_table()

    def get_sqla_query(  # sqla
        self,
        groupby,
        metrics,
        granularity,
        from_dttm,
        to_dttm,
        filter=None,  # noqa
        is_timeseries=True,
        timeseries_limit=15,
        timeseries_limit_metric=None,
        row_limit=None,
        inner_from_dttm=None,
        inner_to_dttm=None,
        orderby=None,
        extras=None,
        columns=None,
        order_desc=True,
        prequeries=None,
        is_prequery=False,
    ):
        """Querying any sqla table from this common interface"""
        template_kwargs = {
            'from_dttm': from_dttm,
            'groupby': groupby,
            'metrics': metrics,
            'row_limit': row_limit,
            'to_dttm': to_dttm,
        }
        template_processor = self.get_template_processor(**template_kwargs)
        db_engine_spec = self.database.db_engine_spec

        orderby = orderby or []

        # For backward compatibility
        if granularity not in self.dttm_cols:
            granularity = self.main_dttm_col

        # Database spec supports join-free timeslot grouping
        time_groupby_inline = db_engine_spec.time_groupby_inline

        cols = {col.column_name: col for col in self.columns}
        metrics_dict = {m.metric_name: m for m in self.metrics}

        if not granularity and is_timeseries:
            raise Exception(
                _('Datetime column not provided as part table configuration '
                  'and is required by this type of chart'))
        if not groupby and not metrics and not columns:
            raise Exception(_('Empty query?'))
        for m in metrics:
            if m not in metrics_dict:
                raise Exception(_("Metric '{}' is not valid".format(m)))
        metrics_exprs = [metrics_dict.get(m).sqla_col for m in metrics]
        if metrics_exprs:
            main_metric_expr = metrics_exprs[0]
        else:
            main_metric_expr = literal_column('COUNT(*)').label('ccount')

        select_exprs = []
        groupby_exprs = []

        if groupby:
            select_exprs = []
            inner_select_exprs = []
            inner_groupby_exprs = []
            for s in groupby:
                col = cols[s]
                outer = col.sqla_col
                inner = col.sqla_col.label(col.column_name + '__')

                groupby_exprs.append(outer)
                select_exprs.append(outer)
                inner_groupby_exprs.append(inner)
                inner_select_exprs.append(inner)
        elif columns:
            for s in columns:
                select_exprs.append(cols[s].sqla_col)
            metrics_exprs = []

        if granularity:
            dttm_col = cols[granularity]
            time_grain = extras.get('time_grain_sqla')
            time_filters = []

            if is_timeseries:
                timestamp = dttm_col.get_timestamp_expression(time_grain)
                select_exprs += [timestamp]
                groupby_exprs += [timestamp]

            # Use main dttm column to support index with secondary dttm columns
            if db_engine_spec.time_secondary_columns and \
                    self.main_dttm_col in self.dttm_cols and \
                    self.main_dttm_col != dttm_col.column_name:
                time_filters.append(cols[self.main_dttm_col].get_time_filter(
                    from_dttm, to_dttm))
            time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm))

        select_exprs += metrics_exprs
        qry = sa.select(select_exprs)

        tbl = self.get_from_clause(template_processor, db_engine_spec)

        if not columns:
            qry = qry.group_by(*groupby_exprs)

        where_clause_and = []
        having_clause_and = []
        for flt in filter:
            if not all([flt.get(s) for s in ['col', 'op', 'val']]):
                continue
            col = flt['col']
            op = flt['op']
            eq = flt['val']
            col_obj = cols.get(col)
            if col_obj:
                if op in ('in', 'not in'):
                    values = []
                    for v in eq:
                        # For backwards compatibility and edge cases
                        # where a column data type might have changed
                        if isinstance(v, basestring):
                            v = v.strip("'").strip('"')
                            if col_obj.is_num:
                                v = utils.string_to_num(v)

                        # Removing empty strings and non numeric values
                        # targeting numeric columns
                        if v is not None:
                            values.append(v)
                    cond = col_obj.sqla_col.in_(values)
                    if op == 'not in':
                        cond = ~cond
                    where_clause_and.append(cond)
                else:
                    if col_obj.is_num:
                        eq = utils.string_to_num(flt['val'])
                    if op == '==':
                        where_clause_and.append(col_obj.sqla_col == eq)
                    elif op == '!=':
                        where_clause_and.append(col_obj.sqla_col != eq)
                    elif op == '>':
                        where_clause_and.append(col_obj.sqla_col > eq)
                    elif op == '<':
                        where_clause_and.append(col_obj.sqla_col < eq)
                    elif op == '>=':
                        where_clause_and.append(col_obj.sqla_col >= eq)
                    elif op == '<=':
                        where_clause_and.append(col_obj.sqla_col <= eq)
                    elif op == 'LIKE':
                        where_clause_and.append(col_obj.sqla_col.like(eq))
        if extras:
            where = extras.get('where')
            if where:
                where = template_processor.process_template(where)
                where_clause_and += [sa.text('({})'.format(where))]
            having = extras.get('having')
            if having:
                having = template_processor.process_template(having)
                having_clause_and += [sa.text('({})'.format(having))]
        if granularity:
            qry = qry.where(and_(*(time_filters + where_clause_and)))
        else:
            qry = qry.where(and_(*where_clause_and))
        qry = qry.having(and_(*having_clause_and))

        if not orderby and not columns:
            orderby = [(main_metric_expr, not order_desc)]

        for col, ascending in orderby:
            direction = asc if ascending else desc
            qry = qry.order_by(direction(col))

        if row_limit:
            qry = qry.limit(row_limit)

        if is_timeseries and \
                timeseries_limit and groupby and not time_groupby_inline:
            if self.database.db_engine_spec.inner_joins:
                # some sql dialects require for order by expressions
                # to also be in the select clause -- others, e.g. vertica,
                # require a unique inner alias
                inner_main_metric_expr = main_metric_expr.label('mme_inner__')
                inner_select_exprs += [inner_main_metric_expr]
                subq = select(inner_select_exprs)
                subq = subq.select_from(tbl)
                inner_time_filter = dttm_col.get_time_filter(
                    inner_from_dttm or from_dttm,
                    inner_to_dttm or to_dttm,
                )
                subq = subq.where(
                    and_(*(where_clause_and + [inner_time_filter])))
                subq = subq.group_by(*inner_groupby_exprs)

                ob = inner_main_metric_expr
                if timeseries_limit_metric:
                    timeseries_limit_metric = metrics_dict.get(
                        timeseries_limit_metric)
                    ob = timeseries_limit_metric.sqla_col
                direction = desc if order_desc else asc
                subq = subq.order_by(direction(ob))
                subq = subq.limit(timeseries_limit)

                on_clause = []
                for i, gb in enumerate(groupby):
                    on_clause.append(groupby_exprs[i] == column(gb + '__'))

                tbl = tbl.join(subq.alias(), and_(*on_clause))
            else:
                # run subquery to get top groups
                subquery_obj = {
                    'prequeries': prequeries,
                    'is_prequery': True,
                    'is_timeseries': False,
                    'row_limit': timeseries_limit,
                    'groupby': groupby,
                    'metrics': metrics,
                    'granularity': granularity,
                    'from_dttm': inner_from_dttm or from_dttm,
                    'to_dttm': inner_to_dttm or to_dttm,
                    'filter': filter,
                    'orderby': orderby,
                    'extras': extras,
                    'columns': columns,
                    'order_desc': True,
                }
                result = self.query(subquery_obj)
                dimensions = [c for c in result.df.columns if c not in metrics]
                top_groups = self._get_top_groups(result.df, dimensions)
                qry = qry.where(top_groups)

        return qry.select_from(tbl)

    def _get_top_groups(self, df, dimensions):
        cols = {col.column_name: col for col in self.columns}
        groups = []
        for unused, row in df.iterrows():
            group = []
            for dimension in dimensions:
                col_obj = cols.get(dimension)
                group.append(col_obj.sqla_col == row[dimension])
            groups.append(and_(*group))

        return or_(*groups)

    def query(self, query_obj):
        qry_start_dttm = datetime.now()
        sql = self.get_query_str(query_obj)
        status = QueryStatus.SUCCESS
        error_message = None
        df = None
        try:
            df = self.database.get_df(sql, self.schema)
        except Exception as e:
            status = QueryStatus.FAILED
            logging.exception(e)
            error_message = (
                self.database.db_engine_spec.extract_error_message(e))

        # if this is a main query with prequeries, combine them together
        if not query_obj['is_prequery']:
            query_obj['prequeries'].append(sql)
            sql = ';\n\n'.join(query_obj['prequeries'])
        sql += ';'

        return QueryResult(status=status,
                           df=df,
                           duration=datetime.now() - qry_start_dttm,
                           query=sql,
                           error_message=error_message)

    def get_sqla_table_object(self):
        return self.database.get_table(self.table_name, schema=self.schema)

    def fetch_metadata(self):
        """Fetches the metadata for the table and merges it in"""
        try:
            table = self.get_sqla_table_object()
        except Exception:
            raise Exception(
                _("Table [{}] doesn't seem to exist in the specified database, "
                  "couldn't fetch column information").format(self.table_name))

        M = SqlMetric  # noqa
        metrics = []
        any_date_col = None
        db_dialect = self.database.get_dialect()
        dbcols = (db.session.query(TableColumn).filter(
            TableColumn.table == self).filter(
                or_(TableColumn.column_name == col.name
                    for col in table.columns)))
        dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}

        for col in table.columns:
            try:
                datatype = col.type.compile(dialect=db_dialect).upper()
            except Exception as e:
                datatype = 'UNKNOWN'
                logging.error('Unrecognized data type in {}.{}'.format(
                    table, col.name))
                logging.exception(e)
            dbcol = dbcols.get(col.name, None)
            if not dbcol:
                dbcol = TableColumn(column_name=col.name, type=datatype)
                dbcol.groupby = dbcol.is_string
                dbcol.filterable = dbcol.is_string
                dbcol.sum = dbcol.is_num
                dbcol.avg = dbcol.is_num
                dbcol.is_dttm = dbcol.is_time
            else:
                dbcol.type = datatype
            self.columns.append(dbcol)
            if not any_date_col and dbcol.is_time:
                any_date_col = col.name
            metrics += dbcol.get_metrics().values()

        metrics.append(
            M(
                metric_name='count',
                verbose_name='COUNT(*)',
                metric_type='count',
                expression='COUNT(*)',
            ))
        if not self.main_dttm_col:
            self.main_dttm_col = any_date_col
        self.add_missing_metrics(metrics)
        db.session.merge(self)
        db.session.commit()

    @classmethod
    def import_obj(cls, i_datasource, import_time=None):
        """Imports the datasource from the object to the database.

         Metrics and columns and datasource will be overrided if exists.
         This function can be used to import/export dashboards between multiple
         superset instances. Audit metadata isn't copies over.
        """
        def lookup_sqlatable(table):
            return db.session.query(SqlaTable).join(Database).filter(
                SqlaTable.table_name == table.table_name,
                SqlaTable.schema == table.schema,
                Database.id == table.database_id,
            ).first()

        def lookup_database(table):
            return db.session.query(Database).filter_by(
                database_name=table.params_dict['database_name']).one()

        return import_util.import_datasource(db.session, i_datasource,
                                             lookup_database, lookup_sqlatable,
                                             import_time)

    @classmethod
    def query_datasources_by_name(cls,
                                  session,
                                  database,
                                  datasource_name,
                                  schema=None):
        query = (session.query(cls).filter_by(
            database_id=database.id).filter_by(table_name=datasource_name))
        if schema:
            query = query.filter_by(schema=schema)
        return query.all()
Ejemplo n.º 20
0
class ImapUid(MailSyncBase, UpdatedAtMixin, DeletedAtMixin):
    """
    Maps UIDs to their IMAP folders and per-UID flag metadata.
    This table is used solely for bookkeeping by the IMAP mail sync backends.

    """
    account_id = Column(ForeignKey(ImapAccount.id, ondelete='CASCADE'),
                        nullable=False)
    account = relationship(ImapAccount)

    message_id = Column(ForeignKey(Message.id, ondelete='CASCADE'),
                        nullable=False)
    message = relationship(Message,
                           backref=backref('imapuids', passive_deletes=True))
    msg_uid = Column(BigInteger, nullable=False, index=True)

    folder_id = Column(ForeignKey(Folder.id, ondelete='CASCADE'),
                       nullable=False)
    # We almost always need the folder name too, so eager load by default.
    folder = relationship(Folder,
                          lazy='joined',
                          backref=backref('imapuids', passive_deletes=True))

    labels = association_proxy('labelitems',
                               'label',
                               creator=lambda label: LabelItem(label=label))

    # Flags #
    # Message has not completed composition (marked as a draft).
    is_draft = Column(Boolean, server_default=false(), nullable=False)
    # Message has been read
    is_seen = Column(Boolean, server_default=false(), nullable=False)
    # Message is "flagged" for urgent/special attention
    is_flagged = Column(Boolean, server_default=false(), nullable=False)
    # session is the first session to have been notified about this message
    is_recent = Column(Boolean, server_default=false(), nullable=False)
    # Message has been answered
    is_answered = Column(Boolean, server_default=false(), nullable=False)
    # things like: ['$Forwarded', 'nonjunk', 'Junk']
    extra_flags = Column(LittleJSON, default=[], nullable=False)
    # labels (Gmail-specific)
    # TO BE DEPRECATED
    g_labels = Column(JSON, default=lambda: [], nullable=True)

    def update_flags(self, new_flags):
        """
        Sets flag and g_labels values based on the new_flags and x_gm_labels
        parameters. Returns True if any values have changed compared to what we
        previously stored.

        """
        changed = False
        new_flags = set(new_flags)
        col_for_flag = {
            u'\\Draft': 'is_draft',
            u'\\Seen': 'is_seen',
            u'\\Recent': 'is_recent',
            u'\\Answered': 'is_answered',
            u'\\Flagged': 'is_flagged',
        }
        for flag, col in col_for_flag.iteritems():
            prior_flag_value = getattr(self, col)
            new_flag_value = flag in new_flags
            if prior_flag_value != new_flag_value:
                changed = True
                setattr(self, col, new_flag_value)
            new_flags.discard(flag)

        extra_flags = sorted(new_flags)

        if extra_flags != self.extra_flags:
            changed = True

        # Sadly, there's a limit of 255 chars for this
        # column.
        while len(json.dumps(extra_flags)) > 255:
            extra_flags.pop()

        self.extra_flags = extra_flags
        return changed

    def update_labels(self, new_labels):
        # TODO(emfree): This is all mad complicated. Simplify if possible?

        # Gmail IMAP doesn't use the normal IMAP \\Draft flag. Silly Gmail
        # IMAP.
        self.is_draft = '\\Draft' in new_labels
        self.is_starred = '\\Starred' in new_labels

        category_map = {
            '\\Inbox': 'inbox',
            '\\Important': 'important',
            '\\Sent': 'sent',
            '\\Trash': 'trash',
            '\\Spam': 'spam',
            '\\All': 'all'
        }

        remote_labels = set()
        for label in new_labels:
            if label in ('\\Draft', '\\Starred'):
                continue
            elif label in category_map:
                remote_labels.add((category_map[label], category_map[label]))
            else:
                remote_labels.add((label, None))

        local_labels = {(l.name, l.canonical_name): l for l in self.labels}

        remove = set(local_labels) - remote_labels
        add = remote_labels - set(local_labels)

        with object_session(self).no_autoflush:
            for key in remove:
                self.labels.remove(local_labels[key])

            for name, canonical_name in add:
                label = Label.find_or_create(object_session(self),
                                             self.account, name,
                                             canonical_name)
                self.labels.add(label)

    @property
    def namespace(self):
        return self.imapaccount.namespace

    @property
    def categories(self):
        categories = set([l.category for l in self.labels])
        categories.add(self.folder.category)
        return categories

    __table_args__ = (UniqueConstraint(
        'folder_id',
        'msg_uid',
        'account_id',
    ), )
Ejemplo n.º 21
0
class Tag(MailSyncBase, HasRevisions):
    """Tags represent extra data associated with threads.

    A note about the schema. The 'public_id' of a tag is immutable. For
    reserved tags such as the inbox or starred tag, the public_id is a fixed
    human-readable string. For other tags, the public_id is an autogenerated
    uid similar to a normal public id, but stored as a string for
    compatibility.

    The name of a tag is allowed to be mutable, to allow for the eventuality
    that users wish to change the name of user-created labels, or that we
    someday expose localized names ('DAS INBOX'), or that we somehow manage to
    sync renamed gmail labels, etc.
    """
    API_OBJECT_NAME = 'tag'
    namespace = relationship(
        Namespace, backref=backref(
            'tags',
            collection_class=attribute_mapped_collection('public_id')),
        load_on_pending=True)
    namespace_id = Column(Integer, ForeignKey(
        'namespace.id', ondelete='CASCADE'), nullable=False)

    public_id = Column(String(MAX_INDEXABLE_LENGTH), nullable=False,
                       default=generate_public_id)
    name = Column(String(MAX_INDEXABLE_LENGTH), nullable=False)

    user_created = Column(Boolean, server_default=false(), nullable=False)

    CANONICAL_TAG_NAMES = ['inbox', 'archive', 'drafts', 'sending', 'sent',
                           'spam', 'starred', 'trash', 'unread', 'unseen',
                           'attachment']

    RESERVED_TAG_NAMES = ['all', 'archive', 'drafts', 'send', 'replied',
                          'file', 'attachment', 'unseen']

    # Tags that are allowed to be both added and removed via the API.
    USER_MUTABLE_TAGS = ['unread', 'starred', 'spam', 'trash', 'inbox',
                         'archive']

    @property
    def user_removable(self):
        # The 'unseen' tag can only be removed.
        return (self.user_created or self.public_id in self.USER_MUTABLE_TAGS
                or self.public_id == 'unseen')

    @property
    def user_addable(self):
        return (self.user_created or self.public_id in self.USER_MUTABLE_TAGS)

    @property
    def readonly(self):
        return not (self.user_removable or self.user_addable)

    @classmethod
    def name_available(cls, name, namespace_id, db_session):
        name = name.lower()
        if name in cls.RESERVED_TAG_NAMES or name in cls.CANONICAL_TAG_NAMES:
            return False

        if (name,) in db_session.query(Tag.name). \
                filter(Tag.namespace_id == namespace_id).all():
            return False

        return True

    __table_args__ = (UniqueConstraint('namespace_id', 'name'),
                      UniqueConstraint('namespace_id', 'public_id'))
Ejemplo n.º 22
0
from trove.db.sqlalchemy.migrate_repo.schema import create_tables
from trove.db.sqlalchemy.migrate_repo.schema import DateTime
from trove.db.sqlalchemy.migrate_repo.schema import drop_tables
from trove.db.sqlalchemy.migrate_repo.schema import Integer
from trove.db.sqlalchemy.migrate_repo.schema import String
from trove.db.sqlalchemy.migrate_repo.schema import Table

meta = MetaData()

quotas = Table('quotas', meta,
               Column('id', String(36), primary_key=True, nullable=False),
               Column('created', DateTime()), Column('updated', DateTime()),
               Column('tenant_id', String(36)),
               Column('resource', String(length=255), nullable=False),
               Column('hard_limit', Integer()),
               UniqueConstraint('tenant_id', 'resource'))

quota_usages = Table(
    'quota_usages', meta,
    Column('id', String(36), primary_key=True, nullable=False),
    Column('created', DateTime()), Column('updated', DateTime()),
    Column('tenant_id', String(36)), Column('in_use', Integer(), default=0),
    Column('reserved', Integer(), default=0),
    Column('resource', String(length=255), nullable=False),
    UniqueConstraint('tenant_id', 'resource'))

reservations = Table(
    'reservations', meta, Column('created', DateTime()),
    Column('updated', DateTime()),
    Column('id', String(36), primary_key=True, nullable=False),
    Column('usage_id', String(36)), Column('delta', Integer(), nullable=False),
Ejemplo n.º 23
0
class Shelf(Base):
    __tablename__ = 'shelf'
    __table_args__ = (UniqueConstraint('shelf'), )
    id = sa.Column(sa.Integer, primary_key=True)
    shelf = sa.Column(sa.String, nullable=False)
Ejemplo n.º 24
0
class TableColumn(Model, BaseColumn):
    """ORM object for table columns, each table can have multiple columns"""

    __tablename__ = "table_columns"
    __table_args__ = (UniqueConstraint("table_id", "column_name"), )
    table_id = Column(Integer, ForeignKey("tables.id"))
    table = relationship(
        "SqlaTable",
        backref=backref("columns", cascade="all, delete-orphan"),
        foreign_keys=[table_id],
    )
    is_dttm = Column(Boolean, default=False)
    expression = Column(Text)
    python_date_format = Column(String(255))

    export_fields = [
        "table_id",
        "column_name",
        "verbose_name",
        "is_dttm",
        "is_active",
        "type",
        "groupby",
        "filterable",
        "expression",
        "description",
        "python_date_format",
    ]

    update_from_object_fields = [
        s for s in export_fields if s not in ("table_id", )
    ]
    export_parent = "table"

    def get_sqla_col(self, label: Optional[str] = None) -> Column:
        label = label or self.column_name
        if not self.expression:
            db_engine_spec = self.table.database.db_engine_spec
            type_ = db_engine_spec.get_sqla_column_type(self.type)
            col = column(self.column_name, type_=type_)
        else:
            col = literal_column(self.expression)
        col = self.table.make_sqla_column_compatible(col, label)
        return col

    @property
    def datasource(self) -> RelationshipProperty:
        return self.table

    def get_time_filter(self, start_dttm: DateTime,
                        end_dttm: DateTime) -> ColumnElement:
        col = self.get_sqla_col(label="__time")
        l = []
        if start_dttm:
            l.append(col >= text(self.dttm_sql_literal(start_dttm)))
        if end_dttm:
            l.append(col <= text(self.dttm_sql_literal(end_dttm)))
        return and_(*l)

    def get_timestamp_expression(
            self,
            time_grain: Optional[str]) -> Union[TimestampExpression, Label]:
        """
        Return a SQLAlchemy Core element representation of self to be used in a query.

        :param time_grain: Optional time grain, e.g. P1Y
        :return: A TimeExpression object wrapped in a Label if supported by db
        """
        label = utils.DTTM_ALIAS

        db = self.table.database
        pdf = self.python_date_format
        is_epoch = pdf in ("epoch_s", "epoch_ms")
        if not self.expression and not time_grain and not is_epoch:
            sqla_col = column(self.column_name, type_=DateTime)
            return self.table.make_sqla_column_compatible(sqla_col, label)
        if self.expression:
            col = literal_column(self.expression)
        else:
            col = column(self.column_name)
        time_expr = db.db_engine_spec.get_timestamp_expr(col, pdf, time_grain)
        return self.table.make_sqla_column_compatible(time_expr, label)

    @classmethod
    def import_obj(cls, i_column):
        def lookup_obj(lookup_column):
            return (db.session.query(TableColumn).filter(
                TableColumn.table_id == lookup_column.table_id,
                TableColumn.column_name == lookup_column.column_name,
            ).first())

        return import_datasource.import_simple_obj(db.session, i_column,
                                                   lookup_obj)

    def dttm_sql_literal(self, dttm: DateTime) -> str:
        """Convert datetime object to a SQL expression string"""
        tf = self.python_date_format
        if tf:
            seconds_since_epoch = int(dttm.timestamp())
            if tf == "epoch_s":
                return str(seconds_since_epoch)
            elif tf == "epoch_ms":
                return str(seconds_since_epoch * 1000)
            return "'{}'".format(dttm.strftime(tf))
        else:
            s = self.table.database.db_engine_spec.convert_dttm(
                self.type or "", dttm)

            # TODO(john-bodley): SIP-15 will explicitly require a type conversion.
            return s or "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S.%f"))
Ejemplo n.º 25
0
class Language(Base):
    __tablename__ = 'language'
    __table_args__ = (UniqueConstraint('code'), )
    id = sa.Column(sa.Integer, primary_key=True)
    code = sa.Column(sa.String, nullable=False)
Ejemplo n.º 26
0
class SqlaTable(Model, BaseDatasource):
    """An ORM object for SqlAlchemy table references"""

    type = "table"
    query_language = "sql"
    metric_class = SqlMetric
    column_class = TableColumn
    owner_class = security_manager.user_model

    __tablename__ = "tables"
    __table_args__ = (UniqueConstraint("database_id", "table_name"), )

    table_name = Column(String(250), nullable=False)
    main_dttm_col = Column(String(250))
    database_id = Column(Integer, ForeignKey("dbs.id"), nullable=False)
    fetch_values_predicate = Column(String(1000))
    owners = relationship(owner_class,
                          secondary=sqlatable_user,
                          backref="tables")
    database = relationship(
        "Database",
        backref=backref("tables", cascade="all, delete-orphan"),
        foreign_keys=[database_id],
    )
    schema = Column(String(255))
    sql = Column(Text)
    is_sqllab_view = Column(Boolean, default=False)
    template_params = Column(Text)

    baselink = "tablemodelview"

    export_fields = [
        "table_name",
        "main_dttm_col",
        "description",
        "default_endpoint",
        "database_id",
        "offset",
        "cache_timeout",
        "schema",
        "sql",
        "params",
        "template_params",
        "filter_select_enabled",
        "fetch_values_predicate",
    ]
    update_from_object_fields = [
        f for f in export_fields if f not in ("table_name", "database_id")
    ]
    export_parent = "database"
    export_children = ["metrics", "columns"]

    sqla_aggregations = {
        "COUNT_DISTINCT":
        lambda column_name: sa.func.COUNT(sa.distinct(column_name)),
        "COUNT":
        sa.func.COUNT,
        "SUM":
        sa.func.SUM,
        "AVG":
        sa.func.AVG,
        "MIN":
        sa.func.MIN,
        "MAX":
        sa.func.MAX,
    }

    def make_sqla_column_compatible(self,
                                    sqla_col: Column,
                                    label: Optional[str] = None) -> Column:
        """Takes a sql alchemy column object and adds label info if supported by engine.
        :param sqla_col: sql alchemy column instance
        :param label: alias/label that column is expected to have
        :return: either a sql alchemy column or label instance if supported by engine
        """
        label_expected = label or sqla_col.name
        db_engine_spec = self.database.db_engine_spec
        if db_engine_spec.allows_column_aliases:
            label = db_engine_spec.make_label_compatible(label_expected)
            sqla_col = sqla_col.label(label)
        sqla_col._df_label_expected = label_expected
        return sqla_col

    def __repr__(self):
        return self.name

    @property
    def connection(self) -> str:
        return str(self.database)

    @property
    def description_markeddown(self) -> str:
        return utils.markdown(self.description)

    @property
    def datasource_name(self) -> str:
        return self.table_name

    @property
    def database_name(self) -> str:
        return self.database.name

    @classmethod
    def get_datasource_by_name(
        cls,
        session: Session,
        datasource_name: str,
        schema: Optional[str],
        database_name: str,
    ) -> Optional["SqlaTable"]:
        schema = schema or None
        query = (session.query(cls).join(Database).filter(
            cls.table_name == datasource_name).filter(
                Database.database_name == database_name))
        # Handling schema being '' or None, which is easier to handle
        # in python than in the SQLA query in a multi-dialect way
        for tbl in query.all():
            if schema == (tbl.schema or None):
                return tbl
        return None

    @property
    def link(self) -> Markup:
        name = escape(self.name)
        anchor = f'<a target="_blank" href="{self.explore_url}">{name}</a>'
        return Markup(anchor)

    @property
    def schema_perm(self) -> Optional[str]:
        """Returns schema permission if present, database one otherwise."""
        return security_manager.get_schema_perm(self.database, self.schema)

    def get_perm(self) -> str:
        return ("[{obj.database}].[{obj.table_name}]"
                "(id:{obj.id})").format(obj=self)

    @property
    def name(self) -> str:  # type: ignore
        if not self.schema:
            return self.table_name
        return "{}.{}".format(self.schema, self.table_name)

    @property
    def full_name(self) -> str:
        return utils.get_datasource_full_name(self.database,
                                              self.table_name,
                                              schema=self.schema)

    @property
    def dttm_cols(self) -> List:
        l = [c.column_name for c in self.columns if c.is_dttm]
        if self.main_dttm_col and self.main_dttm_col not in l:
            l.append(self.main_dttm_col)
        return l

    @property
    def num_cols(self) -> List:
        return [c.column_name for c in self.columns if c.is_num]

    @property
    def any_dttm_col(self) -> Optional[str]:
        cols = self.dttm_cols
        return cols[0] if cols else None

    @property
    def html(self) -> str:
        t = ((c.column_name, c.type) for c in self.columns)
        df = pd.DataFrame(t)
        df.columns = ["field", "type"]
        return df.to_html(
            index=False,
            classes=("dataframe table table-striped table-bordered "
                     "table-condensed"),
        )

    @property
    def sql_url(self) -> str:
        return self.database.sql_url + "?table_name=" + str(self.table_name)

    def external_metadata(self):
        cols = self.database.get_columns(self.table_name, schema=self.schema)
        for col in cols:
            try:
                col["type"] = str(col["type"])
            except CompileError:
                col["type"] = "UNKNOWN"
        return cols

    @property
    def time_column_grains(self) -> Dict[str, Any]:
        return {
            "time_columns": self.dttm_cols,
            "time_grains": [grain.name for grain in self.database.grains()],
        }

    @property
    def select_star(self) -> str:
        # show_cols and latest_partition set to false to avoid
        # the expensive cost of inspecting the DB
        return self.database.select_star(self.table_name,
                                         schema=self.schema,
                                         show_cols=False,
                                         latest_partition=False)

    def get_col(self, col_name: str) -> Optional[Column]:
        columns = self.columns
        for col in columns:
            if col_name == col.column_name:
                return col
        return None

    @property
    def data(self) -> Dict:
        d = super(SqlaTable, self).data
        if self.type == "table":
            grains = self.database.grains() or []
            if grains:
                grains = [(g.duration, g.name) for g in grains]
            d["granularity_sqla"] = utils.choicify(self.dttm_cols)
            d["time_grain_sqla"] = grains
            d["main_dttm_col"] = self.main_dttm_col
            d["fetch_values_predicate"] = self.fetch_values_predicate
            d["template_params"] = self.template_params
        return d

    def values_for_column(self, column_name: str, limit: int = 10000) -> List:
        """Runs query against sqla to retrieve some
        sample values for the given column.
        """
        cols = {col.column_name: col for col in self.columns}
        target_col = cols[column_name]
        tp = self.get_template_processor()

        qry = (select([target_col.get_sqla_col()
                       ]).select_from(self.get_from_clause(tp)).distinct())
        if limit:
            qry = qry.limit(limit)

        if self.fetch_values_predicate:
            tp = self.get_template_processor()
            qry = qry.where(tp.process_template(self.fetch_values_predicate))

        engine = self.database.get_sqla_engine()
        sql = "{}".format(
            qry.compile(engine, compile_kwargs={"literal_binds": True}))
        sql = self.mutate_query_from_config(sql)

        df = pd.read_sql_query(sql=sql, con=engine)
        return df[column_name].to_list()

    def mutate_query_from_config(self, sql: str) -> str:
        """Apply config's SQL_QUERY_MUTATOR

        Typically adds comments to the query with context"""
        SQL_QUERY_MUTATOR = config.get("SQL_QUERY_MUTATOR")
        if SQL_QUERY_MUTATOR:
            username = utils.get_username()
            sql = SQL_QUERY_MUTATOR(sql, username, security_manager,
                                    self.database)
        return sql

    def get_template_processor(self, **kwargs):
        return get_template_processor(table=self,
                                      database=self.database,
                                      **kwargs)

    def get_query_str_extended(self, query_obj: Dict) -> QueryStringExtended:
        sqlaq = self.get_sqla_query(**query_obj)
        sql = self.database.compile_sqla_query(sqlaq.sqla_query)
        logging.info(sql)
        sql = sqlparse.format(sql, reindent=True)
        sql = self.mutate_query_from_config(sql)
        return QueryStringExtended(labels_expected=sqlaq.labels_expected,
                                   sql=sql,
                                   prequeries=sqlaq.prequeries)

    def get_query_str(self, query_obj: Dict) -> str:
        query_str_ext = self.get_query_str_extended(query_obj)
        all_queries = query_str_ext.prequeries + [query_str_ext.sql]
        return ";\n\n".join(all_queries) + ";"

    def get_sqla_table(self):
        tbl = table(self.table_name)
        if self.schema:
            tbl.schema = self.schema
        return tbl

    def get_from_clause(self, template_processor=None):
        # Supporting arbitrary SQL statements in place of tables
        if self.sql:
            from_sql = self.sql
            if template_processor:
                from_sql = template_processor.process_template(from_sql)
            from_sql = sqlparse.format(from_sql, strip_comments=True)
            return TextAsFrom(sa.text(from_sql), []).alias("expr_qry")
        return self.get_sqla_table()

    def adhoc_metric_to_sqla(self, metric: Dict,
                             cols: Dict) -> Optional[Column]:
        """
        Turn an adhoc metric into a sqlalchemy column.

        :param dict metric: Adhoc metric definition
        :param dict cols: Columns for the current table
        :returns: The metric defined as a sqlalchemy column
        :rtype: sqlalchemy.sql.column
        """
        expression_type = metric.get("expressionType")
        label = utils.get_metric_name(metric)

        if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]:
            column_name = metric["column"].get("column_name")
            table_column = cols.get(column_name)
            if table_column:
                sqla_column = table_column.get_sqla_col()
            else:
                sqla_column = column(column_name)
            sqla_metric = self.sqla_aggregations[metric["aggregate"]](
                sqla_column)
        elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SQL"]:
            sqla_metric = literal_column(metric.get("sqlExpression"))
        else:
            return None

        return self.make_sqla_column_compatible(sqla_metric, label)

    def get_sqla_query(  # sqla
        self,
        groupby,
        metrics,
        granularity,
        from_dttm,
        to_dttm,
        filter=None,
        is_timeseries=True,
        timeseries_limit=15,
        timeseries_limit_metric=None,
        row_limit=None,
        inner_from_dttm=None,
        inner_to_dttm=None,
        orderby=None,
        extras=None,
        columns=None,
        order_desc=True,
    ) -> SqlaQuery:
        """Querying any sqla table from this common interface"""
        template_kwargs = {
            "from_dttm": from_dttm,
            "groupby": groupby,
            "metrics": metrics,
            "row_limit": row_limit,
            "to_dttm": to_dttm,
            "filter": filter,
            "columns": {col.column_name: col
                        for col in self.columns},
        }
        template_kwargs.update(self.template_params_dict)
        extra_cache_keys: List[Any] = []
        template_kwargs["extra_cache_keys"] = extra_cache_keys
        template_processor = self.get_template_processor(**template_kwargs)
        db_engine_spec = self.database.db_engine_spec
        prequeries: List[str] = []

        orderby = orderby or []

        # For backward compatibility
        if granularity not in self.dttm_cols:
            granularity = self.main_dttm_col

        # Database spec supports join-free timeslot grouping
        time_groupby_inline = db_engine_spec.time_groupby_inline

        cols: Dict[str,
                   Column] = {col.column_name: col
                              for col in self.columns}
        metrics_dict: Dict[str, SqlMetric] = {
            m.metric_name: m
            for m in self.metrics
        }

        if not granularity and is_timeseries:
            raise Exception(
                _("Datetime column not provided as part table configuration "
                  "and is required by this type of chart"))
        if not groupby and not metrics and not columns:
            raise Exception(_("Empty query?"))
        metrics_exprs = []
        for m in metrics:
            if utils.is_adhoc_metric(m):
                metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
            elif m in metrics_dict:
                metrics_exprs.append(metrics_dict[m].get_sqla_col())
            else:
                raise Exception(
                    _("Metric '%(metric)s' does not exist", metric=m))
        if metrics_exprs:
            main_metric_expr = metrics_exprs[0]
        else:
            main_metric_expr, label = literal_column("COUNT(*)"), "ccount"
            main_metric_expr = self.make_sqla_column_compatible(
                main_metric_expr, label)

        select_exprs: List[Column] = []
        groupby_exprs_sans_timestamp: OrderedDict = OrderedDict()

        if groupby:
            select_exprs = []
            for s in groupby:
                if s in cols:
                    outer = cols[s].get_sqla_col()
                else:
                    outer = literal_column(f"({s})")
                    outer = self.make_sqla_column_compatible(outer, s)

                groupby_exprs_sans_timestamp[outer.name] = outer
                select_exprs.append(outer)
        elif columns:
            for s in columns:
                select_exprs.append(
                    cols[s].get_sqla_col() if s in cols else self.
                    make_sqla_column_compatible(literal_column(s)))
            metrics_exprs = []

        groupby_exprs_with_timestamp = OrderedDict(
            groupby_exprs_sans_timestamp.items())
        if granularity:
            dttm_col = cols[granularity]
            time_grain = extras.get("time_grain_sqla")
            time_filters = []

            if is_timeseries:
                timestamp = dttm_col.get_timestamp_expression(time_grain)
                select_exprs += [timestamp]
                groupby_exprs_with_timestamp[timestamp.name] = timestamp

            # Use main dttm column to support index with secondary dttm columns
            if (db_engine_spec.time_secondary_columns
                    and self.main_dttm_col in self.dttm_cols
                    and self.main_dttm_col != dttm_col.column_name):
                time_filters.append(cols[self.main_dttm_col].get_time_filter(
                    from_dttm, to_dttm))
            time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm))

        select_exprs += metrics_exprs

        labels_expected = [c._df_label_expected for c in select_exprs]

        select_exprs = db_engine_spec.make_select_compatible(
            groupby_exprs_with_timestamp.values(), select_exprs)
        qry = sa.select(select_exprs)

        tbl = self.get_from_clause(template_processor)

        if not columns:
            qry = qry.group_by(*groupby_exprs_with_timestamp.values())

        where_clause_and = []
        having_clause_and: List = []
        for flt in filter:
            if not all([flt.get(s) for s in ["col", "op"]]):
                continue
            col = flt["col"]
            op = flt["op"]
            col_obj = cols.get(col)
            if col_obj:
                is_list_target = op in ("in", "not in")
                eq = self.filter_values_handler(
                    flt.get("val"),
                    target_column_is_numeric=col_obj.is_num,
                    is_list_target=is_list_target,
                )
                if op in ("in", "not in"):
                    cond = col_obj.get_sqla_col().in_(eq)
                    if "<NULL>" in eq:
                        cond = or_(cond, col_obj.get_sqla_col() == None)
                    if op == "not in":
                        cond = ~cond
                    where_clause_and.append(cond)
                else:
                    if col_obj.is_num:
                        eq = utils.string_to_num(flt["val"])
                    if op == "==":
                        where_clause_and.append(col_obj.get_sqla_col() == eq)
                    elif op == "!=":
                        where_clause_and.append(col_obj.get_sqla_col() != eq)
                    elif op == ">":
                        where_clause_and.append(col_obj.get_sqla_col() > eq)
                    elif op == "<":
                        where_clause_and.append(col_obj.get_sqla_col() < eq)
                    elif op == ">=":
                        where_clause_and.append(col_obj.get_sqla_col() >= eq)
                    elif op == "<=":
                        where_clause_and.append(col_obj.get_sqla_col() <= eq)
                    elif op == "LIKE":
                        where_clause_and.append(
                            col_obj.get_sqla_col().like(eq))
                    elif op == "IS NULL":
                        where_clause_and.append(col_obj.get_sqla_col() == None)
                    elif op == "IS NOT NULL":
                        where_clause_and.append(col_obj.get_sqla_col() != None)
        if extras:
            where = extras.get("where")
            if where:
                where = template_processor.process_template(where)
                where_clause_and += [sa.text("({})".format(where))]
            having = extras.get("having")
            if having:
                having = template_processor.process_template(having)
                having_clause_and += [sa.text("({})".format(having))]
        if granularity:
            qry = qry.where(and_(*(time_filters + where_clause_and)))
        else:
            qry = qry.where(and_(*where_clause_and))
        qry = qry.having(and_(*having_clause_and))

        if not orderby and not columns:
            orderby = [(main_metric_expr, not order_desc)]

        for col, ascending in orderby:
            direction = asc if ascending else desc
            if utils.is_adhoc_metric(col):
                col = self.adhoc_metric_to_sqla(col, cols)
            elif col in cols:
                col = cols[col].get_sqla_col()
            qry = qry.order_by(direction(col))

        if row_limit:
            qry = qry.limit(row_limit)

        if is_timeseries and timeseries_limit and groupby and not time_groupby_inline:
            if self.database.db_engine_spec.allows_joins:
                # some sql dialects require for order by expressions
                # to also be in the select clause -- others, e.g. vertica,
                # require a unique inner alias
                inner_main_metric_expr = self.make_sqla_column_compatible(
                    main_metric_expr, "mme_inner__")
                inner_groupby_exprs = []
                inner_select_exprs = []
                for gby_name, gby_obj in groupby_exprs_sans_timestamp.items():
                    inner = self.make_sqla_column_compatible(
                        gby_obj, gby_name + "__")
                    inner_groupby_exprs.append(inner)
                    inner_select_exprs.append(inner)

                inner_select_exprs += [inner_main_metric_expr]
                subq = select(inner_select_exprs).select_from(tbl)
                inner_time_filter = dttm_col.get_time_filter(
                    inner_from_dttm or from_dttm, inner_to_dttm or to_dttm)
                subq = subq.where(
                    and_(*(where_clause_and + [inner_time_filter])))
                subq = subq.group_by(*inner_groupby_exprs)

                ob = inner_main_metric_expr
                if timeseries_limit_metric:
                    ob = self._get_timeseries_orderby(timeseries_limit_metric,
                                                      metrics_dict, cols)
                direction = desc if order_desc else asc
                subq = subq.order_by(direction(ob))
                subq = subq.limit(timeseries_limit)

                on_clause = []
                for gby_name, gby_obj in groupby_exprs_sans_timestamp.items():
                    # in this case the column name, not the alias, needs to be
                    # conditionally mutated, as it refers to the column alias in
                    # the inner query
                    col_name = db_engine_spec.make_label_compatible(gby_name +
                                                                    "__")
                    on_clause.append(gby_obj == column(col_name))

                tbl = tbl.join(subq.alias(), and_(*on_clause))
            else:
                if timeseries_limit_metric:
                    orderby = [(
                        self._get_timeseries_orderby(timeseries_limit_metric,
                                                     metrics_dict, cols),
                        False,
                    )]

                # run prequery to get top groups
                prequery_obj = {
                    "is_timeseries": False,
                    "row_limit": timeseries_limit,
                    "groupby": groupby,
                    "metrics": metrics,
                    "granularity": granularity,
                    "from_dttm": inner_from_dttm or from_dttm,
                    "to_dttm": inner_to_dttm or to_dttm,
                    "filter": filter,
                    "orderby": orderby,
                    "extras": extras,
                    "columns": columns,
                    "order_desc": True,
                }
                result = self.query(prequery_obj)
                prequeries.append(result.query)
                dimensions = [
                    c for c in result.df.columns
                    if c not in metrics and c in groupby_exprs_sans_timestamp
                ]
                top_groups = self._get_top_groups(
                    result.df, dimensions, groupby_exprs_sans_timestamp)
                qry = qry.where(top_groups)

        return SqlaQuery(
            extra_cache_keys=extra_cache_keys,
            labels_expected=labels_expected,
            sqla_query=qry.select_from(tbl),
            prequeries=prequeries,
        )

    def _get_timeseries_orderby(self, timeseries_limit_metric, metrics_dict,
                                cols):
        if utils.is_adhoc_metric(timeseries_limit_metric):
            ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols)
        elif timeseries_limit_metric in metrics_dict:
            timeseries_limit_metric = metrics_dict.get(timeseries_limit_metric)
            ob = timeseries_limit_metric.get_sqla_col()
        else:
            raise Exception(
                _("Metric '%(metric)s' does not exist",
                  metric=timeseries_limit_metric))

        return ob

    def _get_top_groups(self, df: pd.DataFrame, dimensions: List,
                        groupby_exprs: OrderedDict) -> ColumnElement:
        groups = []
        for unused, row in df.iterrows():
            group = []
            for dimension in dimensions:
                group.append(groupby_exprs[dimension] == row[dimension])
            groups.append(and_(*group))

        return or_(*groups)

    def query(self, query_obj: Dict) -> QueryResult:
        qry_start_dttm = datetime.now()
        query_str_ext = self.get_query_str_extended(query_obj)
        sql = query_str_ext.sql
        status = utils.QueryStatus.SUCCESS
        error_message = None

        def mutator(df):
            labels_expected = query_str_ext.labels_expected
            if df is not None and not df.empty:
                if len(df.columns) != len(labels_expected):
                    raise Exception(f"For {sql}, df.columns: {df.columns}"
                                    f" differs from {labels_expected}")
                else:
                    df.columns = labels_expected
            return df

        try:
            df = self.database.get_df(sql, self.schema, mutator)
        except Exception as e:
            df = None
            status = utils.QueryStatus.FAILED
            logging.exception(f"Query {sql} on schema {self.schema} failed")
            db_engine_spec = self.database.db_engine_spec
            error_message = db_engine_spec.extract_error_message(e)

        return QueryResult(
            status=status,
            df=df,
            duration=datetime.now() - qry_start_dttm,
            query=sql,
            error_message=error_message,
        )

    def get_sqla_table_object(self) -> Table:
        return self.database.get_table(self.table_name, schema=self.schema)

    def fetch_metadata(self) -> None:
        """Fetches the metadata for the table and merges it in"""
        try:
            table = self.get_sqla_table_object()
        except Exception as e:
            logging.exception(e)
            raise Exception(
                _("Table [{}] doesn't seem to exist in the specified database, "
                  "couldn't fetch column information").format(self.table_name))

        M = SqlMetric
        metrics = []
        any_date_col = None
        db_engine_spec = self.database.db_engine_spec
        db_dialect = self.database.get_dialect()
        dbcols = (db.session.query(TableColumn).filter(
            TableColumn.table == self).filter(
                or_(TableColumn.column_name == col.name
                    for col in table.columns)))
        dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}

        for col in table.columns:
            try:
                datatype = db_engine_spec.column_datatype_to_string(
                    col.type, db_dialect)
            except Exception as e:
                datatype = "UNKNOWN"
                logging.error("Unrecognized data type in {}.{}".format(
                    table, col.name))
                logging.exception(e)
            dbcol = dbcols.get(col.name, None)
            if not dbcol:
                dbcol = TableColumn(column_name=col.name, type=datatype)
                dbcol.sum = dbcol.is_num
                dbcol.avg = dbcol.is_num
                dbcol.is_dttm = dbcol.is_time
                db_engine_spec.alter_new_orm_column(dbcol)
            else:
                dbcol.type = datatype
            dbcol.groupby = True
            dbcol.filterable = True
            self.columns.append(dbcol)
            if not any_date_col and dbcol.is_time:
                any_date_col = col.name

        metrics.append(
            M(
                metric_name="count",
                verbose_name="COUNT(*)",
                metric_type="count",
                expression="COUNT(*)",
            ))
        if not self.main_dttm_col:
            self.main_dttm_col = any_date_col
        self.add_missing_metrics(metrics)
        db.session.merge(self)
        db.session.commit()

    @classmethod
    def import_obj(cls, i_datasource, import_time=None) -> int:
        """Imports the datasource from the object to the database.

         Metrics and columns and datasource will be overrided if exists.
         This function can be used to import/export dashboards between multiple
         superset instances. Audit metadata isn't copies over.
        """
        def lookup_sqlatable(table):
            return (db.session.query(SqlaTable).join(Database).filter(
                SqlaTable.table_name == table.table_name,
                SqlaTable.schema == table.schema,
                Database.id == table.database_id,
            ).first())

        def lookup_database(table):
            try:
                return (db.session.query(Database).filter_by(
                    database_name=table.params_dict["database_name"]).one())
            except NoResultFound:
                raise DatabaseNotFound(
                    _(
                        "Database '%(name)s' is not found",
                        name=table.params_dict["database_name"],
                    ))

        return import_datasource.import_datasource(db.session, i_datasource,
                                                   lookup_database,
                                                   lookup_sqlatable,
                                                   import_time)

    @classmethod
    def query_datasources_by_name(cls,
                                  session: Session,
                                  database: Database,
                                  datasource_name: str,
                                  schema=None) -> List["SqlaTable"]:
        query = (session.query(cls).filter_by(
            database_id=database.id).filter_by(table_name=datasource_name))
        if schema:
            query = query.filter_by(schema=schema)
        return query.all()

    @staticmethod
    def default_query(qry) -> Query:
        return qry.filter_by(is_sqllab_view=False)

    def has_extra_cache_keys(self, query_obj: Dict) -> bool:
        """
        Detects the presence of calls to cache_key_wrapper in items in query_obj that can
        be templated.

        :param query_obj: query object to analyze
        :return: True if at least one item calls cache_key_wrapper, otherwise False
        """
        regex = re.compile(r"\{\{.*cache_key_wrapper\(.*\).*\}\}")
        templatable_statements: List[str] = []
        if self.sql:
            templatable_statements.append(self.sql)
        if self.fetch_values_predicate:
            templatable_statements.append(self.fetch_values_predicate)
        extras = query_obj.get("extras", {})
        if "where" in extras:
            templatable_statements.append(extras["where"])
        if "having" in extras:
            templatable_statements.append(extras["having"])
        for statement in templatable_statements:
            if regex.search(statement):
                return True
        return False

    def get_extra_cache_keys(self, query_obj: Dict) -> List[Any]:
        if self.has_extra_cache_keys(query_obj):
            sqla_query = self.get_sqla_query(**query_obj)
            extra_cache_keys = sqla_query.extra_cache_keys
            return extra_cache_keys
        return []
Ejemplo n.º 27
0
class ImapFolderSyncStatus(MailSyncBase):
    """ Per-folder status state saving for IMAP folders. """
    account_id = Column(ForeignKey(ImapAccount.id, ondelete='CASCADE'),
                        nullable=False)
    account = relationship(ImapAccount, backref=backref('foldersyncstatuses'))

    folder_id = Column(Integer,
                       ForeignKey('folder.id', ondelete='CASCADE'),
                       nullable=False)
    # We almost always need the folder name too, so eager load by default.
    folder = relationship('Folder',
                          lazy='joined',
                          backref=backref('imapsyncstatus',
                                          passive_deletes=True))

    # see state machine in mailsync/backends/imap/imap.py
    state = Column(Enum('initial', 'initial uidinvalid', 'poll',
                        'poll uidinvalid', 'finish'),
                   server_default='initial',
                   nullable=False)

    # stats on messages downloaded etc.
    _metrics = Column(MutableDict.as_mutable(JSON), default={}, nullable=True)

    @property
    def metrics(self):
        status = dict(name=self.folder.name, state=self.state)
        status.update(self._metrics or {})

        return status

    def start_sync(self):
        self._metrics = dict(run_state='running',
                             sync_start_time=datetime.utcnow())

    def stop_sync(self):
        self._metrics['run_state'] = 'stopped'
        self._metrics['sync_end_time'] = datetime.utcnow()

    def kill_sync(self, error=None):
        self._metrics['run_state'] = 'killed'
        self._metrics['sync_end_time'] = datetime.utcnow()
        self._metrics['sync_error'] = error

    @property
    def is_killed(self):
        return self._metrics.get('run_state') == 'killed'

    def update_metrics(self, metrics):
        sync_status_metrics = [
            'remote_uid_count', 'delete_uid_count', 'update_uid_count',
            'download_uid_count', 'uid_checked_timestamp',
            'num_downloaded_since_timestamp', 'queue_checked_at', 'percent'
        ]

        assert isinstance(metrics, dict)
        for k in metrics.iterkeys():
            assert k in sync_status_metrics, k

        if self._metrics is not None:
            self._metrics.update(metrics)
        else:
            self._metrics = metrics

    __table_args__ = (UniqueConstraint('account_id', 'folder_id'), )
Ejemplo n.º 28
0
class TableColumn(Model, BaseColumn):
    """ORM object for table columns, each table can have multiple columns"""

    __tablename__ = 'table_columns'
    __table_args__ = (UniqueConstraint('table_id', 'column_name'), )
    table_id = Column(Integer, ForeignKey('tables.id'))
    table = relationship('SqlaTable',
                         backref=backref('columns',
                                         cascade='all, delete-orphan'),
                         foreign_keys=[table_id])
    is_dttm = Column(Boolean, default=False)
    expression = Column(Text, default='')
    python_date_format = Column(String(255))
    database_expression = Column(String(255))

    export_fields = (
        'table_id',
        'column_name',
        'verbose_name',
        'is_dttm',
        'is_active',
        'type',
        'groupby',
        'filterable',
        'expression',
        'description',
        'python_date_format',
        'database_expression',
    )

    update_from_object_fields = [
        s for s in export_fields if s not in ('table_id', )
    ]
    export_parent = 'table'

    def get_sqla_col(self, label=None):
        label = label if label else self.column_name
        label = self.table.get_label(label)
        if not self.expression:
            db_engine_spec = self.table.database.db_engine_spec
            type_ = db_engine_spec.get_sqla_column_type(self.type)
            col = column(self.column_name, type_=type_).label(label)
        else:
            col = literal_column(self.expression).label(label)
        return col

    @property
    def datasource(self):
        return self.table

    def get_time_filter(self, start_dttm, end_dttm):
        is_epoch_in_utc = config.get('IS_EPOCH_S_TRULY_UTC', False)
        col = self.get_sqla_col(label='__time')
        l = []  # noqa: E741
        if start_dttm:
            l.append(
                col >= text(self.dttm_sql_literal(start_dttm, is_epoch_in_utc))
            )
        if end_dttm:
            l.append(
                col <= text(self.dttm_sql_literal(end_dttm, is_epoch_in_utc)))
        return and_(*l)

    def get_timestamp_expression(self, time_grain):
        """Getting the time component of the query"""
        label = self.table.get_label(utils.DTTM_ALIAS)

        db = self.table.database
        pdf = self.python_date_format
        is_epoch = pdf in ('epoch_s', 'epoch_ms')
        if not self.expression and not time_grain and not is_epoch:
            return column(self.column_name, type_=DateTime).label(label)
        grain = None
        if time_grain:
            grain = db.grains_dict().get(time_grain)
            if not grain:
                raise NotImplementedError(
                    f'No grain spec for {time_grain} for database {db.database_name}'
                )
        expr = db.db_engine_spec.get_time_expr(
            self.expression or self.column_name, pdf, time_grain, grain)
        return literal_column(expr, type_=DateTime).label(label)

    @classmethod
    def import_obj(cls, i_column):
        def lookup_obj(lookup_column):
            return db.session.query(TableColumn).filter(
                TableColumn.table_id == lookup_column.table_id,
                TableColumn.column_name == lookup_column.column_name).first()

        return import_datasource.import_simple_obj(db.session, i_column,
                                                   lookup_obj)

    def dttm_sql_literal(self, dttm, is_epoch_in_utc):
        """Convert datetime object to a SQL expression string

        If database_expression is empty, the internal dttm
        will be parsed as the string with the pattern that
        the user inputted (python_date_format)
        If database_expression is not empty, the internal dttm
        will be parsed as the sql sentence for the database to convert
        """
        tf = self.python_date_format
        if self.database_expression:
            return self.database_expression.format(
                dttm.strftime('%Y-%m-%d %H:%M:%S'))
        elif tf:
            if is_epoch_in_utc:
                seconds_since_epoch = dttm.timestamp()
            else:
                seconds_since_epoch = (dttm -
                                       datetime(1970, 1, 1)).total_seconds()
            seconds_since_epoch = int(seconds_since_epoch)
            if tf == 'epoch_s':
                return str(seconds_since_epoch)
            elif tf == 'epoch_ms':
                return str(seconds_since_epoch * 1000)
            return "'{}'".format(dttm.strftime(tf))
        else:
            s = self.table.database.db_engine_spec.convert_dttm(
                self.type or '', dttm)
            return s or "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S.%f'))
Ejemplo n.º 29
0
class IBMIKEPolicy(db.Model):
    ID_KEY = "id"
    NAME_KEY = "name"
    DH_GROUP_KEY = "dh_group"
    ENCRYPTION_ALGORITHM_KEY = "encryption_algorithm"
    AUTHENTICATION_ALGORITHM_KEY = "authentication_algorithm"
    KEY_LIFETIME_KEY = "key_lifetime"
    IKE_VERSION_KEY = "ike_version"
    RESOURCE_GROUP_KEY = "resource_group"
    STATUS_KEY = "status"
    REGION_KEY = "region"
    CLOUD_ID_KEY = "cloud_id"
    MESSAGE_KEY = "message"

    __tablename__ = "ibm_ike_policy"

    id = Column(String(32), primary_key=True)
    resource_id = Column(String(64))
    name = Column(String(255), nullable=False)
    region = Column(String(255), nullable=False)
    status = Column(String(50))
    ike_version = Column(Integer, default=2)
    key_lifetime = Column(Integer, default=28800)
    authentication_algorithm = Column(Enum("md5", "sha1", "sha256"),
                                      default="sha1")
    encryption_algorithm = Column(Enum("triple_des", "aes128", "aes256"),
                                  default="aes128")
    dh_group = Column(Integer, default=2)
    cloud_id = Column(String(32), ForeignKey("ibm_clouds.id"))
    resource_group_id = Column(String(32),
                               ForeignKey("ibm_resource_groups.id"))
    vpn_connections = relationship("IBMVpnConnection",
                                   backref="ibm_ike_policy",
                                   lazy="dynamic")
    __table_args__ = (UniqueConstraint(
        name, region, cloud_id,
        name="uix_ibm_ike_policy_name_region_cloud_id"), )

    def __init__(
        self,
        name,
        region,
        key_lifetime=None,
        status=CREATION_PENDING,
        ike_version=None,
        authentication_algorithm=None,
        encryption_algorithm=None,
        dh_group=None,
        resource_id=None,
        cloud_id=None,
    ):
        self.id = uuid.uuid4().hex
        self.name = name
        self.region = region
        self.status = status
        self.ike_version = ike_version
        self.authentication_algorithm = authentication_algorithm
        self.encryption_algorithm = encryption_algorithm
        self.key_lifetime = key_lifetime
        self.dh_group = dh_group
        self.resource_id = resource_id
        self.cloud_id = cloud_id

    def make_copy(self):
        obj = IBMIKEPolicy(
            name=self.name,
            region=self.region,
            key_lifetime=self.key_lifetime,
            status=self.status,
            ike_version=self.ike_version,
            authentication_algorithm=self.authentication_algorithm,
            encryption_algorithm=self.encryption_algorithm,
            dh_group=self.dh_group,
            resource_id=self.resource_id,
            cloud_id=self.cloud_id,
        )
        if self.ibm_resource_group:
            obj.ibm_resource_group = self.ibm_resource_group.make_copy()
        return obj

    def get_existing_from_db(self):
        return (db.session.query(
            self.__class__).filter_by(name=self.name,
                                      cloud_id=self.cloud_id,
                                      region=self.region).first())

    def params_eq(self, other):
        if not type(self) == type(other):
            return False
        if not ((self.name == other.name) and (self.region == other.region) and
                (self.resource_id == other.resource_id) and
                (self.ike_version == other.ike_version) and
                (self.key_lifetime == other.key_lifetime) and
                (self.authentication_algorithm
                 == other.authentication_algorithm) and
                (self.encryption_algorithm == other.encryption_algorithm) and
                (self.dh_group == other.dh_group) and
                (self.status == other.status)):
            return False
        if (self.ibm_resource_group and not other.ibm_resource_group) or (
                not self.ibm_resource_group and other.ibm_resource_group):
            return False
        if self.ibm_resource_group and other.ibm_resource_group:
            if not self.ibm_resource_group.params_eq(other.ibm_resource_group):
                return False
        return True

    def add_update_db(self):
        existing = self.get_existing_from_db()
        if not existing:
            ibm_resource_group = self.ibm_resource_group
            self.ibm_resource_group = None
            db.session.add(self)
            db.session.commit()
            if ibm_resource_group:
                self.ibm_resource_group = ibm_resource_group.add_update_db()
                db.session.commit()
            return self
        if not self.params_eq(existing):
            existing.name = self.name
            existing.region = self.region
            existing.status = self.status
            existing.resource_id = self.resource_id
            existing.ike_version = self.ike_version
            existing.key_lifetime = self.key_lifetime
            existing.authentication_algorithm = self.authentication_algorithm
            existing.encryption_algorithm = self.encryption_algorithm
            existing.dh_group = self.dh_group
            db.session.commit()
            ibm_resource_group = self.ibm_resource_group
            self.ibm_resource_group = None
            if ibm_resource_group:
                existing.ibm_resource_group = ibm_resource_group.add_update_db(
                )
            else:
                existing.ibm_resource_group = None
            db.session.commit()
        return existing

    def to_json(self):
        return {
            self.ID_KEY:
            self.id,
            self.NAME_KEY:
            self.name,
            self.REGION_KEY:
            self.region,
            self.AUTHENTICATION_ALGORITHM_KEY:
            self.authentication_algorithm,
            self.ENCRYPTION_ALGORITHM_KEY:
            self.encryption_algorithm,
            self.KEY_LIFETIME_KEY:
            self.key_lifetime,
            self.IKE_VERSION_KEY:
            self.ike_version,
            self.DH_GROUP_KEY:
            self.dh_group,
            self.STATUS_KEY:
            self.status,
            self.RESOURCE_GROUP_KEY:
            self.ibm_resource_group.to_json()
            if self.ibm_resource_group else "",
            self.CLOUD_ID_KEY:
            self.cloud_id,
        }

    def to_json_body(self):
        return {
            "name": self.name,
            "authentication_algorithm": self.authentication_algorithm,
            "encryption_algorithm": self.encryption_algorithm,
            "key_lifetime": self.key_lifetime,
            "ike_version": self.ike_version,
            "dh_group": self.dh_group,
            "resource_group": {
                "id":
                self.ibm_resource_group.resource_id
                if self.ibm_resource_group else ""
            },
        }

    def to_report_json(self):
        return {
            self.ID_KEY: self.id,
            self.NAME_KEY: self.name,
            self.STATUS_KEY: PENDING,
            self.MESSAGE_KEY: ""
        }

    @classmethod
    def from_ibm_json_body(cls, region, json_body):
        # TODO: Verify Schema
        ibm_ike_policy = IBMIKEPolicy(
            name=json_body["name"],
            region=region,
            key_lifetime=json_body["key_lifetime"],
            status="CREATED",
            ike_version=json_body["ike_version"],
            authentication_algorithm=json_body["authentication_algorithm"],
            encryption_algorithm=json_body["encryption_algorithm"],
            dh_group=json_body["dh_group"],
            resource_id=json_body["id"],
        )

        return ibm_ike_policy
Ejemplo n.º 30
0
class UserTestResult(Base):
    """Class to store the execution results of a user_test.

    """
    __tablename__ = 'user_test_results'
    __table_args__ = (UniqueConstraint('user_test_id', 'dataset_id'), )

    # Primary key is (user_test_id, dataset_id).
    user_test_id = Column(Integer,
                          ForeignKey(UserTest.id,
                                     onupdate="CASCADE",
                                     ondelete="CASCADE"),
                          primary_key=True)
    user_test = relationship(UserTest,
                             backref=backref("results",
                                             cascade="all, delete-orphan",
                                             passive_deletes=True))

    dataset_id = Column(Integer,
                        ForeignKey(Dataset.id,
                                   onupdate="CASCADE",
                                   ondelete="CASCADE"),
                        primary_key=True)
    dataset = relationship(Dataset)

    # Now below follow the actual result fields.

    # Output file's digest for this test
    output = Column(String, nullable=True)

    # Compilation outcome (can be None = yet to compile, "ok" =
    # compilation successful and we can evaluate, "fail" =
    # compilation unsuccessful, throw it away).
    compilation_outcome = Column(String, nullable=True)

    # String containing output from the sandbox.
    compilation_text = Column(String, nullable=True)

    # Number of attempts of compilation.
    compilation_tries = Column(Integer, nullable=False, default=0)

    # The compiler stdout and stderr.
    compilation_stdout = Column(Unicode, nullable=True)
    compilation_stderr = Column(Unicode, nullable=True)

    # Other information about the compilation.
    compilation_time = Column(Float, nullable=True)
    compilation_wall_clock_time = Column(Float, nullable=True)
    compilation_memory = Column(Integer, nullable=True)

    # Worker shard and sandbox where the compilation was performed.
    compilation_shard = Column(Integer, nullable=True)
    compilation_sandbox = Column(String, nullable=True)

    # Evaluation outcome (can be None = yet to evaluate, "ok" =
    # evaluation successful).
    evaluation_outcome = Column(String, nullable=True)
    evaluation_text = Column(String, nullable=True)

    # Number of attempts of evaluation.
    evaluation_tries = Column(Integer, nullable=False, default=0)

    # Other information about the execution.
    execution_time = Column(Float, nullable=True)
    execution_wall_clock_time = Column(Float, nullable=True)
    execution_memory = Column(Integer, nullable=True)

    # Worker shard and sandbox where the evaluation was performed.
    evaluation_shard = Column(Integer, nullable=True)
    evaluation_sandbox = Column(String, nullable=True)

    # Follows the description of the fields automatically added by
    # SQLAlchemy.
    # executables (dict of UserTestExecutable objects indexed by filename)

    def compiled(self):
        """Return whether the user test result has been compiled.

        return (bool): True if compiled, False otherwise.

        """
        return self.compilation_outcome is not None

    def compilation_failed(self):
        """Return whether the user test result did not compile.

        return (bool): True if the compilation failed (in the sense
            that there is a problem in the user's source), False if
            not yet compiled or compilation was successful.

        """
        return self.compilation_outcome == "fail"

    def compilation_succeeded(self):
        """Return whether the user test compiled.

        return (bool): True if the compilation succeeded (in the sense
            that an executable was created), False if not yet compiled
            or compilation was unsuccessful.

        """
        return self.compilation_outcome == "ok"

    def evaluated(self):
        """Return whether the user test result has been evaluated.

        return (bool): True if evaluated, False otherwise.

        """
        return self.evaluation_outcome is not None

    def invalidate_compilation(self):
        """Blank all compilation and evaluation outcomes.

        """
        self.invalidate_evaluation()
        self.compilation_outcome = None
        self.compilation_text = None
        self.compilation_tries = 0
        self.compilation_time = None
        self.compilation_wall_clock_time = None
        self.compilation_memory = None
        self.compilation_shard = None
        self.compilation_sandbox = None
        self.executables = {}

    def invalidate_evaluation(self):
        """Blank the evaluation outcome.

        """
        self.evaluation_outcome = None
        self.evaluation_text = None
        self.evaluation_tries = 0
        self.execution_time = None
        self.execution_wall_clock_time = None
        self.execution_memory = None
        self.evaluation_shard = None
        self.evaluation_sandbox = None
        self.output = None

    def set_compilation_outcome(self, success):
        """Set the compilation outcome based on the success.

        success (bool): if the compilation was successful.

        """
        self.compilation_outcome = "ok" if success else "fail"

    def set_evaluation_outcome(self):
        """Set the evaluation outcome (always ok now).

        """
        self.evaluation_outcome = "ok"