def _sqlalchemy(
        cls,
        execution_engine: SqlAlchemyExecutionEngine,
        metric_domain_kwargs: Dict,
        metric_value_kwargs: Dict,
        metrics: Dict[str, Any],
        runtime_configuration: Dict,
    ):
        (
            selectable,
            compute_domain_kwargs,
            accessor_domain_kwargs,
        ) = execution_engine.get_compute_domain(metric_domain_kwargs,
                                                MetricDomainTypes.COLUMN)
        column_name = accessor_domain_kwargs["column"]
        column = sa.column(column_name)
        sqlalchemy_engine = execution_engine.engine
        """SqlAlchemy Median Implementation"""
        nonnull_count = metrics.get("column_values.nonnull.count")
        if not nonnull_count:
            return None
        element_values = sqlalchemy_engine.execute(
            sa.select([column]).order_by(column).where(column != None).offset(
                max(nonnull_count // 2 - 1,
                    0)).limit(2).select_from(selectable))

        column_values = list(element_values.fetchall())

        if len(column_values) == 0:
            column_median = None
        elif nonnull_count % 2 == 0:
            # An even number of column values: take the average of the two center values
            column_median = (
                float(column_values[0][0] + column_values[1][
                    0]  # left center value  # right center value
                      ) / 2.0)  # Average center values
        else:
            # An odd number of column values, we can just take the center value
            column_median = column_values[1][0]  # True center value
        return column_median
Esempio n. 2
0
def test_get_table_metric_provider_metric_dependencies(empty_sqlite_db):
    mp = ColumnMax()
    metric = MetricConfiguration("column.max", dict(), dict())
    dependencies = mp.get_evaluation_dependencies(
        metric,
        execution_engine=SqlAlchemyExecutionEngine(engine=empty_sqlite_db))
    assert dependencies["metric_partial_fn"].id[0] == "column.max.aggregate_fn"

    mp = ColumnMax()
    metric = MetricConfiguration("column.max", dict(), dict())
    dependencies = mp.get_evaluation_dependencies(
        metric, execution_engine=PandasExecutionEngine())

    table_columns_metric: MetricConfiguration = dependencies["table.columns"]
    assert dependencies == {
        "table.columns": table_columns_metric,
    }
    assert dependencies["table.columns"].id == (
        "table.columns",
        (),
        (),
    )
    def _sqlalchemy(
        cls,
        execution_engine: SqlAlchemyExecutionEngine,
        metric_domain_kwargs: dict,
        metric_value_kwargs: dict,
        metrics: Dict[str, Any],
        runtime_configuration: dict,
    ) -> List[sqlalchemy_engine_Row]:
        query: Optional[str] = metric_value_kwargs.get(
            "query"
        ) or cls.default_kwarg_values.get("query")

        selectable: Union[sa.sql.Selectable, str]
        selectable, _, _ = execution_engine.get_compute_domain(
            metric_domain_kwargs, domain_type=MetricDomainTypes.TABLE
        )

        column: str = metric_value_kwargs.get("column")
        if isinstance(selectable, sa.Table):
            query = query.format(col=column, active_batch=selectable)
        elif isinstance(
            selectable, sa.sql.Subquery
        ):  # Specifying a runtime query in a RuntimeBatchRequest returns the active bacth as a Subquery; sectioning the active batch off w/ parentheses ensures flow of operations doesn't break
            query = query.format(col=column, active_batch=f"({selectable})")
        elif isinstance(
            selectable, sa.sql.Select
        ):  # Specifying a row_condition returns the active batch as a Select object, requiring compilation & aliasing when formatting the parameterized query
            query = query.format(
                col=column,
                active_batch=f'({selectable.compile(compile_kwargs={"literal_binds": True})}) AS subselect',
            )
        else:
            query = query.format(col=column, active_batch=f"({selectable})")

        engine: sqlalchemy_engine_Engine = execution_engine.engine
        result: List[sqlalchemy_engine_Row] = engine.execute(sa.text(query)).fetchall()

        return result
