Ejemplo n.º 1
0
class MsMarcoPassageLoader:
    def __init__(self, index_path: str):
        self.searcher = SimpleSearcher(index_path)

    def load_passage(self, id: str) -> MsMarcoPassage:
        try:
            passage = self.searcher.doc(id).lucene_document().get('raw')
        except AttributeError:
            raise ValueError('passage unretrievable')
        return MsMarcoPassage(passage)
Ejemplo n.º 2
0
class SearchResultFormatter:
    def __init__(self, index):
        self._index = SimpleSearcher(index)
        self._query = None
        self._docids = []
        self._doc_content = []
        self._doc_scores = []
        self._doc_embeddings = []

    def pyserini_search_result(self, hits, query):
        self.clear()
        for hit in hits:
            self._docids.append(hit.docid)
            self._doc_scores.append(hit.score)
            self._doc_content.append(hit.raw)
        return query, zip(self._docids, self._doc_content, self._doc_scores), None

    def hnswlib_search_result(self, labels, distances, query_embedding, doc_embeddings):
        self.clear()
        # since only one query at a time:
        labels = labels[0]
        distances = distances[0]
        for label, distance in zip(labels, distances):
            self._docids.append(label)
            self._doc_scores.append(1.0 - distance)
            self._doc_content.append(self._index.doc(label).get('raw'))
            #self._doc_embeddings.append(doc_embeddings)
        return query_embedding, zip(self._docids, self._doc_content, self._doc_scores), doc_embeddings

    def get_doc(self, did):
        doc = self._index.doc(did)
        return doc.raw()

    def clear(self):
        self._docids = []
        self._doc_content = []
        self._doc_scores = []
        self._doc_embeddings = []
        self._query = None
Ejemplo n.º 3
0
class RM3Searcher:
    def __init__(self, index_path):
        self.searcher = SimpleSearcher(index_path)
        self.searcher.set_qld()  # use Dirichlet
        self.searcher.set_rm3(10, 10, 0.5, True)
        self.name = "RM3"

    def get_name(self):
        return self.name

    def search(self, query, max_amount=10):
        hits = self.searcher.search(query)[:max_amount]
        return hits

    def get_argument(self, id):
        arg = json.loads(self.searcher.doc(id).raw())
        return arg
Ejemplo n.º 4
0
class Cord19AbstractLoader:
    double_space_pattern = re.compile(r'\s\s+')

    def __init__(self, index_path: str):
        self.searcher = SimpleSearcher(index_path)

    @lru_cache(maxsize=1024)
    def load_document(self, id: str) -> Cord19Document:
        try:
            article = json.loads(
                self.searcher.doc(id).lucene_document().get('raw'))
        except json.decoder.JSONDecodeError:
            raise ValueError('article not found')
        except AttributeError as e:
            logging.error(e)
            raise ValueError('document unretrievable')
        return Cord19Abstract(article['csv_metadata']['title'],
                              abstract=article['csv_metadata']['abstract'] if 'abstract' in article else '')
Ejemplo n.º 5
0
def extract_documents():
    experiments = ["run.cw12.bm25+rm3", "run.cw12.bm25"]
    # searcher = SimpleSearcher('/data/anserini/lucene-index.gov2.pos+docvectors+rawdocs')
    # searcher = SimpleSearcher.from_prebuilt_index('robust04')
    searcher = SimpleSearcher(
        '/data/anserini/lucene-index.cw12b13.pos+docvectors+rawdocs')

    for experiment in experiments:
        file_address = "../data/cw12/" + experiment + ".txt"
        with open(file_address, "r") as index_file:
            if not os.path.exists("../data/cw12/" + experiment):
                os.makedirs("../data/cw12/" + experiment)
            for line_number, line in enumerate(index_file):
                # print(line.split(" ")[3])
                idx = line.split(" ")[2]
                write_address = "../data/cw12/" + experiment + "/" + idx + ".txt"
                doc = searcher.doc(idx)
                with open(write_address, "w") as file_to_write:
                    file_to_write.write(doc.raw())
                if line_number % 1000 == 0:
                    print(line_number)
