Beispiel #1
0
    def _load_model(self):
        self.colbert, self.checkpoint = load_colbert(
            self.args, do_print=(self.process_idx == 0))
        self.colbert = self.colbert.cuda()
        self.colbert.eval()

        self.inference = ModelInference(self.colbert, amp=self.args.amp)
Beispiel #2
0
def retrieve(args):
    inference = ModelInference(args.colbert, amp=args.amp)
    ranker = Ranker(args, inference, faiss_depth=args.faiss_depth)

    ranking_logger = RankingLogger(Run.path, qrels=None)
    milliseconds = 0

    with ranking_logger.context('ranking.tsv',
                                also_save_annotations=False) as rlogger:
        queries = args.queries
        qids_in_order = list(queries.keys())

        for qoffset, qbatch in batch(qids_in_order, 100, provide_offset=True):
            qbatch_text = [queries[qid] for qid in qbatch]

            rankings = []

            for query_idx, q in enumerate(qbatch_text):
                torch.cuda.synchronize('cuda:0')
                s = time.time()

                Q = ranker.encode([q])
                pids, scores = ranker.rank(Q)

                torch.cuda.synchronize()
                milliseconds += (time.time() - s) * 1000.0

                if len(pids):
                    print(qoffset + query_idx, q, len(scores), len(pids),
                          scores[0], pids[0],
                          milliseconds / (qoffset + query_idx + 1), 'ms')

                rankings.append(zip(pids, scores))

            for query_idx, (qid, ranking) in enumerate(zip(qbatch, rankings)):
                query_idx = qoffset + query_idx

                if query_idx % 100 == 0:
                    print_message(
                        f"#> Logging query #{query_idx} (qid {qid}) now...")

                ranking = [
                    (score, pid, None)
                    for pid, score in itertools.islice(ranking, args.depth)
                ]
                rlogger.log(qid, ranking, is_ranked=True)

    print('\n\n')
    print(ranking_logger.filename)
    print("#> Done.")
    print('\n\n')
Beispiel #3
0
def batch_retrieve(args):
    assert args.retrieve_only, "TODO: Combine batch (multi-query) retrieval with batch re-ranking"

    faiss_index = FaissIndex(args.index_path, args.faiss_index_path,
                             args.nprobe, args.part_range)
    inference = ModelInference(args.colbert, amp=args.amp)

    ranking_logger = RankingLogger(Run.path, qrels=None)

    with ranking_logger.context('unordered.tsv',
                                also_save_annotations=False) as rlogger:
        queries = args.queries
        qids_in_order = list(queries.keys())

        for qoffset, qbatch in batch(qids_in_order,
                                     100_000,
                                     provide_offset=True):
Beispiel #4
0
def evaluate(args):
    args.inference = ModelInference(args.colbert, amp=args.amp)
    qrels, queries, topK_pids = args.qrels, args.queries, args.topK_pids

    depth = args.depth
    collection = args.collection
    if collection is None:
        topK_docs = args.topK_docs

    def qid2passages(qid):
        if collection is not None:
            return [collection[pid] for pid in topK_pids[qid][:depth]]
        else:
            return topK_docs[qid][:depth]

    metrics = Metrics(mrr_depths={10, 100}, recall_depths={50, 200, 1000},
                      success_depths={5, 10, 20, 50, 100, 1000},
                      total_queries=len(queries))

    ranking_logger = RankingLogger(Run.path, qrels=qrels)

    args.milliseconds = []

    with ranking_logger.context('ranking.tsv', also_save_annotations=(qrels is not None)) as rlogger:
        with torch.no_grad():
            keys = sorted(list(queries.keys()))
            random.shuffle(keys)

            for query_idx, qid in enumerate(keys):
                query = queries[qid]

                print_message(query_idx, qid, query, '\n')

                if qrels and args.shortcircuit and len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) == 0:
                    continue

                ranking = slow_rerank(args, query, topK_pids[qid], qid2passages(qid))

                rlogger.log(qid, ranking, [0, 1])

                if qrels:
                    metrics.add(query_idx, qid, ranking, qrels[qid])

                    for i, (score, pid, passage) in enumerate(ranking):
                        if pid in qrels[qid]:
                            print("\n#> Found", pid, "at position", i+1, "with score", score)
                            print(passage)
                            break

                    metrics.print_metrics(query_idx)
                    metrics.log(query_idx)

                print_message("#> checkpoint['batch'] =", args.checkpoint['batch'], '\n')
                print("rlogger.filename =", rlogger.filename)

                if len(args.milliseconds) > 1:
                    print('Slow-Ranking Avg Latency =', sum(args.milliseconds[1:]) / len(args.milliseconds[1:]))

                print("\n\n")

        print("\n\n")
        try:
            print('Avg Latency =', sum(args.milliseconds[1:]) / len(args.milliseconds[1:]))
        except:
            pass
        print("\n\n")

    print('\n\n')
    if qrels:
        assert query_idx + 1 == len(keys) == len(set(keys))
        metrics.output_final_metrics(os.path.join(Run.path, 'ranking.metrics'), query_idx, len(queries))
    print('\n\n')
