def _Eval(self, node):
        if node.getType() is ExpressionParser.FN:
            func = self._function_table[query_parser.GetQueryNodeText(node)]

            return func(*node.children)

        if node.getType() is ExpressionParser.PLUS:
            return self._EvalBinaryOp(lambda a, b: a + b, 'addition', node)
        if node.getType() is ExpressionParser.MINUS:
            return self._EvalBinaryOp(lambda a, b: a - b, 'subtraction', node)
        if node.getType() is ExpressionParser.DIV:
            return self._EvalBinaryOp(lambda a, b: a / b, 'division', node)
        if node.getType() is ExpressionParser.TIMES:
            return self._EvalBinaryOp(lambda a, b: a * b, 'multiplication',
                                      node)
        if node.getType() is ExpressionParser.NEG:
            return self._EvalUnaryOp(lambda a: -a, 'negation', node)

        if node.getType() in (ExpressionParser.INT, ExpressionParser.FLOAT):
            return float(query_parser.GetQueryNodeText(node))
        if node.getType() is ExpressionParser.PHRASE:
            return query_parser.GetQueryNodeText(node).strip('"')

        if node.getType() is ExpressionParser.NAME:
            name = query_parser.GetQueryNodeText(node)
            if name == '_score':
                return self._doc.score
            field = search_util.GetFieldInDocument(self._doc_pb, name)
            if field:
                return search_util.GetFieldValue(field)
            raise _ExpressionError('No field %s in document' % name)

        raise _ExpressionError('Unable to handle node %s' % node)
예제 #2
0
    def _MatchTextField(self, field, match, document):
        """Check if a textual field matches a query tree node."""

        if match.getType() == QueryParser.FUZZY:
            return self._MatchTextField(field, match.getChild(0), document)

        if match.getType() == QueryParser.VALUE:
            if query_parser.IsPhrase(match):
                return self._MatchPhrase(field, match, document)

            if field.value().type() == document_pb.FieldValue.ATOM:
                return (field.value().string_value() ==
                        query_parser.GetQueryNodeText(match))

            query_tokens = self._parser.TokenizeText(
                query_parser.GetQueryNodeText(match))

            if not query_tokens:
                return True

            if len(query_tokens) > 1:

                def QueryNode(token):
                    return query_parser.CreateQueryNode(
                        token.chars, QueryParser.TEXT)

                return all(
                    self._MatchTextField(field, QueryNode(token), document)
                    for token in query_tokens)

            token_text = query_tokens[0].chars
            matching_docids = [
                post.doc_id for post in self._PostingsForFieldToken(
                    field.name(), token_text)
            ]
            return document.id() in matching_docids

        def ExtractGlobalEq(node):
            op = node.getType()
            if ((op == QueryParser.EQ or op == QueryParser.HAS)
                    and len(node.children) >= 2):
                if node.children[0].getType() == QueryParser.GLOBAL:
                    return node.children[1]
            return node

        if match.getType() == QueryParser.CONJUNCTION:
            return all(
                self._MatchTextField(field, ExtractGlobalEq(child), document)
                for child in match.children)

        if match.getType() == QueryParser.DISJUNCTION:
            return any(
                self._MatchTextField(field, ExtractGlobalEq(child), document)
                for child in match.children)

        if match.getType() == QueryParser.NEGATION:
            raise ExpressionTreeException('Unable to compare \"' +
                                          field.name() + '\" with negation')

        return False
