def _get_column_quantiles_mysql( column, quantiles: Iterable, selectable, sqlalchemy_engine ) -> list: # MySQL does not support "percentile_disc", so we implement it as a compound query. # Please see https://stackoverflow.com/questions/19770026/calculate-percentile-value-using-mysql for reference. percent_rank_query: CTE = ( sa.select( [ column, sa.cast( sa.func.percent_rank().over(order_by=column.asc()), sa.dialects.mysql.DECIMAL(18, 15), ).label("p"), ] ) .order_by(sa.column("p").asc()) .select_from(selectable) .cte("t") ) selects: List[WithinGroup] = [] for idx, quantile in enumerate(quantiles): # pymysql cannot handle conversion of numpy float64 to float; convert just in case if np.issubdtype(type(quantile), np.float_): quantile = float(quantile) quantile_column: Label = ( sa.func.first_value(column) .over( order_by=sa.case( [ ( percent_rank_query.c.p <= sa.cast(quantile, sa.dialects.mysql.DECIMAL(18, 15)), percent_rank_query.c.p, ) ], else_=None, ).desc() ) .label(f"q_{idx}") ) selects.append(quantile_column) quantiles_query: Select = ( sa.select(selects).distinct().order_by(percent_rank_query.c.p.desc()) ) try: quantiles_results: Row = sqlalchemy_engine.execute(quantiles_query).fetchone() return list(quantiles_results) except ProgrammingError as pe: exception_message: str = "An SQL syntax Exception occurred." exception_traceback: str = traceback.format_exc() exception_message += ( f'{type(pe).__name__}: "{str(pe)}". Traceback: "{exception_traceback}".' ) logger.error(exception_message) raise pe
def _sqlalchemy(cls, column_A, column_B, **kwargs): value_pairs_set = kwargs.get("value_pairs_set") if value_pairs_set is None: # vacuously true return sa.case([(column_A == column_B, True)], else_=True) value_pairs_set = [(x, y) for x, y in value_pairs_set] # or_ implementation was required due to mssql issues with in_ conditions = [ sa.or_(sa.and_(column_A == x, column_B == y)) for x, y in value_pairs_set ] row_wise_cond = sa.or_(*conditions) return row_wise_cond
def _sqlalchemy( cls, execution_engine: SqlAlchemyExecutionEngine, metric_domain_kwargs: Dict, metric_value_kwargs: Dict, metrics: Dict[str, Any], runtime_configuration: Dict, ): """return a list of counts corresponding to bins Args: column: the name of the column for which to get the histogram bins: tuple of bin edges for which to get histogram values; *must* be tuple to support caching """ selectable, _, accessor_domain_kwargs = execution_engine.get_compute_domain( domain_kwargs=metric_domain_kwargs, domain_type=MetricDomainTypes.COLUMN) column = accessor_domain_kwargs["column"] bins = metric_value_kwargs["bins"] case_conditions = [] idx = 0 if isinstance(bins, np.ndarray): bins = bins.tolist() else: bins = list(bins) # If we have an infinite lower bound, don't express that in sql if (bins[0] == get_sql_dialect_floating_point_infinity_value( schema="api_np", negative=True)) or ( bins[0] == get_sql_dialect_floating_point_infinity_value( schema="api_cast", negative=True)): case_conditions.append( sa.func.sum( sa.case([(sa.column(column) < bins[idx + 1], 1)], else_=0)).label("bin_" + str(idx))) idx += 1 for idx in range(idx, len(bins) - 2): case_conditions.append( sa.func.sum( sa.case( [( sa.and_( bins[idx] <= sa.column(column), sa.column(column) < bins[idx + 1], ), 1, )], else_=0, )).label("bin_" + str(idx))) if (bins[-1] == get_sql_dialect_floating_point_infinity_value( schema="api_np", negative=False)) or ( bins[-1] == get_sql_dialect_floating_point_infinity_value( schema="api_cast", negative=False)): case_conditions.append( sa.func.sum( sa.case([(bins[-2] <= sa.column(column), 1)], else_=0)).label("bin_" + str(len(bins) - 1))) else: case_conditions.append( sa.func.sum( sa.case( [( sa.and_( bins[-2] <= sa.column(column), sa.column(column) <= bins[-1], ), 1, )], else_=0, )).label("bin_" + str(len(bins) - 1))) query = (sa.select(case_conditions).where( sa.column(column) != None, ).select_from(selectable)) # Run the data through convert_to_json_serializable to ensure we do not have Decimal types hist = convert_to_json_serializable( list(execution_engine.engine.execute(query).fetchone())) return hist