def get_bleu(encoder_decoder: EncoderDecoder, data_path, model_name, data_type): # test_file = open("data/copynet_test.txt", "r", encoding='utf-8') test_file = open(data_path + 'copynet_' + data_type + '.txt', 'r', encoding='utf-8') if (data_type != 'dev'): out_file = open("results/" + model_name.split('/')[-1] + ".txt", 'w', encoding='utf-8') total_score = 0.0 num = 0.0 for i, row in enumerate(tqdm(test_file)): sql = row.split('\t')[1] gold_nl = row.split('\t')[0] predicted = encoder_decoder.get_response(sql) predicted = predicted.replace('<SOS>', '') predicted = predicted.replace('<EOS>', '') predicted = predicted.rstrip() if (data_type != 'dev'): out_file.write(predicted + "\n") # score = sentence_bleu([gold_nl.split()], predicted.split(), smoothing_function=SmoothingFunction().method2) score = sentence_bleu([gold_nl.split()], predicted.split()) # score = sentence_bleu(ref, pred) total_score += score num += 1 ''' if i == 10: break ''' # del encoder_decoder test_file.close() if (data_type != 'dev'): out_file.close() final_score = total_score * 100 / num if (data_type == 'dev'): print("DEV set") else: print("TEST set") print("BLEU score is " + str(final_score)) return final_score
def train(encoder_decoder: EncoderDecoder, train_data_loader: DataLoader, model_name, val_data_loader: DataLoader, keep_prob, teacher_forcing_schedule, lr, max_length): global_step = 0 loss_function = torch.nn.NLLLoss(ignore_index=0) optimizer = optim.Adam(encoder_decoder.parameters(), lr=lr) model_path = './model/' + model_name + '/' for epoch, teacher_forcing in enumerate(teacher_forcing_schedule): print('epoch %i' % epoch, flush=True) for batch_idx, (input_idxs, target_idxs, input_tokens, target_tokens) in enumerate(tqdm(train_data_loader)): # input_idxs and target_idxs have dim (batch_size x max_len) # they are NOT sorted by length lengths = (input_idxs != 0).long().sum(dim=1) sorted_lengths, order = torch.sort(lengths, descending=True) input_variable = Variable(input_idxs[order, :][:, :max(lengths)]) target_variable = Variable(target_idxs[order, :]) optimizer.zero_grad() output_log_probs, output_seqs = encoder_decoder( input_variable, list(sorted_lengths), targets=target_variable, keep_prob=keep_prob, teacher_forcing=teacher_forcing) batch_size = input_variable.shape[0] flattened_outputs = output_log_probs.view(batch_size * max_length, -1) batch_loss = loss_function(flattened_outputs, target_variable.contiguous().view(-1)) batch_loss.backward() optimizer.step() batch_outputs = trim_seqs(output_seqs) batch_targets = [[list(seq[seq > 0])] for seq in list(to_np(target_variable))] batch_bleu_score = corpus_bleu( batch_targets, batch_outputs, smoothing_function=SmoothingFunction().method1) if global_step < 10 or (global_step % 10 == 0 and global_step < 100) or ( global_step % 100 == 0 and epoch < 2): input_string = "Amy, Please schedule a meeting with Marcos on Tuesday April 3rd. Adam Kleczewski" output_string = encoder_decoder.get_response(input_string) writer.add_text('schedule', output_string, global_step=global_step) input_string = "Amy, Please cancel this meeting. Adam Kleczewski" output_string = encoder_decoder.get_response(input_string) writer.add_text('cancel', output_string, global_step=global_step) if global_step % 100 == 0: writer.add_scalar('train_batch_loss', batch_loss, global_step) writer.add_scalar('train_batch_bleu_score', batch_bleu_score, global_step) for tag, value in encoder_decoder.named_parameters(): tag = tag.replace('.', '/') writer.add_histogram('weights/' + tag, value, global_step, bins='doane') writer.add_histogram('grads/' + tag, to_np(value.grad), global_step, bins='doane') global_step += 1 val_loss, val_bleu_score = evaluate(encoder_decoder, val_data_loader) writer.add_scalar('val_loss', val_loss, global_step=global_step) writer.add_scalar('val_bleu_score', val_bleu_score, global_step=global_step) encoder_embeddings = encoder_decoder.encoder.embedding.weight.data encoder_vocab = encoder_decoder.lang.tok_to_idx.keys() writer.add_embedding(encoder_embeddings, metadata=encoder_vocab, global_step=0, tag='encoder_embeddings') decoder_embeddings = encoder_decoder.decoder.embedding.weight.data decoder_vocab = encoder_decoder.lang.tok_to_idx.keys() writer.add_embedding(decoder_embeddings, metadata=decoder_vocab, global_step=0, tag='decoder_embeddings') input_string = "Amy, Please schedule a meeting with Marcos on Tuesday April 3rd. Adam Kleczewski" output_string = encoder_decoder.get_response(input_string) writer.add_text('schedule', output_string, global_step=global_step) input_string = "Amy, Please cancel this meeting. Adam Kleczewski" output_string = encoder_decoder.get_response(input_string) writer.add_text('cancel', output_string, global_step=global_step) print('val loss: %.5f, val BLEU score: %.5f' % (val_loss, val_bleu_score), flush=True) torch.save(encoder_decoder, "%s%s_%i.pt" % (model_path, model_name, epoch)) print('-' * 100, flush=True)