예제 #3
0
    def _Snippet(self, query, field, *args):
        """Create a snippet given a query and the field to query on.

    Args:
      query: A query string containing only a bare term (no operators).
      field: The field name to query on.
      *args: Unused optional arguments. These are not used on dev_appserver.

    Returns:
      A snippet for the field with the query term bolded.

    Raises:
      ExpressionEvaluationError: if this is a sort expression.
    """
        field = query_parser.GetQueryNodeText(field)

        if self._is_sort_expression:
            raise ExpressionEvaluationError(
                'Failed to parse sort expression \'snippet(' +
                query_parser.GetQueryNodeText(query) + ', ' + field +
                ')\': snippet() is not supported in sort expressions')

        schema = self._inverted_index.GetSchema()
        if schema.IsType(field, document_pb.FieldValue.NUMBER):
            raise ExpressionEvaluationError(
                'Failed to parse field expression \'snippet(' +
                query_parser.GetQueryNodeText(query) + ', ' + field +
                ')\': snippet() argument 2 must be text')

        terms = self._tokenizer.TokenizeText(
            query_parser.GetQueryNodeText(query).strip('"'))
        for term in terms:
            search_token = tokens.Token(chars=u'%s:%s' % (field, term.chars))
            postings = self._inverted_index.GetPostingsForToken(search_token)
            for posting in postings:
                if posting.doc_id != self._doc_pb.id(
                ) or not posting.positions:
                    continue

                field_val = self._GetFieldValue(
                    search_util.GetFieldInDocument(self._doc_pb, field))
                if not field_val:
                    continue
                doc_words = [
                    token.chars for token in
                    self._case_preserving_tokenizer.TokenizeText(field_val)
                ]

                position = posting.positions[0]
                return self._GenerateSnippet(
                    doc_words, position,
                    search_util.DEFAULT_MAX_SNIPPET_LENGTH)
            else:
                field_val = self._GetFieldValue(
                    search_util.GetFieldInDocument(self._doc_pb, field))
                if not field_val:
                    return ''
                return '%s...' % field_val[:search_util.
                                           DEFAULT_MAX_SNIPPET_LENGTH]
예제 #4
0
 def _ResolveDistanceArg(self, node):
   if node.getType() == QueryParser.VALUE:
     return query_parser.GetQueryNodeText(node)
   if node.getType() == QueryParser.FUNCTION:
     name, args = node.children
     if name.getText() == 'geopoint':
       lat, lng = (float(query_parser.GetQueryNodeText(v)) for v in args.children)
       return geo_util.LatLng(lat, lng)
   return None
예제 #5
0
    def _MatchComparableField(self, field, match, cast_to_type,
                              query_node_types, document):
        """A generic method to test matching for comparable types.

    Comparable types are defined to be anything that supports <, >, <=, >=, ==
    and !=. For our purposes, this is numbers and dates.

    Args:
      field: The document_pb.Field to test
      match: The query node to match against
      cast_to_type: The type to cast the node string values to
      query_node_types: The query node types that would be valid matches
      document: The document that the field is in

    Returns:
      True iff the field matches the query.

    Raises:
      UnsupportedOnDevError: Raised when an unsupported operator is used, or
      when the query node is of the wrong type.
    """

        field_val = cast_to_type(field.value().string_value())

        op = QueryParser.EQ

        if match.getType() in query_node_types:
            try:
                match_val = cast_to_type(query_parser.GetQueryNodeText(match))
            except ValueError:
                return False
        elif match.children:
            op = match.getType()
            try:
                match_val = cast_to_type(
                    query_parser.GetQueryNodeText(match.children[0]))
            except ValueError:
                return False
        else:
            return False

        if op is QueryParser.EQ:
            return field_val == match_val
        if op is QueryParser.NE:
            return field_val != match_val
        if op is QueryParser.GT:
            return field_val > match_val
        if op is QueryParser.GE:
            return field_val >= match_val
        if op is QueryParser.LT:
            return field_val < match_val
        if op is QueryParser.LE:
            return field_val <= match_val
        raise search_util.UnsupportedOnDevError(
            'Operator %s not supported for numerical fields on development server.'
            % match.getText())
예제 #6
0
    def _Eval(self, node):
        """Evaluate an expression node on the document.

    Args:
      node: The expression AST node representing an expression subtree.

    Returns:
      The Python value that maps to the value of node. Types are inferred from
      the expression, so expressions with numeric results will return as python
      int/long/floats, textual results will be strings, and dates will be
      datetimes.

    Raises:
      _ExpressionError: The expression cannot be evaluated on this document
      because either the expression is malformed or the document does not
      contain the required fields. Callers of _Eval should catch
      _ExpressionErrors and optionally log them; these are not fatal in any way,
      and are used to indicate that this expression should not be set on this
      document.
    """
        if node.getType() in self._function_table:
            func = self._function_table[node.getType()]

            return func(*node.children)

        if node.getType() == ExpressionParser.PLUS:
            return self._EvalBinaryOp(lambda a, b: a + b, 'addition', node)
        if node.getType() == ExpressionParser.MINUS:
            return self._EvalBinaryOp(lambda a, b: a - b, 'subtraction', node)
        if node.getType() == ExpressionParser.DIV:
            return self._EvalBinaryOp(lambda a, b: a / b, 'division', node)
        if node.getType() == ExpressionParser.TIMES:
            return self._EvalBinaryOp(lambda a, b: a * b, 'multiplication',
                                      node)
        if node.getType() == ExpressionParser.NEG:
            return self._EvalUnaryOp(lambda a: -a, 'negation', node)

        if node.getType() in (ExpressionParser.INT, ExpressionParser.FLOAT):
            return float(query_parser.GetQueryNodeText(node))
        if node.getType() == ExpressionParser.PHRASE:
            return query_parser.GetQueryNodeText(node).strip('"')

        if node.getType() == ExpressionParser.NAME:
            name = query_parser.GetQueryNodeText(node)
            if name == '_score':
                return self._doc.score
            field = search_util.GetFieldInDocument(self._doc_pb, name)
            if field:
                return self._GetFieldValue(field)
            raise _ExpressionError('No field %s in document' % name)

        raise _ExpressionError('Unable to handle node %s' % node)
