def doc_distance():
    category = request.form['category']
    in_type = request.form['type']
    if in_type == 'doc':
        f1 = request.files['text1']
        f2 = request.files['text2']
        if save_file(f1) and save_file(f2):
            f_text1 = read_file(f1)
            f_text2 = read_file(f2)
    else:
        f_text1 = request.form['text1'].encode('utf-8').strip()
        f_text2 = request.form['text2'].encode('utf-8').strip()

    inference_engine_wrapper = InferenceEngineWrapper(get_model_dir(category),
                                                      get_lda_conf())
    doc1_seg = inference_engine_wrapper.tokenize(f_text1)
    doc2_seg = inference_engine_wrapper.tokenize(f_text2)
    distances = inference_engine_wrapper.cal_doc_distance(doc1_seg, doc2_seg)

    return json.dumps(
        {
            "Jensen-Shannon Divergence": distances[0],
            "Hellinger Distance": distances[1]
        },
        ensure_ascii=False)
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
#
# Author: [email protected]

import sys
from familia_wrapper import InferenceEngineWrapper

if sys.version_info < (3,0):
    input = raw_input

if __name__ == '__main__':
    if len(sys.argv) < 3:
        sys.stderr.write("Usage:python {} {} {}.\n".format(
            sys.argv[0], "model_dir", "conf_file"))
        exit(-1)

    # 获取参数
    model_dir = sys.argv[1]
    conf_file = sys.argv[2]
    # 创建InferenceEngineWrapper对象
    inference_engine_wrapper = InferenceEngineWrapper(model_dir, conf_file)
    while True:
        # 输入两个长文本
        doc1 = input("Enter Document1: ").strip()
        doc2 = input("Enter Document2: ").strip()
        distances = inference_engine_wrapper.cal_doc_distance(doc1, doc2)
        # 打印结果
        print("Jensen-Shannon Divergence = {}".format(distances[0]))
        print("Hellinger Distance = {}".format(distances[1]))