Ejemplo n.º 1
0
    def _load_model(self):
        self.colbert, self.checkpoint = load_colbert(
            self.args, do_print=(self.process_idx == 0))
        self.colbert = self.colbert.cuda()
        self.colbert.eval()

        self.inference = ModelInference(self.colbert, amp=self.args.amp)
Ejemplo n.º 2
0
def main():
    random.seed(12345)

    parser = Arguments(description='Exhaustive (slow, not index-based) evaluation of re-ranking with ColBERT.')

    parser.add_model_parameters()
    parser.add_model_inference_parameters()
    parser.add_reranking_input()

    parser.add_argument('--depth', dest='depth', required=False, default=None, type=int)

    args = parser.parse()

    with Run.context():
        args.colbert, args.checkpoint = load_colbert(args)
        args.qrels = load_qrels(args.qrels)

        if args.collection or args.queries:
            assert args.collection and args.queries

            args.queries = load_queries(args.queries)
            args.collection = load_collection(args.collection)
            args.topK_pids, args.qrels = load_topK_pids(args.topK, args.qrels)

        else:
            args.queries, args.topK_docs, args.topK_pids = load_topK(args.topK)

        assert (not args.shortcircuit) or args.qrels, \
            "Short-circuiting (i.e., applying minimal computation to queries with no positives in the re-ranked set) " \
            "can only be applied if qrels is provided."

        evaluate_recall(args.qrels, args.queries, args.topK_pids)
        evaluate(args)
Ejemplo n.º 3
0
def main():
    random.seed(12345)

    parser = Arguments(
        description='End-to-end retrieval and ranking with ColBERT.')

    parser.add_model_parameters()
    parser.add_model_inference_parameters()
    parser.add_ranking_input()
    parser.add_retrieval_input()

    parser.add_argument('--faiss_name',
                        dest='faiss_name',
                        default=None,
                        type=str)
    parser.add_argument('--faiss_depth',
                        dest='faiss_depth',
                        default=1024,
                        type=int)
    parser.add_argument('--part-range',
                        dest='part_range',
                        default=None,
                        type=str)
    parser.add_argument('--batch',
                        dest='batch',
                        default=False,
                        action='store_true')
    parser.add_argument('--depth', dest='depth', default=1000, type=int)

    args = parser.parse()

    args.depth = args.depth if args.depth > 0 else None

    if args.part_range:
        part_offset, part_endpos = map(int, args.part_range.split('..'))
        args.part_range = range(part_offset, part_endpos)

    with Run.context():
        args.colbert, args.checkpoint = load_colbert(args)
        args.qrels = load_qrels(args.qrels)
        args.queries = load_queries(args.queries)

        args.index_path = os.path.join(args.index_root, args.index_name)

        if args.faiss_name is not None:
            args.faiss_index_path = os.path.join(args.index_path,
                                                 args.faiss_name)
        else:
            args.faiss_index_path = os.path.join(args.index_path,
                                                 get_faiss_index_name(args))

        if args.batch:
            batch_retrieve(args)
        else:
            retrieve(args)
Ejemplo n.º 4
0
def main():
    random.seed(12345)

    parser = Arguments(description='Re-ranking over a ColBERT index')

    parser.add_model_parameters()
    parser.add_model_inference_parameters()
    parser.add_reranking_input()
    parser.add_index_use_input()

    parser.add_argument('--step', dest='step', default=1, type=int)
    parser.add_argument('--part-range',
                        dest='part_range',
                        default=None,
                        type=str)
    parser.add_argument('--log-scores',
                        dest='log_scores',
                        default=False,
                        action='store_true')
    parser.add_argument('--batch',
                        dest='batch',
                        default=False,
                        action='store_true')
    parser.add_argument('--depth', dest='depth', default=1000, type=int)

    args = parser.parse()

    if args.part_range:
        part_offset, part_endpos = map(int, args.part_range.split('..'))
        args.part_range = range(part_offset, part_endpos)

    with Run.context():
        args.colbert, args.checkpoint = load_colbert(args)

        args.queries = load_queries(args.queries)
        args.qrels = load_qrels(args.qrels)
        args.topK_pids, args.qrels = load_topK_pids(args.topK,
                                                    qrels=args.qrels)

        args.index_path = os.path.join(args.index_root, args.index_name)

        if args.batch:
            batch_rerank(args)
        else:
            rerank(args)