Ejemplo n.º 6
0
class Cord19DocumentLoader:
    double_space_pattern = re.compile(r'\s\s+')

    def __init__(self, index_path: str):
        self.searcher = SimpleSearcher(index_path)

    @lru_cache(maxsize=1024)
    def load_document(self, id: str) -> Cord19Document:
        def unfold(entries):
            return '\n'.join(x['text'] for x in entries)
        try:
            article = json.loads(
                self.searcher.doc(id).lucene_document().get('raw'))
        except json.decoder.JSONDecodeError:
            raise ValueError('article not found')
        except AttributeError:
            raise ValueError('document unretrievable')
        ref_entries = article['ref_entries'].values()
        return Cord19Document(unfold(article['body_text']),
                              unfold(ref_entries),
                              abstract=unfold(article['abstract']) if 'abstract' in article else '')
Ejemplo n.º 7
0
def extract_expanded_documents():
    experiment = "unbiased_expansions"
    searcher = SimpleSearcher(
        '/data/anserini/lucene-index.cw12b13.pos+docvectors+rawdocs')

    # searcher = SimpleSearcher.from_prebuilt_index('robust04')
    lamdas = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    for my_lambda in lamdas:
        print(my_lambda)
        my_directory = "../data/cw12/"+experiment +\
                      "/expanded_landa_"+str(my_lambda)
        file_address = my_directory + ".txt"
        with open(file_address, "r") as index_file:
            if not os.path.exists(my_directory):
                os.makedirs(my_directory)
            for line_number, line in enumerate(index_file):
                # print(line.split(" ")[3])
                idx = line.split(" ")[2]
                write_address = my_directory + "/" + idx + ".txt"
                doc = searcher.doc(idx)
                with open(write_address, "w") as file_to_write:
                    file_to_write.write(doc.raw())
with open(args.t5_input, 'w') as fout_t5, open(args.t5_input_ids,
                                               'w') as fout_tsv:
    for num_examples, (query_id, candidate_doc_ids) in enumerate(
            tqdm(run.items(), total=len(run))):
        query = queries[query_id]
        seen = {}
        for candidate_doc_id in candidate_doc_ids:
            if candidate_doc_id.split("#")[0] in seen:
                passage, ind_desc = seen[candidate_doc_id.split("#")[0]]
                candidate_doc_id = candidate_doc_id.split("#")[0]
            else:
                if args.year == 2020:
                    try:
                        candidate_doc_id, ind_desc = candidate_doc_id.split(
                            "#")
                        content = index.doc(
                            f"<urn:uuid:{candidate_doc_id}>").contents()
                        ind_desc = int(ind_desc)
                    except:
                        print(candidate_doc_id)
                        content = ""
                        ind_desc = 0
                else:
                    try:
                        candidate_doc_id, ind_desc = candidate_doc_id.split(
                            "#")
                        ind_desc = int(ind_desc)
                        content = index.doc(candidate_doc_id).raw()
                    except:
                        print(candidate_doc_id)
                        content = ""
                        ind_desc = 0
Ejemplo n.º 9
0
    args = parser.parse_args()

    plt.switch_backend('agg')

    searcher = SimpleSearcher(args.index)

    # Determine how many documents to iterate over:
    if args.max:
        num_docs = args.max if args.max < searcher.num_docs else searcher.num_docs
    else:
        num_docs = searcher.num_docs

    print(f'Computing lengths for {num_docs} from {args.index}')
    doclengths = []
    for i in tqdm(range(num_docs)):
        doclengths.append(len(searcher.doc(i).raw().split()))

    doclengths = np.asarray(doclengths)

    # Compute bins:
    bins = np.arange(args.bin_min, args.bin_max + args.bin_width,
                     args.bin_width)

    # If user wants raw output of counts:
    if args.output:
        counts, bins = np.histogram(doclengths, bins=bins)
        np.savetxt(f'{args.output}-counts.txt', counts, fmt="%s")
        np.savetxt(f'{args.output}-bins.txt', counts, fmt="%s")

    # If user wants plot:
    if args.plot:
