Ejemplo n.º 1
0
def predict(model_handle, text):
    """ predict for text by model_handle """
    batch_size = 1

    [exe, place, final_score, final_ids, final_index, processors, id_dict_array] = model_handle

    data_generator = processors.preprocessing_for_lines([text], batch_size=batch_size)

    results = []
    for batch_id, data in enumerate(data_generator()):
        data_feed, sent_num = build_data_feed(data, place, batch_size=batch_size)
        out = exe.run(feed=data_feed,
                      fetch_list=[final_score.name, final_ids.name, final_index.name])

        batch_score = out[0]
        batch_ids = out[1]
        batch_pre_index = out[2]

        batch_score_arr = np.split(batch_score, batch_size, axis=1)
        batch_ids_arr = np.split(batch_ids, batch_size, axis=1)
        batch_pre_index_arr = np.split(batch_pre_index, batch_size, axis=1)

        index = 0
        for (score, ids, pre_index) in zip(batch_score_arr, batch_ids_arr, batch_pre_index_arr):
            trace_ids, trace_score = trace_fianl_result(score, ids, pre_index, topk=1, EOS=3)
            results.append(id_to_text(trace_ids[0][:-1], id_dict_array))

            index += 1
            if index >= sent_num:
                break

    return results[0]
Ejemplo n.º 2
0
def test(config):
    """ test """
    batch_size = config.batch_size
    config.vocab_size = len(open(config.vocab_path).readlines())
    final_score, final_ids, final_index = knowledge_seq2seq(config)

    final_score.persistable = True
    final_ids.persistable = True
    final_index.persistable = True

    main_program = fluid.default_main_program()

    if config.use_gpu:
        place = fluid.CUDAPlace(0)
    else:
        place = fluid.CPUPlace()

    exe = Executor(place)
    exe.run(framework.default_startup_program())

    fluid.io.load_params(executor=exe, dirname=config.model_path,
                         main_program=main_program)
    print("laod params finsihed")

    # test data generator
    processors = KnowledgeCorpus(
        data_dir=config.data_dir,
        data_prefix=config.data_prefix,
        vocab_path=config.vocab_path,
        min_len=config.min_len,
        max_len=config.max_len)
    test_generator = processors.data_generator(
        batch_size=config.batch_size,
        phase="test",
        shuffle=False)

    # load dict
    id_dict_array = load_id2str_dict(config.vocab_path)

    out_file = config.output
    fout = open(out_file, 'w')
    for batch_id, data in enumerate(test_generator()):
        data_feed, sent_num = build_data_feed(data, place, batch_size=batch_size)

        if data_feed is None:
            break

        out = exe.run(feed=data_feed,
                      fetch_list=[final_score.name, final_ids.name, final_index.name])

        batch_score = out[0]
        batch_ids = out[1]
        batch_pre_index = out[2]

        batch_score_arr = np.split(batch_score, batch_size, axis=1)
        batch_ids_arr = np.split(batch_ids, batch_size, axis=1)
        batch_pre_index_arr = np.split(batch_pre_index, batch_size, axis=1)

        index = 0
        for (score, ids, pre_index) in zip(batch_score_arr, batch_ids_arr, batch_pre_index_arr):
            trace_ids, trace_score = trace_fianl_result(score, ids, pre_index, topk=1, EOS=3)
            fout.write(id_to_text(trace_ids[0][:-1], id_dict_array))
            fout.write('\n')

            index += 1
            if index >= sent_num:
                break

    fout.close()