def _get_column_quantiles_mysql(
    column, quantiles: Iterable, selectable, sqlalchemy_engine
) -> list:
    # MySQL does not support "percentile_disc", so we implement it as a compound query.
    # Please see https://stackoverflow.com/questions/19770026/calculate-percentile-value-using-mysql for reference.
    percent_rank_query: CTE = (
        sa.select(
            [
                column,
                sa.cast(
                    sa.func.percent_rank().over(order_by=column.asc()),
                    sa.dialects.mysql.DECIMAL(18, 15),
                ).label("p"),
            ]
        )
        .order_by(sa.column("p").asc())
        .select_from(selectable)
        .cte("t")
    )

    selects: List[WithinGroup] = []
    for idx, quantile in enumerate(quantiles):
        # pymysql cannot handle conversion of numpy float64 to float; convert just in case
        if np.issubdtype(type(quantile), np.float_):
            quantile = float(quantile)
        quantile_column: Label = (
            sa.func.first_value(column)
            .over(
                order_by=sa.case(
                    [
                        (
                            percent_rank_query.c.p
                            <= sa.cast(quantile, sa.dialects.mysql.DECIMAL(18, 15)),
                            percent_rank_query.c.p,
                        )
                    ],
                    else_=None,
                ).desc()
            )
            .label(f"q_{idx}")
        )
        selects.append(quantile_column)
    quantiles_query: Select = (
        sa.select(selects).distinct().order_by(percent_rank_query.c.p.desc())
    )

    try:
        quantiles_results: Row = sqlalchemy_engine.execute(quantiles_query).fetchone()
        return list(quantiles_results)
    except ProgrammingError as pe:
        exception_message: str = "An SQL syntax Exception occurred."
        exception_traceback: str = traceback.format_exc()
        exception_message += (
            f'{type(pe).__name__}: "{str(pe)}".  Traceback: "{exception_traceback}".'
        )
        logger.error(exception_message)
        raise pe
Exemple #2
0
    def _sqlalchemy(cls, column_A, column_B, **kwargs):
        value_pairs_set = kwargs.get("value_pairs_set")

        if value_pairs_set is None:
            # vacuously true
            return sa.case([(column_A == column_B, True)], else_=True)

        value_pairs_set = [(x, y) for x, y in value_pairs_set]

        # or_ implementation was required due to mssql issues with in_
        conditions = [
            sa.or_(sa.and_(column_A == x, column_B == y))
            for x, y in value_pairs_set
        ]
        row_wise_cond = sa.or_(*conditions)

        return row_wise_cond
    def _sqlalchemy(
        cls,
        execution_engine: SqlAlchemyExecutionEngine,
        metric_domain_kwargs: Dict,
        metric_value_kwargs: Dict,
        metrics: Dict[str, Any],
        runtime_configuration: Dict,
    ):
        """return a list of counts corresponding to bins

        Args:
            column: the name of the column for which to get the histogram
            bins: tuple of bin edges for which to get histogram values; *must* be tuple to support caching
        """
        selectable, _, accessor_domain_kwargs = execution_engine.get_compute_domain(
            domain_kwargs=metric_domain_kwargs,
            domain_type=MetricDomainTypes.COLUMN)
        column = accessor_domain_kwargs["column"]
        bins = metric_value_kwargs["bins"]

        case_conditions = []
        idx = 0
        if isinstance(bins, np.ndarray):
            bins = bins.tolist()
        else:
            bins = list(bins)

        # If we have an infinite lower bound, don't express that in sql
        if (bins[0] == get_sql_dialect_floating_point_infinity_value(
                schema="api_np", negative=True)) or (
                    bins[0] == get_sql_dialect_floating_point_infinity_value(
                        schema="api_cast", negative=True)):
            case_conditions.append(
                sa.func.sum(
                    sa.case([(sa.column(column) < bins[idx + 1], 1)],
                            else_=0)).label("bin_" + str(idx)))
            idx += 1

        for idx in range(idx, len(bins) - 2):
            case_conditions.append(
                sa.func.sum(
                    sa.case(
                        [(
                            sa.and_(
                                bins[idx] <= sa.column(column),
                                sa.column(column) < bins[idx + 1],
                            ),
                            1,
                        )],
                        else_=0,
                    )).label("bin_" + str(idx)))

        if (bins[-1] == get_sql_dialect_floating_point_infinity_value(
                schema="api_np", negative=False)) or (
                    bins[-1] == get_sql_dialect_floating_point_infinity_value(
                        schema="api_cast", negative=False)):
            case_conditions.append(
                sa.func.sum(
                    sa.case([(bins[-2] <= sa.column(column), 1)],
                            else_=0)).label("bin_" + str(len(bins) - 1)))
        else:
            case_conditions.append(
                sa.func.sum(
                    sa.case(
                        [(
                            sa.and_(
                                bins[-2] <= sa.column(column),
                                sa.column(column) <= bins[-1],
                            ),
                            1,
                        )],
                        else_=0,
                    )).label("bin_" + str(len(bins) - 1)))

        query = (sa.select(case_conditions).where(
            sa.column(column) != None, ).select_from(selectable))

        # Run the data through convert_to_json_serializable to ensure we do not have Decimal types
        hist = convert_to_json_serializable(
            list(execution_engine.engine.execute(query).fetchone()))
        return hist