예제 #7
0
 def __create_query_string(self, query_tree):
     """ Creates a SOLR query string from a antlr3 parse tree.
 
 Args:
   query_tree: A antlr3.tree.CommonTree.
 Returns:
   A string which can be sent to SOLR.
 """
     q_str = ""
     if query_tree.getType() == QueryParser.CONJUNCTION:
         q_str += "("
         for index, child in enumerate(query_tree.children):
             if index != 0:
                 q_str += "+AND"
             q_str += self.__create_query_string(child)
         q_str += ")"
     elif query_tree.getType() == QueryParser.DISJUNCTION:
         q_str += "+AND+("
         for index, child in enumerate(query_tree.children):
             if index != 0:
                 q_str += "+OR"
             q_str += self.__create_query_string(child)
         q_str += ")"
     elif query_tree.getType() == QueryParser.NEGATION:
         q_str += "+NOT+("
         for index, child in enumerate(query_tree.children):
             if index != 0:
                 q_str += "+AND"
             q_str += self.__create_query_string(child)
         q_str += ")"
     elif query_tree.getType() in query_parser.COMPARISON_TYPES:
         field, match = query_tree.children
         if field.getType() == QueryParser.GLOBAL:
             field = query_parser.GetQueryNodeText(match)
             field = self.__escape_chars(field)  #TODO
             q_str += "\"{0}\"".format(field)
         else:
             field = query_parser.GetQueryNodeText(field)
             match = query_parser.GetQueryNodeText(match)
             internal_field_name = self.__get_internal_field_name(
                 field)  #TODO
             escaped_value = self.__escape_chars(match)  #TODO
             oper = self.__get_operator(query_tree.getType())
             q_str += "{0}{1}\"{2}\"".format(internal_field_name, oper,
                                             escaped_value)
     else:
         logging.warning("No node match for {0}".format(
             query_tree.getType()))
     logging.debug("Query string: {0}".format(q_str))
     q_str = urllib.quote_plus(q_str, '+')
     logging.debug("Encoded: {0}".format(q_str))
     return q_str
예제 #8
0
    def _MatchTextField(self, field, match, document):
        """Check if a textual field matches a query tree node."""

        if (match.getType() in (QueryParser.TEXT, QueryParser.NAME)
                or match.getType() in search_util.NUMBER_QUERY_TYPES):
            matching_docids = [
                post.doc_id for post in self._PostingsForFieldToken(
                    field.name(), query_parser.GetQueryNodeText(match))
            ]
            return document.id() in matching_docids

        if match.getType() is QueryParser.PHRASE:
            return self._MatchPhrase(field, match, document)

        if match.getType() is QueryParser.CONJUNCTION:
            return all(
                self._MatchTextField(field, child, document)
                for child in match.children)

        if match.getType() is QueryParser.DISJUNCTION:
            return any(
                self._MatchTextField(field, child, document)
                for child in match.children)

        if match.getType() is QueryParser.NEGATION:
            return not self._MatchTextField(field, match.children[0], document)

        return False
