def testRewrittenAggregation(self):
     stack = [
         util.FieldToken('Year'),
         util.AggregationFunctionToken('DISTINCTCOUNT', 1)
     ]
     self.assertEqual(interpreter.ToInfix(list(stack)),
                      'COUNT(DISTINCT Year)')
 def testString(self):
     stack = [
         util.StringLiteralToken('"TESTING IS FUN."'), 4,
         util.BuiltInFunctionToken('left')
     ]
     self.assertEqual(interpreter.ToInfix(list(stack)),
                      'left("TESTING IS FUN.", 4)')
     self.assertEqual(interpreter.Evaluate(stack), 'TEST')
 def testOneArgumentFunction(self):
     stack = [
         0,
         util.BuiltInFunctionToken('cos'),
         util.BuiltInFunctionToken('ln'),
         util.BuiltInFunctionToken('sqrt')
     ]
     self.assertEqual(interpreter.ToInfix(list(stack)), 'sqrt(ln(cos(0)))')
     self.assertEqual(interpreter.Evaluate(stack), 0)
 def testUnary(self):
     stack = [
         1, 2,
         util.OperatorToken('~', 1),
         util.OperatorToken('<', 2),
         util.OperatorToken('not', 1)
     ]
     self.assertEqual(interpreter.ToInfix(list(stack)), 'not (1 < ~ 2)')
     self.assertEqual(interpreter.Evaluate(stack), True)
 def testMultipleArgumentFunction(self):
     stack = [
         'True', 3, 0,
         util.BuiltInFunctionToken('if'), 2, 1,
         util.BuiltInFunctionToken('pow'),
         util.BuiltInFunctionToken('pow')
     ]
     self.assertEqual(interpreter.ToInfix(list(stack)),
                      'pow(if(True, 3, 0), pow(2, 1))')
     self.assertEqual(interpreter.Evaluate(stack), 9)
 def testBooleanLiterals(self):
     stack = [
         util.LiteralToken('True', True),
         util.LiteralToken('False', False),
         util.OperatorToken('or', 2), 1, 2,
         util.OperatorToken('=', 2),
         util.OperatorToken('or', 2)
     ]
     self.assertEqual(interpreter.ToInfix(list(stack)),
                      '((True or False) or (1 = 2))')
     self.assertEqual(interpreter.Evaluate(stack), True)
 def testBinary(self):
     stack = [
         1, 2,
         util.OperatorToken('+', 2), 3,
         util.OperatorToken('*', 2), 9,
         util.OperatorToken('/', 2), 2,
         util.OperatorToken('-', 2)
     ]
     self.assertEqual(interpreter.ToInfix(list(stack)),
                      '((((1 + 2) * 3) / 9) - 2)')
     self.assertEqual(interpreter.Evaluate(stack), -1)
Exemplo n.º 8
0
 def ConstructColumnNames(self, column_names):
     """Replaces original column names with their alias, if one exists."""
     rewritten_column_names = []
     for i in range(len(column_names)):
         single_column = {}
         if i in self._argument:
             single_column['name'] = self._argument[i]
         else:
             single_column['name'] = str(
                 interpreter.ToInfix(copy(column_names[i])))
         rewritten_column_names.append(single_column)
     return rewritten_column_names
Exemplo n.º 9
0
def _ExtractUnencryptedQueries(postfix_stacks, within):
    """Extracts expressions (not a single term) that are unencrypted.

  If the expression was modified by a within clause, then the within clause
  is prepended to the expression.

  Args:
    postfix_stacks: List of postfix expressions with a potentially unencrypted
      expression.
    within: Dictionary of index of expressions to nodes/records to aggregate
      over.

  Returns:
    List of unencrypted expressions that are not a single term.
  """
    def _IsEncryptedExpression(stack):
        for token in stack:
            if (not isinstance(token, util.FieldToken)
                    and not isinstance(token, util.AggregationQueryToken)):
                continue
            if (util.HOMOMORPHIC_FLOAT_PREFIX in token
                    or util.HOMOMORPHIC_INT_PREFIX in token
                    or util.PSEUDONYM_PREFIX in token
                    or util.PROBABILISTIC_PREFIX in token
                    or util.SEARCHWORDS_PREFIX in token):
                return True
        return False

    unencrypted_expressions = []
    counter = 0
    for i in range(len(postfix_stacks)):
        if not _IsEncryptedExpression(postfix_stacks[i]):
            expression = interpreter.ToInfix(list(postfix_stacks[i]))
            if i in within:
                expression += ' WITHIN %s' % within[i]
            expression += ' AS %s%d_' % (util.UNENCRYPTED_ALIAS_PREFIX,
                                         counter)
            unencrypted_expressions.append(expression)
            postfix_stacks[i] = [
                util.UnencryptedQueryToken(
                    '%s%d_' % (util.UNENCRYPTED_ALIAS_PREFIX, counter))
            ]
            counter += 1

    return unencrypted_expressions
