Example #1
0
def init_query_encoder(encoder, topics_name, encoded_queries, device):
    encoded_queries_map = {
        'msmarco-passage-dev-subset': 'tct_colbert-msmarco-passage-dev-subset',
        'dpr-nq-dev': 'dpr_multi-nq-dev',
        'dpr-nq-test': 'dpr_multi-nq-test',
        'dpr-trivia-dev': 'dpr_multi-trivia-dev',
        'dpr-trivia-test': 'dpr_multi-trivia-test',
        'dpr-wq-test': 'dpr_multi-wq-test',
        'dpr-squad-test': 'dpr_multi-squad-test',
        'dpr-curated-test': 'dpr_multi-curated-test'
    }
    if encoder:
        if 'dpr' in encoder:
            return DprQueryEncoder(encoder_dir=encoder, device=device)
        elif 'tct_colbert' in encoder:
            return TctColBertQueryEncoder(encoder_dir=encoder, device=device)
        elif 'ance' in encoder:
            return AnceQueryEncoder(encoder_dir=encoder, device=device)
        elif 'sentence' in encoder:
            return AutoQueryEncoder(encoder_dir=encoder,
                                    device=device,
                                    pooling='mean',
                                    l2_norm=True)
        else:
            return AutoQueryEncoder(encoder_dir=encoder, device=device)
    if encoded_queries:
        if os.path.exists(encoded_queries):
            return QueryEncoder(encoded_queries)
        return QueryEncoder.load_encoded_queries(encoded_queries)
    if topics_name in encoded_queries_map:
        return QueryEncoder.load_encoded_queries(
            encoded_queries_map[topics_name])
    return None
Example #2
0
def init_query_encoder(encoder, topics_name, encoded_queries, device):
    encoded_queries_map = {
        'msmarco-passage-dev-subset': 'msmarco-passage-dev-subset-tct_colbert',
        'dpr-nq-dev': 'dpr-nq-dev-multi',
        'dpr-nq-test': 'dpr-nq-test-multi',
        'dpr-trivia-dev': 'dpr-trivia-dev-multi',
        'dpr-trivia-test': 'dpr-trivia-test-multi',
        'dpr-wq-test': 'dpr-wq-test-multi',
        'dpr-squad-test': 'dpr-squad-test-multi',
        'dpr-curated-test': 'dpr-curated-test-multi'
    }
    if encoder:
        if 'dpr' in encoder:
            return DPRQueryEncoder(encoder_dir=encoder, device=device)
        elif 'tct_colbert' in encoder:
            return TCTColBERTQueryEncoder(encoder_dir=encoder, device=device)
        elif 'ance' in encoder:
            return AnceQueryEncoder(encoder_dir=encoder, device=device)
        elif 'sentence' in encoder:
            return SBERTQueryEncoder(encoder_dir=encoder, device=device)
        else:
            return AutoQueryEncoder(encoder_dir=encoder, device=device)
    if encoded_queries:
        if os.path.exists(encoded_queries):
            return QueryEncoder(encoded_queries)
        return QueryEncoder.load_encoded_queries(encoded_queries)
    if topics_name in encoded_queries_map:
        return QueryEncoder.load_encoded_queries(
            encoded_queries_map[topics_name])
    return None
Example #3
0
def init_query_encoder(encoder, tokenizer_name, topics_name, encoded_queries,
                       device, prefix):
    encoded_queries_map = {
        'msmarco-passage-dev-subset': 'tct_colbert-msmarco-passage-dev-subset',
        'dpr-nq-dev': 'dpr_multi-nq-dev',
        'dpr-nq-test': 'dpr_multi-nq-test',
        'dpr-trivia-dev': 'dpr_multi-trivia-dev',
        'dpr-trivia-test': 'dpr_multi-trivia-test',
        'dpr-wq-test': 'dpr_multi-wq-test',
        'dpr-squad-test': 'dpr_multi-squad-test',
        'dpr-curated-test': 'dpr_multi-curated-test'
    }
    if encoder:
        if 'dkrr' in encoder:
            return DkrrDprQueryEncoder(encoder_dir=encoder,
                                       device=device,
                                       prefix=prefix)
        elif 'dpr' in encoder:
            return DprQueryEncoder(encoder_dir=encoder,
                                   tokenizer_name=tokenizer_name,
                                   device=device)
        elif 'bpr' in encoder:
            return BprQueryEncoder(encoder_dir=encoder,
                                   tokenizer_name=tokenizer_name,
                                   device=device)
        elif 'tct_colbert' in encoder:
            return TctColBertQueryEncoder(encoder_dir=encoder,
                                          tokenizer_name=tokenizer_name,
                                          device=device)
        elif 'ance' in encoder:
            return AnceQueryEncoder(encoder_dir=encoder,
                                    tokenizer_name=tokenizer_name,
                                    device=device)
        elif 'sentence' in encoder:
            return AutoQueryEncoder(encoder_dir=encoder,
                                    tokenizer_name=tokenizer_name,
                                    device=device,
                                    pooling='mean',
                                    l2_norm=True)
        else:
            return AutoQueryEncoder(encoder_dir=encoder,
                                    tokenizer_name=tokenizer_name,
                                    device=device)

    if encoded_queries:
        if os.path.exists(encoded_queries):
            if 'bpr' in encoded_queries:
                return BprQueryEncoder(encoded_query_dir=encoded_queries)
            else:
                return QueryEncoder(encoded_queries)
        return QueryEncoder.load_encoded_queries(encoded_queries)

    if topics_name in encoded_queries_map:
        return QueryEncoder.load_encoded_queries(
            encoded_queries_map[topics_name])
    raise ValueError(f'No encoded queries for topic {topics_name}')