Esempio n. 4
0
    def _sqlalchemy(
        cls,
        execution_engine: SqlAlchemyExecutionEngine,
        metric_domain_kwargs,
        metric_value_kwargs,
        metrics,
        runtime_configuration,
    ):
        (
            selectable,
            compute_domain_kwargs,
            accessor_domain_kwargs,
        ) = execution_engine.get_compute_domain(metric_domain_kwargs,
                                                MetricDomainTypes.COLUMN)

        column_name = accessor_domain_kwargs["column"]
        column = sa.column(column_name)
        sqlalchemy_engine = execution_engine.engine

        query = sa.select(sa.func.max(column)).select_from(selectable)
        result = sqlalchemy_engine.execute(query).fetchone()

        return result[0]
Esempio n. 5
0
def test_sample_using_random(sqlite_view_engine, test_df):
    my_execution_engine: SqlAlchemyExecutionEngine = SqlAlchemyExecutionEngine(
        engine=sqlite_view_engine
    )

    p: float
    batch_spec: SqlAlchemyDatasourceBatchSpec
    batch_data: SqlAlchemyBatchData
    num_rows: int
    rows_0: List[tuple]
    rows_1: List[tuple]

    # First, make sure that degenerative case never passes.

    test_df_0: pd.DataFrame = test_df.iloc[:1]
    test_df_0.to_sql("test_table_0", con=my_execution_engine.engine)

    p = 1.0
    batch_spec = SqlAlchemyDatasourceBatchSpec(
        table_name="test_table_0",
        schema_name="main",
        sampling_method="_sample_using_random",
        sampling_kwargs={"p": p},
    )

    batch_data = my_execution_engine.get_batch_data(batch_spec=batch_spec)
    num_rows = batch_data.execution_engine.engine.execute(
        sqlalchemy.select([sqlalchemy.func.count()]).select_from(batch_data.selectable)
    ).scalar()
    assert num_rows == round(p * test_df_0.shape[0])

    rows_0: List[tuple] = batch_data.execution_engine.engine.execute(
        sqlalchemy.select([sqlalchemy.text("*")]).select_from(batch_data.selectable)
    ).fetchall()

    batch_data = my_execution_engine.get_batch_data(batch_spec=batch_spec)
    num_rows = batch_data.execution_engine.engine.execute(
        sqlalchemy.select([sqlalchemy.func.count()]).select_from(batch_data.selectable)
    ).scalar()
    assert num_rows == round(p * test_df_0.shape[0])

    rows_1: List[tuple] = batch_data.execution_engine.engine.execute(
        sqlalchemy.select([sqlalchemy.text("*")]).select_from(batch_data.selectable)
    ).fetchall()

    assert len(rows_0) == len(rows_1) == 1

    assert rows_0 == rows_1

    # Second, verify that realistic case always returns different random sample of rows.

    test_df_1: pd.DataFrame = test_df
    test_df_1.to_sql("test_table_1", con=my_execution_engine.engine)

    p = 2.0e-1
    batch_spec = SqlAlchemyDatasourceBatchSpec(
        table_name="test_table_1",
        schema_name="main",
        sampling_method="_sample_using_random",
        sampling_kwargs={"p": p},
    )

    batch_data = my_execution_engine.get_batch_data(batch_spec=batch_spec)
    num_rows = batch_data.execution_engine.engine.execute(
        sqlalchemy.select([sqlalchemy.func.count()]).select_from(batch_data.selectable)
    ).scalar()
    assert num_rows == round(p * test_df_1.shape[0])

    rows_0 = batch_data.execution_engine.engine.execute(
        sqlalchemy.select([sqlalchemy.text("*")]).select_from(batch_data.selectable)
    ).fetchall()

    batch_data = my_execution_engine.get_batch_data(batch_spec=batch_spec)
    num_rows = batch_data.execution_engine.engine.execute(
        sqlalchemy.select([sqlalchemy.func.count()]).select_from(batch_data.selectable)
    ).scalar()
    assert num_rows == round(p * test_df_1.shape[0])

    rows_1 = batch_data.execution_engine.engine.execute(
        sqlalchemy.select([sqlalchemy.text("*")]).select_from(batch_data.selectable)
    ).fetchall()

    assert len(rows_0) == len(rows_1)

    assert not (rows_0 == rows_1)
