예제 #1
0
파일: train.py 프로젝트: hjzf/dssm-1
def predict_siamese_bert(file_="./results/input/test"):
    # 读取配置
    # conf = Config()
    cfg_path = "./configs/config_bert.yml"
    cfg = yaml.load(open(cfg_path, encoding='utf-8'), Loader=yaml.FullLoader)
    os.environ["CUDA_VISIBLE_DEVICES"] = "4"
    # vocab: 将 seq转为id,
    vocab = Vocabulary(meta_file='./data/vocab.txt',
                       max_len=cfg['max_seq_len'],
                       allow_unk=1,
                       unk='[UNK]',
                       pad='[PAD]')
    # 读取数据
    test_arr, query_arr = data_input.get_test_bert(file_, vocab)
    print("test size:{}".format(len(test_arr)))
    model = SiamenseBert(cfg)
    model.restore_session(cfg["checkpoint_dir"])
    test_label, test_prob = model.predict(test_arr)
    out_arr = [
        x + [test_label[i]] + [test_prob[i]] for i, x in enumerate(query_arr)
    ]
    write_file(
        out_arr,
        file_ + '.siamese.bert.predict',
    )
    pass
예제 #2
0
파일: train.py 프로젝트: hjzf/dssm-1
def siamese_bert_sentence_embedding(file_="./results/input/test.single"):
    # 输入一行是一个query,输出是此query对应的向量
    # 读取配置
    cfg_path = "./configs/config_bert.yml"
    cfg = yaml.load(open(cfg_path, encoding='utf-8'), Loader=yaml.FullLoader)
    cfg['batch_size'] = 64
    os.environ["CUDA_VISIBLE_DEVICES"] = "7"
    # vocab: 将 seq转为id,
    vocab = Vocabulary(meta_file='./data/vocab.txt',
                       max_len=cfg['max_seq_len'],
                       allow_unk=1,
                       unk='[UNK]',
                       pad='[PAD]')
    # 读取数据
    test_arr, query_arr = data_input.get_test_bert_single(file_, vocab)
    print("test size:{}".format(len(test_arr)))
    model = SiamenseBert(cfg)
    model.restore_session(cfg["checkpoint_dir"])
    test_label = model.predict_embedding(test_arr)
    test_label = [",".join([str(y) for y in x]) for x in test_label]
    out_arr = [[x, test_label[i]] for i, x in enumerate(query_arr)]
    print("write to file...")
    write_file(
        out_arr,
        file_ + '.siamese.bert.embedding',
    )
    pass