Beispiel #5
0
    def __init__(self,
                 colbert_model: Union[str,
                                      Tuple[colbert.modeling.colbert.ColBERT,
                                            dict]],
                 index_root: str,
                 index_name: str,
                 faiss_partitions=100,
                 memtype="mem",
                 gpu=True):

        args = Object()
        args.query_maxlen = 32
        args.doc_maxlen = 180
        args.dim = 128
        args.bsize = 128
        args.similarity = 'cosine'
        args.dim = 128
        args.amp = True
        args.nprobe = 10
        args.part_range = None
        args.mask_punctuation = False
        args.partitions = faiss_partitions

        self.index_root = index_root
        self.index_name = index_name
        if index_root is None or index_name is None:
            warn(
                "No index_root and index_name specified - no index ranking possible"
            )
        else:
            self.index_path = os.path.join(index_root, index_name)
            docnos_file = os.path.join(self.index_path, "docnos.pkl.gz")
            if os.path.exists(docnos_file):
                with pt.io.autoopen(docnos_file, "rb") as f:
                    self.docid2docno = pickle.load(f)
                    # support reverse docno lookup in memory
                    self.docno2docid = {
                        docno: docid
                        for docid, docno in enumerate(self.docid2docno)
                    }
                    self.docid_as_docno = False
            else:
                self.docid_as_docno = True

        try:
            import faiss
        except:
            warn("Faiss not installed. You cannot do retrieval")

        if not gpu:
            warn("Gpu disabled, YMMV")
            import colbert.parameters
            colbert.parameters.DEVICE = torch.device("cpu")
        if isinstance(colbert_model, str):
            args.checkpoint = colbert_model
            args.colbert, args.checkpoint = load_model(args)
        else:
            assert isinstance(colbert_model, tuple)
            args.colbert, args.checkpoint = colbert_model
            from colbert.modeling.colbert import ColBERT
            assert isinstance(args.colbert, ColBERT)
            assert isinstance(args.checkpoint, dict)

        args.inference = ModelInference(args.colbert, amp=args.amp)
        self.args = args

        self.memtype = memtype

        #we load this lazily
        self.rrm = None
        self.faiss_index = None
Beispiel #6
0
def batch_rerank(args):
    positions, loaded_parts, thread = prepare_ranges(args.index_path, args.dim,
                                                     args.step,
                                                     args.part_range)

    inference = ModelInference(args.colbert, amp=args.amp)
    queries, topK_pids = args.queries, args.topK_pids

    with torch.no_grad():
        queries_in_order = list(queries.values())

        print_message(
            f"#> Encoding all {len(queries_in_order)} queries in batches...")

        all_query_embeddings = inference.queryFromText(queries_in_order,
                                                       bsize=512,
                                                       to_cpu=True)
        all_query_embeddings = all_query_embeddings.to(
            dtype=torch.float16).permute(0, 2, 1).contiguous()

    for qid in queries:
        """
        Since topK_pids is a defaultdict, make sure each qid *has* actual PID information (even if empty).
        """
        assert qid in topK_pids, qid

    all_pids = flatten([[(query_index, pid) for pid in topK_pids[qid]]
                        for query_index, qid in enumerate(queries)])
    all_query_rankings = [defaultdict(list), defaultdict(list)]

    print_message(
        f"#> Will process {len(all_pids)} query--document pairs in total.")

    with torch.no_grad():
        score_by_range(positions, loaded_parts, all_query_embeddings,
                       all_query_rankings, all_pids)

    ranking_logger = RankingLogger(Run.path,
                                   qrels=None,
                                   log_scores=args.log_scores)

    with ranking_logger.context('ranking.tsv',
                                also_save_annotations=False) as rlogger:
        with torch.no_grad():
            for query_index, qid in enumerate(queries):
                if query_index % 1000 == 0:
                    print_message(
                        "#> Logging query #{} (qid {}) now...".format(
                            query_index, qid))

                pids = all_query_rankings[0][query_index]
                scores = all_query_rankings[1][query_index]

                K = min(MAX_DEPTH_LOGGED, len(scores))

                if K == 0:
                    continue

                scores_topk = torch.tensor(scores).topk(K,
                                                        largest=True,
                                                        sorted=True)

                pids, scores = torch.tensor(pids)[
                    scores_topk.indices].tolist(), scores_topk.values.tolist()

                ranking = [(score, pid, None)
                           for pid, score in zip(pids, scores)]
                assert len(ranking) <= MAX_DEPTH_LOGGED, (len(ranking),
                                                          MAX_DEPTH_LOGGED)

                rlogger.log(qid,
                            ranking,
                            is_ranked=True,
                            print_positions=[1, 2] if query_index %
                            100 == 0 else [])

    print('\n\n')
    print(ranking_logger.filename)
    print_message('#> Done.\n')

    thread.join()
