def _load_model(self): self.colbert, self.checkpoint = load_colbert( self.args, do_print=(self.process_idx == 0)) self.colbert = self.colbert.cuda() self.colbert.eval() self.inference = ModelInference(self.colbert, amp=self.args.amp)
def main(): random.seed(12345) parser = Arguments(description='Exhaustive (slow, not index-based) evaluation of re-ranking with ColBERT.') parser.add_model_parameters() parser.add_model_inference_parameters() parser.add_reranking_input() parser.add_argument('--depth', dest='depth', required=False, default=None, type=int) args = parser.parse() with Run.context(): args.colbert, args.checkpoint = load_colbert(args) args.qrels = load_qrels(args.qrels) if args.collection or args.queries: assert args.collection and args.queries args.queries = load_queries(args.queries) args.collection = load_collection(args.collection) args.topK_pids, args.qrels = load_topK_pids(args.topK, args.qrels) else: args.queries, args.topK_docs, args.topK_pids = load_topK(args.topK) assert (not args.shortcircuit) or args.qrels, \ "Short-circuiting (i.e., applying minimal computation to queries with no positives in the re-ranked set) " \ "can only be applied if qrels is provided." evaluate_recall(args.qrels, args.queries, args.topK_pids) evaluate(args)
def main(): random.seed(12345) parser = Arguments( description='End-to-end retrieval and ranking with ColBERT.') parser.add_model_parameters() parser.add_model_inference_parameters() parser.add_ranking_input() parser.add_retrieval_input() parser.add_argument('--faiss_name', dest='faiss_name', default=None, type=str) parser.add_argument('--faiss_depth', dest='faiss_depth', default=1024, type=int) parser.add_argument('--part-range', dest='part_range', default=None, type=str) parser.add_argument('--batch', dest='batch', default=False, action='store_true') parser.add_argument('--depth', dest='depth', default=1000, type=int) args = parser.parse() args.depth = args.depth if args.depth > 0 else None if args.part_range: part_offset, part_endpos = map(int, args.part_range.split('..')) args.part_range = range(part_offset, part_endpos) with Run.context(): args.colbert, args.checkpoint = load_colbert(args) args.qrels = load_qrels(args.qrels) args.queries = load_queries(args.queries) args.index_path = os.path.join(args.index_root, args.index_name) if args.faiss_name is not None: args.faiss_index_path = os.path.join(args.index_path, args.faiss_name) else: args.faiss_index_path = os.path.join(args.index_path, get_faiss_index_name(args)) if args.batch: batch_retrieve(args) else: retrieve(args)
def main(): random.seed(12345) parser = Arguments(description='Re-ranking over a ColBERT index') parser.add_model_parameters() parser.add_model_inference_parameters() parser.add_reranking_input() parser.add_index_use_input() parser.add_argument('--step', dest='step', default=1, type=int) parser.add_argument('--part-range', dest='part_range', default=None, type=str) parser.add_argument('--log-scores', dest='log_scores', default=False, action='store_true') parser.add_argument('--batch', dest='batch', default=False, action='store_true') parser.add_argument('--depth', dest='depth', default=1000, type=int) args = parser.parse() if args.part_range: part_offset, part_endpos = map(int, args.part_range.split('..')) args.part_range = range(part_offset, part_endpos) with Run.context(): args.colbert, args.checkpoint = load_colbert(args) args.queries = load_queries(args.queries) args.qrels = load_qrels(args.qrels) args.topK_pids, args.qrels = load_topK_pids(args.topK, qrels=args.qrels) args.index_path = os.path.join(args.index_root, args.index_name) if args.batch: batch_rerank(args) else: rerank(args)