def pair_iter( dataset, fields: set, pos_source: str = 'intersect', # one of ['qrels', 'intersect'] neg_source: str = 'run', # one of ['run', 'qrels', 'union'] sampling: str = 'query', # one of ['query', 'qrel'] pos_minrel: int = 1, unjudged_rel: int = 0, num_neg: int = 1, random=None, inf: bool = False): qrels_fn = util.Lazy(lambda: dataset.qrels(fmt='df')) run_fn = util.Lazy(lambda: dataset.run(fmt='df')) pos_candidates = { 'qrels': pair_iter_pos_candidates_qrels, 'intersect': pair_iter_pos_candidates_intersect, }[pos_source](dataset, qrels_fn, run_fn, pos_minrel) assert len(pos_candidates.index) > 0 neg_candidates = { 'run': pair_iter_neg_candidates_run, 'qrels': pair_iter_neg_candidates_qrels, 'union': pair_iter_neg_candidates_union, }[neg_source](dataset, qrels_fn, run_fn, unjudged_rel) neg_candidates = neg_candidates.set_index('qid') neg_candidates.sort_index(inplace=True) assert len(neg_candidates.index) > 0 pos_iter = { 'query': pair_iter_sample_by_query, 'qrel': pair_iter_sample_by_qrel }[sampling](dataset, pos_candidates, random, inf) for qid, pos_did, score in pos_iter: negs = pair_iter_filter_neg(dataset, neg_candidates, qid, pos_did, score) if len(negs.index) < num_neg: # not enough negative documents for this positive sample dataset.logger.debug( f'not enough negs for qid {qid} neg_candidates: {neg_candidates}' ) continue dids = [pos_did] negs = negs['did'].sample(n=num_neg, random_state=random) for did in negs: dids.append(did) result = {f: [] for f in fields} for did in dids: record = dataset.build_record(fields, query_id=qid, doc_id=did) for f in fields: result[f].append(record[f]) yield result
def record_iter( dataset, fields: set, source: str, # one of ['run', 'qrels'] minrel: int = None, # integer indicating the minimum relevance score, or None for unfiltered run_threshold: int = 0, # integer representing cutoff rank threshold (if > 0) shuf: bool = True, random=None, inf: bool = False): run_fn = util.Lazy(lambda: dataset.run('df')) qrels_fn = util.Lazy(lambda: dataset.qrels('df')) if source == 'run': src = run_fn() if run_threshold > 0: # cut off the run by rank (by query) src = src.sort_values(['qid', 'score'], ascending=False) \ .groupby('qid') \ .head(run_threshold) \ .reset_index(drop=True) if minrel is not None: src = src[src['score'] >= minrel] src = pd.merge(src, qrels_fn(), on=['qid', 'did'], suffixes=('_run', '_qrels')) src = src[src['score_qrels'] >= minrel] elif source == 'qrels': src = qrels_fn() if minrel is not None: src = src[src['score'] >= minrel] else: raise ValueError(f'unsupported source {source}') src = src.filter(items=['qid', 'did']) it = record_iter_sample(dataset, src, shuf, random, inf) for qid, did in it: record = dataset.build_record(fields, query_id=qid, doc_id=did) yield {f: record[f] for f in fields}
def init(self, force=False): idxs = [self.index_stem, self.doc_store] self._init_indices_parallel(idxs, self._init_iter_collection(), force) base_path = util.path_dataset(self) needs_queries = [] if force or not os.path.exists( os.path.join(base_path, 'train.queries.tsv')): needs_queries.append(lambda it: plaintext.write_tsv( os.path.join(base_path, 'train.queries.tsv'), ((qid, txt) for file, qid, txt in it if file == 'queries.train.tsv' and qid not in MINI_DEV))) if force or not os.path.exists( os.path.join(base_path, 'minidev.queries.tsv')): needs_queries.append(lambda it: plaintext.write_tsv( os.path.join(base_path, 'minidev.queries.tsv'), ((qid, txt) for file, qid, txt in it if file == 'queries.train.tsv' and qid in MINI_DEV))) if force or not os.path.exists( os.path.join(base_path, 'dev.queries.tsv')): needs_queries.append(lambda it: plaintext.write_tsv( os.path.join(base_path, 'dev.queries.tsv'), ((qid, txt) for file, qid, txt in it if file == 'queries.dev.tsv'))) if force or not os.path.exists( os.path.join(base_path, 'eval.queries.tsv')): needs_queries.append(lambda it: plaintext.write_tsv( os.path.join(base_path, 'eval.queries.tsv'), ((qid, txt) for file, qid, txt in it if file == 'queries.eval.tsv'))) if needs_queries and self._confirm_dua(): with util.download_tmp(_SOURCES['queries']) as f, \ tarfile.open(fileobj=f) as tarf, \ contextlib.ExitStack() as ctxt: def _extr_subf(subf): for qid, txt in plaintext.read_tsv( io.TextIOWrapper(tarf.extractfile(subf))): yield subf, qid, txt query_iter = [ _extr_subf('queries.train.tsv'), _extr_subf('queries.dev.tsv'), _extr_subf('queries.eval.tsv') ] query_iter = tqdm(itertools.chain(*query_iter), desc='queries') query_iters = util.blocking_tee(query_iter, len(needs_queries)) for fn, it in zip(needs_queries, query_iters): ctxt.enter_context( util.CtxtThread(functools.partial(fn, it))) file = os.path.join(base_path, 'train.qrels') if (force or not os.path.exists(file)) and self._confirm_dua(): stream = util.download_stream(_SOURCES['train-qrels'], 'utf8') with util.finialized_file(file, 'wt') as out: for qid, _, did, score in plaintext.read_tsv(stream): if qid not in MINI_DEV: trec.write_qrels(out, [(qid, did, score)]) file = os.path.join(base_path, 'minidev.qrels') if (force or not os.path.exists(file)) and self._confirm_dua(): stream = util.download_stream(_SOURCES['train-qrels'], 'utf8') with util.finialized_file(file, 'wt') as out: for qid, _, did, score in plaintext.read_tsv(stream): if qid in MINI_DEV: trec.write_qrels(out, [(qid, did, score)]) file = os.path.join(base_path, 'dev.qrels') if (force or not os.path.exists(file)) and self._confirm_dua(): stream = util.download_stream(_SOURCES['dev-qrels'], 'utf8') with util.finialized_file(file, 'wt') as out: for qid, _, did, score in plaintext.read_tsv(stream): trec.write_qrels(out, [(qid, did, score)]) file = os.path.join(base_path, 'train.mspairs.gz') if not os.path.exists(file) and os.path.exists( os.path.join(base_path, 'qidpidtriples.train.full')): # legacy os.rename(os.path.join(base_path, 'qidpidtriples.train.full'), file) if (force or not os.path.exists(file)) and self._confirm_dua(): util.download(_SOURCES['qidpidtriples.train.full'], file) if not self.config['init_skip_msrun']: for file_name, subf in [('dev.msrun', 'top1000.dev'), ('eval.msrun', 'top1000.eval'), ('train.msrun', 'top1000.train.txt')]: file = os.path.join(base_path, file_name) if (force or not os.path.exists(file)) and self._confirm_dua(): run = {} with util.download_tmp(_SOURCES[file_name]) as f, \ tarfile.open(fileobj=f) as tarf: for qid, did, _, _ in tqdm( plaintext.read_tsv( io.TextIOWrapper(tarf.extractfile(subf)))): if qid not in run: run[qid] = {} run[qid][did] = 0. if file_name == 'train.msrun': minidev = { qid: dids for qid, dids in run.items() if qid in MINI_DEV } with self.logger.duration('writing minidev.msrun'): trec.write_run_dict( os.path.join(base_path, 'minidev.msrun'), minidev) run = { qid: dids for qid, dids in run.items() if qid not in MINI_DEV } with self.logger.duration(f'writing {file_name}'): trec.write_run_dict(file, run) query_path = os.path.join(base_path, 'trec2019.queries.tsv') if (force or not os.path.exists(query_path)) and self._confirm_dua(): stream = util.download_stream(_SOURCES['trec2019.queries'], 'utf8') plaintext.write_tsv(query_path, plaintext.read_tsv(stream)) msrun_path = os.path.join(base_path, 'trec2019.msrun') if (force or not os.path.exists(msrun_path)) and self._confirm_dua(): run = {} with util.download_stream(_SOURCES['trec2019.msrun'], 'utf8') as stream: for qid, did, _, _ in plaintext.read_tsv(stream): if qid not in run: run[qid] = {} run[qid][did] = 0. with util.finialized_file(msrun_path, 'wt') as f: trec.write_run_dict(f, run) qrels_path = os.path.join(base_path, 'trec2019.qrels') if not os.path.exists(qrels_path) and self._confirm_dua(): util.download(_SOURCES['trec2019.qrels'], qrels_path) qrels_path = os.path.join(base_path, 'judgedtrec2019.qrels') if not os.path.exists(qrels_path): os.symlink('trec2019.qrels', qrels_path) query_path = os.path.join(base_path, 'judgedtrec2019.queries.tsv') judged_qids = util.Lazy( lambda: trec.read_qrels_dict(qrels_path).keys()) if (force or not os.path.exists(query_path)): with util.finialized_file(query_path, 'wt') as f: for qid, qtext in plaintext.read_tsv( os.path.join(base_path, 'trec2019.queries.tsv')): if qid in judged_qids(): plaintext.write_tsv(f, [(qid, qtext)]) msrun_path = os.path.join(base_path, 'judgedtrec2019.msrun') if (force or not os.path.exists(msrun_path)) and self._confirm_dua(): with util.finialized_file(msrun_path, 'wt') as f: for qid, dids in trec.read_run_dict( os.path.join(base_path, 'trec2019.msrun')).items(): if qid in judged_qids(): trec.write_run_dict(f, {qid: dids}) # A subset of dev that only contains queries that have relevance judgments judgeddev_path = os.path.join(base_path, 'judgeddev') judged_qids = util.Lazy(lambda: trec.read_qrels_dict( os.path.join(base_path, 'dev.qrels')).keys()) if not os.path.exists(f'{judgeddev_path}.qrels'): os.symlink('dev.qrels', f'{judgeddev_path}.qrels') if not os.path.exists(f'{judgeddev_path}.queries.tsv'): with util.finialized_file(f'{judgeddev_path}.queries.tsv', 'wt') as f: for qid, qtext in plaintext.read_tsv( os.path.join(base_path, 'dev.queries.tsv')): if qid in judged_qids(): plaintext.write_tsv(f, [(qid, qtext)]) if self.config['init_skip_msrun']: if not os.path.exists(f'{judgeddev_path}.msrun'): with util.finialized_file(f'{judgeddev_path}.msrun', 'wt') as f: for qid, dids in trec.read_run_dict( os.path.join(base_path, 'dev.msrun')).items(): if qid in judged_qids(): trec.write_run_dict(f, {qid: dids}) if not self.config['init_skip_train10']: file = os.path.join(base_path, 'train10.queries.tsv') if not os.path.exists(file): with util.finialized_file(file, 'wt') as fout: for qid, qtext in self.logger.pbar( plaintext.read_tsv( os.path.join(base_path, 'train.queries.tsv')), desc='filtering queries for train10'): if int(qid) % 10 == 0: plaintext.write_tsv(fout, [(qid, qtext)]) file = os.path.join(base_path, 'train10.qrels') if not os.path.exists(file): with util.finialized_file(file, 'wt') as fout, open( os.path.join(base_path, 'train.qrels'), 'rt') as fin: for line in self.logger.pbar( fin, desc='filtering qrels for train10'): qid = line.split()[0] if int(qid) % 10 == 0: fout.write(line) if not self.config['init_skip_msrun']: file = os.path.join(base_path, 'train10.msrun') if not os.path.exists(file): with util.finialized_file(file, 'wt') as fout, open( os.path.join(base_path, 'train.msrun'), 'rt') as fin: for line in self.logger.pbar( fin, desc='filtering msrun for train10'): qid = line.split()[0] if int(qid) % 10 == 0: fout.write(line) file = os.path.join(base_path, 'train10.mspairs.gz') if not os.path.exists(file): with gzip.open(file, 'wt') as fout, gzip.open( os.path.join(base_path, 'train.mspairs.gz'), 'rt') as fin: for qid, did1, did2 in self.logger.pbar( plaintext.read_tsv(fin), desc='filtering mspairs for train10'): if int(qid) % 10 == 0: plaintext.write_tsv(fout, [(qid, did1, did2)]) if not self.config['init_skip_train_med']: med_qids = util.Lazy( lambda: { qid.strip() for qid in util.download_stream( 'https://raw.githubusercontent.com/Georgetown-IR-Lab/covid-neural-ir/master/med-msmarco-train.txt', 'utf8', expected_md5="dc5199de7d4a872c361f89f08b1163ef") }) file = os.path.join(base_path, 'train_med.queries.tsv') if not os.path.exists(file): with util.finialized_file(file, 'wt') as fout: for qid, qtext in self.logger.pbar( plaintext.read_tsv( os.path.join(base_path, 'train.queries.tsv')), desc='filtering queries for train_med'): if qid in med_qids(): plaintext.write_tsv(fout, [(qid, qtext)]) file = os.path.join(base_path, 'train_med.qrels') if not os.path.exists(file): with util.finialized_file(file, 'wt') as fout, open( os.path.join(base_path, 'train.qrels'), 'rt') as fin: for line in self.logger.pbar( fin, desc='filtering qrels for train_med'): qid = line.split()[0] if qid in med_qids(): fout.write(line) if not self.config['init_skip_msrun']: file = os.path.join(base_path, 'train_med.msrun') if not os.path.exists(file): with util.finialized_file(file, 'wt') as fout, open( os.path.join(base_path, 'train.msrun'), 'rt') as fin: for line in self.logger.pbar( fin, desc='filtering msrun for train_med'): qid = line.split()[0] if qid in med_qids(): fout.write(line) file = os.path.join(base_path, 'train_med.mspairs.gz') if not os.path.exists(file): with gzip.open(file, 'wt') as fout, gzip.open( os.path.join(base_path, 'train.mspairs.gz'), 'rt') as fin: for qid, did1, did2 in self.logger.pbar( plaintext.read_tsv(fin), desc='filtering mspairs for train_med'): if qid in med_qids(): plaintext.write_tsv(fout, [(qid, did1, did2)])