示例#1
0
def init_query_encoder(encoder, topics_name, encoded_queries, device):
    encoded_queries_map = {
        'msmarco-passage-dev-subset': 'msmarco-passage-dev-subset-tct_colbert',
        'dpr-nq-dev': 'dpr-nq-dev-multi',
        'dpr-nq-test': 'dpr-nq-test-multi',
        'dpr-trivia-dev': 'dpr-trivia-dev-multi',
        'dpr-trivia-test': 'dpr-trivia-test-multi',
        'dpr-wq-test': 'dpr-wq-test-multi',
        'dpr-squad-test': 'dpr-squad-test-multi',
        'dpr-curated-test': 'dpr-curated-test-multi'
    }
    if encoder:
        if 'dpr' in encoder:
            return DPRQueryEncoder(encoder_dir=encoder, device=device)
        elif 'tct_colbert' in encoder:
            return TCTColBERTQueryEncoder(encoder_dir=encoder, device=device)
        elif 'ance' in encoder:
            return AnceQueryEncoder(encoder_dir=encoder, device=device)
        elif 'sentence' in encoder:
            return SBERTQueryEncoder(encoder_dir=encoder, device=device)
        else:
            return AutoQueryEncoder(encoder_dir=encoder, device=device)
    if encoded_queries:
        if os.path.exists(encoded_queries):
            return QueryEncoder(encoded_queries)
        return QueryEncoder.load_encoded_queries(encoded_queries)
    if topics_name in encoded_queries_map:
        return QueryEncoder.load_encoded_queries(
            encoded_queries_map[topics_name])
    return None
示例#2
0
def init_query_encoder(encoder, topics_name, device):
    encoded_queries = {
        'msmarco_passage_dev_subset': 'msmarco-passage-dev-subset-tct_colbert'
    }
    if encoder:
        if 'dpr' in encoder:
            return DPRQueryEncoder(encoder_dir=encoder, device=device)
        elif 'tct_colbert' in encoder:
            return TCTColBERTQueryEncoder(encoder_dir=encoder, device=device)
    if topics_name in encoded_queries:
        return QueryEncoder.load_encoded_queries(encoded_queries[topics_name])
    return None
示例#3
0
def init_query_encoder(encoder, topics_name, device):
    encoded_queries = {
        'msmarco_passage_dev_subset': 'msmarco-passage-dev-subset-tct_colbert',
        'dpr_nq_dev': 'dpr-nq-dev-multi',
        'dpr_nq_test': 'dpr-nq-test-multi',
        'dpr_trivia_dev': 'dpr-trivia_qa-dev-multi',
        'dpr_trivia_test': 'dpr-trivia_qa-test-multi',
        'dpr_wq_test': 'dpr-wq-test-multi',
        'dpr_squad_test': 'dpr-squad-test-multi',
        'dpr_curated_test': 'dpr-curated_trec-test-multi'
    }
    if encoder:
        if 'dpr' in encoder:
            return DPRQueryEncoder(encoder_dir=encoder, device=device)
        elif 'tct_colbert' in encoder:
            return TCTColBERTQueryEncoder(encoder_dir=encoder, device=device)
    if topics_name in encoded_queries:
        return QueryEncoder.load_encoded_queries(encoded_queries[topics_name])
    return None