예제 #9
0
    def _MatchPhrase(self, field, match, document):
        """Match a textual field with a phrase query node."""
        phrase = self._SplitPhrase(query_parser.GetQueryNodeText(match))
        if not phrase:
            return True
        field_text = self._parser.TokenizeText(field.value().string_value())

        posting = None
        for post in self._PostingsForFieldToken(field.name(), phrase[0].chars):
            if post.doc_id == document.id():
                posting = post
                break
        if not posting:
            return False

        def ExtractWords(token_list):
            return (token.chars for token in token_list)

        for position in posting.positions:

            match_words = zip(ExtractWords(field_text[position:]),
                              ExtractWords(phrase))
            if len(match_words) != len(phrase):
                continue

            match = True
            for doc_word, match_word in match_words:
                if doc_word != match_word:
                    match = False

            if match:
                return True
        return False
예제 #10
0
    def _MatchField(self, field, match, operator, document):
        """Check if a field matches a query tree.

    Args:
      field_query_node: Either a string containing the name of a field, a query
      node whose text is the name of the field, or a document_pb.Field.
      match: A query node to match the field with.
      operator: The a query node type corresponding to the type of match to
        perform (eg QueryParser.EQ, QueryParser.GT, etc).
      document: The document to match.
    """

        if isinstance(field, (basestring, tree.CommonTree)):
            if isinstance(field, tree.CommonTree):
                field = query_parser.GetQueryNodeText(field)
            fields = search_util.GetAllFieldInDocument(document, field)
            return any(
                self._MatchField(f, match, operator, document) for f in fields)

        if field.value().type() in search_util.TEXT_DOCUMENT_FIELD_TYPES:
            if operator != QueryParser.EQ:
                return False
            return self._MatchTextField(field, match, document)

        if field.value().type() in search_util.NUMBER_DOCUMENT_FIELD_TYPES:
            return self._MatchNumericField(field, match, operator, document)

        if field.value().type() == document_pb.FieldValue.DATE:
            return self._MatchDateField(field, match, operator, document)

        type_name = document_pb.FieldValue.ContentType_Name(
            field.value().type()).lower()
        raise search_util.UnsupportedOnDevError(
            'Matching fields of type %s is unsupported on dev server (searched for '
            'field %s)' % (type_name, field.name()))
예제 #11
0
    def _Count(self, node):

        if node.getType() != ExpressionParser.NAME:
            raise _ExpressionError(
                'The argument to count() must be a simple field name')
        return search_util.GetFieldCountInDocument(
            self._doc_pb, query_parser.GetQueryNodeText(node))
예제 #12
0
    def ValueOf(self,
                expression,
                default_value=None,
                return_type=None,
                allow_rank=True):
        """Returns the value of an expression on a document.

    Args:
      expression: The expression string.
      default_value: The value to return if the expression cannot be evaluated.
      return_type: The type the expression should evaluate to. Used to create
        multiple sorts for ambiguous expressions. If None, the expression
        evaluates to the inferred type or first type of a field it encounters in
        a document.
      allow_rank: For expressions that will be used in a sort context,
        indicate if rank is allowed.

    Returns:
      The value of the expression on the evaluator's document, or default_value
      if the expression cannot be evaluated on the document.

    Raises:
      ExpressionEvaluationError: sort expression cannot be evaluated
      because the expression or default value is malformed. Callers of
      ValueOf should catch and return error to user in response.
      QueryExpressionEvaluationError: same as ExpressionEvaluationError but
      these errors should return query as error status to users.
    """
        expression_tree = Parse(expression)
        if not expression_tree.getType() and expression_tree.children:
            expression_tree = expression_tree.children[0]

        name = query_parser.GetQueryNodeText(expression_tree)
        schema = self._inverted_index.GetSchema()
        if (expression_tree.getType() == ExpressionParser.NAME
                and name in schema):
            contains_text_result = False
            for field_type in schema[name].type_list():
                if field_type in search_util.TEXT_DOCUMENT_FIELD_TYPES:
                    contains_text_result = True

            if (schema.IsType(name, document_pb.FieldValue.DATE)
                    and not contains_text_result):
                if isinstance(default_value, basestring):
                    try:
                        default_value = search_util.DeserializeDate(
                            default_value)
                    except ValueError:
                        raise QueryExpressionEvaluationError(
                            'Default text value is not appropriate for sort expression \''
                            + name + '\': failed to parse date \"' +
                            default_value + '\"')
        result = default_value
        try:
            result = self._Eval(expression_tree,
                                return_type=return_type,
                                allow_rank=allow_rank)
        except _ExpressionError, e:

            logging.debug('Skipping expression %s: %s', expression, e)
