Пример #1
0
def init_query_encoder(encoder, topics_name, encoded_queries, device):
    encoded_queries_map = {
        'msmarco-passage-dev-subset': 'tct_colbert-msmarco-passage-dev-subset',
        'dpr-nq-dev': 'dpr_multi-nq-dev',
        'dpr-nq-test': 'dpr_multi-nq-test',
        'dpr-trivia-dev': 'dpr_multi-trivia-dev',
        'dpr-trivia-test': 'dpr_multi-trivia-test',
        'dpr-wq-test': 'dpr_multi-wq-test',
        'dpr-squad-test': 'dpr_multi-squad-test',
        'dpr-curated-test': 'dpr_multi-curated-test'
    }
    if encoder:
        if 'dpr' in encoder:
            return DprQueryEncoder(encoder_dir=encoder, device=device)
        elif 'tct_colbert' in encoder:
            return TctColBertQueryEncoder(encoder_dir=encoder, device=device)
        elif 'ance' in encoder:
            return AnceQueryEncoder(encoder_dir=encoder, device=device)
        elif 'sentence' in encoder:
            return AutoQueryEncoder(encoder_dir=encoder,
                                    device=device,
                                    pooling='mean',
                                    l2_norm=True)
        else:
            return AutoQueryEncoder(encoder_dir=encoder, device=device)
    if encoded_queries:
        if os.path.exists(encoded_queries):
            return QueryEncoder(encoded_queries)
        return QueryEncoder.load_encoded_queries(encoded_queries)
    if topics_name in encoded_queries_map:
        return QueryEncoder.load_encoded_queries(
            encoded_queries_map[topics_name])
    return None
Пример #2
0
def init_query_encoder(encoder, topics_name, encoded_queries, device):
    encoded_queries_map = {
        'msmarco-passage-dev-subset': 'msmarco-passage-dev-subset-tct_colbert',
        'dpr-nq-dev': 'dpr-nq-dev-multi',
        'dpr-nq-test': 'dpr-nq-test-multi',
        'dpr-trivia-dev': 'dpr-trivia-dev-multi',
        'dpr-trivia-test': 'dpr-trivia-test-multi',
        'dpr-wq-test': 'dpr-wq-test-multi',
        'dpr-squad-test': 'dpr-squad-test-multi',
        'dpr-curated-test': 'dpr-curated-test-multi'
    }
    if encoder:
        if 'dpr' in encoder:
            return DPRQueryEncoder(encoder_dir=encoder, device=device)
        elif 'tct_colbert' in encoder:
            return TCTColBERTQueryEncoder(encoder_dir=encoder, device=device)
        elif 'ance' in encoder:
            return AnceQueryEncoder(encoder_dir=encoder, device=device)
        elif 'sentence' in encoder:
            return SBERTQueryEncoder(encoder_dir=encoder, device=device)
        else:
            return AutoQueryEncoder(encoder_dir=encoder, device=device)
    if encoded_queries:
        if os.path.exists(encoded_queries):
            return QueryEncoder(encoded_queries)
        return QueryEncoder.load_encoded_queries(encoded_queries)
    if topics_name in encoded_queries_map:
        return QueryEncoder.load_encoded_queries(
            encoded_queries_map[topics_name])
    return None
