def _load_model(self): self.colbert, self.checkpoint = load_colbert( self.args, do_print=(self.process_idx == 0)) self.colbert = self.colbert.cuda() self.colbert.eval() self.inference = ModelInference(self.colbert, amp=self.args.amp)
def retrieve(args): inference = ModelInference(args.colbert, amp=args.amp) ranker = Ranker(args, inference, faiss_depth=args.faiss_depth) ranking_logger = RankingLogger(Run.path, qrels=None) milliseconds = 0 with ranking_logger.context('ranking.tsv', also_save_annotations=False) as rlogger: queries = args.queries qids_in_order = list(queries.keys()) for qoffset, qbatch in batch(qids_in_order, 100, provide_offset=True): qbatch_text = [queries[qid] for qid in qbatch] rankings = [] for query_idx, q in enumerate(qbatch_text): torch.cuda.synchronize('cuda:0') s = time.time() Q = ranker.encode([q]) pids, scores = ranker.rank(Q) torch.cuda.synchronize() milliseconds += (time.time() - s) * 1000.0 if len(pids): print(qoffset + query_idx, q, len(scores), len(pids), scores[0], pids[0], milliseconds / (qoffset + query_idx + 1), 'ms') rankings.append(zip(pids, scores)) for query_idx, (qid, ranking) in enumerate(zip(qbatch, rankings)): query_idx = qoffset + query_idx if query_idx % 100 == 0: print_message( f"#> Logging query #{query_idx} (qid {qid}) now...") ranking = [ (score, pid, None) for pid, score in itertools.islice(ranking, args.depth) ] rlogger.log(qid, ranking, is_ranked=True) print('\n\n') print(ranking_logger.filename) print("#> Done.") print('\n\n')
def batch_retrieve(args): assert args.retrieve_only, "TODO: Combine batch (multi-query) retrieval with batch re-ranking" faiss_index = FaissIndex(args.index_path, args.faiss_index_path, args.nprobe, args.part_range) inference = ModelInference(args.colbert, amp=args.amp) ranking_logger = RankingLogger(Run.path, qrels=None) with ranking_logger.context('unordered.tsv', also_save_annotations=False) as rlogger: queries = args.queries qids_in_order = list(queries.keys()) for qoffset, qbatch in batch(qids_in_order, 100_000, provide_offset=True):
def evaluate(args): args.inference = ModelInference(args.colbert, amp=args.amp) qrels, queries, topK_pids = args.qrels, args.queries, args.topK_pids depth = args.depth collection = args.collection if collection is None: topK_docs = args.topK_docs def qid2passages(qid): if collection is not None: return [collection[pid] for pid in topK_pids[qid][:depth]] else: return topK_docs[qid][:depth] metrics = Metrics(mrr_depths={10, 100}, recall_depths={50, 200, 1000}, success_depths={5, 10, 20, 50, 100, 1000}, total_queries=len(queries)) ranking_logger = RankingLogger(Run.path, qrels=qrels) args.milliseconds = [] with ranking_logger.context('ranking.tsv', also_save_annotations=(qrels is not None)) as rlogger: with torch.no_grad(): keys = sorted(list(queries.keys())) random.shuffle(keys) for query_idx, qid in enumerate(keys): query = queries[qid] print_message(query_idx, qid, query, '\n') if qrels and args.shortcircuit and len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) == 0: continue ranking = slow_rerank(args, query, topK_pids[qid], qid2passages(qid)) rlogger.log(qid, ranking, [0, 1]) if qrels: metrics.add(query_idx, qid, ranking, qrels[qid]) for i, (score, pid, passage) in enumerate(ranking): if pid in qrels[qid]: print("\n#> Found", pid, "at position", i+1, "with score", score) print(passage) break metrics.print_metrics(query_idx) metrics.log(query_idx) print_message("#> checkpoint['batch'] =", args.checkpoint['batch'], '\n') print("rlogger.filename =", rlogger.filename) if len(args.milliseconds) > 1: print('Slow-Ranking Avg Latency =', sum(args.milliseconds[1:]) / len(args.milliseconds[1:])) print("\n\n") print("\n\n") try: print('Avg Latency =', sum(args.milliseconds[1:]) / len(args.milliseconds[1:])) except: pass print("\n\n") print('\n\n') if qrels: assert query_idx + 1 == len(keys) == len(set(keys)) metrics.output_final_metrics(os.path.join(Run.path, 'ranking.metrics'), query_idx, len(queries)) print('\n\n')
def __init__(self, colbert_model: Union[str, Tuple[colbert.modeling.colbert.ColBERT, dict]], index_root: str, index_name: str, faiss_partitions=100, memtype="mem", gpu=True): args = Object() args.query_maxlen = 32 args.doc_maxlen = 180 args.dim = 128 args.bsize = 128 args.similarity = 'cosine' args.dim = 128 args.amp = True args.nprobe = 10 args.part_range = None args.mask_punctuation = False args.partitions = faiss_partitions self.index_root = index_root self.index_name = index_name if index_root is None or index_name is None: warn( "No index_root and index_name specified - no index ranking possible" ) else: self.index_path = os.path.join(index_root, index_name) docnos_file = os.path.join(self.index_path, "docnos.pkl.gz") if os.path.exists(docnos_file): with pt.io.autoopen(docnos_file, "rb") as f: self.docid2docno = pickle.load(f) # support reverse docno lookup in memory self.docno2docid = { docno: docid for docid, docno in enumerate(self.docid2docno) } self.docid_as_docno = False else: self.docid_as_docno = True try: import faiss except: warn("Faiss not installed. You cannot do retrieval") if not gpu: warn("Gpu disabled, YMMV") import colbert.parameters colbert.parameters.DEVICE = torch.device("cpu") if isinstance(colbert_model, str): args.checkpoint = colbert_model args.colbert, args.checkpoint = load_model(args) else: assert isinstance(colbert_model, tuple) args.colbert, args.checkpoint = colbert_model from colbert.modeling.colbert import ColBERT assert isinstance(args.colbert, ColBERT) assert isinstance(args.checkpoint, dict) args.inference = ModelInference(args.colbert, amp=args.amp) self.args = args self.memtype = memtype #we load this lazily self.rrm = None self.faiss_index = None
def batch_rerank(args): positions, loaded_parts, thread = prepare_ranges(args.index_path, args.dim, args.step, args.part_range) inference = ModelInference(args.colbert, amp=args.amp) queries, topK_pids = args.queries, args.topK_pids with torch.no_grad(): queries_in_order = list(queries.values()) print_message( f"#> Encoding all {len(queries_in_order)} queries in batches...") all_query_embeddings = inference.queryFromText(queries_in_order, bsize=512, to_cpu=True) all_query_embeddings = all_query_embeddings.to( dtype=torch.float16).permute(0, 2, 1).contiguous() for qid in queries: """ Since topK_pids is a defaultdict, make sure each qid *has* actual PID information (even if empty). """ assert qid in topK_pids, qid all_pids = flatten([[(query_index, pid) for pid in topK_pids[qid]] for query_index, qid in enumerate(queries)]) all_query_rankings = [defaultdict(list), defaultdict(list)] print_message( f"#> Will process {len(all_pids)} query--document pairs in total.") with torch.no_grad(): score_by_range(positions, loaded_parts, all_query_embeddings, all_query_rankings, all_pids) ranking_logger = RankingLogger(Run.path, qrels=None, log_scores=args.log_scores) with ranking_logger.context('ranking.tsv', also_save_annotations=False) as rlogger: with torch.no_grad(): for query_index, qid in enumerate(queries): if query_index % 1000 == 0: print_message( "#> Logging query #{} (qid {}) now...".format( query_index, qid)) pids = all_query_rankings[0][query_index] scores = all_query_rankings[1][query_index] K = min(MAX_DEPTH_LOGGED, len(scores)) if K == 0: continue scores_topk = torch.tensor(scores).topk(K, largest=True, sorted=True) pids, scores = torch.tensor(pids)[ scores_topk.indices].tolist(), scores_topk.values.tolist() ranking = [(score, pid, None) for pid, score in zip(pids, scores)] assert len(ranking) <= MAX_DEPTH_LOGGED, (len(ranking), MAX_DEPTH_LOGGED) rlogger.log(qid, ranking, is_ranked=True, print_positions=[1, 2] if query_index % 100 == 0 else []) print('\n\n') print(ranking_logger.filename) print_message('#> Done.\n') thread.join()
class CollectionEncoder(): def __init__(self, args, process_idx, num_processes): self.args = args self.collection = args.collection self.process_idx = process_idx self.num_processes = num_processes self.iterator = self._initialize_iterator() assert 0.5 <= args.chunksize <= 128.0 max_bytes_per_file = args.chunksize * (1024 * 1024 * 1024) max_bytes_per_doc = (self.args.doc_maxlen * self.args.dim * 2.0) minimum_subset_size = 10_000 maximum_subset_size = max_bytes_per_file / max_bytes_per_doc maximum_subset_size = max(minimum_subset_size, maximum_subset_size) self.possible_subset_sizes = [int(maximum_subset_size)] self.print_main("#> Local args.bsize =", args.bsize) self.print_main("#> args.index_root =", args.index_root) self.print_main( f"#> self.possible_subset_sizes = {self.possible_subset_sizes}") self._load_model() self.indexmgr = IndexManager(args.dim) def _initialize_iterator(self): return open(self.collection) def _saver_thread(self): for args in iter(self.saver_queue.get, None): self._save_batch(*args) def _load_model(self): self.colbert, self.checkpoint = load_colbert( self.args, do_print=(self.process_idx == 0)) self.colbert = self.colbert.cuda() self.colbert.eval() self.inference = ModelInference(self.colbert, amp=self.args.amp) def encode(self): self.saver_queue = queue.Queue(maxsize=3) thread = threading.Thread(target=self._saver_thread) thread.start() t0 = time.time() local_docs_processed = 0 for batch_idx, (offset, lines, owner) in enumerate( self._batch_passages(self.iterator)): if owner != self.process_idx: continue t1 = time.time() batch = self._preprocess_batch(offset, lines) embs, doclens = self._encode_batch(batch_idx, batch) t2 = time.time() self.saver_queue.put((batch_idx, embs, offset, doclens)) print(len(lines)) t3 = time.time() local_docs_processed += len(lines) overall_throughput = compute_throughput(local_docs_processed, t0, t3) this_encoding_throughput = compute_throughput(len(lines), t1, t2) this_saving_throughput = compute_throughput(len(lines), t2, t3) self.print( f'#> Completed batch #{batch_idx} (starting at passage #{offset}) \t\t' f'Passages/min: {overall_throughput} (overall), ', f'{this_encoding_throughput} (this encoding), ', f'{this_saving_throughput} (this saving)') self.saver_queue.put(None) self.print("#> Joining saver thread.") thread.join() def _batch_passages(self, fi): """ Must use the same seed across processes! """ np.random.seed(0) offset = 0 for owner in itertools.cycle(range(self.num_processes)): batch_size = np.random.choice(self.possible_subset_sizes) L = [line for _, line in zip(range(batch_size), fi)] if len(L) == 0: break # EOF yield (offset, L, owner) offset += len(L) if len(L) < batch_size: break # EOF self.print("[NOTE] Done with local share.") return def _preprocess_batch(self, offset, lines): endpos = offset + len(lines) batch = [] for line_idx, line in zip(range(offset, endpos), lines): line_parts = line.strip().split('\t') pid, passage, *other = line_parts assert len(passage) >= 1 if len(other) >= 1: title, *_ = other passage = title + ' | ' + passage batch.append(passage) assert pid == 'id' or int(pid) == line_idx return batch def _encode_batch(self, batch_idx, batch): with torch.no_grad(): embs = self.inference.docFromText(batch, bsize=self.args.bsize, keep_dims=False) assert type(embs) is list assert len(embs) == len(batch) local_doclens = [d.size(0) for d in embs] embs = torch.cat(embs) return embs, local_doclens def _save_batch(self, batch_idx, embs, offset, doclens): start_time = time.time() output_path = os.path.join(self.args.index_path, "{}.pt".format(batch_idx)) output_sample_path = os.path.join(self.args.index_path, "{}.sample".format(batch_idx)) doclens_path = os.path.join(self.args.index_path, 'doclens.{}.json'.format(batch_idx)) # Save the embeddings. print(output_path) self.indexmgr.save(embs, output_path) self.indexmgr.save( embs[torch.randint(0, high=embs.size(0), size=(embs.size(0) // 20, ))], output_sample_path) # Save the doclens. with open(doclens_path, 'w') as output_doclens: ujson.dump(doclens, output_doclens) throughput = compute_throughput(len(doclens), start_time, time.time()) self.print_main( "#> Saved batch #{} to {} \t\t".format(batch_idx, output_path), "Saving Throughput =", throughput, "passages per minute.\n") def print(self, *args): print_message("[" + str(self.process_idx) + "]", "\t\t", *args) def print_main(self, *args): if self.process_idx == 0: self.print(*args)