class MsMarcoPassageLoader: def __init__(self, index_path: str): self.searcher = SimpleSearcher(index_path) def load_passage(self, id: str) -> MsMarcoPassage: try: passage = self.searcher.doc(id).lucene_document().get('raw') except AttributeError: raise ValueError('passage unretrievable') return MsMarcoPassage(passage)
class SearchResultFormatter: def __init__(self, index): self._index = SimpleSearcher(index) self._query = None self._docids = [] self._doc_content = [] self._doc_scores = [] self._doc_embeddings = [] def pyserini_search_result(self, hits, query): self.clear() for hit in hits: self._docids.append(hit.docid) self._doc_scores.append(hit.score) self._doc_content.append(hit.raw) return query, zip(self._docids, self._doc_content, self._doc_scores), None def hnswlib_search_result(self, labels, distances, query_embedding, doc_embeddings): self.clear() # since only one query at a time: labels = labels[0] distances = distances[0] for label, distance in zip(labels, distances): self._docids.append(label) self._doc_scores.append(1.0 - distance) self._doc_content.append(self._index.doc(label).get('raw')) #self._doc_embeddings.append(doc_embeddings) return query_embedding, zip(self._docids, self._doc_content, self._doc_scores), doc_embeddings def get_doc(self, did): doc = self._index.doc(did) return doc.raw() def clear(self): self._docids = [] self._doc_content = [] self._doc_scores = [] self._doc_embeddings = [] self._query = None
class RM3Searcher: def __init__(self, index_path): self.searcher = SimpleSearcher(index_path) self.searcher.set_qld() # use Dirichlet self.searcher.set_rm3(10, 10, 0.5, True) self.name = "RM3" def get_name(self): return self.name def search(self, query, max_amount=10): hits = self.searcher.search(query)[:max_amount] return hits def get_argument(self, id): arg = json.loads(self.searcher.doc(id).raw()) return arg
class Cord19AbstractLoader: double_space_pattern = re.compile(r'\s\s+') def __init__(self, index_path: str): self.searcher = SimpleSearcher(index_path) @lru_cache(maxsize=1024) def load_document(self, id: str) -> Cord19Document: try: article = json.loads( self.searcher.doc(id).lucene_document().get('raw')) except json.decoder.JSONDecodeError: raise ValueError('article not found') except AttributeError as e: logging.error(e) raise ValueError('document unretrievable') return Cord19Abstract(article['csv_metadata']['title'], abstract=article['csv_metadata']['abstract'] if 'abstract' in article else '')
def extract_documents(): experiments = ["run.cw12.bm25+rm3", "run.cw12.bm25"] # searcher = SimpleSearcher('/data/anserini/lucene-index.gov2.pos+docvectors+rawdocs') # searcher = SimpleSearcher.from_prebuilt_index('robust04') searcher = SimpleSearcher( '/data/anserini/lucene-index.cw12b13.pos+docvectors+rawdocs') for experiment in experiments: file_address = "../data/cw12/" + experiment + ".txt" with open(file_address, "r") as index_file: if not os.path.exists("../data/cw12/" + experiment): os.makedirs("../data/cw12/" + experiment) for line_number, line in enumerate(index_file): # print(line.split(" ")[3]) idx = line.split(" ")[2] write_address = "../data/cw12/" + experiment + "/" + idx + ".txt" doc = searcher.doc(idx) with open(write_address, "w") as file_to_write: file_to_write.write(doc.raw()) if line_number % 1000 == 0: print(line_number)
class Cord19DocumentLoader: double_space_pattern = re.compile(r'\s\s+') def __init__(self, index_path: str): self.searcher = SimpleSearcher(index_path) @lru_cache(maxsize=1024) def load_document(self, id: str) -> Cord19Document: def unfold(entries): return '\n'.join(x['text'] for x in entries) try: article = json.loads( self.searcher.doc(id).lucene_document().get('raw')) except json.decoder.JSONDecodeError: raise ValueError('article not found') except AttributeError: raise ValueError('document unretrievable') ref_entries = article['ref_entries'].values() return Cord19Document(unfold(article['body_text']), unfold(ref_entries), abstract=unfold(article['abstract']) if 'abstract' in article else '')
def extract_expanded_documents(): experiment = "unbiased_expansions" searcher = SimpleSearcher( '/data/anserini/lucene-index.cw12b13.pos+docvectors+rawdocs') # searcher = SimpleSearcher.from_prebuilt_index('robust04') lamdas = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] for my_lambda in lamdas: print(my_lambda) my_directory = "../data/cw12/"+experiment +\ "/expanded_landa_"+str(my_lambda) file_address = my_directory + ".txt" with open(file_address, "r") as index_file: if not os.path.exists(my_directory): os.makedirs(my_directory) for line_number, line in enumerate(index_file): # print(line.split(" ")[3]) idx = line.split(" ")[2] write_address = my_directory + "/" + idx + ".txt" doc = searcher.doc(idx) with open(write_address, "w") as file_to_write: file_to_write.write(doc.raw())
with open(args.t5_input, 'w') as fout_t5, open(args.t5_input_ids, 'w') as fout_tsv: for num_examples, (query_id, candidate_doc_ids) in enumerate( tqdm(run.items(), total=len(run))): query = queries[query_id] seen = {} for candidate_doc_id in candidate_doc_ids: if candidate_doc_id.split("#")[0] in seen: passage, ind_desc = seen[candidate_doc_id.split("#")[0]] candidate_doc_id = candidate_doc_id.split("#")[0] else: if args.year == 2020: try: candidate_doc_id, ind_desc = candidate_doc_id.split( "#") content = index.doc( f"<urn:uuid:{candidate_doc_id}>").contents() ind_desc = int(ind_desc) except: print(candidate_doc_id) content = "" ind_desc = 0 else: try: candidate_doc_id, ind_desc = candidate_doc_id.split( "#") ind_desc = int(ind_desc) content = index.doc(candidate_doc_id).raw() except: print(candidate_doc_id) content = "" ind_desc = 0
args = parser.parse_args() plt.switch_backend('agg') searcher = SimpleSearcher(args.index) # Determine how many documents to iterate over: if args.max: num_docs = args.max if args.max < searcher.num_docs else searcher.num_docs else: num_docs = searcher.num_docs print(f'Computing lengths for {num_docs} from {args.index}') doclengths = [] for i in tqdm(range(num_docs)): doclengths.append(len(searcher.doc(i).raw().split())) doclengths = np.asarray(doclengths) # Compute bins: bins = np.arange(args.bin_min, args.bin_max + args.bin_width, args.bin_width) # If user wants raw output of counts: if args.output: counts, bins = np.histogram(doclengths, bins=bins) np.savetxt(f'{args.output}-counts.txt', counts, fmt="%s") np.savetxt(f'{args.output}-bins.txt', counts, fmt="%s") # If user wants plot: if args.plot:
class TestSearch(unittest.TestCase): def setUp(self): # Download pre-built CACM index; append a random value to avoid filename clashes. r = randint(0, 10000000) self.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene-index.cacm.tar.gz' self.tarball_name = 'lucene-index.cacm-{}.tar.gz'.format(r) self.index_dir = 'index{}/'.format(r) filename, headers = urlretrieve(self.collection_url, self.tarball_name) tarball = tarfile.open(self.tarball_name) tarball.extractall(self.index_dir) tarball.close() self.searcher = SimpleSearcher(f'{self.index_dir}lucene-index.cacm') def test_basic(self): self.assertTrue( self.searcher.get_similarity().toString().startswith('BM25')) hits = self.searcher.search('information retrieval') self.assertEqual(3204, self.searcher.num_docs) self.assertTrue(isinstance(hits, List)) self.assertTrue(isinstance(hits[0], JSimpleSearcherResult)) self.assertEqual(hits[0].docid, 'CACM-3134') self.assertEqual(hits[0].lucene_docid, 3133) self.assertEqual(len(hits[0].contents), 1500) self.assertEqual(len(hits[0].raw), 1532) self.assertAlmostEqual(hits[0].score, 4.76550, places=5) # Test accessing the raw Lucene document and fetching fields from it: self.assertEqual(hits[0].lucene_document.getField('id').stringValue(), 'CACM-3134') self.assertEqual(hits[0].lucene_document.get('id'), 'CACM-3134') # simpler call, same result as above self.assertEqual( len(hits[0].lucene_document.getField('raw').stringValue()), 1532) self.assertEqual(len(hits[0].lucene_document.get('raw')), 1532) # simpler call, same result as above self.assertTrue(isinstance(hits[9], JSimpleSearcherResult)) self.assertEqual(hits[9].docid, 'CACM-2516') self.assertAlmostEqual(hits[9].score, 4.21740, places=5) hits = self.searcher.search('search') self.assertTrue(isinstance(hits[0], JSimpleSearcherResult)) self.assertEqual(hits[0].docid, 'CACM-3058') self.assertAlmostEqual(hits[0].score, 2.85760, places=5) self.assertTrue(isinstance(hits[9], JSimpleSearcherResult)) self.assertEqual(hits[9].docid, 'CACM-3040') self.assertAlmostEqual(hits[9].score, 2.68780, places=5) def test_batch(self): results = self.searcher.batch_search( ['information retrieval', 'search'], ['q1', 'q2'], threads=2) self.assertEqual(3204, self.searcher.num_docs) self.assertTrue(isinstance(results, Dict)) self.assertTrue(isinstance(results['q1'], List)) self.assertTrue(isinstance(results['q1'][0], JSimpleSearcherResult)) self.assertEqual(results['q1'][0].docid, 'CACM-3134') self.assertAlmostEqual(results['q1'][0].score, 4.76550, places=5) self.assertTrue(isinstance(results['q1'][9], JSimpleSearcherResult)) self.assertEqual(results['q1'][9].docid, 'CACM-2516') self.assertAlmostEqual(results['q1'][9].score, 4.21740, places=5) self.assertTrue(isinstance(results['q2'], List)) self.assertTrue(isinstance(results['q2'][0], JSimpleSearcherResult)) self.assertEqual(results['q2'][0].docid, 'CACM-3058') self.assertAlmostEqual(results['q2'][0].score, 2.85760, places=5) self.assertTrue(isinstance(results['q2'][9], JSimpleSearcherResult)) self.assertEqual(results['q2'][9].docid, 'CACM-3040') self.assertAlmostEqual(results['q2'][9].score, 2.68780, places=5) def test_basic_k(self): hits = self.searcher.search('information retrieval', k=100) self.assertEqual(3204, self.searcher.num_docs) self.assertTrue(isinstance(hits, List)) self.assertTrue(isinstance(hits[0], JSimpleSearcherResult)) self.assertEqual(len(hits), 100) def test_batch_k(self): results = self.searcher.batch_search( ['information retrieval', 'search'], ['q1', 'q2'], k=100, threads=2) self.assertEqual(3204, self.searcher.num_docs) self.assertTrue(isinstance(results, Dict)) self.assertTrue(isinstance(results['q1'], List)) self.assertTrue(isinstance(results['q1'][0], JSimpleSearcherResult)) self.assertEqual(len(results['q1']), 100) self.assertTrue(isinstance(results['q2'], List)) self.assertTrue(isinstance(results['q2'][0], JSimpleSearcherResult)) self.assertEqual(len(results['q2']), 100) def test_basic_fields(self): # This test just provides a sanity check, it's not that interesting as it only searches one field. hits = self.searcher.search('information retrieval', k=42, fields={'contents': 2.0}) self.assertEqual(3204, self.searcher.num_docs) self.assertTrue(isinstance(hits, List)) self.assertTrue(isinstance(hits[0], JSimpleSearcherResult)) self.assertEqual(len(hits), 42) def test_batch_fields(self): # This test just provides a sanity check, it's not that interesting as it only searches one field. results = self.searcher.batch_search( ['information retrieval', 'search'], ['q1', 'q2'], k=42, threads=2, fields={'contents': 2.0}) self.assertEqual(3204, self.searcher.num_docs) self.assertTrue(isinstance(results, Dict)) self.assertTrue(isinstance(results['q1'], List)) self.assertTrue(isinstance(results['q1'][0], JSimpleSearcherResult)) self.assertEqual(len(results['q1']), 42) self.assertTrue(isinstance(results['q2'], List)) self.assertTrue(isinstance(results['q2'][0], JSimpleSearcherResult)) self.assertEqual(len(results['q2']), 42) def test_different_similarity(self): # qld, default mu self.searcher.set_qld() self.assertTrue(self.searcher.get_similarity().toString().startswith( 'LM Dirichlet')) hits = self.searcher.search('information retrieval') self.assertEqual(hits[0].docid, 'CACM-3134') self.assertAlmostEqual(hits[0].score, 3.68030, places=5) self.assertEqual(hits[9].docid, 'CACM-1927') self.assertAlmostEqual(hits[9].score, 2.53240, places=5) # bm25, default parameters self.searcher.set_bm25() self.assertTrue( self.searcher.get_similarity().toString().startswith('BM25')) hits = self.searcher.search('information retrieval') self.assertEqual(hits[0].docid, 'CACM-3134') self.assertAlmostEqual(hits[0].score, 4.76550, places=5) self.assertEqual(hits[9].docid, 'CACM-2516') self.assertAlmostEqual(hits[9].score, 4.21740, places=5) # qld, custom mu self.searcher.set_qld(100) self.assertTrue(self.searcher.get_similarity().toString().startswith( 'LM Dirichlet')) hits = self.searcher.search('information retrieval') self.assertEqual(hits[0].docid, 'CACM-3134') self.assertAlmostEqual(hits[0].score, 6.35580, places=5) self.assertEqual(hits[9].docid, 'CACM-2631') self.assertAlmostEqual(hits[9].score, 5.18960, places=5) # bm25, custom parameters self.searcher.set_bm25(0.8, 0.3) self.assertTrue( self.searcher.get_similarity().toString().startswith('BM25')) hits = self.searcher.search('information retrieval') self.assertEqual(hits[0].docid, 'CACM-3134') self.assertAlmostEqual(hits[0].score, 4.86880, places=5) self.assertEqual(hits[9].docid, 'CACM-2516') self.assertAlmostEqual(hits[9].score, 4.33320, places=5) def test_rm3(self): self.searcher.set_rm3() self.assertTrue(self.searcher.is_using_rm3()) hits = self.searcher.search('information retrieval') self.assertEqual(hits[0].docid, 'CACM-3134') self.assertAlmostEqual(hits[0].score, 2.18010, places=5) self.assertEqual(hits[9].docid, 'CACM-2516') self.assertAlmostEqual(hits[9].score, 1.70330, places=5) self.searcher.unset_rm3() self.assertFalse(self.searcher.is_using_rm3()) hits = self.searcher.search('information retrieval') self.assertEqual(hits[0].docid, 'CACM-3134') self.assertAlmostEqual(hits[0].score, 4.76550, places=5) self.assertEqual(hits[9].docid, 'CACM-2516') self.assertAlmostEqual(hits[9].score, 4.21740, places=5) self.searcher.set_rm3(fb_docs=4, fb_terms=6, original_query_weight=0.3) self.assertTrue(self.searcher.is_using_rm3()) hits = self.searcher.search('information retrieval') self.assertEqual(hits[0].docid, 'CACM-3134') self.assertAlmostEqual(hits[0].score, 2.17190, places=5) self.assertEqual(hits[9].docid, 'CACM-1457') self.assertAlmostEqual(hits[9].score, 1.43700, places=5) def test_doc_int(self): # The doc method is overloaded: if input is int, it's assumed to be a Lucene internal docid. doc = self.searcher.doc(1) self.assertTrue(isinstance(doc, Document)) # These are all equivalent ways to get the docid. self.assertEqual('CACM-0002', doc.id()) self.assertEqual('CACM-0002', doc.docid()) self.assertEqual('CACM-0002', doc.get('id')) self.assertEqual('CACM-0002', doc.lucene_document().getField('id').stringValue()) # These are all equivalent ways to get the 'raw' field self.assertEqual(186, len(doc.raw())) self.assertEqual(186, len(doc.get('raw'))) self.assertEqual(186, len(doc.lucene_document().get('raw'))) self.assertEqual( 186, len(doc.lucene_document().getField('raw').stringValue())) # These are all equivalent ways to get the 'contents' field self.assertEqual(154, len(doc.contents())) self.assertEqual(154, len(doc.get('contents'))) self.assertEqual(154, len(doc.lucene_document().get('contents'))) self.assertEqual( 154, len(doc.lucene_document().getField('contents').stringValue())) # Should return None if we request a docid that doesn't exist self.assertTrue(self.searcher.doc(314159) is None) def test_doc_str(self): # The doc method is overloaded: if input is str, it's assumed to be an external collection docid. doc = self.searcher.doc('CACM-0002') self.assertTrue(isinstance(doc, Document)) # These are all equivalent ways to get the docid. self.assertEqual(doc.lucene_document().getField('id').stringValue(), 'CACM-0002') self.assertEqual(doc.id(), 'CACM-0002') self.assertEqual(doc.docid(), 'CACM-0002') self.assertEqual(doc.get('id'), 'CACM-0002') # These are all equivalent ways to get the 'raw' field self.assertEqual(186, len(doc.raw())) self.assertEqual(186, len(doc.get('raw'))) self.assertEqual(186, len(doc.lucene_document().get('raw'))) self.assertEqual( 186, len(doc.lucene_document().getField('raw').stringValue())) # These are all equivalent ways to get the 'contents' field self.assertEqual(154, len(doc.contents())) self.assertEqual(154, len(doc.get('contents'))) self.assertEqual(154, len(doc.lucene_document().get('contents'))) self.assertEqual( 154, len(doc.lucene_document().getField('contents').stringValue())) # Should return None if we request a docid that doesn't exist self.assertTrue(self.searcher.doc('foo') is None) def test_doc_by_field(self): self.assertEqual( self.searcher.doc('CACM-3134').docid(), self.searcher.doc_by_field('id', 'CACM-3134').docid()) # Should return None if we request a docid that doesn't exist self.assertTrue(self.searcher.doc_by_field('foo', 'bar') is None) def tearDown(self): self.searcher.close() os.remove(self.tarball_name) shutil.rmtree(self.index_dir)
searcher = SimpleSearcher(args.index) else: searcher = SimpleSearcher.from_prebuilt_index(args.index) if not searcher: exit() retrieval = {} with open(args.input) as f_in: for line in tqdm(f_in.readlines()): question_id, _, doc_id, _, score, _ = line.strip().split() question_id = int(question_id) question = qas[question_id]['title'] answers = qas[question_id]['answers'] if answers[0] == '"': answers = answers[1:-1].replace('""', '"') answers = eval(answers) ctx = json.loads(searcher.doc(doc_id).raw())['contents'] if question_id not in retrieval: retrieval[question_id] = { 'question': question, 'answers': answers, 'contexts': [] } retrieval[question_id]['contexts'].append({ 'docid': doc_id, 'score': score, 'text': ctx }) json.dump(retrieval, open(args.output, 'w'), indent=4)
with open(args.t5_input, 'w') as fout_t5, open(args.t5_input_ids, 'w') as fout_tsv: for num_examples, (query_id, candidate_doc_ids) in enumerate( tqdm(run.items(), total=len(run))): query = queries[query_id] seen = {} for candidate_doc_id in candidate_doc_ids: if candidate_doc_id.split("#")[0] in seen: passage, ind_desc = seen[candidate_doc_id.split("#")[0]] candidate_doc_id = candidate_doc_id.split("#")[0] else: if args.year == 2020: try: candidate_doc_id, ind_desc = candidate_doc_id.split( "#") content = index.doc( f"<urn:uuid:{candidate_doc_id}>").contents() ind_desc = int(ind_desc) except: print(candidate_doc_id) content = "" ind_desc = 0 elif args.year == 2021: candidate_doc_id, ind_desc = candidate_doc_id.split("#") ind_desc = int(ind_desc) content = json.loads( index.doc(str(candidate_doc_id)).raw()) content = content['text'] else: try: candidate_doc_id, ind_desc = candidate_doc_id.split( "#")
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import argparse import json import sys # We're going to explicitly use a local installation of Pyserini (as opposed to a pip-installed one). # Comment these lines out to use a pip-installed one instead. sys.path.insert(0, './') from pyserini.search import SimpleSearcher if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--qrels', type=str, help='qrels file', required=True) parser.add_argument('--index', type=str, help='index location', required=True) args = parser.parse_args() searcher = SimpleSearcher(args.index) with open(args.qrels, 'r') as reader: for line in reader.readlines(): arr = line.split('\t') doc = json.loads(searcher.doc(arr[2]).raw())['contents'] print(f'{arr[2]}\t{doc}')
class CAsT_RawDataLoader(): def __init__( self, CAsT_index="/nfs/phd_by_carlos/notebooks/datasets/TREC_CAsT/CAsT_collection_with_meta.index", treccastweb_dir="/nfs/phd_by_carlos/notebooks/datasets/TREC_CAsT/treccastweb", NIST_qrels=[ "/nfs/phd_by_carlos/notebooks/datasets/TREC_CAsT/2019qrels.txt", '/nfs/phd_by_carlos/notebooks/datasets/TREC_CAsT/2020qrels.txt' ], **kwargs): self.searcher = SimpleSearcher(CAsT_index) self.q_rels = {} for q_rel_file in NIST_qrels: with open(q_rel_file) as NIST_fp: for line in NIST_fp.readlines(): q_id, _, d_id, score = line.split(" ") if int(score) < 3: # ignore some of the worst ranked continue if q_id not in self.q_rels: self.q_rels[q_id] = [] self.q_rels[q_id].append(d_id) with open( os.path.join( treccastweb_dir, "2020/2020_manual_evaluation_topics_v1.0.json")) as y2_fp: y2_data = json.load(y2_fp) for topic in y2_data: topic_id = topic["number"] for turn in topic["turn"]: turn_id = turn["number"] q_id = f"{topic_id}_{turn_id}" if q_id not in self.q_rels: self.q_rels[q_id] = [] self.q_rels[q_id].append( turn["manual_canonical_result_id"]) year1_query_collection, self.year1_topics = self.load_CAsT_topics_file( os.path.join(treccastweb_dir, "2019/data/evaluation/evaluation_topics_v1.0.json")) year2_query_collection, self.year2_topics = self.load_CAsT_topics_file( os.path.join(treccastweb_dir, "2020/2020_manual_evaluation_topics_v1.0.json")) self.query_collection = { **year1_query_collection, **year2_query_collection } with open( os.path.join( treccastweb_dir, "2019/data/evaluation/evaluation_topics_annotated_resolved_v1.0.tsv" )) as resolved_f: reader = csv.reader(resolved_f, delimiter="\t") for row in reader: q_id, resolved_query = row if q_id in self.query_collection: self.query_collection[q_id][ "manual_rewritten_utterance"] = resolved_query def NIST_result_curve(self, score): "0->0, 1~>0.1, 3~>0.6, 4->1" return (1 / 16) * (score**2) def load_CAsT_topics_file(self, file): query_collection = {} topics = {} with open(file) as topics_fp: topics_data = json.load(topics_fp) for topic in topics_data: previous_turns = [] topic_id = topic["number"] for turn in topic["turn"]: turn_id = turn["number"] q_id = f"{topic_id}_{turn_id}" # if q_id not in self.q_rels: # continue query_collection[q_id] = turn topics[q_id] = previous_turns[:] previous_turns.append(q_id) return query_collection, topics def get_doc(self, doc_id): raw_text = self.searcher.doc(doc_id).raw() paragraph = raw_text[raw_text.find('<BODY>\n') + 7:raw_text.find('\n</BODY>')] return paragraph def get_split(self, split): ''' return: dict: {q_id: [q_id,...]}: contains the query turns and the previous query turns for the topic ''' dev_topic_cutoff = 71 if split == "train": return { k: v for k, v in self.year1_topics.items() if int(k.split("_")[0]) < dev_topic_cutoff } elif split == "dev": return { k: v for k, v in self.year1_topics.items() if int(k.split("_")[0]) >= dev_topic_cutoff } elif split == "all": return self.year1_topics elif split == "eval": return self.year2_topics else: raise Exception(f"Split '{split}' not recognised") def get_topics(self, split, ignore_missing_q_rels=False): ''' split: str: "train", "dev", "all", "eval" returns: [dict]: [{'q_id':"32_4", 'q_rel':["CAR_xxx",..]}, 'prev_turns':["32_3",..],...] ''' topic_split = self.get_split(split) samples = [] samples = [{ 'prev_turns': prev_turns, 'q_id': q_id, 'q_rel': self.q_rels.get(q_id) } for q_id, prev_turns in topic_split.items() if q_id in self.q_rels or ignore_missing_q_rels] return samples def get_query(self, q_id, utterance_type="raw_utterance"): ''' >>> raw_data_loader.get_query("31_4", utterance_type="manual_rewritten_utterance") ''' return self.query_collection[q_id][utterance_type]