Ejemplo n.º 1
0
def main(opt):
    src.util.init_logger(is_main=True)
    tokenizer = transformers.BertTokenizerFast.from_pretrained('bert-base-uncased')
    data = src.data.load_data(opt.data)
    model_class = src.model.Retriever
    model = model_class.from_pretrained(opt.model_path)

    model.cuda()
    model.eval()
    if not opt.no_fp16:
        model = model.half()

    # index = src.index.Indexer(model.config.indexing_dimension, opt.n_subquantizers, opt.n_bits)

    # # index all passages
    # input_paths = glob.glob(args.passages_embeddings)
    # input_paths = sorted(input_paths)
    # embeddings_dir = Path(input_paths[0]).parent
    # index_path = embeddings_dir / 'index.faiss'
    # if args.save_or_load_index and index_path.exists():
    #     index.deserialize_from(embeddings_dir)
    # else:
    #     logger.info(f'Indexing passages from files {input_paths}')
    #     start_time_indexing = time.time()
    #     index_encoded_data(index, input_paths, opt.indexing_batch_size)
    #     logger.info(f'Indexing time: {time.time()-start_time_indexing:.1f} s.')
    #     if args.save_or_load_index:
    #         index.serialize(embeddings_dir)

    questions_embedding, question_ids = embed_questions(opt, data, model, tokenizer)
    lout = []
    for ct, (qe, qid) in enumerate(zip(questions_embedding, question_ids)):
        lout.append((str(ct), (qid, qe)))
    pickle.dump(lout, open("fid.pkl", 'wb'))
    exit()
    # get top k results
    start_time_retrieval = time.time()
    top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs) 
    logger.info(f'Search time: {time.time()-start_time_retrieval:.1f} s.')

    passages = src.util.load_passages(args.passages)
    passages = {x[0]:(x[1], x[2]) for x in passages}

    add_passages(data, passages, top_ids_and_scores)
    hasanswer = validate(data, args.validation_workers)
    add_hasanswer(data, hasanswer)
    output_path = Path(args.output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(args.output_path, 'w') as fout:
        json.dump(data, fout, indent=4)
    logger.info(f'Saved results to {args.output_path}')
Ejemplo n.º 2
0
def main(opt):
    logger = src.util.init_logger(is_main=True)
    tokenizer = transformers.BertTokenizerFast.from_pretrained(
        'bert-base-uncased')
    model_class = src.model.Retriever
    #model, _, _, _, _, _ = src.util.load(model_class, opt.model_path, opt)
    model = model_class.from_pretrained(opt.model_path)

    model.eval()
    model = model.to(opt.device)
    if not opt.no_fp16:
        model = model.half()

    passages = src.util.load_passages(args.passages)

    shard_size = len(passages) // args.num_shards
    start_idx = args.shard_id * shard_size
    end_idx = start_idx + shard_size
    if args.shard_id == args.num_shards - 1:
        end_idx = len(passages)

    passages = passages[start_idx:end_idx]
    logger.info(
        f'Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}'
    )

    allids, allembeddings = embed_passages(opt, passages, model, tokenizer)

    output_path = Path(args.output_path)
    save_file = output_path.parent / (output_path.name +
                                      f'_{args.shard_id:02d}')
    output_path.parent.mkdir(parents=True, exist_ok=True)
    logger.info(f'Saving {len(allids)} passage embeddings to {save_file}')
    with open(save_file, mode='wb') as f:
        pickle.dump((allids, allembeddings), f)

    logger.info(
        f'Total passages processed {len(allids)}. Written to {save_file}.')