Ejemplo n.º 10
0
class TestSearch(unittest.TestCase):
    def setUp(self):
        # Download pre-built CACM index; append a random value to avoid filename clashes.
        r = randint(0, 10000000)
        self.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene-index.cacm.tar.gz'
        self.tarball_name = 'lucene-index.cacm-{}.tar.gz'.format(r)
        self.index_dir = 'index{}/'.format(r)

        filename, headers = urlretrieve(self.collection_url, self.tarball_name)

        tarball = tarfile.open(self.tarball_name)
        tarball.extractall(self.index_dir)
        tarball.close()

        self.searcher = SimpleSearcher(f'{self.index_dir}lucene-index.cacm')

    def test_basic(self):
        self.assertTrue(
            self.searcher.get_similarity().toString().startswith('BM25'))

        hits = self.searcher.search('information retrieval')

        self.assertEqual(3204, self.searcher.num_docs)
        self.assertTrue(isinstance(hits, List))

        self.assertTrue(isinstance(hits[0], JSimpleSearcherResult))
        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertEqual(hits[0].lucene_docid, 3133)
        self.assertEqual(len(hits[0].contents), 1500)
        self.assertEqual(len(hits[0].raw), 1532)
        self.assertAlmostEqual(hits[0].score, 4.76550, places=5)

        # Test accessing the raw Lucene document and fetching fields from it:
        self.assertEqual(hits[0].lucene_document.getField('id').stringValue(),
                         'CACM-3134')
        self.assertEqual(hits[0].lucene_document.get('id'),
                         'CACM-3134')  # simpler call, same result as above
        self.assertEqual(
            len(hits[0].lucene_document.getField('raw').stringValue()), 1532)
        self.assertEqual(len(hits[0].lucene_document.get('raw')),
                         1532)  # simpler call, same result as above

        self.assertTrue(isinstance(hits[9], JSimpleSearcherResult))
        self.assertEqual(hits[9].docid, 'CACM-2516')
        self.assertAlmostEqual(hits[9].score, 4.21740, places=5)

        hits = self.searcher.search('search')

        self.assertTrue(isinstance(hits[0], JSimpleSearcherResult))
        self.assertEqual(hits[0].docid, 'CACM-3058')
        self.assertAlmostEqual(hits[0].score, 2.85760, places=5)

        self.assertTrue(isinstance(hits[9], JSimpleSearcherResult))
        self.assertEqual(hits[9].docid, 'CACM-3040')
        self.assertAlmostEqual(hits[9].score, 2.68780, places=5)

    def test_batch(self):
        results = self.searcher.batch_search(
            ['information retrieval', 'search'], ['q1', 'q2'], threads=2)

        self.assertEqual(3204, self.searcher.num_docs)
        self.assertTrue(isinstance(results, Dict))

        self.assertTrue(isinstance(results['q1'], List))
        self.assertTrue(isinstance(results['q1'][0], JSimpleSearcherResult))
        self.assertEqual(results['q1'][0].docid, 'CACM-3134')
        self.assertAlmostEqual(results['q1'][0].score, 4.76550, places=5)

        self.assertTrue(isinstance(results['q1'][9], JSimpleSearcherResult))
        self.assertEqual(results['q1'][9].docid, 'CACM-2516')
        self.assertAlmostEqual(results['q1'][9].score, 4.21740, places=5)

        self.assertTrue(isinstance(results['q2'], List))
        self.assertTrue(isinstance(results['q2'][0], JSimpleSearcherResult))
        self.assertEqual(results['q2'][0].docid, 'CACM-3058')
        self.assertAlmostEqual(results['q2'][0].score, 2.85760, places=5)

        self.assertTrue(isinstance(results['q2'][9], JSimpleSearcherResult))
        self.assertEqual(results['q2'][9].docid, 'CACM-3040')
        self.assertAlmostEqual(results['q2'][9].score, 2.68780, places=5)

    def test_basic_k(self):
        hits = self.searcher.search('information retrieval', k=100)

        self.assertEqual(3204, self.searcher.num_docs)
        self.assertTrue(isinstance(hits, List))
        self.assertTrue(isinstance(hits[0], JSimpleSearcherResult))
        self.assertEqual(len(hits), 100)

    def test_batch_k(self):
        results = self.searcher.batch_search(
            ['information retrieval', 'search'], ['q1', 'q2'],
            k=100,
            threads=2)

        self.assertEqual(3204, self.searcher.num_docs)
        self.assertTrue(isinstance(results, Dict))
        self.assertTrue(isinstance(results['q1'], List))
        self.assertTrue(isinstance(results['q1'][0], JSimpleSearcherResult))
        self.assertEqual(len(results['q1']), 100)
        self.assertTrue(isinstance(results['q2'], List))
        self.assertTrue(isinstance(results['q2'][0], JSimpleSearcherResult))
        self.assertEqual(len(results['q2']), 100)

    def test_basic_fields(self):
        # This test just provides a sanity check, it's not that interesting as it only searches one field.
        hits = self.searcher.search('information retrieval',
                                    k=42,
                                    fields={'contents': 2.0})

        self.assertEqual(3204, self.searcher.num_docs)
        self.assertTrue(isinstance(hits, List))
        self.assertTrue(isinstance(hits[0], JSimpleSearcherResult))
        self.assertEqual(len(hits), 42)

    def test_batch_fields(self):
        # This test just provides a sanity check, it's not that interesting as it only searches one field.
        results = self.searcher.batch_search(
            ['information retrieval', 'search'], ['q1', 'q2'],
            k=42,
            threads=2,
            fields={'contents': 2.0})

        self.assertEqual(3204, self.searcher.num_docs)
        self.assertTrue(isinstance(results, Dict))
        self.assertTrue(isinstance(results['q1'], List))
        self.assertTrue(isinstance(results['q1'][0], JSimpleSearcherResult))
        self.assertEqual(len(results['q1']), 42)
        self.assertTrue(isinstance(results['q2'], List))
        self.assertTrue(isinstance(results['q2'][0], JSimpleSearcherResult))
        self.assertEqual(len(results['q2']), 42)

    def test_different_similarity(self):
        # qld, default mu
        self.searcher.set_qld()
        self.assertTrue(self.searcher.get_similarity().toString().startswith(
            'LM Dirichlet'))

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 3.68030, places=5)
        self.assertEqual(hits[9].docid, 'CACM-1927')
        self.assertAlmostEqual(hits[9].score, 2.53240, places=5)

        # bm25, default parameters
        self.searcher.set_bm25()
        self.assertTrue(
            self.searcher.get_similarity().toString().startswith('BM25'))

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 4.76550, places=5)
        self.assertEqual(hits[9].docid, 'CACM-2516')
        self.assertAlmostEqual(hits[9].score, 4.21740, places=5)

        # qld, custom mu
        self.searcher.set_qld(100)
        self.assertTrue(self.searcher.get_similarity().toString().startswith(
            'LM Dirichlet'))

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 6.35580, places=5)
        self.assertEqual(hits[9].docid, 'CACM-2631')
        self.assertAlmostEqual(hits[9].score, 5.18960, places=5)

        # bm25, custom parameters
        self.searcher.set_bm25(0.8, 0.3)
        self.assertTrue(
            self.searcher.get_similarity().toString().startswith('BM25'))

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 4.86880, places=5)
        self.assertEqual(hits[9].docid, 'CACM-2516')
        self.assertAlmostEqual(hits[9].score, 4.33320, places=5)

    def test_rm3(self):
        self.searcher.set_rm3()
        self.assertTrue(self.searcher.is_using_rm3())

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 2.18010, places=5)
        self.assertEqual(hits[9].docid, 'CACM-2516')
        self.assertAlmostEqual(hits[9].score, 1.70330, places=5)

        self.searcher.unset_rm3()
        self.assertFalse(self.searcher.is_using_rm3())

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 4.76550, places=5)
        self.assertEqual(hits[9].docid, 'CACM-2516')
        self.assertAlmostEqual(hits[9].score, 4.21740, places=5)

        self.searcher.set_rm3(fb_docs=4, fb_terms=6, original_query_weight=0.3)
        self.assertTrue(self.searcher.is_using_rm3())

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 2.17190, places=5)
        self.assertEqual(hits[9].docid, 'CACM-1457')
        self.assertAlmostEqual(hits[9].score, 1.43700, places=5)

    def test_doc_int(self):
        # The doc method is overloaded: if input is int, it's assumed to be a Lucene internal docid.
        doc = self.searcher.doc(1)
        self.assertTrue(isinstance(doc, Document))

        # These are all equivalent ways to get the docid.
        self.assertEqual('CACM-0002', doc.id())
        self.assertEqual('CACM-0002', doc.docid())
        self.assertEqual('CACM-0002', doc.get('id'))
        self.assertEqual('CACM-0002',
                         doc.lucene_document().getField('id').stringValue())

        # These are all equivalent ways to get the 'raw' field
        self.assertEqual(186, len(doc.raw()))
        self.assertEqual(186, len(doc.get('raw')))
        self.assertEqual(186, len(doc.lucene_document().get('raw')))
        self.assertEqual(
            186, len(doc.lucene_document().getField('raw').stringValue()))

        # These are all equivalent ways to get the 'contents' field
        self.assertEqual(154, len(doc.contents()))
        self.assertEqual(154, len(doc.get('contents')))
        self.assertEqual(154, len(doc.lucene_document().get('contents')))
        self.assertEqual(
            154, len(doc.lucene_document().getField('contents').stringValue()))

        # Should return None if we request a docid that doesn't exist
        self.assertTrue(self.searcher.doc(314159) is None)

    def test_doc_str(self):
        # The doc method is overloaded: if input is str, it's assumed to be an external collection docid.
        doc = self.searcher.doc('CACM-0002')
        self.assertTrue(isinstance(doc, Document))

        # These are all equivalent ways to get the docid.
        self.assertEqual(doc.lucene_document().getField('id').stringValue(),
                         'CACM-0002')
        self.assertEqual(doc.id(), 'CACM-0002')
        self.assertEqual(doc.docid(), 'CACM-0002')
        self.assertEqual(doc.get('id'), 'CACM-0002')

        # These are all equivalent ways to get the 'raw' field
        self.assertEqual(186, len(doc.raw()))
        self.assertEqual(186, len(doc.get('raw')))
        self.assertEqual(186, len(doc.lucene_document().get('raw')))
        self.assertEqual(
            186, len(doc.lucene_document().getField('raw').stringValue()))

        # These are all equivalent ways to get the 'contents' field
        self.assertEqual(154, len(doc.contents()))
        self.assertEqual(154, len(doc.get('contents')))
        self.assertEqual(154, len(doc.lucene_document().get('contents')))
        self.assertEqual(
            154, len(doc.lucene_document().getField('contents').stringValue()))

        # Should return None if we request a docid that doesn't exist
        self.assertTrue(self.searcher.doc('foo') is None)

    def test_doc_by_field(self):
        self.assertEqual(
            self.searcher.doc('CACM-3134').docid(),
            self.searcher.doc_by_field('id', 'CACM-3134').docid())

        # Should return None if we request a docid that doesn't exist
        self.assertTrue(self.searcher.doc_by_field('foo', 'bar') is None)

    def tearDown(self):
        self.searcher.close()
        os.remove(self.tarball_name)
        shutil.rmtree(self.index_dir)
        searcher = SimpleSearcher(args.index)
    else:
        searcher = SimpleSearcher.from_prebuilt_index(args.index)
    if not searcher:
        exit()

    retrieval = {}
    with open(args.input) as f_in:
        for line in tqdm(f_in.readlines()):
            question_id, _, doc_id, _, score, _ = line.strip().split()
            question_id = int(question_id)
            question = qas[question_id]['title']
            answers = qas[question_id]['answers']
            if answers[0] == '"':
                answers = answers[1:-1].replace('""', '"')
            answers = eval(answers)
            ctx = json.loads(searcher.doc(doc_id).raw())['contents']
            if question_id not in retrieval:
                retrieval[question_id] = {
                    'question': question,
                    'answers': answers,
                    'contexts': []
                }
            retrieval[question_id]['contexts'].append({
                'docid': doc_id,
                'score': score,
                'text': ctx
            })

    json.dump(retrieval, open(args.output, 'w'), indent=4)