예제 #13
0
  def _CheckMatch(self, node, document):
    """Check if a document matches a query tree.

    Args:
      node: the query node to match
      document: the document to match

    Returns:
      True iff the query node matches the document.

    Raises:
      ExpressionTreeException: when != operator is used or numeric value is used
      in comparison for DATE field.
    """

    if node.getType() == QueryParser.SEQUENCE:
      result = all(self._CheckMatch(child, document) for child in node.children)
      return result or self._MatchGlobalPhrase(node, document)

    if node.getType() == QueryParser.CONJUNCTION:
      return all(self._CheckMatch(child, document) for child in node.children)

    if node.getType() == QueryParser.DISJUNCTION:
      return any(self._CheckMatch(child, document) for child in node.children)

    if node.getType() == QueryParser.NEGATION:
      return not self._CheckMatch(node.children[0], document)

    if node.getType() == QueryParser.NE:
      raise ExpressionTreeException('!= comparison operator is not available')

    if node.getType() in query_parser.COMPARISON_TYPES:
      lhs, match = node.children
      if lhs.getType() == QueryParser.GLOBAL:
        return self._MatchGlobal(match, document)
      elif lhs.getType() == QueryParser.FUNCTION:
        return self._MatchFunction(lhs, match, node.getType(), document)





      field_name = self._GetFieldName(lhs)
      if node.getType() in INEQUALITY_COMPARISON_TYPES:
        try:
          float(query_parser.GetQueryNodeText(match))
        except ValueError:
          self._CheckValidDateComparison(field_name, match)
      elif (self._IsValidDateValue(field_name) or
            self._IsValidNumericValue(field_name)):




        raise ExpressionTreeException('Invalid field name "%s"' % field_name)
      return self._MatchAnyField(lhs, match, node.getType(), document)

    return False
 def _CollectFields(self, node):
   if node.getType() == QueryParser.EQ and node.children:
     return set([query_parser.GetQueryNodeText(node.children[0])])
   elif node.children:
     result = set()
     for term_set in (self._CollectFields(child) for child in node.children):
       result.update(term_set)
     return result
   return set()
예제 #15
0
def _create_query_string(index_name, query_tree):
    """ Creates a SOLR query string from a antlr3 parse tree.

  Args:
    index_name: A str representing full index name (appID_namespace_index).
    query_tree: A antlr3.tree.CommonTree.
  Returns:
    A string which can be sent to SOLR.
  """
    query_tree_type = query_tree.getType()
    has_nested = query_tree_type in [
        QueryParser.CONJUNCTION, QueryParser.DISJUNCTION, QueryParser.NEGATION
    ]
    if has_nested:
        # Processes nested query parts
        nested = [
            _create_query_string(index_name, child)
            for child in query_tree.children
        ]
        if query_tree_type == QueryParser.CONJUNCTION:
            return '({})'.format(' AND '.join(nested))
        if query_tree_type == QueryParser.DISJUNCTION:
            return '({})'.format(' OR '.join(nested))
        if query_tree_type == QueryParser.NEGATION:
            return 'NOT ({})'.format(' AND '.join(nested))

    # Process leaf of the tree
    if query_tree_type in query_parser.COMPARISON_TYPES:
        field, match = query_tree.children
        if field.getType() == QueryParser.GLOBAL:
            value = query_parser.GetQueryNodeText(match).strip('"')
            escaped_value = value.replace('"', '\\"')
            return '"{}"'.format(escaped_value)
        else:
            field_name = query_parser.GetQueryNodeText(field)
            value = query_parser.GetQueryNodeText(match).strip('"')
            internal_field_name = '{}_{}'.format(index_name, field_name)
            escaped_value = value.replace('"', '\\"')
            oper = _get_operator(query_tree_type)
            return '{}{}"{}"'.format(internal_field_name, oper, escaped_value)
    else:
        raise ParsingError(
            'Unexpected query tree type: {}'.format(query_tree_type))
