コード例 #1
0
 def __init__(self, index):
     self._index = SimpleSearcher(index)
     self._query = None
     self._docids = []
     self._doc_content = []
     self._doc_scores = []
     self._doc_embeddings = []
コード例 #2
0
def bm25(qid, query, docs, index_path):
    s = SimpleSearcher(index_path)
    hits = s.search(query, 1000)

    n = 1
    seen_docids = {}
    with open(f'run-passage-{qid}.txt', 'w') as writer:
        for i in range(0, len(hits)):
            if hits[i].docid in seen_docids:
                continue
            writer.write(f'{qid} Q0 {hits[i].docid} {n} {hits[i].score:.5f} pyserini\n')
            n = n + 1
            seen_docids[hits[i].docid] = 1

    with open(f'run-doc-{qid}.txt', 'w') as writer:
        for doc in docs:
            writer.write(f'{qid} Q0 {doc["docid"]} {doc["rank"]} {doc["score"]} base\n')
            n = n + 1

    os.system(f'python -m pyserini.fusion --method rrf --runs run-passage-{qid}.txt run-doc-{qid}.txt ' +
              f'--output run-rrf-{qid}.txt --runtag test')
    fused_run = TrecRun(f'run-rrf-{qid}.txt')

    output = []
    for idx, r in fused_run.get_docs_by_topic(qid).iterrows():
        output.append([qid, r["docid"], r["rank"]])

    return output
コード例 #3
0
 def test_reranking(self):
     if(os.path.isdir('ltr_test')):
         rmtree('ltr_test')
         os.mkdir('ltr_test')
     inp = 'run.msmarco-passage.bm25tuned.txt'
     outp = 'run.ltr.msmarco-passage.test.tsv'
     #Download candidate
     os.system('wget https://www.dropbox.com/s/bjyzf65uns2is61/run.msmarco-passage.bm25tuned.txt -P ltr_test')
     #Download prebuilt index
     SimpleSearcher.from_prebuilt_index('msmarco-passage-ltr')
     #Pre-trained ltr model
     model_url = 'https://www.dropbox.com/s/ffl2bfw4cd5ngyz/msmarco-passage-ltr-mrr-v1.tar.gz'
     model_tar_name = 'msmarco-passage-ltr-mrr-v1.tar.gz'
     os.system(f'wget {model_url} -P ltr_test/')
     os.system(f'tar -xzvf ltr_test/{model_tar_name} -C ltr_test')
     #ibm model
     ibm_model_url = 'https://www.dropbox.com/s/vlrfcz3vmr4nt0q/ibm_model.tar.gz'
     ibm_model_tar_name = 'ibm_model.tar.gz'
     os.system(f'wget {ibm_model_url} -P ltr_test/')
     os.system(f'tar -xzvf ltr_test/{ibm_model_tar_name} -C ltr_test')
     #queries process
     os.system('python scripts/ltr_msmarco-passage/convert_queries.py --input tools/topics-and-qrels/topics.msmarco-passage.dev-subset.txt --output ltr_test/queries.dev.small.json')
     os.system(f'python -m pyserini.ltr.search_msmarco_passage --input ltr_test/{inp} --input-format tsv --model ltr_test/msmarco-passage-ltr-mrr-v1 --index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3 --ibm-model ltr_test/ibm_model/ --queries ltr_test --output ltr_test/{outp}')
     result = subprocess.check_output(f'python tools/scripts/msmarco/msmarco_passage_eval.py tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt ltr_test/{outp}', shell=True).decode(sys.stdout.encoding)
     a,b = result.find('#####################\nMRR @10:'), result.find('\nQueriesRanked: 6980\n#####################\n')
     mrr = result[a+31:b]
     self.assertAlmostEqual(float(mrr),0.24709612498294367, delta=0.000001)
     rmtree('ltr_test')
