def _CollapseFunctions(stack): """Collapses functions by evaluating them for actual values. Replaces a function's postfix expression with a single token. If the function can be evaluated (no fields included as arguments), the single token is the value of function's evaluation. Otherwise, the function is collapsed into a single token without evaluation. Arguments: stack: The stack whose functions are to be collapsed and resolved. Raises: bigquery_client.BigqueryInvalidQueryError: If a field exists inside the arguments of a function. Returns: True iff a function is found and collapsed. In other words, another potential function can still exist. """ for i in xrange(len(stack)): if isinstance(stack[i], util.BuiltInFunctionToken): start_idx, postfix_expr = interpreter.GetSingleValue(stack[:i + 1]) if util.IsEncryptedExpression(postfix_expr): raise bigquery_client.BigqueryInvalidQueryError( 'Invalid aggregation function argument: Cannot put an encrypted ' 'field as an argument to a built-in function.', None, None, None) # If the expression has no fields, we want to get the actual value. # But, if the field has a field, we have to get the infix string instead. try: result = interpreter.Evaluate(list(postfix_expr)) if isinstance(result, basestring): result = util.StringLiteralToken('"%s"' % result) elif result is None: result = util.LiteralToken('NULL', None) elif str(result).lower() in ['true', 'false']: result = util.LiteralToken(str(result).lower(), result) stack[start_idx:i + 1] = [result] except bigquery_client.BigqueryInvalidQueryError: result = interpreter.ToInfix(list(postfix_expr)) stack[start_idx:i + 1] = [util.FieldToken(result)] return True return False
def _CollapseAggregations(stack, nsquare): """Collapses the aggregations by combining arguments and functions. During collapses, checks will be done to if aggregations are done on encrypted fields. The following aggregations will be rewritten: SUM(<homomorphic field>) becomes TO_BASE64(PAILLIER_SUM(FROM_BASE64(<homomorphic field>), <nsquare>)) AVG(<homomorphic field>) becomes TO_BASE64(PAILLIER_SUM(FROM_BASE64(<homomorphic field>), <nsquare>)) / COUNT(<homomorphic field>) Arguments: stack: The stack whose aggregations are to be collapsed. nsquare: Used for homomorphic addition. Returns: True iff an aggregation was found and collapsed. In other words, another potential aggregation can still exist. """ for i in xrange(len(stack)): if isinstance(stack[i], util.AggregationFunctionToken): num_args = stack[i].num_args function_type = str(stack[i]) postfix_exprs = [] infix_exprs = [] start_idx = i rewritten_infix_expr = None is_encrypted = False # pylint: disable=unused-variable for j in xrange(int(num_args)): start_idx, postfix_expr = interpreter.GetSingleValue( stack[:start_idx]) is_encrypted = is_encrypted or util.IsEncryptedExpression( postfix_expr) while _CollapseFunctions(postfix_expr): pass postfix_exprs.append(postfix_expr) infix_exprs.append(interpreter.ToInfix(list(postfix_expr))) # Check for proper nested aggregations. # PAILLIER_SUM and GROUP_CONCAT on encrypted fields are not legal # arguments for an aggregation. for expr in postfix_exprs: for token in expr: if not isinstance(token, util.AggregationQueryToken): continue if token.startswith(util.PAILLIER_SUM_PREFIX): raise bigquery_client.BigqueryInvalidQueryError( 'Cannot use SUM/AVG on homomorphic encryption as argument ' 'for another aggregation.', None, None, None) elif token.startswith(util.GROUP_CONCAT_PREFIX): fieldname = token.split( util.GROUP_CONCAT_PREFIX)[1][:-1] if util.IsEncrypted(fieldname): raise bigquery_client.BigqueryInvalidQueryError( 'Cannot use GROUP_CONCAT on an encrypted field as argument ' 'for another aggregation.', None, None, None) infix_exprs.reverse() if function_type in ['COUNT', 'DISTINCTCOUNT']: if (function_type == 'DISTINCTCOUNT' and util.IsDeterministicExpression(postfix_exprs[0])): raise bigquery_client.BigqueryInvalidQueryError( 'Cannot do distinct count on non-pseudonym encryption.', None, None, None) if function_type == 'DISTINCTCOUNT': infix_exprs[0] = 'DISTINCT ' + infix_exprs[0] rewritten_infix_expr = [ util.AggregationQueryToken('COUNT(%s)' % ', '.join(infix_exprs)) ] elif function_type == 'TOP': if util.IsDeterministicExpression(postfix_exprs[0]): raise bigquery_client.BigqueryInvalidQueryError( 'Cannot do TOP on non-deterministic encryption.', None, None, None) rewritten_infix_expr = [ util.AggregationQueryToken('TOP(%s)' % ', '.join(infix_exprs)) ] elif function_type in ['AVG', 'SUM'] and is_encrypted: list_fields = interpreter.CheckValidSumAverageArgument( postfix_expr)[0] rewritten_infix_expr = [] # The representative label is the field that is going to be used # to get constant values. An expression SUM(ax + b) must be rewritten as # a * SUM(x) + b * COUNT(x). Represetative label is x (this isn't unique # as many fields can be in COUNT). representative_label = '' for field in list_fields: for token in field: if util.IsLabel(token): representative_label = token break if representative_label: break for field in list_fields: expression = interpreter.ExpandExpression(field) queries, constant = expression[0], expression[1] rewritten_infix_expr.append(float(constant)) rewritten_infix_expr.append( util.AggregationQueryToken('COUNT(%s)' % representative_label)) rewritten_infix_expr.append(util.OperatorToken('*', 2)) for query in queries: rewritten_infix_expr.append(float(query[0])) if (isinstance(query[1], util.HomomorphicFloatToken) or isinstance(query[1], util.HomomorphicIntToken)): rewritten_infix_expr.append( util.ConstructPaillierSumQuery( query[1], nsquare)) else: rewritten_infix_expr.append( util.AggregationQueryToken('SUM(%s)' % query[1])) rewritten_infix_expr.append(util.OperatorToken('*', 2)) for j in range(len(queries)): rewritten_infix_expr.append(util.OperatorToken('+', 2)) for j in range(len(list_fields) - 1): rewritten_infix_expr.append(util.OperatorToken('+', 2)) if function_type == 'AVG': rewritten_infix_expr.append( util.AggregationQueryToken('COUNT(%s)' % representative_label)) rewritten_infix_expr.append(util.OperatorToken('/', 2)) elif function_type == 'GROUP_CONCAT': rewritten_infix_expr = [ util.AggregationQueryToken('GROUP_CONCAT(%s)' % ', '.join(infix_exprs)) ] elif is_encrypted: raise bigquery_client.BigqueryInvalidQueryError( 'Cannot do %s aggregation on any encrypted fields.' % function_type, None, None, None) else: rewritten_infix_expr = [ util.AggregationQueryToken( '%s(%s)' % (function_type, ', '.join(infix_exprs))) ] stack[start_idx:i + 1] = rewritten_infix_expr return True return False
def FailIfEncrypted(tokens): if util.IsEncryptedExpression(tokens): raise bigquery_client.BigqueryInvalidQueryError( 'Invalid where/having expression.', None, None, None)