def train(encoder_decoder: EncoderDecoder, train_data_loader: DataLoader, model_name, val_data_loader: DataLoader, keep_prob, teacher_forcing_schedule, lr, max_length, use_decay, data_path): global_step = 0 loss_function = torch.nn.NLLLoss(ignore_index=0) optimizer = optim.Adam(encoder_decoder.parameters(), lr=lr) model_path = './saved/' + model_name + '/' if (use_decay == False): gamma = 1.0 else: gamma = 0.5 scheduler = StepLR(optimizer, step_size=1, gamma=gamma) #val_loss, val_bleu_score = evaluate(encoder_decoder, val_data_loader) best_bleu = 0.0 for epoch, teacher_forcing in enumerate(teacher_forcing_schedule): #scheduler.step() print('epoch %i' % (epoch), flush=True) print('lr: ' + str(scheduler.get_lr())) for batch_idx, (input_idxs, target_idxs, input_tokens, target_tokens) in enumerate(tqdm(train_data_loader)): # input_idxs and target_idxs have dim (batch_size x max_len) # they are NOT sorted by length ''' print(input_idxs[0]) print(input_tokens[0]) print(target_idxs[0]) print(target_tokens[0]) ''' lengths = (input_idxs != 0).long().sum(dim=1) sorted_lengths, order = torch.sort(lengths, descending=True) input_variable = Variable(input_idxs[order, :][:, :max(lengths)]) target_variable = Variable(target_idxs[order, :]) optimizer.zero_grad() output_log_probs, output_seqs = encoder_decoder( input_variable, list(sorted_lengths), targets=target_variable, keep_prob=keep_prob, teacher_forcing=teacher_forcing) batch_size = input_variable.shape[0] flattened_outputs = output_log_probs.view(batch_size * max_length, -1) batch_loss = loss_function(flattened_outputs, target_variable.contiguous().view(-1)) batch_loss.backward() optimizer.step() batch_outputs = trim_seqs(output_seqs) batch_targets = [[list(seq[seq > 0])] for seq in list(to_np(target_variable))] #batch_bleu_score = corpus_bleu(batch_targets, batch_outputs, smoothing_function=SmoothingFunction().method2) batch_bleu_score = corpus_bleu(batch_targets, batch_outputs) ''' if global_step < 10 or (global_step % 10 == 0 and global_step < 100) or (global_step % 100 == 0 and epoch < 2): input_string = "Amy, Please schedule a meeting with Marcos on Tuesday April 3rd. Adam Kleczewski" output_string = encoder_decoder.get_response(input_string) writer.add_text('schedule', output_string, global_step=global_step) input_string = "Amy, Please cancel this meeting. Adam Kleczewski" output_string = encoder_decoder.get_response(input_string) writer.add_text('cancel', output_string, global_step=global_step) ''' if global_step % 100 == 0: writer.add_scalar('train_batch_loss', batch_loss, global_step) writer.add_scalar('train_batch_bleu_score', batch_bleu_score, global_step) for tag, value in encoder_decoder.named_parameters(): tag = tag.replace('.', '/') writer.add_histogram('weights/' + tag, value, global_step, bins='doane') writer.add_histogram('grads/' + tag, to_np(value.grad), global_step, bins='doane') global_step += 1 debug = False if (debug): if batch_idx == 5: break val_loss, val_bleu_score = evaluate(encoder_decoder, val_data_loader) writer.add_scalar('val_loss', val_loss, global_step=global_step) writer.add_scalar('val_bleu_score', val_bleu_score, global_step=global_step) encoder_embeddings = encoder_decoder.encoder.embedding.weight.data encoder_vocab = encoder_decoder.lang.tok_to_idx.keys() writer.add_embedding(encoder_embeddings, metadata=encoder_vocab, global_step=0, tag='encoder_embeddings') decoder_embeddings = encoder_decoder.decoder.embedding.weight.data decoder_vocab = encoder_decoder.lang.tok_to_idx.keys() writer.add_embedding(decoder_embeddings, metadata=decoder_vocab, global_step=0, tag='decoder_embeddings') ''' input_string = "Amy, Please schedule a meeting with Marcos on Tuesday April 3rd. Adam Kleczewski" output_string = encoder_decoder.get_response(input_string) writer.add_text('schedule', output_string, global_step=global_step) input_string = "Amy, Please cancel this meeting. Adam Kleczewski" output_string = encoder_decoder.get_response(input_string) writer.add_text('cancel', output_string, global_step=global_step) ''' calc_bleu_score = get_bleu(encoder_decoder, data_path, None, 'dev') print('val loss: %.5f, val BLEU score: %.5f' % (val_loss, calc_bleu_score), flush=True) if (calc_bleu_score > best_bleu): print("Best BLEU score! Saving model...") best_bleu = calc_bleu_score torch.save( encoder_decoder, "%s%s_%i_%.3f.pt" % (model_path, model_name, epoch, calc_bleu_score)) print('-' * 100, flush=True) scheduler.step()
def train(encoder_decoder: EncoderDecoder, train_data_loader: DataLoader, model_name, val_data_loader: DataLoader, keep_prob, teacher_forcing_schedule, lr, max_length, device, test_data_loader: DataLoader): global_step = 0 loss_function = torch.nn.NLLLoss(ignore_index=0) optimizer = optim.Adam(encoder_decoder.parameters(), lr=lr) model_path = './model/' + model_name + '/' trained_model = encoder_decoder for epoch, teacher_forcing in enumerate(teacher_forcing_schedule): print('epoch %i' % epoch, flush=True) correct_predictions = 0.0 all_predictions = 0.0 for batch_idx, (input_idxs, target_idxs, input_tokens, target_tokens) in enumerate(tqdm(train_data_loader)): # Empty the cache at each batch torch.cuda.empty_cache() # input_idxs and target_idxs have dim (batch_size x max_len) # they are NOT sorted by length lengths = (input_idxs != 0).long().sum(dim=1) sorted_lengths, order = torch.sort(lengths, descending=True) input_variable = input_idxs[order, :][:, :max(lengths)] input_variable = input_variable.to(device) target_variable = target_idxs[order, :] target_variable = target_variable.to(device) optimizer.zero_grad() output_log_probs, output_seqs = encoder_decoder( input_variable, list(sorted_lengths), targets=target_variable, keep_prob=keep_prob, teacher_forcing=teacher_forcing) batch_size = input_variable.shape[0] output_sentences = output_seqs.squeeze(2) flattened_outputs = output_log_probs.view(batch_size * max_length, -1) batch_loss = loss_function(flattened_outputs, target_variable.contiguous().view(-1)) batch_outputs = trim_seqs(output_seqs) batch_inputs = [[list(seq[seq > 0])] for seq in list(to_np(input_variable))] batch_targets = [[list(seq[seq > 0])] for seq in list(to_np(target_variable))] for i in range(len(batch_outputs)): y_i = batch_outputs[i] tgt_i = batch_targets[i][0] if y_i == tgt_i: correct_predictions += 1.0 all_predictions += 1.0 batch_loss.backward() optimizer.step() batch_bleu_score = corpus_bleu( batch_targets, batch_outputs, smoothing_function=SmoothingFunction().method1) if global_step % 100 == 0: writer.add_scalar('train_batch_loss', batch_loss, global_step) writer.add_scalar('train_batch_bleu_score', batch_bleu_score, global_step) for tag, value in encoder_decoder.named_parameters(): tag = tag.replace('.', '/') writer.add_histogram('weights/' + tag, value, global_step, bins='doane') writer.add_histogram('grads/' + tag, to_np(value.grad), global_step, bins='doane') global_step += 1 encoder_embeddings = encoder_decoder.encoder.embedding.weight.data encoder_vocab = encoder_decoder.lang.tok_to_idx.keys() writer.add_embedding(encoder_embeddings, metadata=encoder_vocab, global_step=0, tag='encoder_embeddings') decoder_embeddings = encoder_decoder.decoder.embedding.weight.data decoder_vocab = encoder_decoder.lang.tok_to_idx.keys() writer.add_embedding(decoder_embeddings, metadata=decoder_vocab, global_step=0, tag='decoder_embeddings') print('training accuracy %.5f' % (100.0 * (correct_predictions / all_predictions))) torch.save(encoder_decoder, "%s%s_%i.pt" % (model_path, model_name, epoch)) trained_model = encoder_decoder print('-' * 100, flush=True) torch.save(encoder_decoder, "%s%s_final.pt" % (model_path, model_name)) return trained_model