예제 #16
0
    def _Snippet(self, query, field, *args):
        """Create a snippet given a query and the field to query on.

    Args:
      query: A query string containing only a bare term (no operators).
      field: The field name to query on.
      *args: Unused optional arguments. These are not used on dev_appserver.

    Returns:
      A snippet for the field with the query term bolded.
    """
        field = query_parser.GetQueryNodeText(field)
        terms = self._tokenizer.TokenizeText(
            query_parser.GetQueryNodeText(query).strip('"'))
        for term in terms:
            search_token = tokens.Token(chars=u'%s:%s' % (field, term.chars))
            postings = self._inverted_index.GetPostingsForToken(search_token)
            for posting in postings:
                if posting.doc_id != self._doc_pb.id(
                ) or not posting.positions:
                    continue

                field_val = search_util.GetFieldValue(
                    search_util.GetFieldInDocument(self._doc_pb, field))
                if not field_val:
                    continue
                doc_words = [
                    token.chars for token in
                    self._case_preserving_tokenizer.TokenizeText(field_val)
                ]

                position = posting.positions[0]
                return self._GenerateSnippet(
                    doc_words, position,
                    search_util.DEFAULT_MAX_SNIPPET_LENGTH)
            else:
                field_val = search_util.GetFieldValue(
                    search_util.GetFieldInDocument(self._doc_pb, field))
                if not field_val:
                    return None
                return '%s...' % field_val[:search_util.
                                           DEFAULT_MAX_SNIPPET_LENGTH]
예제 #17
0
    def _Count(self, node):

        if node.getType() != ExpressionParser.NAME:
            raise _ExpressionError(
                'The argument to count() must be a simple field name')
        if self._is_sort_expression:
            raise query_parser.QueryException(
                'Failed to parse sort expression \'count(' + node.getText() +
                ')\': count() is not supported in sort expressions')
        return search_util.GetFieldCountInDocument(
            self._doc_pb, query_parser.GetQueryNodeText(node))
예제 #18
0
 def _MatchFunction(self, node, match, operator, document):
   name, args = node.children
   if name.getText() == 'distance':
     x, y = args.children
     x, y = self._ResolveDistanceArg(x), self._ResolveDistanceArg(y)
     if isinstance(x, geo_util.LatLng) and isinstance(y, basestring):
       x, y = y, x
     if isinstance(x, basestring) and isinstance(y, geo_util.LatLng):
       distance = float(query_parser.GetQueryNodeText(match))
       matcher = DistanceMatcher(y, distance)
       return self._MatchGeoField(x, matcher, operator, document)
   return False
예제 #19
0
  def _MatchGeoField(self, field, matcher, operator, document):
    """Check if a geo field matches a query tree node."""

    if not isinstance(matcher, DistanceMatcher):
      return False

    if isinstance(field, tree.CommonTree):
      field = query_parser.GetQueryNodeText(field)
    values = [ field.value() for field in
               search_util.GetAllFieldInDocument(document, field) if
               field.value().type() == document_pb.FieldValue.GEO ]
    return matcher.IsMatch(values, operator)
예제 #20
0
    def _MatchGlobalPhrase(self, node, document):
        """Check if a document matches a parsed global phrase."""
        if not all(self._IsHasGlobalValue(child) for child in node.children):
            return False

        value_nodes = (child.children[1] for child in node.children)
        phrase_text = ' '.join(
            (query_parser.GetQueryNodeText(node) for node in value_nodes))
        for field in document.field_list():
            if self._MatchRawPhraseWithRawAtom(field.value().string_value(),
                                               phrase_text):
                return True
        return False
    def _Snippet(self, query, field, *args):
        field = query_parser.GetQueryNodeText(field)
        terms = self._tokenizer.TokenizeText(
            query_parser.GetQueryNodeText(query).strip('"'))
        for term in terms:
            search_token = tokens.Token(chars=u'%s:%s' % (field, term.chars))
            postings = self._inverted_index.GetPostingsForToken(search_token)
            for posting in postings:
                if posting.doc_id != self._doc_pb.id(
                ) or not posting.positions:
                    continue

                field_val = search_util.GetFieldValue(
                    search_util.GetFieldInDocument(self._doc_pb, field))
                doc_words = [
                    token.chars for token in
                    self._case_preserving_tokenizer.TokenizeText(field_val)
                ]

                position = posting.positions[0]
                return self._GenerateSnippet(
                    doc_words, position,
                    search_util.DEFAULT_MAX_SNIPPET_LENGTH)
    def _CollectTerms(self, node):
        """Get all search terms for scoring."""
        if node.getType() in search_util.TEXT_QUERY_TYPES:
            return set([query_parser.GetQueryNodeText(node).strip('"')])
        elif node.children:
            if node.getType() == QueryParser.EQ and len(node.children) > 1:
                children = node.children[1:]
            else:
                children = node.children

            result = set()
            for term_set in (self._CollectTerms(child) for child in children):
                result.update(term_set)
            return result
        return set()