def test_instantiation_with_and_without_temp_table(sqlite_view_engine, sa):
    print(get_sqlite_temp_table_names(sqlite_view_engine))
    assert len(get_sqlite_temp_table_names(sqlite_view_engine)) == 1
    assert get_sqlite_temp_table_names(sqlite_view_engine) == {
        "test_temp_view"
    }

    execution_engine: SqlAlchemyExecutionEngine = SqlAlchemyExecutionEngine(
        engine=sqlite_view_engine)
    # When the SqlAlchemyBatchData object is based on a table, a new temp table is NOT created, even if create_temp_table=True
    SqlAlchemyBatchData(
        execution_engine=execution_engine,
        table_name="test_table",
        create_temp_table=True,
    )
    assert len(get_sqlite_temp_table_names(sqlite_view_engine)) == 1

    selectable = sa.select("*").select_from(sa.text("main.test_table"))

    # If create_temp_table=False, a new temp table should NOT be created
    SqlAlchemyBatchData(
        execution_engine=execution_engine,
        selectable=selectable,
        create_temp_table=False,
    )
    assert len(get_sqlite_temp_table_names(sqlite_view_engine)) == 1

    # If create_temp_table=True, a new temp table should be created
    SqlAlchemyBatchData(
        execution_engine=execution_engine,
        selectable=selectable,
        create_temp_table=True,
    )
    assert len(get_sqlite_temp_table_names(sqlite_view_engine)) == 2

    # If create_temp_table=True, a new temp table should be created
    SqlAlchemyBatchData(
        execution_engine=execution_engine,
        selectable=selectable,
        # create_temp_table defaults to True
    )
    assert len(get_sqlite_temp_table_names(sqlite_view_engine)) == 3

    # testing whether schema is supported
    selectable = sa.select("*").select_from(
        sa.table(name="test_table", schema="main"))
    SqlAlchemyBatchData(
        execution_engine=execution_engine,
        selectable=selectable,
        # create_temp_table defaults to True
    )
    assert len(get_sqlite_temp_table_names(sqlite_view_engine)) == 4

    # test schema with execution engine
    # TODO : Will20210222 Add tests for specifying schema with non-sqlite backend that actually supports new schema creation
    my_batch_spec = SqlAlchemyDatasourceBatchSpec(
        **{
            "table_name": "test_table",
            "batch_identifiers": {},
            "schema_name": "main",
        })
    res = execution_engine.get_batch_data_and_markers(batch_spec=my_batch_spec)
    assert len(res) == 2
