def construct_t5(options: PassageRankingEvaluationOptions) -> Reranker:
    model = MonoT5.get_model(options.model,
                             from_tf=options.from_tf,
                             device=options.device)
    tokenizer = MonoT5.get_tokenizer(options.model_type,
                                     batch_size=options.batch_size)
    return MonoT5(model, tokenizer)
def construct_t5(options: KaggleEvaluationOptions) -> Reranker:
    loader = CachedT5ModelLoader(SETTINGS.t5_model_dir, SETTINGS.cache_dir,
                                 'ranker', SETTINGS.t5_model_type,
                                 SETTINGS.flush_cache)
    device = torch.device(options.device)
    model = loader.load().to(device).eval()
    tokenizer = MonoT5.get_tokenizer(options.model_type,
                                     do_lower_case=options.do_lower_case,
                                     batch_size=options.batch_size)
    return MonoT5(model, tokenizer)
示例#3
0
def construct_t5(options: PassageRankingEvaluationOptions) -> Reranker:
    device = torch.device(options.device)
    model = T5ForConditionalGeneration.from_pretrained(
        options.model, from_tf=options.from_tf).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained(options.model_type)
    tokenizer = T5BatchTokenizer(tokenizer, options.batch_size)
    return MonoT5(model, tokenizer)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_or_path",
                        default='unicamp-dl/mt5-base-multi-msmarco',
                        type=str,
                        required=False,
                        help="Reranker model.")
    parser.add_argument("--initial_run",
                        default=None,
                        type=str,
                        required=True,
                        help="Initial run to be reranked.")
    parser.add_argument("--corpus",
                        default=None,
                        type=str,
                        required=True,
                        help="Document collection.")
    parser.add_argument("--output_run",
                        default=None,
                        type=str,
                        required=True,
                        help="Path to save the reranked run.")
    parser.add_argument("--queries",
                        default=None,
                        type=str,
                        required=True,
                        help="Path to the queries file.")

    args = parser.parse_args()
    model = MonoT5(args.model_name_or_path)
    run = load_run(args.initial_run)
    corpus = load_corpus(args.corpus)
    queries = load_queries(args.queries)

    # Run reranker
    trec = open(args.output_run + '-trec.txt', 'w')
    marco = open(args.output_run + '-marco.txt', 'w')
    for idx, query_id in enumerate(tqdm(run.keys())):
        query = Query(queries[query_id])
        texts = [
            Text(corpus[doc_id], {'docid': doc_id}, 0)
            for doc_id in run[query_id]
        ]
        reranked = model.rerank(query, texts)
        for rank, document in enumerate(reranked):
            trec.write(
                f'{query_id}\tQ0\t{document.metadata["docid"]}\t{rank+1}\t{document.score}\t{args.model_name_or_path}\n'
            )
            marco.write(
                f'{query_id}\t{document.metadata["docid"]}\t{rank+1}\n')
    trec.close()
    marco.close()
    print("Done!")
def main(output_path=OUTPUT_PATH,
         index_path=INDEX_PATH,
         queries_path=QUERIES_PATH,
         run=RUN,
         k=K):
    print('################################################')
    print("##### Performing Passage Ranking using L2R #####")
    print('################################################')
    print("Output will be placed in:", output_path,
          ", format used will be TREC")
    print('Loading pre-trained model MonoT5...')
    from pygaggle.rerank.transformer import MonoT5
    reranker = MonoT5()

    print('Fetching anserini-like indices from:', index_path)
    # fetch some passages to rerank from MS MARCO with Pyserini (BM25)
    searcher = SimpleSearcher(index_path)
    print('Loading queries from:', queries_path)
    with open(queries_path, 'r') as f:
        content = f.readlines()
        content = [x.strip().split('\t') for x in content]
        queries = [Query(x[1], x[0]) for x in content]
    print(f'Ranking queries using BM25 (k={k})')
    queries_text = []
    for query in tqdm(queries):
        hits = searcher.search(query.text, k=K)
        texts = hits_to_texts(hits)
        queries_text.append(texts)

    print('Reranking all queries using MonoT5!')
    rankings = []

    for (i, query) in enumerate(tqdm(queries)):
        reranked = reranker.rerank(query, queries_text[i])
        reranked.sort(key=lambda x: x.score, reverse=True)
        rankings.append(reranked)

    print('Outputting to file...')
    if '.tsv' in output_path:
        output_to_tsv(queries, rankings, run, output_path)
    elif '.csv' in output_path:
        output_to_csv(queries, rankings, run, output_path)
    else:
        print(
            'ERROR: invalid output file format provided, please use either .csv or .tsv. Exiting'
        )
        sys.exit(1)
    print('SUCCESS: completed reranking, you may check the output at:',
          output_path)
    sys.exit(0)
def construct_t5(options: KaggleEvaluationOptions) -> Reranker:
    model = MonoT5.get_model(options.model, device=options.device)
    tokenizer = MonoT5.get_tokenizer(options.model,
                                     batch_size=options.batch_size)
    return MonoT5(model, tokenizer)
示例#7
0
    parser.add_argument('--max_length',
                        type=int,
                        default=10,
                        help='Maximum number of sentences of each segment.')
    parser.add_argument(
        '--stride',
        type=int,
        default=5,
        help='Stride (step) in sentences between each segment.')
    return parser.parse_args(args)


args = parse_args(sys.argv[1:])

# Model
reranker = MonoT5()
monot5_results = args.output_monot5

# Sentencizer
nlp = spacy.blank("en")
nlp.add_pipe('sentencizer')

# Input Files
queries = load_queries(path=args.queries)
run = load_run(path=args.run)
corpus = load_corpus(path=args.corpus)

# Pipeline
n_segments = 0
n_docs = 0
n_doc_ids_not_found = 0