with open(args.t5_input, 'w') as fout_t5, open(args.t5_input_ids,
                                               'w') as fout_tsv:
    for num_examples, (query_id, candidate_doc_ids) in enumerate(
            tqdm(run.items(), total=len(run))):
        query = queries[query_id]
        seen = {}
        for candidate_doc_id in candidate_doc_ids:
            if candidate_doc_id.split("#")[0] in seen:
                passage, ind_desc = seen[candidate_doc_id.split("#")[0]]
                candidate_doc_id = candidate_doc_id.split("#")[0]
            else:
                if args.year == 2020:
                    try:
                        candidate_doc_id, ind_desc = candidate_doc_id.split(
                            "#")
                        content = index.doc(
                            f"<urn:uuid:{candidate_doc_id}>").contents()
                        ind_desc = int(ind_desc)
                    except:
                        print(candidate_doc_id)
                        content = ""
                        ind_desc = 0
                elif args.year == 2021:
                    candidate_doc_id, ind_desc = candidate_doc_id.split("#")
                    ind_desc = int(ind_desc)
                    content = json.loads(
                        index.doc(str(candidate_doc_id)).raw())
                    content = content['text']
                else:
                    try:
                        candidate_doc_id, ind_desc = candidate_doc_id.split(
                            "#")
Ejemplo n.º 13
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
import json
import sys

