def testFailOperationsOnEncryptions(self): schema = test_util.GetCarsSchema() key = test_util.GetMasterKey() stack = [ util.PseudonymToken('Year'), 1, util.OperatorToken('+', 2), 2000, util.OperatorToken('>=', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID) stack = [ util.ProbabilisticToken('Model'), 2, util.BuiltInFunctionToken('left') ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID) stack = [ util.HomomorphicIntToken('Invoice_Price'), util.BuiltInFunctionToken('is_nan') ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID)
def testOneArgumentFunction(self): stack = [ 0, util.BuiltInFunctionToken('cos'), util.BuiltInFunctionToken('ln'), util.BuiltInFunctionToken('sqrt') ] self.assertEqual(interpreter.ToInfix(list(stack)), 'sqrt(ln(cos(0)))') self.assertEqual(interpreter.Evaluate(stack), 0)
def testMultipleArgumentFunction(self): stack = [ 'True', 3, 0, util.BuiltInFunctionToken('if'), 2, 1, util.BuiltInFunctionToken('pow'), util.BuiltInFunctionToken('pow') ] self.assertEqual(interpreter.ToInfix(list(stack)), 'pow(if(True, 3, 0), pow(2, 1))') self.assertEqual(interpreter.Evaluate(stack), 9)
def PushFunction(tokens): """Push a function token onto the stack. Args: tokens: list of all tokens, tokens[0] is the function name str. """ math_stack.append(util.BuiltInFunctionToken(tokens[0]))
def testToBase64(self): """Test built-in function TO_BASE64(string).""" stack = [ util.StringLiteralToken('"hello test"'), util.BuiltInFunctionToken('to_base64') ] self.assertEqual(interpreter.Evaluate(stack), 'aGVsbG8gdGVzdA==')
def testString(self): stack = [ util.StringLiteralToken('"TESTING IS FUN."'), 4, util.BuiltInFunctionToken('left') ] self.assertEqual(interpreter.ToInfix(list(stack)), 'left("TESTING IS FUN.", 4)') self.assertEqual(interpreter.Evaluate(stack), 'TEST')
def testFromBase64(self): """Don't implement FROM_BASE64() until BigQuery has BYTES type.""" base64_str = 'aGVsbG8gdGVzdA==' # "hello test" stack = [ util.StringLiteralToken('"%s"' % base64_str), util.BuiltInFunctionToken('from_base64') ] self.assertEqual(interpreter.Evaluate(stack), 'FROM_BASE64("%s")' % base64_str)
def testToBase64RawBytes(self): """Test built-in function TO_BASE64(arbitrary bytes).""" bytes_str = '\xb4\x00\xb0\x09\xcd\x10' bytes_str_b64 = 'tACwCc0Q' stack = [ util.StringLiteralToken('"%s"' % bytes_str), util.BuiltInFunctionToken('to_base64') ] self.assertEqual(interpreter.Evaluate(stack), bytes_str_b64)
def testToBase64Utf8(self): """Test built-in function TO_BASE64(utf8 string).""" unicode_str = u'M\u00fcnchen' unicode_str_b64 = 'TcO8bmNoZW4=' stack = [ util.StringLiteralToken('"%s"' % unicode_str.encode('utf-8')), util.BuiltInFunctionToken('to_base64') ] self.assertEqual(interpreter.Evaluate(stack), unicode_str_b64)
def testToBase64Ascii(self): """Test built-in function TO_BASE64(string).""" ascii_str = 'hello test' ascii_str_b64 = 'aGVsbG8gdGVzdA==' stack = [ util.StringLiteralToken('"%s"' % ascii_str), util.BuiltInFunctionToken('to_base64') ] self.assertEqual(interpreter.Evaluate(stack), ascii_str_b64)
def testHavingRewrite(self): schema = test_util.GetCarsSchema() master_key = test_util.GetMasterKey() as_clause = query_lib._AsClause({}) stack = [util.FieldToken('SUM(Year)'), 1, util.OperatorToken('<', 2)] having_clause = query_lib._HavingClause(stack, as_clause=as_clause, schema=schema, nsquare=_TEST_NSQUARE, master_key=master_key, table_id=_TABLE_ID) self.assertEqual(having_clause.Rewrite(), 'HAVING (SUM(Year) < 1)') stack = [ 1000, util.AggregationQueryToken( 'TO_BASE64(BYTES(PAILLIER_SUM(FROM_BASE64(' + util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price), \'0\')))'), util.OperatorToken('==', 2) ] having_clause = query_lib._HavingClause(stack, as_clause=as_clause, schema=schema, nsquare=_TEST_NSQUARE, master_key=master_key, table_id=_TABLE_ID) self.assertRaises(bigquery_client.BigqueryInvalidQueryError, having_clause.Rewrite) stack = [ util.FieldToken('GROUP_CONCAT(' + util.PSEUDONYM_PREFIX + 'Model)'), util.BuiltInFunctionToken('len'), 5, util.OperatorToken('>', 2) ] having_clause = query_lib._HavingClause(stack, as_clause=as_clause, schema=schema, nsquare=_TEST_NSQUARE, master_key=master_key, table_id=_TABLE_ID) self.assertRaises(bigquery_client.BigqueryInvalidQueryError, having_clause.Rewrite) stack = [] having_clause = query_lib._HavingClause(stack, as_clause=as_clause, schema=schema, nsquare=_TEST_NSQUARE, master_key=master_key, table_id=_TABLE_ID) self.assertEqual(having_clause.Rewrite(), '')
def testExtractEncryptedQueries(self): # Original query is 'SELECT (111 + a_1) * a, TRUE OR False, null, PI(), # FUNC_PI, a1' # Query sent to server becomes 'SELECT a, a_1, FUNC_PI, a1' stacks = [[ util.FieldToken('a'), 111, util.FieldToken('a_1'), util.OperatorToken('+', 2), util.OperatorToken('*', 2) ], ['TRUE', 'False', util.OperatorToken('OR', 2)], ['null'], [util.BuiltInFunctionToken('PI')], [util.FieldToken('FUNC_PI')], [util.FieldToken('a1')], [util.FieldToken('a.b')], [ util.UnencryptedQueryToken('%s0_' % util.UNENCRYPTED_ALIAS_PREFIX) ]] query_list = query_lib._ExtractFieldQueries(stacks, strize=True) expect_query_list = set([ 'a', 'a_1', 'FUNC_PI', 'a1', 'a.b AS a' + util.PERIOD_REPLACEMENT + 'b' ]) self.assertEqual(expect_query_list, query_list)
def testNonexistentFunction(self): stack = [1, util.BuiltInFunctionToken('hi')] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.ToInfix, list(stack)) self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.Evaluate, stack)
def testRewriteAggregations(self): stack = [ util.CountStarToken(), util.AggregationFunctionToken('COUNT', 1) ] rewritten_stack = ['COUNT(*)'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.ProbabilisticToken('Price'), util.AggregationFunctionToken('COUNT', 1) ] rewritten_stack = ['COUNT(' + util.PROBABILISTIC_PREFIX + 'Price)'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.ProbabilisticToken('Price'), 4, util.AggregationFunctionToken('COUNT', 2) ] rewritten_stack = ['COUNT(' + util.PROBABILISTIC_PREFIX + 'Price, 4)'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.FieldToken('Year'), 5, util.AggregationFunctionToken('DISTINCTCOUNT', 2), util.FieldToken('Year'), util.AggregationFunctionToken('COUNT', 1), util.OperatorToken('+', 2) ] rewritten_stack = ['COUNT(DISTINCT Year, 5)', 'COUNT(Year)', '+'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ 0, util.BuiltInFunctionToken('cos'), util.AggregationFunctionToken('COUNT', 1) ] rewritten_stack = ['COUNT(1.0)'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.StringLiteralToken('"Hello"'), 2, util.BuiltInFunctionToken('left'), util.StringLiteralToken('"y"'), util.BuiltInFunctionToken('concat'), util.AggregationFunctionToken('GROUP_CONCAT', 1) ] rewritten_stack = ['GROUP_CONCAT("Hey")'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.FieldToken('Year'), util.FieldToken('Year'), util.OperatorToken('*', 2), util.AggregationFunctionToken('SUM', 1) ] rewritten_stack = ['SUM((Year * Year))'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.HomomorphicIntToken('Invoice_Price'), util.AggregationFunctionToken('SUM', 1) ] rewritten_stack = [ 0.0, 'COUNT(' + util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price)', '*', 1.0, 'TO_BASE64(BYTES(PAILLIER_SUM(FROM_BASE64(' + util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price), \'0\')))', '*', '+' ] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.HomomorphicFloatToken('Holdback_Percentage'), util.AggregationFunctionToken('AVG', 1) ] rewritten_stack = [ 0.0, 'COUNT(' + util.HOMOMORPHIC_FLOAT_PREFIX + 'Holdback_Percentage)', '*', 1.0, 'TO_BASE64(BYTES(PAILLIER_SUM(FROM_BASE64(' + util.HOMOMORPHIC_FLOAT_PREFIX + 'Holdback_Percentage), \'0\')))', '*', '+', 'COUNT(' + util.HOMOMORPHIC_FLOAT_PREFIX + 'Holdback_Percentage)', '/' ] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.HomomorphicIntToken('Invoice_Price'), 2, util.OperatorToken('+', 2), 5, util.OperatorToken('*', 2), util.AggregationFunctionToken('SUM', 1) ] rewritten_stack = [ 0.0, 'COUNT(' + util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price)', '*', 5.0, 'TO_BASE64(BYTES(PAILLIER_SUM(FROM_BASE64(' + util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price), \'0\')))', '*', '+', 0.0, 'COUNT(' + util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price)', '*', 1.0, 'SUM((2 * 5))', '*', '+', '+' ] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.PseudonymToken('Make'), 2, util.AggregationFunctionToken('DISTINCTCOUNT', 2) ] rewritten_stack = [ 'COUNT(DISTINCT ' + util.PSEUDONYM_PREFIX + 'Make, 2)' ] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.FieldToken('Year'), util.AggregationFunctionToken('TOP', 1) ] rewritten_stack = ['TOP(Year)'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.PseudonymToken('Make'), 5, 1, util.AggregationFunctionToken('TOP', 3) ] rewritten_stack = ['TOP(' + util.PSEUDONYM_PREFIX + 'Make, 5, 1)'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.FieldToken('Year'), util.BuiltInFunctionToken('cos'), util.HomomorphicIntToken('Invoice_Price'), util.OperatorToken('+', 2), util.AggregationFunctionToken('SUM', 1) ] rewritten_stack = [ 0.0, 'COUNT(' + util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price)', '*', 1.0, 'SUM(cos(Year))', '*', '+', 0.0, 'COUNT(' + util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price)', '*', 1.0, 'TO_BASE64(BYTES(PAILLIER_SUM(FROM_BASE64(' + util.HOMOMORPHIC_INT_PREFIX + 'Invoice_Price),' ' \'0\')))', '*', '+', '+' ] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.ProbabilisticToken('Model'), util.AggregationFunctionToken('DISTINCTCOUNT', 1) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, query_lib._RewriteAggregations, [stack], _TEST_NSQUARE) stack = [ util.ProbabilisticToken('Price'), util.AggregationFunctionToken('SUM', 1) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, query_lib._RewriteAggregations, [stack], _TEST_NSQUARE) stack = [ util.HomomorphicIntToken('Invoice_Price'), util.HomomorphicFloatToken('Holdback_Percentage'), util.OperatorToken('*', 2), util.AggregationFunctionToken('SUM', 1) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, query_lib._RewriteAggregations, [stack], _TEST_NSQUARE) stack = [ util.HomomorphicFloatToken('Holdback_Percentage'), util.BuiltInFunctionToken('cos'), util.AggregationFunctionToken('SUM', 1) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, query_lib._RewriteAggregations, [stack], _TEST_NSQUARE) stack = [ util.HomomorphicIntToken('Invoice_Price'), util.AggregationFunctionToken('TOP', 1) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, query_lib._RewriteAggregations, [stack], _TEST_NSQUARE) stack = [ util.FieldToken('Year'), util.AggregationFunctionToken('SUM', 1), util.AggregationFunctionToken('SUM', 1) ] rewritten_stack = ['SUM(SUM(Year))'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.HomomorphicIntToken('Invoice_Price'), util.AggregationFunctionToken('SUM', 1), util.AggregationFunctionToken('SUM', 1) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, query_lib._RewriteAggregations, [stack], _TEST_NSQUARE) stack = [ util.FieldToken('Year'), util.AggregationFunctionToken('GROUP_CONCAT', 1), util.AggregationFunctionToken('GROUP_CONCAT', 1) ] rewritten_stack = ['GROUP_CONCAT(GROUP_CONCAT(Year))'] self.assertEqual( query_lib._RewriteAggregations([stack], _TEST_NSQUARE), [rewritten_stack]) stack = [ util.PseudonymToken('Make'), util.AggregationFunctionToken('GROUP_CONCAT', 1), util.AggregationFunctionToken('GROUP_CONCAT', 1) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, query_lib._RewriteAggregations, [stack], _TEST_NSQUARE)
def testNoArgumentFunction(self): stack = [util.BuiltInFunctionToken('pi')] self.assertEqual(interpreter.ToInfix(list(stack)), 'pi()') self.assertEqual(interpreter.Evaluate(stack), math.pi)
def testTooManyArgumentsFunction(self): stack = [1, 2, 3, util.BuiltInFunctionToken('COS')] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.ToInfix, list(stack)) self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.Evaluate, stack)
def testEncryptedContains(self): schema = test_util.GetCarsSchema() key = test_util.GetMasterKey() stack = [ util.FieldToken('Year'), util.BuiltInFunctionToken('string'), util.StringLiteralToken('"1"'), util.OperatorToken('CONTAINS', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '(string(Year) contains "1")') stack = [ util.SearchwordsToken('Model'), util.StringLiteralToken('"A"'), util.OperatorToken('contains', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '(' + util.SEARCHWORDS_PREFIX + 'Model contains ' 'to_base64(left(bytes(sha1(concat(left(' + util.SEARCHWORDS_PREFIX + 'Model, 24), \'yB9HY2qv+DI=\'))), 8)))') stack = [ util.SearchwordsToken('Model'), util.FieldToken('Year'), util.OperatorToken('contains', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID) stack = [ util.PseudonymToken('Make'), 'A', util.OperatorToken('contains', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID) stack = [ util.SearchwordsToken('Model'), util.SearchwordsToken('Model'), util.OperatorToken('contains', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID) stack = [ 'Hello', util.SearchwordsToken('Model'), util.OperatorToken('contains', 2) ] self.assertRaises(bigquery_client.BigqueryInvalidQueryError, interpreter.RewriteSelectionCriteria, stack, schema, key, _TABLE_ID) stack = [ util.SearchwordsToken('Model'), util.StringLiteralToken('"A"'), util.OperatorToken('contains', 2), util.OperatorToken('not', 1) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), 'not (' + util.SEARCHWORDS_PREFIX + 'Model contains ' 'to_base64(left(bytes(sha1(concat(left(' + util.SEARCHWORDS_PREFIX + 'Model, 24), \'yB9HY2qv+DI=\'))), 8)))') schema = test_util.GetPlacesSchema() stack = [ util.SearchwordsToken('citiesLived.place'), util.StringLiteralToken('"A"'), util.OperatorToken('contains', 2) ] self.assertEqual( interpreter.RewriteSelectionCriteria(stack, schema, key, _TABLE_ID), '(citiesLived.' + util.SEARCHWORDS_PREFIX + 'place contains ' 'to_base64(left(bytes(sha1(concat(left(citiesLived.' + util.SEARCHWORDS_PREFIX + 'place, 24), \'cBKPKGiY2cg=\'))), 8)))')