コード例 #4
0
ファイル: anserini.py プロジェクト: bpiwowar/experimaestro-ir
class AnseriniRetriever(Retriever):
    """An Anserini-based retriever

    Attributes:
        index: The Anserini index
        model: the model used to search. Only suupports BM25 so far.
        k: Number of results to retrieve
    """

    index: Param[Index]
    model: Param[Model]
    k: Param[int] = 1500

    def initialize(self):
        from pyserini.search import SimpleSearcher

        self.searcher = SimpleSearcher(str(self.index.path))

        modelhandler = Handler()

        @modelhandler()
        def handle(bm25: BM25):
            self.searcher.set_bm25(bm25.k1, bm25.b)

        modelhandler[self.model]

    def getindex(self) -> Index:
        """Returns the associated index (if any)"""
        return self.index

    def retrieve(self, query: str) -> List[ScoredDocument]:
        hits = self.searcher.search(query, k=self.k)
        return [
            ScoredDocument(hit.docid, hit.score, hit.contents) for hit in hits
        ]
コード例 #5
0
 def __init__(
         self,
         index_dir="/nfs/phd_by_carlos/notebooks/datasets/TREC_CAsT/CAsT_collection_with_meta.index",
         k1=0.82,
         b=0.68,
         **kwargs):
     self.searcher = SimpleSearcher(index_dir)
コード例 #6
0
def sampling(args):
    # load the positive doc
    qrels = defaultdict(list)
    for line in open(os.path.join(args.msmarco_dir, f"qrels.{args.mode}.tsv"),
                     'r'):
        qid, _, pid, _ = line.split('\t')
        qrels[qid].append(int(pid))
    qrels = dict(qrels)

    # load the queries
    queries = dict()
    for line in open(
            os.path.join(args.msmarco_dir, f"queries.{args.mode}.tsv"), 'r'):
        qid, query = line.split('\t')
        query = query.rstrip()
        queries[qid] = query

    searcher = SimpleSearcher(args.index_dir)
    searcher.set_bm25(k1=args.bm25_k1, b=args.bm25_b)

    with open(os.path.join(args.output_dir, f'top_candidates.{args.mode}.tsv'),
              'w') as outfile:
        for qid in tqdm(qrels):
            query = queries[qid]
            candidates = searcher.search(query, k=args.topN)
            for i in range(len(candidates)):
                outfile.write(
                    f"{qid}\t{candidates[i].docid}\t{candidates[i].score}\n")
コード例 #7
0
    def setUp(self):
        curdir = os.getcwd()
        if curdir.endswith('sparse'):
            self.pyserini_root = '../..'
        else:
            self.pyserini_root = '.'

        if (os.path.isdir('ibm_test')):
            rmtree('ibm_test')
            os.mkdir('ibm_test')
        #Download prebuilt index
        SimpleSearcher.from_prebuilt_index('msmarco-passage-ltr')
        inp = 'run.msmarco-passage.bm25tuned.trec'
        os.system(
            f'python -m pyserini.search --topics msmarco-passage-dev-subset  --index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3/ --output ibm_test/{inp} --bm25 --output-format trec --hits 1000 --k1 0.82 --b 0.68'
        )
        #ibm model
        ibm_model_url = 'https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-models/ibm_model_1_bert_tok_20211117.tar.gz'
        ibm_model_tar_name = 'ibm_model_1_bert_tok_20211117.tar.gz'
        os.system(f'wget {ibm_model_url} -P ibm_test/')
        os.system(f'tar -xzvf ibm_test/{ibm_model_tar_name} -C ibm_test')
        #queries process
        os.system(
            'python scripts/ltr_msmarco/convert_queries.py --input tools/topics-and-qrels/topics.msmarco-passage.dev-subset.txt --output ibm_test/queries.dev.small.json'
        )
コード例 #8
0
ファイル: run-rpf-rm.py プロジェクト: khalidelhaji/ir
def main():
    # This assumes the index has already been generated
    searcher = SimpleSearcher('indexes/msmarco-passage')
    searcher.set_qld()

    topics = read_topics('msmarco-test2019-queries.tsv')

    run_all_queries('runs/run.msmarco-test2019-queries-bm25.trec', topics,
                    searcher)