Beispiel #7
0
class CollectionEncoder():
    def __init__(self, args, process_idx, num_processes):
        self.args = args
        self.collection = args.collection
        self.process_idx = process_idx
        self.num_processes = num_processes
        self.iterator = self._initialize_iterator()

        assert 0.5 <= args.chunksize <= 128.0
        max_bytes_per_file = args.chunksize * (1024 * 1024 * 1024)

        max_bytes_per_doc = (self.args.doc_maxlen * self.args.dim * 2.0)

        minimum_subset_size = 10_000
        maximum_subset_size = max_bytes_per_file / max_bytes_per_doc
        maximum_subset_size = max(minimum_subset_size, maximum_subset_size)
        self.possible_subset_sizes = [int(maximum_subset_size)]

        self.print_main("#> Local args.bsize =", args.bsize)
        self.print_main("#> args.index_root =", args.index_root)
        self.print_main(
            f"#> self.possible_subset_sizes = {self.possible_subset_sizes}")

        self._load_model()
        self.indexmgr = IndexManager(args.dim)

    def _initialize_iterator(self):
        return open(self.collection)

    def _saver_thread(self):
        for args in iter(self.saver_queue.get, None):
            self._save_batch(*args)

    def _load_model(self):
        self.colbert, self.checkpoint = load_colbert(
            self.args, do_print=(self.process_idx == 0))
        self.colbert = self.colbert.cuda()
        self.colbert.eval()

        self.inference = ModelInference(self.colbert, amp=self.args.amp)

    def encode(self):
        self.saver_queue = queue.Queue(maxsize=3)
        thread = threading.Thread(target=self._saver_thread)
        thread.start()

        t0 = time.time()
        local_docs_processed = 0

        for batch_idx, (offset, lines, owner) in enumerate(
                self._batch_passages(self.iterator)):
            if owner != self.process_idx:
                continue

            t1 = time.time()
            batch = self._preprocess_batch(offset, lines)
            embs, doclens = self._encode_batch(batch_idx, batch)

            t2 = time.time()
            self.saver_queue.put((batch_idx, embs, offset, doclens))

            print(len(lines))

            t3 = time.time()
            local_docs_processed += len(lines)
            overall_throughput = compute_throughput(local_docs_processed, t0,
                                                    t3)
            this_encoding_throughput = compute_throughput(len(lines), t1, t2)
            this_saving_throughput = compute_throughput(len(lines), t2, t3)

            self.print(
                f'#> Completed batch #{batch_idx} (starting at passage #{offset}) \t\t'
                f'Passages/min: {overall_throughput} (overall), ',
                f'{this_encoding_throughput} (this encoding), ',
                f'{this_saving_throughput} (this saving)')
        self.saver_queue.put(None)

        self.print("#> Joining saver thread.")
        thread.join()

    def _batch_passages(self, fi):
        """
        Must use the same seed across processes!
        """
        np.random.seed(0)

        offset = 0
        for owner in itertools.cycle(range(self.num_processes)):
            batch_size = np.random.choice(self.possible_subset_sizes)

            L = [line for _, line in zip(range(batch_size), fi)]

            if len(L) == 0:
                break  # EOF

            yield (offset, L, owner)
            offset += len(L)

            if len(L) < batch_size:
                break  # EOF

        self.print("[NOTE] Done with local share.")

        return

    def _preprocess_batch(self, offset, lines):
        endpos = offset + len(lines)

        batch = []

        for line_idx, line in zip(range(offset, endpos), lines):
            line_parts = line.strip().split('\t')

            pid, passage, *other = line_parts

            assert len(passage) >= 1

            if len(other) >= 1:
                title, *_ = other
                passage = title + ' | ' + passage

            batch.append(passage)

            assert pid == 'id' or int(pid) == line_idx

        return batch

    def _encode_batch(self, batch_idx, batch):
        with torch.no_grad():
            embs = self.inference.docFromText(batch,
                                              bsize=self.args.bsize,
                                              keep_dims=False)
            assert type(embs) is list
            assert len(embs) == len(batch)

            local_doclens = [d.size(0) for d in embs]
            embs = torch.cat(embs)

        return embs, local_doclens

    def _save_batch(self, batch_idx, embs, offset, doclens):
        start_time = time.time()

        output_path = os.path.join(self.args.index_path,
                                   "{}.pt".format(batch_idx))
        output_sample_path = os.path.join(self.args.index_path,
                                          "{}.sample".format(batch_idx))
        doclens_path = os.path.join(self.args.index_path,
                                    'doclens.{}.json'.format(batch_idx))

        # Save the embeddings.
        print(output_path)
        self.indexmgr.save(embs, output_path)
        self.indexmgr.save(
            embs[torch.randint(0,
                               high=embs.size(0),
                               size=(embs.size(0) // 20, ))],
            output_sample_path)

        # Save the doclens.
        with open(doclens_path, 'w') as output_doclens:
            ujson.dump(doclens, output_doclens)

        throughput = compute_throughput(len(doclens), start_time, time.time())
        self.print_main(
            "#> Saved batch #{} to {} \t\t".format(batch_idx, output_path),
            "Saving Throughput =", throughput, "passages per minute.\n")

    def print(self, *args):
        print_message("[" + str(self.process_idx) + "]", "\t\t", *args)

    def print_main(self, *args):
        if self.process_idx == 0:
            self.print(*args)