Exemplo n.º 10
0
def _CollapseFunctions(stack):
    """Collapses functions by evaluating them for actual values.

  Replaces a function's postfix expression with a single token. If the function
  can be evaluated (no fields included as arguments), the single token is
  the value of function's evaluation. Otherwise, the function is collapsed
  into a single token without evaluation.

  Arguments:
    stack: The stack whose functions are to be collapsed and resolved.

  Raises:
    bigquery_client.BigqueryInvalidQueryError: If a field exists inside
    the arguments of a function.

  Returns:
    True iff a function is found and collapsed. In other words, another
    potential function can still exist.
  """
    for i in xrange(len(stack)):
        if isinstance(stack[i], util.BuiltInFunctionToken):
            start_idx, postfix_expr = interpreter.GetSingleValue(stack[:i + 1])
            if util.IsEncryptedExpression(postfix_expr):
                raise bigquery_client.BigqueryInvalidQueryError(
                    'Invalid aggregation function argument: Cannot put an encrypted '
                    'field as an argument to a built-in function.', None, None,
                    None)
            # If the expression has no fields, we want to get the actual value.
            # But, if the field has a field, we have to get the infix string instead.
            try:
                result = interpreter.Evaluate(list(postfix_expr))
                if isinstance(result, basestring):
                    result = util.StringLiteralToken('"%s"' % result)
                elif result is None:
                    result = util.LiteralToken('NULL', None)
                elif str(result).lower() in ['true', 'false']:
                    result = util.LiteralToken(str(result).lower(), result)
                stack[start_idx:i + 1] = [result]
            except bigquery_client.BigqueryInvalidQueryError:
                result = interpreter.ToInfix(list(postfix_expr))
                stack[start_idx:i + 1] = [util.FieldToken(result)]
            return True
    return False