コード例 #9
0
class BM25Retriever:
    def __init__(self, dataset):
        self._searcher = SimpleSearcher(dataset)
        self._searcher.set_bm25(3.44, 0.87)
        self._searcher.set_rm3(10, 10, 0.5)

    def query(self, query_text, k=100):
        hits = self._searcher.search(query_text, k=k)
        return hits, query_text
コード例 #10
0
class KeywordSearchUtil:
    def __init__(self, index_file):
        self.searcher = SimpleSearcher(index_file)
        self.searcher.set_bm25(0.9, 0.4)
        self.searcher.set_rm3(10, 10, 0.5)

    def retrieve_top_k_hits(self, query, k=5000):
        hits = self.searcher.search(query, k)
        return hits
コード例 #11
0
def main():
    # This assumes the index has already been generated
    searcher = SimpleSearcher('indexes/msmarco-passage')
    # searcher.set_bm25(0.82, 0.68)
    searcher.set_rm3(fb_terms=25, fb_docs=50, original_query_weight=0.5)

    topics = read_topics('msmarco-test2019-queries.tsv')

    run_all_queries('runs/run.msmarco-test2019-queries-bm25.trec', topics,
                    searcher)
コード例 #12
0
class QAengine():
    def __init__(self):
        self.searcher = SimpleSearcher(PATH_TO_WIKI_INDEX)
        self.searcher.set_bm25()
        self.searcher.unset_rm3()
        self.processor = SquadV2Processor()
        self.k = 29
        self.mu = 0.5
        self.use_ir_score = True
        self.tokenizer = BertTokenizer.from_pretrained(PATH_TO_DILBERT,
                                                       do_lower_case=True)
        self.model = DilBert.from_pretrained(PATH_TO_DILBERT)
        self.device = DEVICE_COMP
        self.model.to(torch.device(self.device))

    def answer(self, question):
        hits = self.searcher.search(question, k=self.k)
        ir_scores = []
        paragraphs = []
        for j in range(len(hits)):
            passage = hits[j].raw
            ir_scores.append(hits[j].score)
            paragraphs.append(passage)
        input_ = build_squad_input(question, paragraphs)
        examples = self.processor._create_examples(input_["data"], "dev")
        features, dataset = squad_convert_examples_to_features(
            examples=examples,
            tokenizer=self.tokenizer,
            max_seq_length=384,
            doc_stride=128,
            max_query_length=64,
            is_training=False,
            return_dataset="pt",
            threads=1,
        )
        all_results, predictions = process_one_question(
            features, dataset, self.model, self.tokenizer, examples,
            self.device, self.use_ir_score, self.mu, ir_scores)

        scores = np.array([(p['start_logit'] + p['end_logit'])
                           for p in predictions['0']])
        texts = [p['text'] for p in predictions['0']]

        predicted_p_indexes_all = scores.argsort()[::-1].argsort()
        iterator_idx = 0
        is_empty = True
        predicted_p_index = 0
        while is_empty and iterator_idx < len(predicted_p_indexes_all):
            predicted_p_index = predicted_p_indexes_all[iterator_idx]
            is_empty = texts[predicted_p_index] == "empty"
            iterator_idx += 1

        predicted_answer = texts[predicted_p_index]
        return predicted_answer
コード例 #13
0
 def __init__(self, ranker, index, topn=10, topw=10, original_q_w=0.5):
     RelevanceFeedback.__init__(self,
                                ranker=ranker,
                                prels=None,
                                anserini=None,
                                index=index,
                                topn=topn)
     self.topw = topw
     self.searcher = SimpleSearcher(index)
     self.ranker = ranker
     self.original_q_w = original_q_w
コード例 #14
0
 def index(self):
     self._mkdir('./index/')
     self._mkdir('./index/convert/')
     self._mkdir('./index/chunks/')
     self._make_chuncks("./data/livivo/documents/")
     p = Pool()
     p.map(self._convert_chunks, os.listdir("./index/chunks/"))
     p.close()
     shutil.rmtree('./index/chunks')
     JIndexCollection.main(ARGS)
     self.searcher = SimpleSearcher('./index/')
     shutil.rmtree('./index/convert/')
