def test_GraphAttention(): # DATA_PATH = './data/dataset_eval' DATA_PATH = './nbc' DICT_PATH = './checkpoint/dict_20000.pkl' EMBEDDING_PATH_RANDOM = './model/save_embedding_97and3.ckpt' GPU_NUM = 0 model_epoch = 7 MODEL_PATH = f"./checkpoint/syntax_Attn/model-{model_epoch}.ckpt" BATCH_SIZE = 200 torch.cuda.set_device(GPU_NUM) embed = get_word_embed().cuda() embed_flag = get_flag_embed().cuda() vocab = get_vocab() config = SyntaxConfig() model = BiLSTMSyntax(config) model.load(MODEL_PATH) model.cuda() trainset = SyntaxDataset(vocab=vocab, data_path=DATA_PATH) trainloader = DataLoader(dataset=trainset, batch_size=config.batch_size, collate_fn=syntax_fn, num_workers=10, pin_memory=True, shuffle=False) correct_num = 0 all_num = 0 recall_correct = 0 recall_all = 0 model.eval() for index, (src, trg, labels, tags) in enumerate(tqdm(trainloader)): print(index) batch_size = src.shape[0] src = embed(src.cuda()) trg = embed(trg.cuda()) labels = labels.cuda() tags = tags.cuda() flag4encoder = torch.zeros(src.shape[0], src.shape[1], 3).cuda() src = torch.cat([src, flag4encoder], dim=2) encoder_output, hidden, syntax_hidden = model.step_encoding( src, tags) # get the hidden state of the encoder output_labels = [] input_flag = [[2] for j in range(trg.shape[0])] # 第一位的标志位 input_flag = torch.Tensor(input_flag).long().cuda() output_labels = [] out_flag = 2 # 一开始的output flag hidden_syntax = syntax_hidden for index in range(trg.shape[1]): flag4decoder = embed_flag(input_flag) trg_step = trg[:, index, :].view(batch_size, 1, -1) tags_step = tags[:, index].view(batch_size, -1) trg_step = torch.cat([trg_step, flag4decoder], dim=-1) out, hidden_syntax, hidden = model.test_decoding( trg_step, tags_step, flag4decoder, hidden_syntax=hidden_syntax, hidden=hidden) input_flag = torch.max(out, 2)[1] output_labels.append(input_flag) output_labels = torch.cat(output_labels, dim=1) labels = labels.squeeze() mask_matrix = labels < 2 predict_labels = torch.masked_select(output_labels, mask_matrix) ground_truth = torch.masked_select(labels, mask_matrix) correct_num += torch.sum(predict_labels == ground_truth).item() recall_correct += torch.sum(predict_labels & ground_truth).item() recall_all += torch.sum(ground_truth).item() all_num += len(ground_truth) P = correct_num / all_num R = recall_correct / recall_all F1 = 2 * P * R / (P + R) print('Precision is {}'.format(P)) print('Recall is {}'.format(R)) print('F1 is {} \n'.format(F1)) P = correct_num / all_num R = recall_correct / recall_all F1 = 2 * P * R / (P + R) print('Finally', BATCH_SIZE) print('\tPrecision is {}'.format(P)) print('\tRecall is {}'.format(R)) print('\tF1 is {}'.format(F1)) print(correct_num, recall_correct) return P, R, F1
def train(reload_dataset=False, pretrain_model_path=None, optim_fu='adam'): write = SummaryWriter() vis = visdom.Visdom(env="syntax_compression") viz = Visdom_line(vis=vis, win="syntax_geted_lstm") # 一些配置 DATA_DIR = '../data/train_pairs' DICT_PATH = '../checkpoint/dict_20000.pkl' EMBEDDING_PATH_RANDOM = '../model/save_embedding_97and3.ckpt' SAVE_EMBEDDING = False RELOAD_DATASET = reload_dataset SAVE_DATASET_OBJ = '../data/dataset.pkl' SAVE_MODEL_PATH = './checkpoint/syntax_gate_lstm/' PRINT_STEP = 10 SAVE_STEP = 1 GPU_NUM = 1 torch.manual_seed(2) torch.cuda.set_device(GPU_NUM) model = SyntaxLSTM(100, 100, 10) model.cuda() if os.path.exists(SAVE_MODEL_PATH) is False: os.makedirs(SAVE_MODEL_PATH) # 读取embedding embed = get_word_embed().cuda() embed_flag = get_flag_embed().cuda() vocab = get_vocab() criterion = nn.CrossEntropyLoss(ignore_index=2) # if pretrain_model_path is not None: # print('Loading the pre train model', pretrain_model_path) # model.load(pretrain_model_path) # model.embed.weight.requires_grad = True # parameters = model.parameters() # optimizer = optim.SGD(parameters, lr=0.000001) # else: parameters = filter(lambda p: p.requires_grad, model.parameters()) optimizer = optim.Adam(parameters, lr=0.0001) trainset = SyntaxDataset(vocab=vocab, reverse_src=True) trainloader = DataLoader(dataset=trainset, batch_size=200, collate_fn=syntax_fn, pin_memory=True, num_workers=5, shuffle=True) global_step = 0 loss_print = 0 step_print = 0 PINRT_STEP = 10 for epoch in range(500): epoch_loss = 0 for index, (src, trg, labels, tags) in enumerate(tqdm(trainloader)): src = embed(src.cuda()) trg = embed(trg.cuda()) tags = tags.cuda() flag4encoder = torch.zeros(src.shape[0], src.shape[1], 3).cuda() src = torch.cat([src, flag4encoder], dim=2) flag4decoder = torch.zeros([labels.shape[0], 1]).long() # decoder最前面插入一个起始全0 flag4decoder = torch.cat([flag4decoder, labels[:, :-1]], dim=1).cuda() flag4decoder = embed_flag(flag4decoder) trg = torch.cat([trg, flag4decoder], dim=2) # 插入最后三位标志位 labels = labels.cuda() out = model(src, trg, tags) out = out.view(-1, 2) labels = labels.view(-1) loss = criterion(out, labels) epoch_loss += loss.item() loss_print += loss.item() print(loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() global_step += 1 if global_step % PRINT_STEP == 0: write.add_scalar('loss', loss_print / PRINT_STEP, step_print) step_print += 1 loss_print = 0 model.save(SAVE_MODEL_PATH + 'model-' + str(epoch) + '.ckpt') write.add_scalar('epoch_loss', epoch_loss, epoch)
def test_syntax_lstm(): DATA_PATH = '../data/dataset_eval' DICT_PATH = '../checkpoint/dict_20000.pkl' EMBEDDING_PATH_RANDOM = '../model/save_embedding_97and3.ckpt' GPU_NUM = 0 model_epoch = 20 MODEL_PATH = f"./checkpoint/syntax_gate_lstm/model-{model_epoch}.ckpt" torch.cuda.set_device(GPU_NUM) embed = get_word_embed().cuda() embed_flag = get_flag_embed().cuda() vocab = get_vocab() model = SyntaxLSTM(100, 100, 10) # model.load(MODEL_PATH) model.cuda() testset = SyntaxDataset(vocab=vocab, data_path='../data/dataset_eval') testloader = DataLoader(dataset=testset, batch_size=200, collate_fn=syntax_fn, num_workers=2, pin_memory=True, shuffle=False) correct_num = 0 all_num = 0 recall_correct = 0 recall_all = 0 model.eval() for index, (src, trg, labels, tags) in enumerate(tqdm(testloader)): print(index) batch_size = src.shape[0] src = embed(src.cuda()) trg = embed(trg.cuda()) labels = labels.cuda() tags = tags.cuda() # finally get the source flag4encoder = torch.zeros(src.shape[0], src.shape[1], 3).cuda() src = torch.cat([src, flag4encoder], dim=2) output_labels = model.testing(src, trg, tags, embed_flag) # print(output_labels) labels = labels.squeeze() mask_matrix = labels < 2 predict_labels = torch.masked_select(output_labels, mask_matrix) ground_truth = torch.masked_select(labels, mask_matrix) correct_num += torch.sum(predict_labels == ground_truth).item() recall_correct += torch.sum(predict_labels & ground_truth).item() recall_all += torch.sum(ground_truth).item() all_num += len(ground_truth) P = correct_num / all_num R = recall_correct / recall_all F1 = 2 * P * R / (P + R) print('Precision is {}'.format(P)) print('Recall is {}'.format(R)) print('F1 is {} \n'.format(F1)) P = correct_num / all_num R = recall_correct / recall_all F1 = 2 * P * R / (P + R) print('Finally') print('\tPrecision is {}'.format(P)) print('\tRecall is {}'.format(R)) print('\tF1 is {}'.format(F1)) print(correct_num, recall_correct) return P, R, F1