Esempio n. 7
0
def get_sqlalchemy_runtime_validator_postgresql(df,
                                                schemas=None,
                                                caching=True,
                                                table_name=None):
    sa_engine_name = "postgresql"
    db_hostname = os.getenv("GE_TEST_LOCAL_DB_HOSTNAME", "localhost")
    try:
        engine = connection_manager.get_engine(
            f"postgresql://postgres@{db_hostname}/test_ci")
    except sqlalchemy.exc.OperationalError:
        return None

    sql_dtypes = {}

    if (schemas and sa_engine_name in schemas
            and isinstance(engine.dialect, postgresqltypes.dialect)):
        schema = schemas[sa_engine_name]
        sql_dtypes = {
            col: POSTGRESQL_TYPES[dtype]
            for (col, dtype) in schema.items()
        }

        for col in schema:
            type_ = schema[col]
            if type_ in ["INTEGER", "SMALLINT", "BIGINT"]:
                df[col] = pd.to_numeric(df[col], downcast="signed")
            elif type_ in ["FLOAT", "DOUBLE", "DOUBLE_PRECISION"]:
                df[col] = pd.to_numeric(df[col])
                min_value_dbms = get_sql_dialect_floating_point_infinity_value(
                    schema=sa_engine_name, negative=True)
                max_value_dbms = get_sql_dialect_floating_point_infinity_value(
                    schema=sa_engine_name, negative=False)
                for api_schema_type in ["api_np", "api_cast"]:
                    min_value_api = get_sql_dialect_floating_point_infinity_value(
                        schema=api_schema_type, negative=True)
                    max_value_api = get_sql_dialect_floating_point_infinity_value(
                        schema=api_schema_type, negative=False)
                    df.replace(
                        to_replace=[min_value_api, max_value_api],
                        value=[min_value_dbms, max_value_dbms],
                        inplace=True,
                    )
            elif type_ in ["DATETIME", "TIMESTAMP"]:
                df[col] = pd.to_datetime(df[col])

    if table_name is None:
        table_name = "test_data_" + "".join([
            random.choice(string.ascii_letters + string.digits)
            for _ in range(8)
        ])
    df.to_sql(
        name=table_name,
        con=engine,
        index=False,
        dtype=sql_dtypes,
        if_exists="replace",
    )
    batch_data = SqlAlchemyBatchData(execution_engine=engine,
                                     table_name=table_name)
    batch = Batch(data=batch_data)
    execution_engine = SqlAlchemyExecutionEngine(caching=caching,
                                                 engine=engine)
    batch_data = SqlAlchemyBatchData(execution_engine=execution_engine,
                                     table_name=table_name)
    batch = Batch(data=batch_data)

    return Validator(execution_engine=execution_engine, batches=(batch, ))
    def _sqlalchemy(
        cls,
        execution_engine: SqlAlchemyExecutionEngine,
        metric_domain_kwargs: Dict,
        metric_value_kwargs: Dict,
        metrics: Dict[str, Any],
        runtime_configuration: Dict,
    ):
        min_value = metric_value_kwargs.get("min_value")
        max_value = metric_value_kwargs.get("max_value")
        strict_min = metric_value_kwargs.get("strict_min")
        strict_max = metric_value_kwargs.get("strict_max")
        if min_value is not None and max_value is not None and min_value > max_value:
            raise ValueError("min_value cannot be greater than max_value")

        if min_value is None and max_value is None:
            raise ValueError("min_value and max_value cannot both be None")
        dialect_name = execution_engine.engine.dialect.name.lower()

        if (min_value == get_sql_dialect_floating_point_infinity_value(
                schema="api_np", negative=True)) or (
                    min_value == get_sql_dialect_floating_point_infinity_value(
                        schema="api_cast", negative=True)):
            min_value = get_sql_dialect_floating_point_infinity_value(
                schema=dialect_name, negative=True)

        if (min_value == get_sql_dialect_floating_point_infinity_value(
                schema="api_np", negative=False)) or (
                    min_value == get_sql_dialect_floating_point_infinity_value(
                        schema="api_cast", negative=False)):
            min_value = get_sql_dialect_floating_point_infinity_value(
                schema=dialect_name, negative=False)

        if (max_value == get_sql_dialect_floating_point_infinity_value(
                schema="api_np", negative=True)) or (
                    max_value == get_sql_dialect_floating_point_infinity_value(
                        schema="api_cast", negative=True)):
            max_value = get_sql_dialect_floating_point_infinity_value(
                schema=dialect_name, negative=True)

        if (max_value == get_sql_dialect_floating_point_infinity_value(
                schema="api_np", negative=False)) or (
                    max_value == get_sql_dialect_floating_point_infinity_value(
                        schema="api_cast", negative=False)):
            max_value = get_sql_dialect_floating_point_infinity_value(
                schema=dialect_name, negative=False)

        (
            selectable,
            compute_domain_kwargs,
            accessor_domain_kwargs,
        ) = execution_engine.get_compute_domain(
            domain_kwargs=metric_domain_kwargs,
            domain_type=MetricDomainTypes.COLUMN)
        column = sa.column(accessor_domain_kwargs["column"])

        if min_value is None:
            if strict_max:
                condition = column < max_value
            else:
                condition = column <= max_value

        elif max_value is None:
            if strict_min:
                condition = column > min_value
            else:
                condition = column >= min_value

        else:
            if strict_min and strict_max:
                condition = sa.and_(column > min_value, column < max_value)
            elif strict_min:
                condition = sa.and_(column > min_value, column <= max_value)
            elif strict_max:
                condition = sa.and_(column >= min_value, column < max_value)
            else:
                condition = sa.and_(column >= min_value, column <= max_value)

        return execution_engine.engine.execute(
            sa.select([sa.func.count()
                       ]).select_from(selectable).where(condition)).scalar()
    def _sqlalchemy(
        cls,
        execution_engine: SqlAlchemyExecutionEngine,
        metric_domain_kwargs: Dict,
        metric_value_kwargs: Dict,
        metrics: Dict[str, Any],
        runtime_configuration: Dict,
    ):
        """return a list of counts corresponding to bins

        Args:
            column: the name of the column for which to get the histogram
            bins: tuple of bin edges for which to get histogram values; *must* be tuple to support caching
        """
        selectable, _, accessor_domain_kwargs = execution_engine.get_compute_domain(
            domain_kwargs=metric_domain_kwargs,
            domain_type=MetricDomainTypes.COLUMN)
        column = accessor_domain_kwargs["column"]
        bins = metric_value_kwargs["bins"]

        case_conditions = []
        idx = 0
        if isinstance(bins, np.ndarray):
            bins = bins.tolist()
        else:
            bins = list(bins)

        # If we have an infinite lower bound, don't express that in sql
        if (bins[0] == get_sql_dialect_floating_point_infinity_value(
                schema="api_np", negative=True)) or (
                    bins[0] == get_sql_dialect_floating_point_infinity_value(
                        schema="api_cast", negative=True)):
            case_conditions.append(
                sa.func.sum(
                    sa.case([(sa.column(column) < bins[idx + 1], 1)],
                            else_=0)).label("bin_" + str(idx)))
            idx += 1

        for idx in range(idx, len(bins) - 2):
            case_conditions.append(
                sa.func.sum(
                    sa.case(
                        [(
                            sa.and_(
                                bins[idx] <= sa.column(column),
                                sa.column(column) < bins[idx + 1],
                            ),
                            1,
                        )],
                        else_=0,
                    )).label("bin_" + str(idx)))

        if (bins[-1] == get_sql_dialect_floating_point_infinity_value(
                schema="api_np", negative=False)) or (
                    bins[-1] == get_sql_dialect_floating_point_infinity_value(
                        schema="api_cast", negative=False)):
            case_conditions.append(
                sa.func.sum(
                    sa.case([(bins[-2] <= sa.column(column), 1)],
                            else_=0)).label("bin_" + str(len(bins) - 1)))
        else:
            case_conditions.append(
                sa.func.sum(
                    sa.case(
                        [(
                            sa.and_(
                                bins[-2] <= sa.column(column),
                                sa.column(column) <= bins[-1],
                            ),
                            1,
                        )],
                        else_=0,
                    )).label("bin_" + str(len(bins) - 1)))

        query = (sa.select(case_conditions).where(
            sa.column(column) != None, ).select_from(selectable))

        # Run the data through convert_to_json_serializable to ensure we do not have Decimal types
        hist = convert_to_json_serializable(
            list(execution_engine.engine.execute(query).fetchone()))
        return hist
 def _sqlalchemy(
     cls,
     execution_engine: SqlAlchemyExecutionEngine,
     metric_domain_kwargs: Dict,
     metric_value_kwargs: Dict,
     metrics: Dict[str, Any],
     runtime_configuration: Dict,
 ):
     (
         selectable,
         compute_domain_kwargs,
         accessor_domain_kwargs,
     ) = execution_engine.get_compute_domain(
         metric_domain_kwargs, domain_type=MetricDomainTypes.COLUMN
     )
     column_name = accessor_domain_kwargs["column"]
     column = sa.column(column_name)
     sqlalchemy_engine = execution_engine.engine
     dialect = sqlalchemy_engine.dialect
     quantiles = metric_value_kwargs["quantiles"]
     allow_relative_error = metric_value_kwargs.get("allow_relative_error", False)
     table_row_count = metrics.get("table.row_count")
     if dialect.name.lower() == "mssql":
         return _get_column_quantiles_mssql(
             column=column,
             quantiles=quantiles,
             selectable=selectable,
             sqlalchemy_engine=sqlalchemy_engine,
         )
     elif dialect.name.lower() == "bigquery":
         return _get_column_quantiles_bigquery(
             column=column,
             quantiles=quantiles,
             selectable=selectable,
             sqlalchemy_engine=sqlalchemy_engine,
         )
     elif dialect.name.lower() == "mysql":
         return _get_column_quantiles_mysql(
             column=column,
             quantiles=quantiles,
             selectable=selectable,
             sqlalchemy_engine=sqlalchemy_engine,
         )
     elif dialect.name.lower() == "snowflake":
         # NOTE: 20201216 - JPC - snowflake has a representation/precision limitation
         # in its percentile_disc implementation that causes an error when we do
         # not round. It is unclear to me *how* the call to round affects the behavior --
         # the binary representation should be identical before and after, and I do
         # not observe a type difference. However, the issue is replicable in the
         # snowflake console and directly observable in side-by-side comparisons with
         # and without the call to round()
         quantiles = [round(x, 10) for x in quantiles]
         return _get_column_quantiles_generic_sqlalchemy(
             column=column,
             quantiles=quantiles,
             allow_relative_error=allow_relative_error,
             dialect=dialect,
             selectable=selectable,
             sqlalchemy_engine=sqlalchemy_engine,
         )
     elif dialect.name.lower() == "sqlite":
         return _get_column_quantiles_sqlite(
             column=column,
             quantiles=quantiles,
             selectable=selectable,
             sqlalchemy_engine=sqlalchemy_engine,
             table_row_count=table_row_count,
         )
     else:
         return _get_column_quantiles_generic_sqlalchemy(
             column=column,
             quantiles=quantiles,
             allow_relative_error=allow_relative_error,
             dialect=dialect,
             selectable=selectable,
             sqlalchemy_engine=sqlalchemy_engine,
         )
