def testExtractUnencryptedQueries(self): stacks = [ [util.FieldToken('Year')], [1], [util.FieldToken('Year'), 1, util.OperatorToken('+', 2)], [util.ProbabilisticToken('Price')], [util.FieldToken('GROUP_CONCAT(%sModel)' % util.PSEUDONYM_PREFIX)], [util.FieldToken('SUM(Year + 1)')] ] unencrypted_expression_list = [ 'Year AS ' + util.UNENCRYPTED_ALIAS_PREFIX + '0_', '1 AS ' + util.UNENCRYPTED_ALIAS_PREFIX + '1_', '(Year + 1) AS ' + util.UNENCRYPTED_ALIAS_PREFIX + '2_', 'SUM(Year + 1) AS ' + util.UNENCRYPTED_ALIAS_PREFIX + '3_' ] self.assertEqual(query_lib._ExtractUnencryptedQueries(stacks, {}), unencrypted_expression_list) stacks = [ [util.FieldToken('Year')], [1], [util.FieldToken('Year'), 1, util.OperatorToken('+', 2)], [util.ProbabilisticToken('Price')], [util.FieldToken('GROUP_CONCAT(%sModel)' % util.PSEUDONYM_PREFIX)], [util.FieldToken('SUM(Year + 1)')] ] within = {4: 'w1', 5: 'w2'} unencrypted_expression_list = [ 'Year AS ' + util.UNENCRYPTED_ALIAS_PREFIX + '0_', '1 AS ' + util.UNENCRYPTED_ALIAS_PREFIX + '1_', '(Year + 1) AS ' + util.UNENCRYPTED_ALIAS_PREFIX + '2_', 'SUM(Year + 1) WITHIN w2 AS ' + util.UNENCRYPTED_ALIAS_PREFIX + '3_' ] self.assertEqual(query_lib._ExtractUnencryptedQueries(stacks, within), unencrypted_expression_list)
def testFailOperationsOnEncryptions(self): schema = test_util.GetCarsSchema() key = test_util.GetMasterKey() stack = [ util.PseudonymToken('Year'), 1, util.OperatorToken('+', 2), 2000, util.OperatorToken('>=', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID) stack = [ util.ProbabilisticToken('Model'), 2, util.BuiltInFunctionToken('left') ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID) stack = [ util.HomomorphicIntToken('Invoice_Price'), util.BuiltInFunctionToken('is_nan') ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID)
def testGetSingleValue(self): stack = [1, 1, 1, util.OperatorToken('+', 2)] start, postfix = interpreter.GetSingleValue(stack) self.assertEqual(start, 1) self.assertEqual(postfix, [1, 1, util.OperatorToken('+', 2)]) stack = [1, util.OperatorToken('+', 2)] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.GetSingleValue, stack)
def testWhereRewriteWithRelated(self): """Test WHERE when pseudonym value exists in two different tables.""" schema = test_util.GetCarsSchema() # add 'related' field just for this test for field in schema: if field['name'] == 'Make': field['related'] = 'cars_name' break # this value determined by running the test, not by manual calc ciphertext = 'sspWKAH/NKuUyX8ji1mmSw==' # test 1, use table_id table_id = _TABLE_ID master_key = test_util.GetMasterKey() as_clause = query_lib._AsClause({}) stack = [ util.FieldToken('Make'), util.StringLiteralToken('"Hello"'), util.OperatorToken('==', 2) ] where_clause_1 = query_lib._WhereClause(stack, as_clause=as_clause, schema=schema, nsquare=_TEST_NSQUARE, master_key=master_key, table_id=table_id) rewritten_sql_1 = where_clause_1.Rewrite() self.assertEqual( rewritten_sql_1, 'WHERE (%sMake == "%s")' % (util.PSEUDONYM_PREFIX, ciphertext)) # test 2, change table_id, query should be same as test #1 table_id = _TABLE_ID + '_other' master_key = test_util.GetMasterKey() as_clause = query_lib._AsClause({}) stack = [ util.FieldToken('Make'), util.StringLiteralToken('"Hello"'), util.OperatorToken('==', 2) ] where_clause_2 = query_lib._WhereClause(stack, as_clause=as_clause, schema=schema, nsquare=_TEST_NSQUARE, master_key=master_key, table_id=table_id) rewritten_sql_2 = where_clause_2.Rewrite() self.assertEqual( rewritten_sql_2, 'WHERE (%sMake == "%s")' % (util.PSEUDONYM_PREFIX, ciphertext)) # verify different tables were used self.assertNotEqual(where_clause_1.table_id, where_clause_2.table_id) # and verify that same WHERE query="literal" was generated self.assertEqual(rewritten_sql_1, rewritten_sql_2)
def testComputeRowsEvaluate1(self): # Query is 'SELECT 1 + 1, 1 * 1' # Testing no queried values. stack = [[1, 1, util.OperatorToken('+', 2)], [1, 1, util.OperatorToken('*', 2)]] query = {} real_result = [['2', '1']] result = encrypted_bigquery_client._ComputeRows(stack, query) self.assertEqual(result, real_result)
def testUnary(self): stack = [ 1, 2, util.OperatorToken('~', 1), util.OperatorToken('<', 2), util.OperatorToken('not', 1) ] self.assertEqual(interpreter.ToInfix(list(stack)), 'not (1 < ~ 2)') self.assertEqual(interpreter.Evaluate(stack), True)
def testBinary(self): stack = [ 1, 2, util.OperatorToken('+', 2), 3, util.OperatorToken('*', 2), 9, util.OperatorToken('/', 2), 2, util.OperatorToken('-', 2) ] self.assertEqual(interpreter.ToInfix(list(stack)), '((((1 + 2) * 3) / 9) - 2)') self.assertEqual(interpreter.Evaluate(stack), -1)
def testBooleanLiterals(self): stack = [ util.LiteralToken('True', True), util.LiteralToken('False', False), util.OperatorToken('or', 2), 1, 2, util.OperatorToken('=', 2), util.OperatorToken('or', 2) ] self.assertEqual(interpreter.ToInfix(list(stack)), '((True or False) or (1 = 2))') self.assertEqual(interpreter.Evaluate(stack), True)
def PushUnaryOperators(tokens): # The list must be reversed since unary operations are unwrapped in the # other direction. An example is ~-1. The negation occurs before the bit # inversion. for i in reversed(range(0, len(tokens))): if tokens[i] == '-': math_stack.append(int('-1')) math_stack.append(util.OperatorToken('*', 2)) elif tokens[i] == '~': math_stack.append(util.OperatorToken('~', 1)) elif tokens[i].lower() == 'not': math_stack.append(util.OperatorToken('not', 1))
def testHavingRewrite(self): schema = test_util.GetCarsSchema() master_key = test_util.GetMasterKey() as_clause = query_lib._AsClause({}) stack = [util.FieldToken('SUM(Year)'), 1, util.OperatorToken('<', 2)] having_clause = query_lib._HavingClause(stack, as_clause=as_clause, schema=schema, nsquare=_TEST_NSQUARE, master_key=master_key, table_id=_TABLE_ID) self.assertEqual(having_clause.Rewrite(), 'HAVING (SUM(Year) < 1)') stack = [ 1000, util.AggregationQueryToken( 'TO_BASE64(BYTES(PAILLIER_SUM(FROM_BASE64(' + util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price), \'0\')))'), util.OperatorToken('==', 2) ] having_clause = query_lib._HavingClause(stack, as_clause=as_clause, schema=schema, nsquare=_TEST_NSQUARE, master_key=master_key, table_id=_TABLE_ID) self.assertRaises(bigquery_client.BigqueryInvalidQueryError, having_clause.Rewrite) stack = [ util.FieldToken('GROUP_CONCAT(' + util.PSEUDONYM_PREFIX + 'Model)'), util.BuiltInFunctionToken('len'), 5, util.OperatorToken('>', 2) ] having_clause = query_lib._HavingClause(stack, as_clause=as_clause, schema=schema, nsquare=_TEST_NSQUARE, master_key=master_key, table_id=_TABLE_ID) self.assertRaises(bigquery_client.BigqueryInvalidQueryError, having_clause.Rewrite) stack = [] having_clause = query_lib._HavingClause(stack, as_clause=as_clause, schema=schema, nsquare=_TEST_NSQUARE, master_key=master_key, table_id=_TABLE_ID) self.assertEqual(having_clause.Rewrite(), '')
def testReplaceAliasWhenNested(self): # Query is 'SELECT a + b as a, a + b as b' stacks = [[ util.FieldToken('a'), util.FieldToken('b'), util.OperatorToken('+', 2) ], [ util.FieldToken('a'), util.FieldToken('b'), util.OperatorToken('+', 2) ]] alias = {0: 'a', 1: 'b'} new_stack = query_lib._ReplaceAlias(stacks, alias) real_stack = [['a', 'b', '+'], ['a', 'b', '+', 'b', '+']] self.assertEqual(new_stack, real_stack)
def testComputeRowsEvaluate2(self): # Query is 'SELECT 1 + a, 1 * b, "hello"' # There are two rows of values for a and b (shown in query). # Result becomes as below: # 1 + a | 1 * b | "hello" # 2 3 "hello" # 4 5 "hello" stack = [[1, util.FieldToken('a'), util.OperatorToken('+', 2)], [1, util.FieldToken('b'), util.OperatorToken('*', 2)], [util.StringLiteralToken('"hello"')]] query = {'a': [1, 3], 'b': [3, 5]} real_result = [['2', '3', 'hello'], ['4', '5', 'hello']] result = encrypted_bigquery_client._ComputeRows(stack, query) self.assertEqual(result, real_result)
def testWhereRewrite(self): schema = test_util.GetCarsSchema() master_key = test_util.GetMasterKey() as_clause = query_lib._AsClause({}) stack = [ util.FieldToken('Make'), util.StringLiteralToken('"Hello"'), util.OperatorToken('==', 2) ] where_clause = query_lib._WhereClause(stack, as_clause=as_clause, schema=schema, nsquare=_TEST_NSQUARE, master_key=master_key, table_id=_TABLE_ID) self.assertEqual( where_clause.Rewrite(), 'WHERE (%sMake == "HS57DHbh2KlkqNJREmu1wQ==")' % util.PSEUDONYM_PREFIX) stack = [ util.FieldToken('Model'), util.StringLiteralToken('"A"'), util.OperatorToken('contains', 2) ] where_clause = query_lib._WhereClause(stack, as_clause=as_clause, schema=schema, nsquare=_TEST_NSQUARE, master_key=master_key, table_id=_TABLE_ID) self.assertEqual( where_clause.Rewrite(), 'WHERE (%sModel contains to_base64(left(bytes(sha1(concat(left(' '%sModel, 24), \'yB9HY2qv+DI=\'))), 8)))' % (util.SEARCHWORDS_PREFIX, util.SEARCHWORDS_PREFIX)) stack = [] where_clause = query_lib._WhereClause(stack, as_clause=as_clause, schema=schema, nsquare=_TEST_NSQUARE, master_key=master_key, table_id=_TABLE_ID) self.assertEqual(where_clause.Rewrite(), '')
def testAsConstructColumnNames(self): alias = {0: 'a'} columns = [[util.FieldToken('b')], [1, 2, util.OperatorToken('+', 2)]] as_clause = query_lib._AsClause(alias) self.assertEqual(as_clause.ConstructColumnNames(columns), [{ 'name': 'a' }, { 'name': '(1 + 2)' }])
def testCheckValidSumAverageArgument(self): stack = [ util.FieldToken('Year'), util.FieldToken('Year'), util.OperatorToken('*', 2), util.HomomorphicIntToken('Invoice_Price'), util.OperatorToken('+', 2) ] expected_stack = [[['Year', 'Year', '*'], [util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price']], True, True] self.assertEqual(interpreter.CheckValidSumAverageArgument(stack), expected_stack) stack = [ 2, util.FieldToken('Year'), util.FieldToken('Year'), util.OperatorToken('*', 2), util.HomomorphicIntToken('Invoice_Price'), util.OperatorToken('+', 2), util.OperatorToken('*', 2) ] expected_stack = [[ [2, 'Year', 'Year', '*', '*'], [2, util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price', '*'] ], True, True] self.assertEqual(interpreter.CheckValidSumAverageArgument(stack), expected_stack) stack = [util.ProbabilisticToken('Price')] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.CheckValidSumAverageArgument, stack) stack = [ util.HomomorphicIntToken('Invoice_Price'), util.HomomorphicFloatToken('Holdback_Percentage'), util.OperatorToken('*', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.CheckValidSumAverageArgument, stack) stack = [ util.HomomorphicIntToken('Invoice_Price'), util.FieldToken('Year'), util.OperatorToken('*', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.CheckValidSumAverageArgument, stack) stack = [ util.HomomorphicIntToken('Invoice_Price'), util.FieldToken('Year'), util.OperatorToken('/', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.CheckValidSumAverageArgument, stack)
def testSimpleWhere(self): schema = test_util.GetCarsSchema() key = test_util.GetMasterKey() stack = [1, 2, util.OperatorToken('>', 2)] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '(1 > 2)') stack = [1, 2, util.OperatorToken('=', 2)] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '(1 = 2)') stack = [util.FieldToken('PI()'), 1, util.OperatorToken('>', 2)] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '(PI() > 1)') stack = [1, util.OperatorToken('>', 2)] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID) stack = [util.FieldToken('Year'), 2000, util.OperatorToken('<', 2)] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '(Year < 2000)')
def testExtractEncryptedQueries(self): # Original query is 'SELECT (111 + a_1) * a, TRUE OR False, null, PI(), # FUNC_PI, a1' # Query sent to server becomes 'SELECT a, a_1, FUNC_PI, a1' stacks = [[ util.FieldToken('a'), 111, util.FieldToken('a_1'), util.OperatorToken('+', 2), util.OperatorToken('*', 2) ], ['TRUE', 'False', util.OperatorToken('OR', 2)], ['null'], [util.BuiltInFunctionToken('PI')], [util.FieldToken('FUNC_PI')], [util.FieldToken('a1')], [util.FieldToken('a.b')], [ util.UnencryptedQueryToken('%s0_' % util.UNENCRYPTED_ALIAS_PREFIX) ]] query_list = query_lib._ExtractFieldQueries(stacks, strize=True) expect_query_list = set([ 'a', 'a_1', 'FUNC_PI', 'a1', 'a.b AS a' + util.PERIOD_REPLACEMENT + 'b' ]) self.assertEqual(expect_query_list, query_list)
def CheckValidSumAverageArgument(stack): """Checks if stack is a proper argument for SUM/AVG. This recursive algorithm performs tainting. It uses a special structure to store data which is as follows: s = [list of postfix expressions, taint1, taint2] The list of postfix expressiions all added together is equivalent to the expanded version of <stack>. taint1 represents whether s contains an encrypted field. taint2 represents whether s contains any field. taint1 is true iff s contains any field (encrypted or unencrypted). This algorithm fails if any encrypted field is multipled/divided by any other field (either encrypted or unencrypted). Arguments: stack: The postfix expression that is being checked if valid for SUM/AVG argument. Returns: A tuple containing a list of postfix expressions, and two types of taints. Representing whether a field is in s and an encrypted field is in s. Raises: bigquery_client.BigqueryInvalidQueryError: Thrown iff <stack> is not a valid linear expression (or one we cannot compute) that can be a SUM/AVG argument. """ top = stack.pop() if ((isinstance(top, util.OperatorToken) and top.num_args == 1) or isinstance(top, util.BuiltInFunctionToken) or isinstance(top, util.AggregationFunctionToken) or isinstance(top, util.LiteralToken)): raise bigquery_client.BigqueryInvalidQueryError( 'Invalid SUM arguments. %s is not supported' % top, None, None, None) elif top in ['+', '-']: op2 = CheckValidSumAverageArgument(stack) op1 = CheckValidSumAverageArgument(stack) list_fields = list(op1[0]) if top == '-': for i in range(len(op2[0])): op2[0][i].extend([-1, util.OperatorToken('*', 2)]) for i in range(len(op2[0])): list_fields.append(op2[0][i]) return [list_fields, op1[1] or op2[1], op1[2] or op2[2]] elif top == '*': op2 = CheckValidSumAverageArgument(stack) op1 = CheckValidSumAverageArgument(stack) if (op1[1] and (op2[1] or op2[2])) or (op2[1] and (op1[1] or op1[2])): raise bigquery_client.BigqueryInvalidQueryError( 'Invalid AVG/SUM argument. An encrypted field is multipled by another' ' field.', None, None, None) list_fields = [] for field1 in op1[0]: for field2 in op2[0]: value = list(field1) value.extend(field2) value.append(util.OperatorToken('*', 2)) list_fields.append(value) return [list_fields, op1[1] or op2[1], op1[2] or op2[2]] elif top == '/': op2 = CheckValidSumAverageArgument(stack) op1 = CheckValidSumAverageArgument(stack) if op2[1] or (op1[1] and op2[2]): raise bigquery_client.BigqueryInvalidQueryError( 'Division by/of an encrypted field: not a linear function.', None, None, None) append_divisor = [] for field in op2[0]: append_divisor.extend(field) for i in range(len(op2[0]) - 1): append_divisor.append(util.OperatorToken('+', 2)) append_divisor.append(util.OperatorToken('/', 2)) list_fields = list(op1[0]) for i in xrange(len(list_fields)): list_fields[i].extend(append_divisor) return [list_fields, op1[1], op1[2] or op2[2]] else: if (isinstance(top, util.PseudonymToken) or isinstance(top, util.SearchwordsToken) or isinstance(top, util.ProbabilisticToken)): raise bigquery_client.BigqueryInvalidQueryError( 'Cannot do SUM/AVG on non-homomorphic encryption.', None, None, None) is_encrypted = (isinstance(top, util.HomomorphicIntToken) or isinstance(top, util.HomomorphicFloatToken)) return [[[top]], is_encrypted, not util.IsFloat(top)]
def testSimpleExpression(self): stack = [1, 2, util.OperatorToken('+', 2)] self.assertEqual(interpreter.ToInfix(list(stack)), '(1 + 2)') self.assertEqual(interpreter.Evaluate(stack), 3)
def testRewriteAggregations(self): stack = [ util.CountStarToken(), util.AggregationFunctionToken('COUNT', 1) ] rewritten_stack = ['COUNT(*)'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.ProbabilisticToken('Price'), util.AggregationFunctionToken('COUNT', 1) ] rewritten_stack = ['COUNT(' + util.PROBABILISTIC_PREFIX + 'Price)'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.ProbabilisticToken('Price'), 4, util.AggregationFunctionToken('COUNT', 2) ] rewritten_stack = ['COUNT(' + util.PROBABILISTIC_PREFIX + 'Price, 4)'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.FieldToken('Year'), 5, util.AggregationFunctionToken('DISTINCTCOUNT', 2), util.FieldToken('Year'), util.AggregationFunctionToken('COUNT', 1), util.OperatorToken('+', 2) ] rewritten_stack = ['COUNT(DISTINCT Year, 5)', 'COUNT(Year)', '+'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ 0, util.BuiltInFunctionToken('cos'), util.AggregationFunctionToken('COUNT', 1) ] rewritten_stack = ['COUNT(1.0)'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.StringLiteralToken('"Hello"'), 2, util.BuiltInFunctionToken('left'), util.StringLiteralToken('"y"'), util.BuiltInFunctionToken('concat'), util.AggregationFunctionToken('GROUP_CONCAT', 1) ] rewritten_stack = ['GROUP_CONCAT("Hey")'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.FieldToken('Year'), util.FieldToken('Year'), util.OperatorToken('*', 2), util.AggregationFunctionToken('SUM', 1) ] rewritten_stack = ['SUM((Year * Year))'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.HomomorphicIntToken('Invoice_Price'), util.AggregationFunctionToken('SUM', 1) ] rewritten_stack = [ 0.0, 'COUNT(' + util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price)', '*', 1.0, 'TO_BASE64(BYTES(PAILLIER_SUM(FROM_BASE64(' + util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price), \'0\')))', '*', '+' ] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.HomomorphicFloatToken('Holdback_Percentage'), util.AggregationFunctionToken('AVG', 1) ] rewritten_stack = [ 0.0, 'COUNT(' + util.HOMOMORPHIC_FLOAT_PREFIX + 'Holdback_Percentage)', '*', 1.0, 'TO_BASE64(BYTES(PAILLIER_SUM(FROM_BASE64(' + util.HOMOMORPHIC_FLOAT_PREFIX + 'Holdback_Percentage), \'0\')))', '*', '+', 'COUNT(' + util.HOMOMORPHIC_FLOAT_PREFIX + 'Holdback_Percentage)', '/' ] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.HomomorphicIntToken('Invoice_Price'), 2, util.OperatorToken('+', 2), 5, util.OperatorToken('*', 2), util.AggregationFunctionToken('SUM', 1) ] rewritten_stack = [ 0.0, 'COUNT(' + util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price)', '*', 5.0, 'TO_BASE64(BYTES(PAILLIER_SUM(FROM_BASE64(' + util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price), \'0\')))', '*', '+', 0.0, 'COUNT(' + util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price)', '*', 1.0, 'SUM((2 * 5))', '*', '+', '+' ] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.PseudonymToken('Make'), 2, util.AggregationFunctionToken('DISTINCTCOUNT', 2) ] rewritten_stack = [ 'COUNT(DISTINCT ' + util.PSEUDONYM_PREFIX + 'Make, 2)' ] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.FieldToken('Year'), util.AggregationFunctionToken('TOP', 1) ] rewritten_stack = ['TOP(Year)'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.PseudonymToken('Make'), 5, 1, util.AggregationFunctionToken('TOP', 3) ] rewritten_stack = ['TOP(' + util.PSEUDONYM_PREFIX + 'Make, 5, 1)'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.FieldToken('Year'), util.BuiltInFunctionToken('cos'), util.HomomorphicIntToken('Invoice_Price'), util.OperatorToken('+', 2), util.AggregationFunctionToken('SUM', 1) ] rewritten_stack = [ 0.0, 'COUNT(' + util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price)', '*', 1.0, 'SUM(cos(Year))', '*', '+', 0.0, 'COUNT(' + util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price)', '*', 1.0, 'TO_BASE64(BYTES(PAILLIER_SUM(FROM_BASE64(' + util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price),' ' \'0\')))', '*', '+', '+' ] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.ProbabilisticToken('Model'), util.AggregationFunctionToken('DISTINCTCOUNT', 1) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, query_lib._RewriteAggregations, [stack], _TEST_NSQUARE) stack = [ util.ProbabilisticToken('Price'), util.AggregationFunctionToken('SUM', 1) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, query_lib._RewriteAggregations, [stack], _TEST_NSQUARE) stack = [ util.HomomorphicIntToken('Invoice_Price'), util.HomomorphicFloatToken('Holdback_Percentage'), util.OperatorToken('*', 2), util.AggregationFunctionToken('SUM', 1) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, query_lib._RewriteAggregations, [stack], _TEST_NSQUARE) stack = [ util.HomomorphicFloatToken('Holdback_Percentage'), util.BuiltInFunctionToken('cos'), util.AggregationFunctionToken('SUM', 1) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, query_lib._RewriteAggregations, [stack], _TEST_NSQUARE) stack = [ util.HomomorphicIntToken('Invoice_Price'), util.AggregationFunctionToken('TOP', 1) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, query_lib._RewriteAggregations, [stack], _TEST_NSQUARE) stack = [ util.FieldToken('Year'), util.AggregationFunctionToken('SUM', 1), util.AggregationFunctionToken('SUM', 1) ] rewritten_stack = ['SUM(SUM(Year))'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.HomomorphicIntToken('Invoice_Price'), util.AggregationFunctionToken('SUM', 1), util.AggregationFunctionToken('SUM', 1) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, query_lib._RewriteAggregations, [stack], _TEST_NSQUARE) stack = [ util.FieldToken('Year'), util.AggregationFunctionToken('GROUP_CONCAT', 1), util.AggregationFunctionToken('GROUP_CONCAT', 1) ] rewritten_stack = ['GROUP_CONCAT(GROUP_CONCAT(Year))'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.PseudonymToken('Make'), util.AggregationFunctionToken('GROUP_CONCAT', 1), util.AggregationFunctionToken('GROUP_CONCAT', 1) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, query_lib._RewriteAggregations, [stack], _TEST_NSQUARE)
def PushBinaryOperator(tokens): math_stack.append(util.OperatorToken(tokens[0], 2))
def testEncryptedEquality(self): schema = test_util.GetCarsSchema() key = test_util.GetMasterKey() stack = [ util.FieldToken('Year'), 1, util.OperatorToken('+', 2), 2000, util.OperatorToken('=', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '((Year + 1) = 2000)') stack = [ util.FieldToken('Year'), util.PseudonymToken('Make'), util.OperatorToken('=', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '(Year = ' + util.PSEUDONYM_PREFIX + 'Make)') stack = [ util.PseudonymToken('Make'), util.StringLiteralToken('"Hello"'), util.OperatorToken('==', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '(' + util.PSEUDONYM_PREFIX + 'Make == "HS57DHbh2KlkqNJREmu1wQ==")') # begin: tests about 'related' schema option schema2 = test_util.GetCarsSchema() for field in schema2: if field['name'] == 'Make': field['related'] = 'cars_name' # value is deterministic calc with related instead of _TABLE_ID stack = [ util.PseudonymToken('Make', related=_RELATED), util.StringLiteralToken('"Hello"'), util.OperatorToken('==', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema2, key, _TABLE_ID), '(' + util.PSEUDONYM_PREFIX + 'Make == "sspWKAH/NKuUyX8ji1mmSw==")') # token with related attribute makes no sense if schema doesn't have it stack = [ util.StringLiteralToken('"Hello"'), util.PseudonymToken('Make', related=_RELATED), util.OperatorToken('==', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID) # end: tests about 'related' schema option stack = [ util.StringLiteralToken('"Hello"'), util.PseudonymToken('Make'), util.OperatorToken('==', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '("HS57DHbh2KlkqNJREmu1wQ==" == ' + util.PSEUDONYM_PREFIX + 'Make)') stack = [ util.StringLiteralToken('"Hello"'), util.PseudonymToken('Make'), util.OperatorToken('!=', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '("HS57DHbh2KlkqNJREmu1wQ==" != ' + util.PSEUDONYM_PREFIX + 'Make)') stack = [ util.PseudonymToken('Make'), util.PseudonymToken('Make2'), util.OperatorToken('==', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '(' + util.PSEUDONYM_PREFIX + 'Make == ' + util.PSEUDONYM_PREFIX + 'Make2)') stack = [ util.HomomorphicIntToken('Invoice_Price'), 2, util.OperatorToken('==', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID) stack = [ util.PseudonymToken('Make'), util.ProbabilisticToken('Price'), util.OperatorToken('=', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID) schema = test_util.GetPlacesSchema() stack = [ util.PseudonymToken('spouse.spouseName'), util.StringLiteralToken('"Hello"'), util.OperatorToken('=', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '(spouse.' + util.PSEUDONYM_PREFIX + 'spouseName = "HS57DHbh2KlkqNJREmu1wQ==")')
def testExpandExpression(self): stack = [ util.FieldToken('x'), util.FieldToken('y'), util.OperatorToken('+', 2), util.FieldToken('x'), util.OperatorToken('+', 2), 2, util.FieldToken('z'), util.OperatorToken('*', 2), util.OperatorToken('-', 2), 5, util.OperatorToken('-', 2), 3, util.OperatorToken('+', 2) ] list_fields, constant = interpreter._ExpandExpression(stack) self.assertEqual( list_fields, [[2.0, util.FieldToken('x')], [1.0, util.FieldToken('y')], [-2.0, util.FieldToken('z')]]) self.assertEqual(constant, -2.0) stack = [ util.FieldToken('x'), 4, util.OperatorToken('+', 2), 6, util.OperatorToken('*', 2), 2, util.OperatorToken('/', 2) ] list_fields, constant = interpreter._ExpandExpression(stack) self.assertEqual(list_fields, [[3.0, util.FieldToken('x')]]) self.assertEqual(constant, 12.0) stack = [ util.FieldToken('x'), 1, util.OperatorToken('+', 2), util.FieldToken('y'), util.OperatorToken('*', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter._ExpandExpression, stack) stack = [ util.FieldToken('x'), util.FieldToken('y'), util.OperatorToken('/', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter._ExpandExpression, stack)
def testEncryptedContains(self): schema = test_util.GetCarsSchema() key = test_util.GetMasterKey() stack = [ util.FieldToken('Year'), util.BuiltInFunctionToken('string'), util.StringLiteralToken('"1"'), util.OperatorToken('CONTAINS', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '(string(Year) contains "1")') stack = [ util.SearchwordsToken('Model'), util.StringLiteralToken('"A"'), util.OperatorToken('contains', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '(' + util.SEARCHWORDS_PREFIX + 'Model contains ' 'to_base64(left(bytes(sha1(concat(left(' + util.SEARCHWORDS_PREFIX + 'Model, 24), \'yB9HY2qv+DI=\'))), 8)))') stack = [ util.SearchwordsToken('Model'), util.FieldToken('Year'), util.OperatorToken('contains', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID) stack = [ util.PseudonymToken('Make'), 'A', util.OperatorToken('contains', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID) stack = [ util.SearchwordsToken('Model'), util.SearchwordsToken('Model'), util.OperatorToken('contains', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID) stack = [ 'Hello', util.SearchwordsToken('Model'), util.OperatorToken('contains', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID) stack = [ util.SearchwordsToken('Model'), util.StringLiteralToken('"A"'), util.OperatorToken('contains', 2), util.OperatorToken('not', 1) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), 'not (' + util.SEARCHWORDS_PREFIX + 'Model contains ' 'to_base64(left(bytes(sha1(concat(left(' + util.SEARCHWORDS_PREFIX + 'Model, 24), \'yB9HY2qv+DI=\'))), 8)))') schema = test_util.GetPlacesSchema() stack = [ util.SearchwordsToken('citiesLived.place'), util.StringLiteralToken('"A"'), util.OperatorToken('contains', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '(citiesLived.' + util.SEARCHWORDS_PREFIX + 'place contains ' 'to_base64(left(bytes(sha1(concat(left(citiesLived.' + util.SEARCHWORDS_PREFIX + 'place, 24), \'cBKPKGiY2cg=\'))), 8)))')
def _CollapseAggregations(stack, nsquare): """Collapses the aggregations by combining arguments and functions. During collapses, checks will be done to if aggregations are done on encrypted fields. The following aggregations will be rewritten: SUM(<homomorphic field>) becomes TO_BASE64(PAILLIER_SUM(FROM_BASE64(<homomorphic field>), <nsquare>)) AVG(<homomorphic field>) becomes TO_BASE64(PAILLIER_SUM(FROM_BASE64(<homomorphic field>), <nsquare>)) / COUNT(<homomorphic field>) Arguments: stack: The stack whose aggregations are to be collapsed. nsquare: Used for homomorphic addition. Returns: True iff an aggregation was found and collapsed. In other words, another potential aggregation can still exist. """ for i in xrange(len(stack)): if isinstance(stack[i], util.AggregationFunctionToken): num_args = stack[i].num_args function_type = str(stack[i]) postfix_exprs = [] infix_exprs = [] start_idx = i rewritten_infix_expr = None is_encrypted = False # pylint: disable=unused-variable for j in xrange(int(num_args)): start_idx, postfix_expr = interpreter.GetSingleValue( stack[:start_idx]) is_encrypted = is_encrypted or util.IsEncryptedExpression( postfix_expr) while _CollapseFunctions(postfix_expr): pass postfix_exprs.append(postfix_expr) infix_exprs.append(interpreter.ToInfix(list(postfix_expr))) # Check for proper nested aggregations. # PAILLIER_SUM and GROUP_CONCAT on encrypted fields are not legal # arguments for an aggregation. for expr in postfix_exprs: for token in expr: if not isinstance(token, util.AggregationQueryToken): continue if token.startswith(util.PAILLIER_SUM_PREFIX): raise bigquery_client.BigqueryInvalidQueryError( 'Cannot use SUM/AVG on homomorphic encryption as argument ' 'for another aggregation.', None, None, None) elif token.startswith(util.GROUP_CONCAT_PREFIX): fieldname = token.split( util.GROUP_CONCAT_PREFIX)[1][:-1] if util.IsEncrypted(fieldname): raise bigquery_client.BigqueryInvalidQueryError( 'Cannot use GROUP_CONCAT on an encrypted field as argument ' 'for another aggregation.', None, None, None) infix_exprs.reverse() if function_type in ['COUNT', 'DISTINCTCOUNT']: if (function_type == 'DISTINCTCOUNT' and util.IsDeterministicExpression(postfix_exprs[0])): raise bigquery_client.BigqueryInvalidQueryError( 'Cannot do distinct count on non-pseudonym encryption.', None, None, None) if function_type == 'DISTINCTCOUNT': infix_exprs[0] = 'DISTINCT ' + infix_exprs[0] rewritten_infix_expr = [ util.AggregationQueryToken('COUNT(%s)' % ', '.join(infix_exprs)) ] elif function_type == 'TOP': if util.IsDeterministicExpression(postfix_exprs[0]): raise bigquery_client.BigqueryInvalidQueryError( 'Cannot do TOP on non-deterministic encryption.', None, None, None) rewritten_infix_expr = [ util.AggregationQueryToken('TOP(%s)' % ', '.join(infix_exprs)) ] elif function_type in ['AVG', 'SUM'] and is_encrypted: list_fields = interpreter.CheckValidSumAverageArgument( postfix_expr)[0] rewritten_infix_expr = [] # The representative label is the field that is going to be used # to get constant values. An expression SUM(ax + b) must be rewritten as # a * SUM(x) + b * COUNT(x). Represetative label is x (this isn't unique # as many fields can be in COUNT). representative_label = '' for field in list_fields: for token in field: if util.IsLabel(token): representative_label = token break if representative_label: break for field in list_fields: expression = interpreter.ExpandExpression(field) queries, constant = expression[0], expression[1] rewritten_infix_expr.append(float(constant)) rewritten_infix_expr.append( util.AggregationQueryToken('COUNT(%s)' % representative_label)) rewritten_infix_expr.append(util.OperatorToken('*', 2)) for query in queries: rewritten_infix_expr.append(float(query[0])) if (isinstance(query[1], util.HomomorphicFloatToken) or isinstance(query[1], util.HomomorphicIntToken)): rewritten_infix_expr.append( util.ConstructPaillierSumQuery( query[1], nsquare)) else: rewritten_infix_expr.append( util.AggregationQueryToken('SUM(%s)' % query[1])) rewritten_infix_expr.append(util.OperatorToken('*', 2)) for j in range(len(queries)): rewritten_infix_expr.append(util.OperatorToken('+', 2)) for j in range(len(list_fields) - 1): rewritten_infix_expr.append(util.OperatorToken('+', 2)) if function_type == 'AVG': rewritten_infix_expr.append( util.AggregationQueryToken('COUNT(%s)' % representative_label)) rewritten_infix_expr.append(util.OperatorToken('/', 2)) elif function_type == 'GROUP_CONCAT': rewritten_infix_expr = [ util.AggregationQueryToken('GROUP_CONCAT(%s)' % ', '.join(infix_exprs)) ] elif is_encrypted: raise bigquery_client.BigqueryInvalidQueryError( 'Cannot do %s aggregation on any encrypted fields.' % function_type, None, None, None) else: rewritten_infix_expr = [ util.AggregationQueryToken( '%s(%s)' % (function_type, ', '.join(infix_exprs))) ] stack[start_idx:i + 1] = rewritten_infix_expr return True return False
def testEncryptedEquality(self): schema = test_util.GetCarsSchema() key = test_util.GetMasterKey() stack = [ util.FieldToken('Year'), 1, util.OperatorToken('+', 2), 2000, util.OperatorToken('=', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '((Year + 1) = 2000)') stack = [ util.FieldToken('Year'), util.PseudonymToken('Make'), util.OperatorToken('=', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '(Year = ' + util.PSEUDONYM_PREFIX + 'Make)') stack = [ util.PseudonymToken('Make'), util.StringLiteralToken('"Hello"'), util.OperatorToken('==', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '(' + util.PSEUDONYM_PREFIX + 'Make == "HS57DHbh2KlkqNJREmu1wQ==")') stack = [ util.StringLiteralToken('"Hello"'), util.PseudonymToken('Make'), util.OperatorToken('==', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '("HS57DHbh2KlkqNJREmu1wQ==" == ' + util.PSEUDONYM_PREFIX + 'Make)') stack = [ util.StringLiteralToken('"Hello"'), util.PseudonymToken('Make'), util.OperatorToken('!=', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '("HS57DHbh2KlkqNJREmu1wQ==" != ' + util.PSEUDONYM_PREFIX + 'Make)') stack = [ util.PseudonymToken('Make'), util.PseudonymToken('Make2'), util.OperatorToken('==', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '(' + util.PSEUDONYM_PREFIX + 'Make == ' + util.PSEUDONYM_PREFIX + 'Make2)') stack = [ util.HomomorphicIntToken('Invoice_Price'), 2, util.OperatorToken('==', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID) stack = [ util.PseudonymToken('Make'), util.ProbabilisticToken('Price'), util.OperatorToken('=', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID) schema = test_util.GetPlacesSchema() stack = [ util.PseudonymToken('spouse.spouseName'), util.StringLiteralToken('"Hello"'), util.OperatorToken('=', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '(spouse.' + util.PSEUDONYM_PREFIX + 'spouseName = "HS57DHbh2KlkqNJREmu1wQ==")')