def clamp_expression(self, ne, relations, scope, do_clamp=True): """ Lookup the expression referenced in each named expression and write a clamped select for it, using the schema """ exp = ne.expression cols = exp.find_nodes(Column) if type(exp) is Column: cols += [exp] for col in cols: colname = col.name minval = None maxval = None sym = col.symbol(relations) if do_clamp and sym.valtype in ["float", "int" ] and not sym.unbounded: minval = sym.minval maxval = sym.maxval if minval is None or sym.is_key: cexpr = Column(colname) ce_name = scope.push_name(cexpr, str(colname)) else: clamped_string = "CASE WHEN {0} < {1} THEN {1} WHEN {0} > {2} THEN {2} ELSE {0} END".format( str(colname), minval, maxval) cexpr = Expression(clamped_string) ce_name = scope.push_name(cexpr, str(colname)) else: cexpr = Column(colname) ce_name = scope.push_name(cexpr, str(colname)) col.name = ce_name return ne
def exact_aggregates(self, query): child_scope = Scope() if self.options.row_privacy: keycount_expr = AggFunction("COUNT", None, AllColumns()) else: key_col = self.key_col(query) keycount_expr = AggFunction("COUNT", "DISTINCT", Column(key_col)) child_scope.push_name(keycount_expr.expression) for ne in query.select.namedExpressions: child_scope.push_name(ne.expression) keycount = NamedExpression("keycount", keycount_expr) select = Seq([keycount] + [ne for ne in query.select.namedExpressions]) select = Select(None, select) subquery = Query(child_scope.select(), query.source, query.where, query.agg, None, None, None) if self.options.reservoir_sample and not self.options.row_privacy: subquery = self.per_key_random(subquery) subquery = [AliasedRelation(subquery, "per_key_random")] filtered = Where(BooleanCompare(Column("per_key_random.row_num"), "<=", Literal(str(self.options.max_contrib), self.options.max_contrib))) return Query(select, From(subquery), filtered, query.agg, None, None, None) else: subquery = self.per_key_clamped(subquery) subquery = [AliasedRelation(subquery, "per_key_all")] return Query(select, From(subquery), None, query.agg, None, None, None)
def clamp_expression(self, ne, relations, scope, do_clamp=True): """ Lookup the expression referenced in each named expression and write a clamped select for it, using the schema """ exp = ne.expression cols = exp.find_nodes(Column) if type(exp) is Column: cols += [exp] for col in cols: colname = col.name minval = None maxval = None sym = col.symbol(relations) if do_clamp and sym.valtype in ["float", "int"] and not sym.unbounded: minval = sym.minval maxval = sym.maxval if minval is None or sym.is_key: cexpr = Column(colname) ce_name = scope.push_name(cexpr, str(colname)) else: when_min = WhenExpression( BooleanCompare(col, Op("<"), Literal(minval)), Literal(minval) ) when_max = WhenExpression( BooleanCompare(col, Op(">"), Literal(maxval)), Literal(maxval) ) cexpr = CaseExpression(None, [when_min, when_max], col) ce_name = scope.push_name(cexpr, str(colname)) else: cexpr = Column(colname) ce_name = scope.push_name(cexpr, str(colname)) col.name = ce_name return ne
def rewrite_outer_named_expression(self, ne, scope): """ look for all the agg functions and rewrite them, preserving all other portions of expression """ name = ne.name exp = ne.expression if type(exp) is Column: new_name = scope.push_name(Column(exp.name)) exp.name = new_name elif type(exp) is AggFunction: exp = self.rewrite_agg_expression(exp, scope) else: for outer_col_exp in exp.find_nodes(Column, AggFunction): new_name = scope.push_name(Column(outer_col_exp.name)) outer_col_exp.name = new_name def replace_agg_exprs(expr): for child_name, child_expr in expr.__dict__.items(): if isinstance(child_expr, Sql): replace_agg_exprs(child_expr) if isinstance(child_expr, AggFunction): expr.__dict__[child_name] = self.rewrite_agg_expression(child_expr, scope) replace_agg_exprs(exp) return NamedExpression(name, exp)
def per_key_random(self, query): key_col = self.key_col(query) select = [ NamedExpression(None, AllColumns()), NamedExpression( Identifier("row_num"), RankingFunction(FuncName("ROW_NUMBER"), OverClause( Column(key_col), Order([ SortItem(BareFunction(FuncName("RANDOM")), None) ]) ), ), ), ] select = Select(None, select) subquery = self.per_key_clamped(query) subquery = [ Relation(AliasedSubquery(subquery, Identifier("clamped" if self.options.clamp_columns else "not_clamped")), None) ] return Query(select, From(subquery), None, None, None, None, None)
def push_sum_or_count(self, exp, scope): """ Push a sum or count expression to child scope and convert to a sum """ new_name = scope.push_name(AggFunction(exp.name, exp.quantifier, exp.expression)) new_exp = Column(new_name) return new_exp
def rewrite_outer_named_expression(self, ne, scope): """ rewrite AVG, VAR, etc. and push all sum or count to child scope, preserving all other portions of expression """ name = ne.name exp = ne.expression if type(exp) is not AggFunction: outer_col_exps = exp.find_nodes(Column, AggFunction) else: outer_col_exps = [] if type(exp) is Column: outer_col_exps += [exp] for outer_col_exp in outer_col_exps: new_name = scope.push_name(Column(outer_col_exp.name)) outer_col_exp.name = new_name agg_exps = exp.find_nodes(AggFunction) if type(exp) is AggFunction: agg_exps = agg_exps + [exp] for agg_exp in agg_exps: child_agg_exps = agg_exp.find_nodes(AggFunction) if len(child_agg_exps) > 0: raise ValueError("Cannot have nested aggregate functions: " + str(agg_exp)) agg_func = agg_exp.name if agg_func in ["SUM", "COUNT"]: new_exp = self.push_sum_or_count(agg_exp, scope) elif agg_func == "AVG": new_exp = self.calculate_avg(agg_exp, scope) elif agg_func in ["VAR", "VARIANCE"]: new_exp = self.calculate_variance(agg_exp, scope) elif agg_func in ["STD", "STDDEV"]: new_exp = self.calculate_stddev(agg_exp, scope) else: raise ValueError( "We don't know how to rewrite aggregate function: " + str(agg_exp)) agg_exp.name = "" agg_exp.quantifier = None agg_exp.expression = new_exp return NamedExpression(name, exp)