def batch_search( self, queries: List[str], qids: List[str], k: int = 10, threads: int = 1, fields=dict()) -> Dict[str, List[JImpactSearcherResult]]: """Search the collection concurrently for multiple queries, using multiple threads. Parameters ---------- queries : List[str] List of query string. qids : List[str] List of corresponding query ids. k : int Number of hits to return. threads : int Maximum number of threads to use. min_idf : int Minimum idf for query tokens fields : dict Optional map of fields to search with associated boosts. Returns ------- Dict[str, List[JImpactSearcherResult]] Dictionary holding the search results, with the query ids as keys and the corresponding lists of search results as the values. """ query_lst = JArrayList() qid_lst = JArrayList() for q in queries: encoded_query = self.query_encoder.encode(q) jquery = JHashMap() for (token, weight) in encoded_query.items(): if token in self.idf and self.idf[token] > self.min_idf: jquery.put(token, JFloat(weight)) query_lst.add(jquery) for qid in qids: jqid = qid qid_lst.add(jqid) jfields = JHashMap() for (field, boost) in fields.items(): jfields.put(field, JFloat(boost)) if not fields: results = self.object.batchSearch(query_lst, qid_lst, int(k), int(threads)) else: results = self.object.batchSearchFields(query_lst, qid_lst, int(k), int(threads), jfields) return {r.getKey(): r.getValue() for r in results.entrySet().toArray()}
def search(self, q: str, k: int = 10, fields=dict()) -> List[JImpactSearcherResult]: """Search the collection. Parameters ---------- q : str Query string. k : int Number of hits to return. min_idf : int Minimum idf for query tokens fields : dict Optional map of fields to search with associated boosts. Returns ------- List[JImpactSearcherResult] List of search results. """ jfields = JHashMap() for (field, boost) in fields.items(): jfields.put(field, JFloat(boost)) encoded_query = self.query_encoder.encode(q) jquery = JHashMap() for (token, weight) in encoded_query.items(): if token in self.idf and self.idf[token] > self.min_idf: jquery.put(token, JFloat(weight)) if not fields: hits = self.object.search(jquery, k) else: hits = self.object.searchFields(jquery, jfields, k) return hits
def batch_search(self, queries: List[str], qids: List[str], k: int = 10, threads: int = 1, query_generator: JQueryGenerator = None, fields = dict()) -> Dict[str, List[JSimpleSearcherResult]]: """Search the collection concurrently for multiple queries, using multiple threads. Parameters ---------- queries : List[str] List of query strings. qids : List[str] List of corresponding query ids. k : int Number of hits to return. threads : int Maximum number of threads to use. query_generator : JQueryGenerator Generator to build queries. Set to ``None`` by default to use Anserini default. fields : dict Optional map of fields to search with associated boosts. Returns ------- Dict[str, List[JSimpleSearcherResult]] Dictionary holding the search results, with the query ids as keys and the corresponding lists of search results as the values. """ query_strings = JArrayList() qid_strings = JArrayList() for query in queries: jq = JString(query.encode('utf8')) query_strings.add(jq) for qid in qids: jqid = JString(qid) qid_strings.add(jqid) jfields = JHashMap() for (field, boost) in fields.items(): jfields.put(JString(field), JFloat(boost)) if query_generator: if not fields: results = self.object.batchSearch(query_generator, query_strings, qid_strings, int(k), int(threads)) else: results = self.object.batchSearchFields(query_generator, query_strings, qid_strings, int(k), int(threads), jfields) else: if not fields: results = self.object.batchSearch(query_strings, qid_strings, int(k), int(threads)) else: results = self.object.batchSearchFields(query_strings, qid_strings, int(k), int(threads), jfields) return {r.getKey(): r.getValue() for r in results.entrySet().toArray()}
def search(self, q: Union[str, JQuery], k: int = 10, query_generator: JQueryGenerator = None, fields=dict(), strip_segment_id=False, remove_dups=False) -> List[JLuceneSearcherResult]: """Search the collection. Parameters ---------- q : Union[str, JQuery] Query string or the ``JQuery`` objected. k : int Number of hits to return. query_generator : JQueryGenerator Generator to build queries. Set to ``None`` by default to use Anserini default. fields : dict Optional map of fields to search with associated boosts. strip_segment_id : bool Remove the .XXXXX suffix used to denote different segments from an document. remove_dups : bool Remove duplicate docids when writing final run output. Returns ------- List[JLuceneSearcherResult] List of search results. """ jfields = JHashMap() for (field, boost) in fields.items(): jfields.put(field, JFloat(boost)) hits = None if query_generator: if not fields: hits = self.object.search(query_generator, q, k) else: hits = self.object.searchFields(query_generator, q, jfields, k) elif isinstance(q, JQuery): # Note that RM3 requires the notion of a query (string) to estimate the appropriate models. If we're just # given a Lucene query, it's unclear what the "query" is for this estimation. One possibility is to extract # all the query terms from the Lucene query, although this might yield unexpected behavior from the user's # perspective. Until we think through what exactly is the "right thing to do", we'll raise an exception # here explicitly. if self.is_using_rm3(): raise NotImplementedError( 'RM3 incompatible with search using a Lucene query.') if fields: raise NotImplementedError( 'Cannot specify fields to search when using a Lucene query.' ) hits = self.object.search(q, k) else: if not fields: hits = self.object.search(q, k) else: hits = self.object.searchFields(q, jfields, k) docids = set() filtered_hits = [] for hit in hits: if strip_segment_id is True: hit.docid = hit.docid.split('.')[0] if hit.docid in docids: continue filtered_hits.append(hit) if remove_dups is True: docids.add(hit.docid) return filtered_hits