Example #4
0
def init_query_encoder(encoder, topics_name, device):
    encoded_queries = {
        'msmarco_passage_dev_subset': 'msmarco-passage-dev-subset-tct_colbert'
    }
    if encoder:
        if 'dpr' in encoder:
            return DPRQueryEncoder(encoder_dir=encoder, device=device)
        elif 'tct_colbert' in encoder:
            return TCTColBERTQueryEncoder(encoder_dir=encoder, device=device)
    if topics_name in encoded_queries:
        return QueryEncoder.load_encoded_queries(encoded_queries[topics_name])
    return None
def check_dense(index):
    # dummy queries; there is no explicit validation...
    # we just try to initialize the and make sure there are no exceptions
    dummy_queries = QueryEncoder.load_encoded_queries(
        'tct_colbert-msmarco-passage-dev-subset')
    print('\n')
    for entry in index:
        print(f'# Validating "{entry}"...')
        if "bpr" in entry:
            BinaryDenseSearcher.from_prebuilt_index(entry, dummy_queries)
        else:
            SimpleDenseSearcher.from_prebuilt_index(entry, dummy_queries)
        print('\n')
Example #6
0
def init_query_encoder(encoder, topics_name, device):
    encoded_queries = {
        'msmarco_passage_dev_subset': 'msmarco-passage-dev-subset-tct_colbert',
        'dpr_nq_dev': 'dpr-nq-dev-multi',
        'dpr_nq_test': 'dpr-nq-test-multi',
        'dpr_trivia_dev': 'dpr-trivia_qa-dev-multi',
        'dpr_trivia_test': 'dpr-trivia_qa-test-multi',
        'dpr_wq_test': 'dpr-wq-test-multi',
        'dpr_squad_test': 'dpr-squad-test-multi',
        'dpr_curated_test': 'dpr-curated_trec-test-multi'
    }
    if encoder:
        if 'dpr' in encoder:
            return DPRQueryEncoder(encoder_dir=encoder, device=device)
        elif 'tct_colbert' in encoder:
            return TCTColBERTQueryEncoder(encoder_dir=encoder, device=device)
    if topics_name in encoded_queries:
        return QueryEncoder.load_encoded_queries(encoded_queries[topics_name])
    return None
Example #7
0
 def test_msmarco_passage_sbert_encoded_queries(self):
     encoder = QueryEncoder.load_encoded_queries(
         'sbert-msmarco-passage-dev-subset')
     topics = get_topics('msmarco-passage-dev-subset')
     for t in topics:
         self.assertTrue(topics[t]['title'] in encoder.embedding)
Example #8
0
 def test_msmarco_doc_ance_bf_encoded_queries(self):
     encoder = QueryEncoder.load_encoded_queries(
         'ance_maxp-msmarco-doc-dev')
     topics = get_topics('maxp-msmarco-doc-dev')
     for t in topics:
         self.assertTrue(topics[t]['title'] in encoder.embedding)
Example #9
0
 def test_trivia_test_ance_encoded_queries(self):
     encoder = QueryEncoder.load_encoded_queries('dpr_multi-trivia-test')
     topics = get_topics('dpr-trivia-test')
     for t in topics:
         self.assertTrue(topics[t]['title'] in encoder.embedding)
Example #10
0
 def test_ance_multi_nq_dev(self):
     encoder = QueryEncoder.load_encoded_queries('ance_multi-nq-dev')
     topics = get_topics('dpr-nq-dev')
     for t in topics:
         self.assertTrue(topics[t]['title'] in encoder.embedding)
Example #11
0
 def test_tct_colbert_msmarco_doc_dev(self):
     encoder = QueryEncoder.load_encoded_queries(
         'tct_colbert-msmarco-doc-dev')
     topics = get_topics('msmarco-doc-dev')
     for t in topics:
         self.assertTrue(topics[t]['title'] in encoder.embedding)
Example #12
0
 def test_dpr_single_nq_test(self):
     encoder = QueryEncoder.load_encoded_queries('dpr_single_nq-nq-test')
     topics = get_topics('dpr-nq-test')
     for t in topics:
         self.assertTrue(topics[t]['title'] in encoder.embedding)
Example #13
0
 def test_dpr_multi_curated_test(self):
     encoder = QueryEncoder.load_encoded_queries('dpr_multi-curated-test')
     topics = get_topics('dpr-curated-test')
     for t in topics:
         self.assertTrue(topics[t]['title'] in encoder.embedding)