Пример #3
0
def init_query_encoder(encoder, tokenizer_name, topics_name, encoded_queries,
                       device, prefix):
    encoded_queries_map = {
        'msmarco-passage-dev-subset': 'tct_colbert-msmarco-passage-dev-subset',
        'dpr-nq-dev': 'dpr_multi-nq-dev',
        'dpr-nq-test': 'dpr_multi-nq-test',
        'dpr-trivia-dev': 'dpr_multi-trivia-dev',
        'dpr-trivia-test': 'dpr_multi-trivia-test',
        'dpr-wq-test': 'dpr_multi-wq-test',
        'dpr-squad-test': 'dpr_multi-squad-test',
        'dpr-curated-test': 'dpr_multi-curated-test'
    }
    if encoder:
        if 'dkrr' in encoder:
            return DkrrDprQueryEncoder(encoder_dir=encoder,
                                       device=device,
                                       prefix=prefix)
        elif 'dpr' in encoder:
            return DprQueryEncoder(encoder_dir=encoder,
                                   tokenizer_name=tokenizer_name,
                                   device=device)
        elif 'bpr' in encoder:
            return BprQueryEncoder(encoder_dir=encoder,
                                   tokenizer_name=tokenizer_name,
                                   device=device)
        elif 'tct_colbert' in encoder:
            return TctColBertQueryEncoder(encoder_dir=encoder,
                                          tokenizer_name=tokenizer_name,
                                          device=device)
        elif 'ance' in encoder:
            return AnceQueryEncoder(encoder_dir=encoder,
                                    tokenizer_name=tokenizer_name,
                                    device=device)
        elif 'sentence' in encoder:
            return AutoQueryEncoder(encoder_dir=encoder,
                                    tokenizer_name=tokenizer_name,
                                    device=device,
                                    pooling='mean',
                                    l2_norm=True)
        else:
            return AutoQueryEncoder(encoder_dir=encoder,
                                    tokenizer_name=tokenizer_name,
                                    device=device)

    if encoded_queries:
        if os.path.exists(encoded_queries):
            if 'bpr' in encoded_queries:
                return BprQueryEncoder(encoded_query_dir=encoded_queries)
            else:
                return QueryEncoder(encoded_queries)
        return QueryEncoder.load_encoded_queries(encoded_queries)

    if topics_name in encoded_queries_map:
        return QueryEncoder.load_encoded_queries(
            encoded_queries_map[topics_name])
    raise ValueError(f'No encoded queries for topic {topics_name}')
Пример #4
0
    def do_model(self, arg):
        if arg == "tct":
            encoder = TctColBertQueryEncoder("castorini/tct_colbert-msmarco")
            index = "msmarco-passage-tct_colbert-hnsw"
        elif arg == "ance":
            encoder = AnceQueryEncoder("castorini/ance-msmarco-passage")
            index = "msmarco-passage-ance-bf"
        else:
            print(
                f'Model "{arg}" is invalid. Model should be one of [tct, ance].'
            )
            return

        self.dsearcher = SimpleDenseSearcher.from_prebuilt_index(
            index, encoder)
        self.hsearcher = HybridSearcher(self.dsearcher, self.ssearcher)
        print(f'setting model = {arg}')
Пример #5
0
def init_query_encoder(encoder, topics_name, device):
    encoded_queries = {
        'msmarco-passage-dev-subset': 'msmarco-passage-dev-subset-tct_colbert',
        'dpr-nq-dev': 'dpr-nq-dev-multi',
        'dpr-nq-test': 'dpr-nq-test-multi',
        'dpr-trivia-dev': 'dpr-trivia_qa-dev-multi',
        'dpr-trivia-test': 'dpr-trivia_qa-test-multi',
        'dpr-wq-test': 'dpr-wq-test-multi',
        'dpr-squad-test': 'dpr-squad-test-multi',
        'dpr-curated-test': 'dpr-curated_trec-test-multi'


    }
    if encoder:
        if 'dpr' in encoder:
            return DPRQueryEncoder(encoder_dir=encoder, device=device)
        elif 'tct_colbert' in encoder:
            return TCTColBERTQueryEncoder(encoder_dir=encoder, device=device)
        elif 'ance' in encoder:
            return AnceQueryEncoder(encoder_dir=encoder, device=device)
    if topics_name in encoded_queries:
        return QueryEncoder.load_encoded_queries(encoded_queries[topics_name])
    return None