コード例 #15
0
ファイル: anserini.py プロジェクト: bpiwowar/experimaestro-ir
    def initialize(self):
        from pyserini.search import SimpleSearcher

        self.searcher = SimpleSearcher(str(self.index.path))

        modelhandler = Handler()

        @modelhandler()
        def handle(bm25: BM25):
            self.searcher.set_bm25(bm25.k1, bm25.b)

        modelhandler[self.model]
コード例 #16
0
 def __init__(self):
     self.searcher = SimpleSearcher(PATH_TO_WIKI_INDEX)
     self.searcher.set_bm25()
     self.searcher.unset_rm3()
     self.processor = SquadV2Processor()
     self.k = 29
     self.mu = 0.5
     self.use_ir_score = True
     self.tokenizer = BertTokenizer.from_pretrained(PATH_TO_DILBERT,
                                                    do_lower_case=True)
     self.model = DilBert.from_pretrained(PATH_TO_DILBERT)
     self.device = DEVICE_COMP
     self.model.to(torch.device(self.device))
コード例 #17
0
def build_searcher(
        k1=0.9,
        b=0.4,
        index_path="index/lucene-index.wiki_paragraph_drqa.pos+docvectors",
        segmented=False,
        rm3=False,
        chinese=False):
    searcher = SimpleSearcher(index_path)
    searcher.set_bm25(k1, b)
    if chinese:
        searcher.object.setLanguage("zh")
        print("########### we are usinig Chinese retriever ##########")
    return searcher
def main(output_path=OUTPUT_PATH,
         index_path=INDEX_PATH,
         queries_path=QUERIES_PATH,
         run=RUN,
         k=K):
    print('################################################')
    print("##### Performing Passage Ranking using L2R #####")
    print('################################################')
    print("Output will be placed in:", output_path,
          ", format used will be TREC")
    print('Loading pre-trained model MonoT5...')
    from pygaggle.rerank.transformer import MonoT5
    reranker = MonoT5()

    print('Fetching anserini-like indices from:', index_path)
    # fetch some passages to rerank from MS MARCO with Pyserini (BM25)
    searcher = SimpleSearcher(index_path)
    print('Loading queries from:', queries_path)
    with open(queries_path, 'r') as f:
        content = f.readlines()
        content = [x.strip().split('\t') for x in content]
        queries = [Query(x[1], x[0]) for x in content]
    print(f'Ranking queries using BM25 (k={k})')
    queries_text = []
    for query in tqdm(queries):
        hits = searcher.search(query.text, k=K)
        texts = hits_to_texts(hits)
        queries_text.append(texts)

    print('Reranking all queries using MonoT5!')
    rankings = []

    for (i, query) in enumerate(tqdm(queries)):
        reranked = reranker.rerank(query, queries_text[i])
        reranked.sort(key=lambda x: x.score, reverse=True)
        rankings.append(reranked)

    print('Outputting to file...')
    if '.tsv' in output_path:
        output_to_tsv(queries, rankings, run, output_path)
    elif '.csv' in output_path:
        output_to_csv(queries, rankings, run, output_path)
    else:
        print(
            'ERROR: invalid output file format provided, please use either .csv or .tsv. Exiting'
        )
        sys.exit(1)
    print('SUCCESS: completed reranking, you may check the output at:',
          output_path)
    sys.exit(0)
コード例 #19
0
    def setUp(self):
        # Download pre-built CACM index; append a random value to avoid filename clashes.
        r = randint(0, 10000000)
        self.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene-index.cacm.tar.gz'
        self.tarball_name = 'lucene-index.cacm-{}.tar.gz'.format(r)
        self.index_dir = 'index{}/'.format(r)

        filename, headers = urlretrieve(self.collection_url, self.tarball_name)

        tarball = tarfile.open(self.tarball_name)
        tarball.extractall(self.index_dir)
        tarball.close()

        self.searcher = SimpleSearcher(f'{self.index_dir}lucene-index.cacm')
