def training_loop(train_dict, val_dict, idx_dict, encoder, decoder, criterion, optimizer, opts): """Runs the main training loop; evaluates the model on the val set every epoch. * Prints training and val loss each epoch. * Prints qualitative translation results each epoch using TEST_SENTENCE * Saves an attention map for TEST_WORD_ATTN each epoch Arguments: train_dict: The training word pairs, organized by source and target lengths. val_dict: The validation word pairs, organized by source and target lengths. idx_dict: Contains char-to-index and index-to-char mappings, and start & end token indexes. encoder: An encoder model to produce annotations for each step of the input sequence. decoder: A decoder model (with or without attention) to generate output tokens. criterion: Used to compute the CrossEntropyLoss for each decoder output. optimizer: Implements a step rule to update the parameters of the encoder and decoder. opts: The command-line arguments. """ start_token = idx_dict['start_token'] end_token = idx_dict['end_token'] char_to_index = idx_dict['char_to_index'] loss_log = open(os.path.join(opts.checkpoint_path, 'loss_log.txt'), 'w') best_val_loss = 1e6 train_losses = [] val_losses = [] for epoch in range(opts.nepochs): optimizer.param_groups[0]['lr'] *= opts.lr_decay epoch_losses = [] for key in train_dict: input_strings, target_strings = zip(*train_dict[key]) input_tensors = [ torch.LongTensor( utils.string_to_index_list(s, char_to_index, end_token)) for s in input_strings ] target_tensors = [ torch.LongTensor( utils.string_to_index_list(s, char_to_index, end_token)) for s in target_strings ] num_tensors = len(input_tensors) num_batches = int(np.ceil(num_tensors / float(opts.batch_size))) for i in range(num_batches): start = i * opts.batch_size end = start + opts.batch_size inputs = utils.to_var(torch.stack(input_tensors[start:end]), opts.cuda) targets = utils.to_var(torch.stack(target_tensors[start:end]), opts.cuda) # The batch size may be different in each epoch BS = inputs.size(0) encoder_annotations, encoder_hidden = encoder(inputs) # The last hidden state of the encoder becomes the first hidden state of the decoder decoder_hidden = encoder_hidden start_vector = torch.ones(BS).long().unsqueeze( 1) * start_token # BS x 1 --> 16x1 CHECKED decoder_input = utils.to_var( start_vector, opts.cuda) # BS x 1 --> 16x1 CHECKED loss = 0.0 seq_len = targets.size(1) # Gets seq_len from BS x seq_len use_teacher_forcing = np.random.rand( ) < opts.teacher_forcing_ratio for i in range(seq_len): decoder_output, decoder_hidden, attention_weights = decoder( decoder_input, decoder_hidden, encoder_annotations) current_target = targets[:, i] loss += criterion( decoder_output, current_target ) # cross entropy between the decoder distribution and GT ni = F.softmax(decoder_output, dim=1).data.max(1)[1] if use_teacher_forcing: # With teacher forcing, use the ground-truth token to condition the next step decoder_input = targets[:, i].unsqueeze(1) else: # Without teacher forcing, use the model's own predictions to condition the next step decoder_input = utils.to_var(ni.unsqueeze(1), opts.cuda) loss /= float(seq_len) epoch_losses.append(loss.item()) # Zero gradients optimizer.zero_grad() # Compute gradients loss.backward() # Update the parameters of the encoder and decoder optimizer.step() train_loss = np.mean(epoch_losses) val_loss = evaluate(val_dict, encoder, decoder, idx_dict, criterion, opts) if val_loss < best_val_loss: checkpoint(encoder, decoder, idx_dict, opts) if not opts.no_attention: # Save attention maps for the fixed word TEST_WORD_ATTN throughout training utils.visualize_attention( TEST_WORD_ATTN, encoder, decoder, idx_dict, opts, save=os.path.join( opts.checkpoint_path, 'train_attns/attn-epoch-{}.png'.format(epoch))) gen_string = utils.translate_sentence(TEST_SENTENCE, encoder, decoder, idx_dict, opts) print( "Epoch: {:3d} | Train loss: {:.3f} | Val loss: {:.3f} | Gen: {:20s}" .format(epoch, train_loss, val_loss, gen_string)) loss_log.write('{} {} {}\n'.format(epoch, train_loss, val_loss)) loss_log.flush() train_losses.append(train_loss) val_losses.append(val_loss) save_loss_plot(train_losses, val_losses, opts)
def evaluate(data_dict, encoder, decoder, idx_dict, criterion, opts): """Evaluates the model on a held-out validation or test set. Arguments: data_dict: The validation/test word pairs, organized by source and target lengths. encoder: An encoder model to produce annotations for each step of the input sequence. decoder: A decoder model (with or without attention) to generate output tokens. idx_dict: Contains char-to-index and index-to-char mappings, and start & end token indexes. criterion: Used to compute the CrossEntropyLoss for each decoder output. opts: The command-line arguments. Returns: mean_loss: The average loss over all batches from data_dict. """ start_token = idx_dict['start_token'] end_token = idx_dict['end_token'] char_to_index = idx_dict['char_to_index'] losses = [] for key in data_dict: input_strings, target_strings = zip(*data_dict[key]) input_tensors = [ torch.LongTensor( utils.string_to_index_list(s, char_to_index, end_token)) for s in input_strings ] target_tensors = [ torch.LongTensor( utils.string_to_index_list(s, char_to_index, end_token)) for s in target_strings ] num_tensors = len(input_tensors) num_batches = int(np.ceil(num_tensors / float(opts.batch_size))) for i in range(num_batches): start = i * opts.batch_size end = start + opts.batch_size inputs = utils.to_var(torch.stack(input_tensors[start:end]), opts.cuda) targets = utils.to_var(torch.stack(target_tensors[start:end]), opts.cuda) # The batch size may be different in each epoch BS = inputs.size(0) encoder_annotations, encoder_hidden = encoder(inputs) # The final hidden state of the encoder becomes the initial hidden state of the decoder decoder_hidden = encoder_hidden start_vector = torch.ones(BS).long().unsqueeze( 1) * start_token # BS x 1 decoder_input = utils.to_var(start_vector, opts.cuda) # BS x 1 loss = 0.0 seq_len = targets.size(1) # Gets seq_len from BS x seq_len for i in range(seq_len): decoder_output, decoder_hidden, attention_weights = decoder( decoder_input, decoder_hidden, encoder_annotations) current_target = targets[:, i] loss += criterion( decoder_output, current_target ) # cross entropy between the decoder distribution and GT ni = F.softmax(decoder_output, dim=1).data.max(1)[1] decoder_input = targets[:, i].unsqueeze(1) loss /= float(seq_len) losses.append(loss.item()) mean_loss = np.mean(losses) return mean_loss