예제 #23
0
  def _MatchAnyField(self, field, match, operator, document):
    """Check if a field matches a query tree.

    Args:
      field: the name of the field, or a query node containing the field.
      match: A query node to match the field with.
      operator: The query node type corresponding to the type of match to
        perform (eg QueryParser.EQ, QueryParser.GT, etc).
      document: The document to match.
    """

    if isinstance(field, tree.CommonTree):
      field = query_parser.GetQueryNodeText(field)
    fields = search_util.GetAllFieldInDocument(document, field)
    return any(self._MatchField(f, match, operator, document) for f in fields)
예제 #24
0
    def _MatchComparableField(self, field, match, cast_to_type, op, document):
        """A generic method to test matching for comparable types.

    Comparable types are defined to be anything that supports <, >, <=, >=, ==.
    For our purposes, this is numbers and dates.

    Args:
      field: The document_pb.Field to test
      match: The query node to match against
      cast_to_type: The type to cast the node string values to
      op: The query node type representing the type of comparison to perform
      document: The document that the field is in

    Returns:
      True iff the field matches the query.

    Raises:
      UnsupportedOnDevError: Raised when an unsupported operator is used, or
      when the query node is of the wrong type.
      ExpressionTreeException: Raised when a != inequality operator is used.
    """

        field_val = cast_to_type(field.value().string_value())

        if match.getType() == QueryParser.VALUE:
            try:
                match_val = cast_to_type(query_parser.GetQueryNodeText(match))
            except ValueError:
                return False
        else:
            return False

        if op == QueryParser.EQ:
            return field_val == match_val
        if op == QueryParser.NE:
            raise ExpressionTreeException(
                '!= comparison operator is not available')
        if op == QueryParser.GT:
            return field_val > match_val
        if op == QueryParser.GE:
            return field_val >= match_val
        if op == QueryParser.LESSTHAN:
            return field_val < match_val
        if op == QueryParser.LE:
            return field_val <= match_val
        raise search_util.UnsupportedOnDevError(
            'Operator %s not supported for numerical fields on development server.'
            % match.getText())
예제 #25
0
 def _MatchFunction(self, node, match, operator, document):
   name, args = node.children
   if name.getText() == 'distance':
     x, y = args.children
     x, y = self._ResolveDistanceArg(x), self._ResolveDistanceArg(y)
     if isinstance(x, geo_util.LatLng) and isinstance(y, six.string_types):
       x, y = y, x
     if isinstance(x, six.string_types) and isinstance(y, geo_util.LatLng):
       match_val = query_parser.GetQueryNodeText(match)
       try:
         distance = float(match_val)
       except ValueError:
         raise ExpressionTreeException('Unable to compare "%s()" with "%s"' %
                                       (name, match_val))
       matcher = DistanceMatcher(y, distance)
       return self._MatchGeoField(x, matcher, operator, document)
   return False
예제 #26
0
    def _CheckInvalidNumericComparison(self, match, document):
        """Check if this is an invalid numeric comparison.

    Valid numeric comparisons are "numeric_field OP numeric_constant" where OP
    is one of [>, <, >=, <=, =, :].

    Args:
      match: The right hand side argument of the operator.
      document: The document we are checking for a match.

    Raises:
      ExpressionTreeException: when right hand side of numeric comparison is not
      a numeric constant.
    """
        match_text = query_parser.GetQueryNodeText(match)
        match_fields = search_util.GetFieldInDocument(
            document, match_text, document_pb.FieldValue.NUMBER)

        if match_fields:
            raise ExpressionTreeException(
                'Expected numeric constant, found \"' + match_text + '\"')