コード例 #20
0
ファイル: pyserini_ranker.py プロジェクト: deepmipt/deepy
 def __init__(
     self,
     index_folder: str,
     n_threads: int = 1,
     top_n: int = 5,
     text_column_name: str = "contents",
     return_scores: bool = False,
     *args,
     **kwargs,
 ):
     self.searcher = SimpleSearcher(str(expand_path(index_folder)))
     self.n_threads = n_threads
     self.top_n = top_n
     self.text_column_name = text_column_name
     self.return_scores = return_scores
コード例 #21
0
ファイル: search.py プロジェクト: MXueguang/pyserini
def main(args):
    query = args.query
    index = args.index
    if args.do_tokenize:
        tokenizer = AutoTokenizer.from_pretrained('bert-multilingual-base-uncased')
        query = " ".join(tokenizer.tokenize(query))

    logger.info(f'searching for: {query}')
    searcher = SimpleSearcher(index)
    searcher.set_analyzer(JWhiteSpaceAnalyzer())
    hits = searcher.search(query, 1000)

    for i in range(len(hits)):
        doc = hits[i]
        print(f'{i+1:2} {hits[i].docid:4} {hits[i].score:.5f}')
コード例 #22
0
ファイル: msmarco.py プロジェクト: yuxuan-ji/pyserini
class MsMarcoDemo(cmd.Cmd):
    searcher = SimpleSearcher.from_prebuilt_index('msmarco-passage')
    k = 10
    prompt = '>>> '

    # https://stackoverflow.com/questions/35213134/command-prefixes-in-python-cli-using-cmd-in-pythons-standard-library
    def precmd(self, line):
        if line[0] == '/':
            line = line[1:]
        return line

    def do_help(self, arg):
        print(f'/help    : returns this message')
        print(f'/k [NUM] : sets k (number of hits to return) to [NUM]')

    def do_k(self, arg):
        print(f'setting k = {int(arg)}')
        self.k = int(arg)

    def do_EOF(self, line):
        return True

    def default(self, q):
        hits = self.searcher.search(q, self.k)

        for i in range(0, len(hits)):
            jsondoc = json.loads(hits[i].raw)
            print(f'{i + 1:2} {hits[i].score:.5f} {jsondoc["contents"]}')
コード例 #23
0
def main():
    parser = argparse.ArgumentParser("ACL Anthology document DynamoDB bulk importer")
    parser.add_argument("--index", required=True, type=str, help="Path to ACL Anthology Lucene index")
    parser.add_argument("--table", default="ACL", type=str, help="Dynamo table to insert the raw ACL documents to")
    parser.add_argument("--batch-size", dest="batch", default=MAX_BATCH_SIZE, help="The size of batch insert to Dynamo")
    parser.add_argument("--threads", default=5, type=int, help="Number of threads for batch inserts")
    parser.add_argument("--report-interval", dest="report_interval", default=500, type=int, help="Output progress interval")
    args = parser.parse_args()

    # TODO: use https://github.com/castorini/pyserini/blob/master/docs/usage-collection.md once AclAnthology support is added
    searcher = SimpleSearcher(args.index)

    progress = 0
    next_report_threshold = args.report_interval
    batches = build_item_batches(searcher, args.batch)
    with concurrent.futures.ThreadPoolExecutor(max_workers=args.batch) as executor:
        futures = {executor.submit(batch_write_dynamo, args.table, batch): batch for batch in batches}
        for future in concurrent.futures.as_completed(futures):
            batch = futures[future]
            try:
                failed_docids = future.result()
                if failed_docids:
                    logger.error("Error writing batches %s" % failed_docids)
            except Exception:
                batch_ids = [item["id"] for item in batch]
                logger.exception("Error writing batches %s" % batch_ids)
            finally:
                progress += len(batch)
                if progress > next_report_threshold:
                    logger.info("Processed %s/%s records" % (progress, searcher.num_docs))
                    next_report_threshold += args.report_interval
