예제 #1
0
def search_in_milvus(text_embedding, query_text):
    collection_name = 'faq_finance'
    partition_tag = 'partition_1'
    client = RecallByMilvus()
    start_time = time.time()
    status, results = client.search(collection_name=collection_name,
                                    vectors=text_embedding,
                                    partition_tag=partition_tag)
    end_time = time.time()
    print('Search milvus time cost is {} seconds '.format(end_time -
                                                          start_time))

    corpus_file = "data/qa_pair.csv"
    id2corpus = gen_id2corpus(corpus_file)
    list_data = []
    for line in results:
        for item in line:
            idx = item.id
            distance = item.distance
            text = id2corpus[idx]
            print(text, distance)
            list_data.append([query_text, text, distance])
    df = pd.DataFrame(list_data, columns=['query_text', 'text', 'distance'])
    df = df.sort_values(by="distance", ascending=True)
    df.to_csv('data/recall_predict.csv',
              columns=['text', 'distance'],
              sep='\t',
              header=None,
              index=False)
예제 #2
0
def build_data_loader(args, tokenizer):
    """ build corpus_data_loader and text_data_loader
    """

    id2corpus = gen_id2corpus(args.corpus_file)

    # conver_example function's input must be dict
    corpus_list = [{idx: text} for idx, text in id2corpus.items()]
    corpus_ds = MapDataset(corpus_list)

    trans_func = partial(convert_example,
                         tokenizer=tokenizer,
                         max_seq_length=args.max_seq_length)

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # text_input
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # text_segment
    ): [data for data in fn(samples)]

    corpus_data_loader = create_dataloader(corpus_ds,
                                           mode='predict',
                                           batch_size=args.batch_size,
                                           batchify_fn=batchify_fn,
                                           trans_fn=trans_func)

    # build text data_loader
    text_list, text2similar_text = gen_text_file(args.similar_text_pair_file)

    text_ds = MapDataset(text_list)

    text_data_loader = create_dataloader(text_ds,
                                         mode='predict',
                                         batch_size=args.batch_size,
                                         batchify_fn=batchify_fn,
                                         trans_fn=trans_func)

    d = {
        "text_data_loader": text_data_loader,
        "corpus_data_loader": corpus_data_loader,
        "id2corpus": id2corpus,
        "text2similar_text": text2similar_text,
        "text_list": text_list
    }

    return d
예제 #3
0
def search_in_milvus(text_embedding):
    collection_name = 'literature_search'
    partition_tag = 'partition_2'
    client = RecallByMilvus()
    status, results = client.search(collection_name=collection_name,
                                    vectors=text_embedding.tolist(),
                                    partition_tag=partition_tag)
    # print(status)
    # print(resultes)
    corpus_file = "milvus/milvus_data.csv"
    id2corpus = gen_id2corpus(corpus_file)
    # print(status)
    # print(results)
    for line in results:
        for item in line:
            idx = item.id
            distance = item.distance
            text = id2corpus[idx]
            print(idx, text, distance)
예제 #4
0
        "ernie-1.0")

    model = SemanticIndexBase(pretrained_model,
                              output_emb_size=args.output_emb_size)
    model = paddle.DataParallel(model)

    # Load pretrained semantic model
    if args.params_path and os.path.isfile(args.params_path):
        state_dict = paddle.load(args.params_path)
        model.set_dict(state_dict)
        logger.info("Loaded parameters from %s" % args.params_path)
    else:
        raise ValueError(
            "Please set --params_path with correct pretrained model file")

    id2corpus = gen_id2corpus(args.corpus_file)

    # conver_example function's input must be dict
    corpus_list = [{idx: text} for idx, text in id2corpus.items()]
    corpus_ds = MapDataset(corpus_list)

    corpus_data_loader = create_dataloader(corpus_ds,
                                           mode='predict',
                                           batch_size=args.batch_size,
                                           batchify_fn=batchify_fn,
                                           trans_fn=trans_func)

    # Need better way to get inner model of DataParallel
    inner_model = model._layers

    final_index = build_index(args, corpus_data_loader, inner_model)