def train_siamese(): # 读取配置 # conf = Config() cfg_path = "./configs/config.yml" cfg = yaml.load(open(cfg_path, encoding='utf-8'), Loader=yaml.FullLoader) # 读取数据 data_train, data_val, data_test = data_input.get_lcqmc() # data_train = data_train[:100] print("train size:{},val size:{}, test size:{}".format( len(data_train), len(data_val), len(data_test))) model = SiamenseRNN(cfg) model.fit(data_train, data_val, data_test) pass
def predict_siamese(file_='./results/'): # 加载配置 cfg_path = "./configs/config.yml" cfg = yaml.load(open(cfg_path, encoding='utf-8'), Loader=yaml.FullLoader) os.environ["CUDA_VISIBLE_DEVICES"] = "4" # 将 seq转为id, vocab = Vocabulary(meta_file='./data/vocab.txt', max_len=cfg['max_seq_len'], allow_unk=1, unk='[UNK]', pad='[PAD]') test_arr, query_arr = get_test(file_, vocab) # 加载模型 model = SiamenseRNN(cfg) model.restore_session(cfg["checkpoint_dir"]) test_label, test_prob = model.predict(test_arr) out_arr = [x + [test_label[i]] + [test_prob[i]] for i, x in enumerate(query_arr)] write_file(out_arr, file_ + '.siamese.predict', ) pass