# We're going to explicitly use a local installation of Pyserini (as opposed to a pip-installed one).
# Comment these lines out to use a pip-installed one instead.
sys.path.insert(0, './')

from pyserini.search import SimpleSearcher


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--qrels', type=str, help='qrels file', required=True)
    parser.add_argument('--index', type=str, help='index location', required=True)
    args = parser.parse_args()

    searcher = SimpleSearcher(args.index)
    with open(args.qrels, 'r') as reader:
        for line in reader.readlines():
            arr = line.split('\t')
            doc = json.loads(searcher.doc(arr[2]).raw())['contents']
            print(f'{arr[2]}\t{doc}')
Ejemplo n.º 14
0
class CAsT_RawDataLoader():
    def __init__(
            self,
            CAsT_index="/nfs/phd_by_carlos/notebooks/datasets/TREC_CAsT/CAsT_collection_with_meta.index",
            treccastweb_dir="/nfs/phd_by_carlos/notebooks/datasets/TREC_CAsT/treccastweb",
            NIST_qrels=[
                "/nfs/phd_by_carlos/notebooks/datasets/TREC_CAsT/2019qrels.txt",
                '/nfs/phd_by_carlos/notebooks/datasets/TREC_CAsT/2020qrels.txt'
            ],
            **kwargs):

        self.searcher = SimpleSearcher(CAsT_index)
        self.q_rels = {}
        for q_rel_file in NIST_qrels:
            with open(q_rel_file) as NIST_fp:
                for line in NIST_fp.readlines():
                    q_id, _, d_id, score = line.split(" ")
                    if int(score) < 3:
                        # ignore some of the worst ranked
                        continue
                    if q_id not in self.q_rels:
                        self.q_rels[q_id] = []
                    self.q_rels[q_id].append(d_id)

        with open(
                os.path.join(
                    treccastweb_dir,
                    "2020/2020_manual_evaluation_topics_v1.0.json")) as y2_fp:
            y2_data = json.load(y2_fp)
            for topic in y2_data:
                topic_id = topic["number"]
                for turn in topic["turn"]:
                    turn_id = turn["number"]
                    q_id = f"{topic_id}_{turn_id}"
                    if q_id not in self.q_rels:
                        self.q_rels[q_id] = []
                    self.q_rels[q_id].append(
                        turn["manual_canonical_result_id"])

        year1_query_collection, self.year1_topics = self.load_CAsT_topics_file(
            os.path.join(treccastweb_dir,
                         "2019/data/evaluation/evaluation_topics_v1.0.json"))
        year2_query_collection, self.year2_topics = self.load_CAsT_topics_file(
            os.path.join(treccastweb_dir,
                         "2020/2020_manual_evaluation_topics_v1.0.json"))

        self.query_collection = {
            **year1_query_collection,
            **year2_query_collection
        }

        with open(
                os.path.join(
                    treccastweb_dir,
                    "2019/data/evaluation/evaluation_topics_annotated_resolved_v1.0.tsv"
                )) as resolved_f:
            reader = csv.reader(resolved_f, delimiter="\t")
            for row in reader:
                q_id, resolved_query = row
                if q_id in self.query_collection:
                    self.query_collection[q_id][
                        "manual_rewritten_utterance"] = resolved_query

    def NIST_result_curve(self, score):
        "0->0, 1~>0.1, 3~>0.6, 4->1"
        return (1 / 16) * (score**2)

    def load_CAsT_topics_file(self, file):
        query_collection = {}
        topics = {}
        with open(file) as topics_fp:
            topics_data = json.load(topics_fp)
            for topic in topics_data:
                previous_turns = []
                topic_id = topic["number"]
                for turn in topic["turn"]:
                    turn_id = turn["number"]
                    q_id = f"{topic_id}_{turn_id}"
                    #                     if q_id not in self.q_rels:
                    #                         continue
                    query_collection[q_id] = turn
                    topics[q_id] = previous_turns[:]
                    previous_turns.append(q_id)
            return query_collection, topics

    def get_doc(self, doc_id):
        raw_text = self.searcher.doc(doc_id).raw()
        paragraph = raw_text[raw_text.find('<BODY>\n') +
                             7:raw_text.find('\n</BODY>')]
        return paragraph

    def get_split(self, split):
        '''
        return: dict: {q_id: [q_id,...]}: contains the query turns and the previous query turns for the topic
        '''
        dev_topic_cutoff = 71
        if split == "train":
            return {
                k: v
                for k, v in self.year1_topics.items()
                if int(k.split("_")[0]) < dev_topic_cutoff
            }
        elif split == "dev":
            return {
                k: v
                for k, v in self.year1_topics.items()
                if int(k.split("_")[0]) >= dev_topic_cutoff
            }
        elif split == "all":
            return self.year1_topics
        elif split == "eval":
            return self.year2_topics
        else:
            raise Exception(f"Split '{split}' not recognised")

    def get_topics(self, split, ignore_missing_q_rels=False):
        '''
        split: str: "train", "dev", "all", "eval"
        returns: [dict]: [{'q_id':"32_4", 'q_rel':["CAR_xxx",..]}, 'prev_turns':["32_3",..],...]
        '''
        topic_split = self.get_split(split)
        samples = []
        samples = [{
            'prev_turns': prev_turns,
            'q_id': q_id,
            'q_rel': self.q_rels.get(q_id)
        } for q_id, prev_turns in topic_split.items()
                   if q_id in self.q_rels or ignore_missing_q_rels]
        return samples

    def get_query(self, q_id, utterance_type="raw_utterance"):
        '''
        >>> raw_data_loader.get_query("31_4", utterance_type="manual_rewritten_utterance")
        '''
        return self.query_collection[q_id][utterance_type]