Exemplo n.º 11
0
def _CollapseAggregations(stack, nsquare):
    """Collapses the aggregations by combining arguments and functions.

  During collapses, checks will be done to if aggregations are done on
  encrypted fields. The following aggregations will be rewritten:

  SUM(<homomorphic field>) becomes
  TO_BASE64(PAILLIER_SUM(FROM_BASE64(<homomorphic field>), <nsquare>))

  AVG(<homomorphic field>) becomes
  TO_BASE64(PAILLIER_SUM(FROM_BASE64(<homomorphic field>), <nsquare>)) /
  COUNT(<homomorphic field>)

  Arguments:
    stack: The stack whose aggregations are to be collapsed.
    nsquare: Used for homomorphic addition.

  Returns:
    True iff an aggregation was found and collapsed. In other words, another
    potential aggregation can still exist.
  """
    for i in xrange(len(stack)):
        if isinstance(stack[i], util.AggregationFunctionToken):
            num_args = stack[i].num_args
            function_type = str(stack[i])
            postfix_exprs = []
            infix_exprs = []
            start_idx = i
            rewritten_infix_expr = None
            is_encrypted = False
            # pylint: disable=unused-variable
            for j in xrange(int(num_args)):
                start_idx, postfix_expr = interpreter.GetSingleValue(
                    stack[:start_idx])
                is_encrypted = is_encrypted or util.IsEncryptedExpression(
                    postfix_expr)
                while _CollapseFunctions(postfix_expr):
                    pass
                postfix_exprs.append(postfix_expr)
                infix_exprs.append(interpreter.ToInfix(list(postfix_expr)))
            # Check for proper nested aggregations.
            # PAILLIER_SUM and GROUP_CONCAT on encrypted fields are not legal
            # arguments for an aggregation.
            for expr in postfix_exprs:
                for token in expr:
                    if not isinstance(token, util.AggregationQueryToken):
                        continue
                    if token.startswith(util.PAILLIER_SUM_PREFIX):
                        raise bigquery_client.BigqueryInvalidQueryError(
                            'Cannot use SUM/AVG on homomorphic encryption as argument '
                            'for another aggregation.', None, None, None)
                    elif token.startswith(util.GROUP_CONCAT_PREFIX):
                        fieldname = token.split(
                            util.GROUP_CONCAT_PREFIX)[1][:-1]
                        if util.IsEncrypted(fieldname):
                            raise bigquery_client.BigqueryInvalidQueryError(
                                'Cannot use GROUP_CONCAT on an encrypted field as argument '
                                'for another aggregation.', None, None, None)
            infix_exprs.reverse()
            if function_type in ['COUNT', 'DISTINCTCOUNT']:
                if (function_type == 'DISTINCTCOUNT'
                        and util.IsDeterministicExpression(postfix_exprs[0])):
                    raise bigquery_client.BigqueryInvalidQueryError(
                        'Cannot do distinct count on non-pseudonym encryption.',
                        None, None, None)
                if function_type == 'DISTINCTCOUNT':
                    infix_exprs[0] = 'DISTINCT ' + infix_exprs[0]
                rewritten_infix_expr = [
                    util.AggregationQueryToken('COUNT(%s)' %
                                               ', '.join(infix_exprs))
                ]
            elif function_type == 'TOP':
                if util.IsDeterministicExpression(postfix_exprs[0]):
                    raise bigquery_client.BigqueryInvalidQueryError(
                        'Cannot do TOP on non-deterministic encryption.', None,
                        None, None)
                rewritten_infix_expr = [
                    util.AggregationQueryToken('TOP(%s)' %
                                               ', '.join(infix_exprs))
                ]
            elif function_type in ['AVG', 'SUM'] and is_encrypted:
                list_fields = interpreter.CheckValidSumAverageArgument(
                    postfix_expr)[0]
                rewritten_infix_expr = []
                # The representative label is the field that is going to be used
                # to get constant values. An expression SUM(ax + b) must be rewritten as
                # a * SUM(x) + b * COUNT(x). Represetative label is x (this isn't unique
                # as many fields can be in COUNT).
                representative_label = ''
                for field in list_fields:
                    for token in field:
                        if util.IsLabel(token):
                            representative_label = token
                            break
                    if representative_label:
                        break
                for field in list_fields:
                    expression = interpreter.ExpandExpression(field)
                    queries, constant = expression[0], expression[1]
                    rewritten_infix_expr.append(float(constant))
                    rewritten_infix_expr.append(
                        util.AggregationQueryToken('COUNT(%s)' %
                                                   representative_label))
                    rewritten_infix_expr.append(util.OperatorToken('*', 2))
                    for query in queries:
                        rewritten_infix_expr.append(float(query[0]))
                        if (isinstance(query[1], util.HomomorphicFloatToken)
                                or isinstance(query[1],
                                              util.HomomorphicIntToken)):
                            rewritten_infix_expr.append(
                                util.ConstructPaillierSumQuery(
                                    query[1], nsquare))
                        else:
                            rewritten_infix_expr.append(
                                util.AggregationQueryToken('SUM(%s)' %
                                                           query[1]))
                        rewritten_infix_expr.append(util.OperatorToken('*', 2))
                    for j in range(len(queries)):
                        rewritten_infix_expr.append(util.OperatorToken('+', 2))
                for j in range(len(list_fields) - 1):
                    rewritten_infix_expr.append(util.OperatorToken('+', 2))
                if function_type == 'AVG':
                    rewritten_infix_expr.append(
                        util.AggregationQueryToken('COUNT(%s)' %
                                                   representative_label))
                    rewritten_infix_expr.append(util.OperatorToken('/', 2))
            elif function_type == 'GROUP_CONCAT':
                rewritten_infix_expr = [
                    util.AggregationQueryToken('GROUP_CONCAT(%s)' %
                                               ', '.join(infix_exprs))
                ]
            elif is_encrypted:
                raise bigquery_client.BigqueryInvalidQueryError(
                    'Cannot do %s aggregation on any encrypted fields.' %
                    function_type, None, None, None)
            else:
                rewritten_infix_expr = [
                    util.AggregationQueryToken(
                        '%s(%s)' % (function_type, ', '.join(infix_exprs)))
                ]
            stack[start_idx:i + 1] = rewritten_infix_expr
            return True
    return False
 def testCountStar(self):
     stack = [util.CountStarToken()]
     self.assertEqual(interpreter.ToInfix(list(stack)), '*')
 def testNoArgumentFunction(self):
     stack = [util.BuiltInFunctionToken('pi')]
     self.assertEqual(interpreter.ToInfix(list(stack)), 'pi()')
     self.assertEqual(interpreter.Evaluate(stack), math.pi)
 def testSimpleExpression(self):
     stack = [1, 2, util.OperatorToken('+', 2)]
     self.assertEqual(interpreter.ToInfix(list(stack)), '(1 + 2)')
     self.assertEqual(interpreter.Evaluate(stack), 3)
 def testNull(self):
     stack = [util.LiteralToken('null', None)]
     self.assertEqual(interpreter.ToInfix(list(stack)), 'null')
     self.assertEqual(interpreter.Evaluate(stack), None)