Пример #6
0
    parser.add_argument('--output', type=str, help='dir to store query embeddings', required=True)
    parser.add_argument('--device', type=str,
                        help='device cpu or cuda [cuda:0, cuda:1...]', default='cpu', required=False)
    args = parser.parse_args()
    device = args.device
    topics = get_topics(args.topics)

    if not os.path.exists(args.output):
        os.mkdir(args.output)

    if 'dpr' in args.encoder:
        encoder = DprQueryEncoder(encoder_dir=args.encoder, device=device)
    elif 'tct_colbert' in args.encoder:
        encoder = TctColBertQueryEncoder(encoder_dir=args.encoder, device=device)
    elif 'ance' in args.encoder:
        encoder = AnceQueryEncoder(encoder_dir=args.encoder, device=device)
    elif 'sentence' in args.encoder:
        encoder = AutoQueryEncoder(encoder_dir=args.encoder, device=device, pooling='mean', l2_norm=True)
    else:
        encoder = AutoQueryEncoder(encoder_dir=args.encoder, device=device)

    embeddings = {'id': [], 'text': [], 'embedding': []}
    for key in tqdm(topics):
        qid = str(key)
        text = topics[key]['title']
        embeddings['id'].append(qid)
        embeddings['text'].append(text)
        embeddings['embedding'].append(encoder.encode(text.strip()))
    embeddings = pd.DataFrame(embeddings)
    embeddings.to_pickle(os.path.join(args.output, 'embedding.pkl'))
Пример #7
0
        exit()

    # Check PRF Flag
    if args.prf_depth > 0 and type(searcher) == SimpleDenseSearcher:
        PRF_FLAG = True
        if args.prf_method.lower() == 'avg':
            prfRule = DenseVectorAveragePrf()
        elif args.prf_method.lower() == 'rocchio':
            prfRule = DenseVectorRocchioPrf(args.rocchio_alpha, args.rocchio_beta)
        # ANCE-PRF is using a new query encoder, so the input to DenseVectorAncePrf is different
        elif args.prf_method.lower() == 'ance-prf' and type(query_encoder) == AnceQueryEncoder:
            if os.path.exists(args.sparse_index):
                sparse_searcher = SimpleSearcher(args.sparse_index)
            else:
                sparse_searcher = SimpleSearcher.from_prebuilt_index(args.sparse_index)
            prf_query_encoder = AnceQueryEncoder(encoder_dir=args.ance_prf_encoder, tokenizer_name=args.tokenizer,
                                                 device=args.device)
            prfRule = DenseVectorAncePrf(prf_query_encoder, sparse_searcher)
        print(f'Running SimpleDenseSearcher with {args.prf_method.upper()} PRF...')
    else:
        PRF_FLAG = False

    # build output path
    output_path = args.output

    print(f'Running {args.topics} topics, saving to {output_path}...')
    tag = 'Faiss'

    output_writer = get_output_writer(output_path, OutputFormat(args.output_format), 'w',
                                      max_hits=args.hits, tag=tag, topics=topics,
                                      use_max_passage=args.max_passage,
                                      max_passage_delimiter=args.max_passage_delimiter,
Пример #8
0
    parser.add_argument('--encoder',
                        type=str,
                        help='encoder name or path',
                        required=True)
    parser.add_argument('--input',
                        type=str,
                        help='query file to be encoded.',
                        required=True)
    parser.add_argument('--output',
                        type=str,
                        help='path to store query embeddings',
                        required=True)
    parser.add_argument('--device',
                        type=str,
                        help='device cpu or cuda [cuda:0, cuda:1...]',
                        default='cpu',
                        required=False)
    args = parser.parse_args()

    encoder = AnceQueryEncoder(args.encoder, device=args.device)
    embeddings = {'id': [], 'text': [], 'embedding': []}
    for line in tqdm(open(args.input, 'r').readlines()):
        qid, text = line.rstrip().split('\t')
        qid = qid.strip()
        text = text.strip()
        embeddings['id'].append(qid)
        embeddings['text'].append(text)
        embeddings['embedding'].append(encoder.encode(text))
    embeddings = pd.DataFrame(embeddings)
    embeddings.to_pickle(args.output)