コード例 #24
0
    def index(self):

        data = []

        with jsonlines.open(
                './data/gesis-search/datasets/dataset.jsonl') as reader:
            for obj in reader:
                title = obj.get('title') or ''
                title = title[0] if type(title) is list else title
                abstract = obj.get('abstract') or ''
                abstract = abstract[0] if type(abstract) is list else abstract
                try:
                    data.append({
                        'id': obj.get('id'),
                        'contents': ' '.join([title, abstract])
                    })
                except Exception as e:
                    print(e)

        try:
            os.mkdir('./convert/')
        except OSError as error:
            print(error)

        with jsonlines.open('./convert/output.jsonl', mode='w') as writer:
            for doc in data:
                writer.write(doc)

        try:
            os.mkdir('./indexes/')
        except OSError as error:
            print(error)

        args = [
            "-collection", "JsonCollection", "-generator",
            "DefaultLuceneDocumentGenerator", "-threads", "1", "-input",
            "./convert", "-index", "./indexes/gesis", "-storePositions",
            "-storeDocvectors", "-storeRaw"
        ]

        JIndexCollection.main(args)
        self.searcher = SimpleSearcher('indexes/gesis')

        with jsonlines.open(
                './data/gesis-search/documents/publication.jsonl') as reader:
            for obj in reader:
                self.title_lookup[obj.get('id')] = obj.get('title')
コード例 #25
0
def _run_thread(arguments):
    idz = arguments["id"]
    index = arguments["index"]
    k = arguments["k"]
    data = arguments["data"]

    # BM25 parameters #TODO
    # bm25_a = arguments["bm25_a"]
    # bm25_b = arguments["bm25_b"]
    # searcher.set_bm25(bm25_a, bm25_b)

    from pyserini.search import SimpleSearcher

    searcher = SimpleSearcher(index)

    _iter = data
    if idz == 0:
        _iter = tqdm(data)

    provenance = {}
    for x in _iter:
        query_id = x["id"]
        query = (
            x["query"].replace(utils.ENT_END, "").replace(utils.ENT_START, "").strip()
        )

        hits = searcher.search(query, k)

        element = []
        for y in hits:
            try:
                doc_data = json.loads(str(y.docid).strip())
                doc_data["score"] = y.score
                doc_data["text"] = str(y.raw).strip()
                element.append(doc_data)
            except Exception as e:
                print(e)
                element.append(
                    {
                        "score": y.score,
                        "text": str(y.raw).strip(),
                        "title": y.docid,
                    }
                )
        provenance[query_id] = element

    return provenance
コード例 #26
0
def main():
    try:
        # Location of the generated index
        index_loc = "indexes/msmarco-passage/lucene-index-msmarco"

        # Create a searcher object
        searcher = SimpleSearcher(index_loc)
        # Set the active scorer to BM25
        searcher.set_bm25(k1=0.9, b=0.4)
        # Fetch 3 results for the given test query
        results = searcher.search('this is a test query', k=3)
        # For all results print the docid and the score
        expected = ['5578280', '2016011', '7004677']
        docids = [x.docid for x in results]
        if expected != docids:
            raise Exception('Test query results do not match expected:',
                            expected, '(expecteD)', docids, '(actual)')
        # IndexReader can give information about the index
        indexer = IndexReader(index_loc)
        if indexer.stats()['total_terms'] != 352316036:
            raise Exception(
                'There are an unexpected number of terms in your index set, perhaps something went wrong while downloading and indexing the dataset?'
            )
        topics = get_topics("msmarco-passage-dev-subset")
        if topics == {}:
            raise Exception(
                'Could not find msmarco-passage-dev-subset... Best approach is to retry indexing the dataset.'
            )
        first_query = topics[list(topics.keys())[0]]['title']
        if first_query != "why do people grind teeth in sleep":
            raise Exception(
                'Found a different first query than expected in the dataset. Did you download the right dataset?'
            )
        # Using the pyserini tokenizer/stemmer/etc. to create queries from scratch
        # Using the pyserini tokenizer/stemmer/etc. to create queries from scratch
        query = "This is a test query in which things are tested. Found using www.google.com of course!"
        # Tokenizing in pyserini is called Analyzing
        output = indexer.analyze(query)
        if len(output) != 9:
            raise Exception(
                'Tokenizer is not working correctly, something is probably wrong in Anserini. Perhaps try to install Anserini again.'
            )
    except Exception as inst:
        print('ERROR: something went wrong in the installation')
        print(inst)
    else:
        print("INSTALLATION OK")
