def testGetEntryFromSchema(self):
     simple_schema = test_util.GetCarsSchema()
     nested_schema = test_util.GetJobsSchema()
     row = util.GetEntryFromSchema('Year', simple_schema)
     self.assertEqual(row['name'], 'Year')
     self.assertEqual(row['encrypt'], 'none')
     row = util.GetEntryFromSchema('citiesLived.place', nested_schema)
     self.assertEqual(row['name'], 'place')
     self.assertEqual(row['encrypt'], 'searchwords')
     row = util.GetEntryFromSchema('citiesLived.job.position',
                                   nested_schema)
     self.assertEqual(row['name'], 'position')
     self.assertEqual(row['encrypt'], 'pseudonym')
     row = util.GetEntryFromSchema('citiesLived.job', nested_schema)
     self.assertEqual(row, None)
     row = util.GetEntryFromSchema('citiesLived.non_existent_field',
                                   nested_schema)
     self.assertEqual(row, None)
예제 #2
0
    def Rewrite(self):
        """Rewrites group by argument to send to BigQuery server.

    Returns:
      Rewritten group by clause.

    Raises:
      ValueError: Invalid clause type or necessary argument not given.
    """
        if not self._argument:
            return ''
        necessary_attributes = [
            'nsquare',
            'schema',
            'select_clause',
        ]
        self._CheckNecessaryAttributes(necessary_attributes)
        if not isinstance(self.select_clause, _SelectClause):
            raise ValueError('Invalid select clause.')
        for argument in self._argument:
            row = util.GetEntryFromSchema(argument, self.schema)
            if (row['encrypt'].startswith('probabilistic')
                    or row['encrypt'] == 'homomorphic'
                    or row['encrypt'] == 'searchwords'):
                raise bigquery_client.BigqueryInvalidQueryError(
                    'Cannot GROUP BY %s encryption.' % row['encrypt'], None,
                    None, None)
        # Group by arguments have no alias, so an empty dictionary is adequate.
        rewritten_argument = _RewritePostfixExpressions([self._argument], {},
                                                        self.schema,
                                                        self.nsquare)[0]
        # Only want expressions, remove alias from expression.
        unencrypted_expression_list = []
        for query in self.select_clause.GetUnencryptedQueries():
            unencrypted_expression_list.append(' '.join(query.split(' ')[:-2]))
        for i in range(len(rewritten_argument)):
            if rewritten_argument[i] in unencrypted_expression_list:
                rewritten_argument[i] = (
                    '%s%d_' %
                    (util.UNENCRYPTED_ALIAS_PREFIX,
                     unencrypted_expression_list.index(rewritten_argument[i])))
            else:
                manifest = getattr(self, 'manifest', None)
                if manifest is not None:
                    column_alias = manifest.GetColumnAliasForName(
                        rewritten_argument[i], generate=False)
                else:
                    column_alias = None
                if column_alias is not None:
                    rewritten_argument[i] = column_alias
                else:
                    rewritten_argument[i] = rewritten_argument[i].replace(
                        '.', util.PERIOD_REPLACEMENT)
        return 'GROUP BY %s' % ', '.join(rewritten_argument)
예제 #3
0
 def RewriteField(field):
     """Rewrite fields for real query with server."""
     if not isinstance(field, util.FieldToken):
         return field
     row = util.GetEntryFromSchema(field, schema)
     if not row:
         return field
     if row['encrypt'].startswith('probabilistic'):
         return util.ProbabilisticToken(str(field))
     elif row['encrypt'] == 'pseudonym':
         return util.PseudonymToken(str(field))
     elif row['encrypt'] == 'homomorphic' and row['type'] == 'integer':
         return util.HomomorphicIntToken(str(field))
     elif row['encrypt'] == 'homomorphic' and row['type'] == 'float':
         return util.HomomorphicFloatToken(str(field))
     elif row['encrypt'] == 'searchwords':
         return util.SearchwordsToken(str(field))
     return field
예제 #4
0
    def CheckSearchableField(op1):
        """Checks if the operand is a searchable encrypted field.

    Arguments:
      op1: The operand that is being checked if it is searchable.

    Returns:
      True iff op1 is searchable.
    """
        if isinstance(op1, util.SearchwordsToken):
            return True
        elif not isinstance(op1, util.ProbabilisticToken):
            return False
        op1 = op1.original_name
        row = util.GetEntryFromSchema(op1, schema)
        if row['encrypt'] in ['probabilistic_searchwords', 'searchwords']:
            return True
        else:
            return False
        return False
예제 #5
0
    def RewriteSearchwordsEncryption(field, literal):
        """Rewrites the literal such that it can be checked for containment.

    Arguments:
      field: The field which is being checked if literal is contained within.
      literal: Substring being searched for.

    Returns:
      A tuple containing both field and literal rewritten.

    Raises:
      ValueError: Try to rewrite non-searchwords encryption.
    """
        if (not isinstance(field, util.SearchwordsToken)
                and not isinstance(field, util.ProbabilisticToken)):
            raise ValueError('Invalid encryption to check containment.')
        field = field.original_name
        row = util.GetEntryFromSchema(field, schema)
        modified_field = util.SEARCHWORDS_PREFIX + row['name']
        field = field.split('.')
        field[-1] = modified_field
        modified_field = '.'.join(field)
        if 'searchwords_separator' in row:
            searchwords_separator = row['searchwords_separator']
        else:
            searchwords_separator = None
        word_list = ecrypto.CleanUnicodeString(unicode(literal.value),
                                               separator=searchwords_separator)
        if searchwords_separator is None:
            word_seq = ' '.join(word_list)
        else:
            word_seq = searchwords_separator.join(word_list)
        keyed_hash = (u'\'%s\'' % string_hasher.GetStringKeyHash(
            modified_field.split('.')[-1], word_seq))
        modified_string = (
            u'to_base64(left(bytes(sha1(concat(left(%s, 24), %s))), 8))' %
            (modified_field, keyed_hash))
        return (modified_field, modified_string)