Esempio n. 11
0
    def _sqlalchemy(
        cls,
        execution_engine: SqlAlchemyExecutionEngine,
        metric_domain_kwargs: Dict,
        metric_value_kwargs: Dict,
        metrics: Dict[str, Any],
        runtime_configuration: Dict,
    ):
        selectable, _, _ = execution_engine.get_compute_domain(
            metric_domain_kwargs, domain_type=MetricDomainTypes.TABLE)
        df = None
        table_name = getattr(selectable, "name", None)
        if table_name is None:
            # if a custom query was passed
            try:
                if metric_value_kwargs["fetch_all"]:
                    df = pd.read_sql_query(
                        sql=selectable,
                        con=execution_engine.engine,
                    )
                else:
                    df = next(
                        pd.read_sql_query(
                            sql=selectable,
                            con=execution_engine.engine,
                            chunksize=metric_value_kwargs["n_rows"],
                        ))
            except (ValueError, NotImplementedError):
                # it looks like MetaData that is used by pd.read_sql_query
                # cannot work on a temp table.
                # If it fails, we are trying to get the data using read_sql
                df = None
            except StopIteration:
                validator = Validator(execution_engine=execution_engine)
                columns = validator.get_metric(
                    MetricConfiguration("table.columns", metric_domain_kwargs))
                df = pd.DataFrame(columns=columns)
        else:
            try:
                if metric_value_kwargs["fetch_all"]:
                    df = pd.read_sql_table(
                        table_name=getattr(selectable, "name", None),
                        schema=getattr(selectable, "schema", None),
                        con=execution_engine.engine,
                    )
                else:
                    df = next(
                        pd.read_sql_table(
                            table_name=getattr(selectable, "name", None),
                            schema=getattr(selectable, "schema", None),
                            con=execution_engine.engine,
                            chunksize=metric_value_kwargs["n_rows"],
                        ))
            except (ValueError, NotImplementedError):
                # it looks like MetaData that is used by pd.read_sql_table
                # cannot work on a temp table.
                # If it fails, we are trying to get the data using read_sql
                df = None
            except StopIteration:
                validator = Validator(execution_engine=execution_engine)
                columns = validator.get_metric(
                    MetricConfiguration("table.columns", metric_domain_kwargs))
                df = pd.DataFrame(columns=columns)

        if df is None:
            # we want to compile our selectable
            stmt = sa.select(["*"]).select_from(selectable)
            if metric_value_kwargs["fetch_all"]:
                sql = stmt.compile(
                    dialect=execution_engine.engine.dialect,
                    compile_kwargs={"literal_binds": True},
                )
            elif execution_engine.engine.dialect.name.lower() == "mssql":
                # limit doesn't compile properly for mssql
                sql = str(
                    stmt.compile(
                        dialect=execution_engine.engine.dialect,
                        compile_kwargs={"literal_binds": True},
                    ))
                sql = f"SELECT TOP {metric_value_kwargs['n_rows']}{sql[6:]}"
            else:
                stmt = stmt.limit(metric_value_kwargs["n_rows"])
                sql = stmt.compile(
                    dialect=execution_engine.engine.dialect,
                    compile_kwargs={"literal_binds": True},
                )

            df = pd.read_sql(sql, con=execution_engine.engine)

        return df