コード例 #27
0
        def __init__(self, candidates, num_candidates_samples, path_index, sample_data, anserini_folder, set_rm3=False, seed=42):
            random.seed(seed)
            self.candidates = candidates
            self.num_candidates_samples = num_candidates_samples
            self.path_index  = path_index
            if set_rm3:
                self.name = "BM25RM3NS"
            else:
                self.name = "BM25NS"
            self.sample_data = sample_data
            self.anserini_folder = anserini_folder
            self._create_index()

            self.searcher = SimpleSearcher(self.path_index+"anserini_index")
            self.searcher.set_bm25(0.9, 0.4)
            if set_rm3:
                self.searcher.set_rm3()
コード例 #28
0
ファイル: BaselineSearcher.py プロジェクト: nz63paxe/IR
class BaselineSearcher:

    def __init__(self, index_path):
        self.searcher = SimpleSearcher(index_path)
        self.searcher.set_qld() # use Dirichlet
        self.name = "Baseline"

    def get_name(self):
        return self.name

    def search(self, query, max_amount = 10):
        hits = self.searcher.search(query)[:max_amount]
        return hits

    def get_argument(self, id):
        arg = json.loads(self.searcher.doc(id).raw())
        return arg
コード例 #29
0
ファイル: rerank_with_maxp.py プロジェクト: yuki617/pyserini
def main(args):
    if args.cache and not os.path.exists(args.cache):
        os.mkdir(args.cache)

    # Load queries:
    queries = load_queries(args.queries)
    # Load base run to rerank:
    base_run = TrecRun(args.input)

    # SimpleSearcher to fetch document texts.
    searcher = SimpleSearcher.from_prebuilt_index('msmarco-doc')

    output = []

    if args.bm25:
        reranker = 'bm25'
    elif args.ance:
        reranker = 'ance'
    elif not args.identity:
        sys.exit('Unknown reranking method!')

    cnt = 1
    for row in queries:
        qid = int(row[0])
        query = row[1]
        print(f'{cnt} {qid} {query}')
        qid_results = base_run.get_docs_by_topic(qid)

        # Don't actually do reranking, just pass along the base run:
        if args.identity:
            rank = 1
            for docid in qid_results['docid'].tolist():
                output.append([qid, docid, rank])
                rank = rank + 1
            cnt = cnt + 1
            continue

        # Gather results for reranking:
        results_to_rerank = []
        for index, result in qid_results.iterrows():
            raw_doc = searcher.doc(
                result['docid']).raw().lstrip('<TEXT>').rstrip('</TEXT>')
            results_to_rerank.append({
                'docid': result['docid'],
                'rank': result['rank'],
                'score': result['score'],
                'text': raw_doc
            })

        # Perform the actual reranking:
        output.extend(
            rerank(args.cache, qid, query, results_to_rerank, reranker))
        cnt = cnt + 1

    # Write the output run file:
    with open(args.output, 'w') as writer:
        for r in output:
            writer.write(f'{r[0]}\t{r[1]}\t{r[2]}\n')
コード例 #30
0
    def __init__(self, name, num_threads, index_dir=None, k1=0.9, b=0.4, use_bigrams=False, stem_bigrams=False):
        super().__init__(name)

        self.num_threads = min(num_threads, int(multiprocessing.cpu_count()))

        # initialize a ranker per thread
        self.arguments = []
        for id in tqdm(range(self.num_threads)):
            ranker = SimpleSearcher(index_dir)
            ranker.set_bm25(k1, b)
            self.arguments.append(
                {
                    "id": id,
                    "ranker": ranker,
                    "use_bigrams": use_bigrams,
                    "stem_bigrams": stem_bigrams
                }
            )