예제 #27
0
    def ValueOf(self, expression, default_value=None):
        """Returns the value of an expression on a document.

    Args:
      expression: The expression string.
      default_value: The value to return if the expression cannot be evaluated.

    Returns:
      The value of the expression on the evaluator's document, or default_value
      if the expression cannot be evaluated on the document.

    Raises:
      ExpressionEvaluationError: sort expression cannot be evaluated
      because the expression or default value is malformed. Callers of
      ValueOf should catch and return error to user in response.
      QueryExpressionEvaluationError: same as ExpressionEvaluationError but
      these errors should return query as error status to users.
    """
        expression_tree = Parse(expression)
        if not expression_tree.getType() and expression_tree.children:
            expression_tree = expression_tree.children[0]

        name = query_parser.GetQueryNodeText(expression_tree)
        schema = self._inverted_index.GetSchema()
        if (expression_tree.getType() == ExpressionParser.NAME
                and schema.IsType(name, document_pb.FieldValue.DATE)):
            if isinstance(default_value, basestring):
                try:
                    default_value = search_util.DeserializeDate(default_value)
                except ValueError:
                    raise QueryExpressionEvaluationError(
                        'failed to parse date \"' + default_value + '\"')
        result = default_value
        try:
            result = self._Eval(expression_tree)
        except _ExpressionError, e:

            logging.debug('Skipping expression %s: %s', expression, e)
예제 #28
0
 def _Count(self, node):
     return search_util.GetFieldCountInDocument(
         self._doc_pb, query_parser.GetQueryNodeText(node))
예제 #29
0
 def _GetFieldName(self, field):
     """Get the field name of the given field node."""
     if isinstance(field, tree.CommonTree):
         return query_parser.GetQueryNodeText(field)
     return field
예제 #30
0
    def _Eval(self, node, return_type=None, allow_rank=True):
        """Evaluate an expression node on the document.

    Args:
      node: The expression AST node representing an expression subtree.
      return_type: The type to retrieve for fields with multiple types
        in the expression. Used when the field type is ambiguous and cannot be
        inferred from the context. If None, we retrieve the first field type
        found in doc list.
      allow_rank: For expressions that will be used in a sort context, indicate
        if rank is allowed.

    Returns:
      The Python value that maps to the value of node. Types are inferred from
      the expression, so expressions with numeric results will return as python
      int/long/floats, textual results will be strings, and dates will be
      datetimes.

    Raises:
      _ExpressionError: The expression cannot be evaluated on this document
        because either the expression is malformed or the document does not
        contain the required fields. Callers of _Eval should catch
        _ExpressionErrors and optionally log them; these are not fatal in any
        way and are used to indicate that this expression should not be set on
        this document.
      QueryExpressionEvaluationError: same as ExpressionEvaluationError but
        these errors should return query as error status to users.
    """
        if node.getType() in self._function_table:
            func = self._function_table[node.getType()]

            return func(return_type, *node.children)

        if node.getType() == ExpressionParser.PLUS:
            return self._EvalNumericBinaryOp(lambda a, b: a + b, 'addition',
                                             node, return_type)
        if node.getType() == ExpressionParser.MINUS:
            return self._EvalNumericBinaryOp(lambda a, b: a - b, 'subtraction',
                                             node, return_type)
        if node.getType() == ExpressionParser.DIV:
            return self._EvalNumericBinaryOp(lambda a, b: a / b, 'division',
                                             node, return_type)
        if node.getType() == ExpressionParser.TIMES:
            return self._EvalNumericBinaryOp(lambda a, b: a * b,
                                             'multiplication', node,
                                             return_type)
        if node.getType() == ExpressionParser.NEG:
            return self._EvalNumericUnaryOp(lambda a: -a, 'negation', node,
                                            return_type)
        if node.getType() in (ExpressionParser.INT, ExpressionParser.FLOAT):
            return float(query_parser.GetQueryNodeText(node))
        if node.getType() == ExpressionParser.PHRASE:
            return query_parser.GetQueryNodeText(node).strip('"')

        if node.getType() == ExpressionParser.NAME:
            name = query_parser.GetQueryNodeText(node)
            if name == '_score':
                return self._doc.score
            elif name == '_rank':
                if allow_rank:
                    return self._doc.document.order_id()
                else:
                    raise QueryExpressionEvaluationError(
                        'SortSpec order must be descending in \'_rank\'')

            field = search_util.GetFieldInDocument(self._doc_pb, name,
                                                   return_type)
            if field:
                return self._GetFieldValue(field)
            raise _ExpressionError('No field %s in document' % name)

        raise _ExpressionError('Unable to handle node %s' % node)