def greedy_search(self, x, max_sum_len, len_oovs, x_padding_masks): """Function which returns a summary by always picking the highest probability option conditioned on the previous word. Args: x (Tensor): Input sequence as the source. max_sum_len (int): The maximum length a summary can have. len_oovs (Tensor): Numbers of out-of-vocabulary tokens. x_padding_masks (Tensor): The padding masks for the input sequences with shape (batch_size, seq_len). Returns: summary (list): The token list of the result summary. """ # Get encoder output and states. encoder_output, encoder_states = self.model.encoder( replace_oovs(x, self.vocab)) # Initialize decoder's hidden states with encoder's hidden states. decoder_states = self.model.reduce_state(encoder_states) # Initialize decoder's input at time step 0 with the SOS token. x_t = torch.ones(1) * self.vocab.SOS x_t = x_t.to(self.DEVICE, dtype=torch.int64) summary = [self.vocab.SOS] coverage_vector = torch.zeros((1, x.shape[1])).to(self.DEVICE) # Generate hypothesis with maximum decode step. while int(x_t.item()) != (self.vocab.EOS) \ and len(summary) < max_sum_len: context_vector, attention_weights, coverage_vector = \ self.model.attention(decoder_states, encoder_output, x_padding_masks, coverage_vector) p_vocab, decoder_states, p_gen = \ self.model.decoder(x_t.unsqueeze(1), decoder_states, context_vector) final_dist = self.model.get_final_distribution(x, p_gen, p_vocab, attention_weights, torch.max(len_oovs)) # Get next token with maximum probability. x_t = torch.argmax(final_dist, dim=1).to(self.DEVICE) decoder_word_idx = x_t.item() summary.append(decoder_word_idx) x_t = replace_oovs(x_t, self.vocab) return summary
def forward(self, x, x_len, y, len_oovs, batch, num_batches, teacher_forcing): """Define the forward propagation for the seq2seq model. Args: x (Tensor): Input sequences as source with shape (batch_size, seq_len) x_len ([int): Sequence length of the current batch. y (Tensor): Input sequences as reference with shape (bacth_size, y_len) len_oovs (Tensor): The numbers of out-of-vocabulary words for samples in this batch. batch (int): The number of the current batch. num_batches(int): Number of batches in the epoch. teacher_forcing(bool): teacher_forcing or not Returns: batch_loss (Tensor): The average loss of the current batch. """ x_copy = replace_oovs(x, self.v) x_padding_masks = torch.ne(x, 0).byte().float() # Call encoder forward propagation encoder_output, encoder_states = self.encoder(x_copy, self.decoder.embedding) # Reduce encoder hidden states. decoder_states = self.reduce_state(encoder_states) # Initialize coverage vector. coverage_vector = torch.zeros(x.size()).to(self.DEVICE) # Calculate loss for every step. step_losses = [] # use ground true to set x_t as first step data for decoder input x_t = y[:, 0] for t in range(y.shape[1] - 1): # use ground true to set x_t ,if teacher_forcing is True if teacher_forcing: x_t = y[:, t] x_t = replace_oovs(x_t, self.v) y_t = y[:, t + 1] # Get context vector from the attention network. context_vector, attention_weights, coverage_vector = \ self.attention(decoder_states, encoder_output, x_padding_masks, coverage_vector) # Get vocab distribution and hidden states from the decoder. p_vocab, decoder_states, p_gen = self.decoder( x_t.unsqueeze(1), decoder_states, context_vector) final_dist = self.get_final_distribution(x, p_gen, p_vocab, attention_weights, torch.max(len_oovs)) # t step predict result as t+1 step input x_t = torch.argmax(final_dist, dim=1).to(self.DEVICE) # Get the probabilities predict by the model for target tokens. if not config.pointer: y_t = replace_oovs(y_t, self.v) target_probs = torch.gather(final_dist, 1, y_t.unsqueeze(1)) target_probs = target_probs.squeeze(1) # Apply a mask such that pad zeros do not affect the loss mask = torch.ne(y_t, 0).byte() # Do smoothing to prevent getting NaN loss because of log(0). loss = -torch.log(target_probs + config.eps) if config.coverage: # Add coverage loss. ct_min = torch.min(attention_weights, coverage_vector) cov_loss = torch.sum(ct_min, dim=1) loss = loss + config.LAMBDA * cov_loss mask = mask.float() loss = loss * mask step_losses.append(loss) sample_losses = torch.sum(torch.stack(step_losses, 1), 1) # get the non-padded length of each sequence in the batch seq_len_mask = torch.ne(y, 0).byte().float() batch_seq_len = torch.sum(seq_len_mask, dim=1) # get batch loss by dividing the loss of each batch # by the target sequence length and mean batch_loss = torch.mean(sample_losses / batch_seq_len) return batch_loss
def beam_search(self, x, max_sum_len, beam_width, len_oovs, x_padding_masks): """Using beam search to generate summary. Args: x (Tensor): Input sequence as the source. max_sum_len (int): The maximum length a summary can have. beam_width (int): Beam size. max_oovs (int): Number of out-of-vocabulary tokens. x_padding_masks (Tensor): The padding masks for the input sequences. Returns: result (list(Beam)): The list of best k candidates. """ # run body_sequence input through encoder. Call encoder forward propagation ########################################### # TODO: module 4 task 2 # ########################################### encoder_output, encoder_states = self.model.encoder( replace_oovs(x, self.vocab), self.model.decoder.embedding) coverage_vector = torch.zeros((1, x.shape[1])).to(self.DEVICE) # initialize decoder states with encoder forward states decoder_states = self.model.reduce_state(encoder_states) # initialize the hypothesis with a class Beam instance. init_beam = Beam([self.vocab.SOS], [0], decoder_states, coverage_vector) # get the beam size and create a list for stroing current candidates # and a list for completed hypothesis k = beam_width curr, completed = [init_beam], [] # use beam search for max_sum_len (maximum length) steps for _ in range(max_sum_len): # get k best hypothesis when adding a new token topk = [] for beam in curr: # When an EOS token is generated, add the hypo to the completed # list and decrease beam size. if beam.tokens[-1] == self.vocab.EOS: completed.append(beam) k -= 1 continue for can in self.best_k(beam, k, encoder_output, x_padding_masks, x, torch.max(len_oovs)): # Using topk as a heap to keep track of top k candidates. # Using the sequence scores of the hypos to campare # and object ids to break ties. add2heap(topk, (can.seq_score(), id(can), can), k) curr = [items[2] for items in topk] # stop when there are enough completed hypothesis if len(completed) == beam_width: break # When there are not engouh completed hypotheses, # take whatever when have in current best k as the final candidates. completed += curr # sort the hypothesis by normalized probability and choose the best one result = sorted(completed, key=lambda x: x.seq_score(), reverse=True)[0].tokens return result
def best_k(self, beam, k, encoder_output, x_padding_masks, x, len_oovs): """Get best k tokens to extend the current sequence at the current time step. Args: beam (untils.Beam): The candidate beam to be extended. k (int): Beam size. encoder_output (Tensor): The lstm output from the encoder. x_padding_masks (Tensor): The padding masks for the input sequences. x (Tensor): Source token ids. len_oovs (Tensor): Number of oov tokens in a batch. Returns: best_k (list(Beam)): The list of best k candidates. """ # use decoder to generate vocab distribution for the next token x_t = torch.tensor(beam.tokens[-1]).reshape(1, 1) x_t = x_t.to(self.DEVICE) # Get context vector from attention network. context_vector, attention_weights, coverage_vector = \ self.model.attention(beam.decoder_states, encoder_output, x_padding_masks, beam.coverage_vector) # Replace the indexes of OOV words with the index of OOV token # to prevent index-out-of-bound error in the decoder. p_vocab, decoder_states, p_gen = \ self.model.decoder(replace_oovs(x_t, self.vocab), beam.decoder_states, context_vector) final_dist = self.model.get_final_distribution(x, p_gen, p_vocab, attention_weights, torch.max(len_oovs)) # Calculate log probabilities. log_probs = torch.log(final_dist.squeeze()) # Filter forbidden tokens. if len(beam.tokens) == 1: forbidden_ids = [ self.vocab[u"台独"], self.vocab[u"吸毒"], self.vocab[u"黄赌毒"] ] log_probs[forbidden_ids] = -float('inf') # EOS token penalty. Follow the definition in # https://opennmt.net/OpenNMT/translation/beam_search/. log_probs[self.vocab.EOS] *= \ config.gamma * x.size()[1] / len(beam.tokens) log_probs[self.vocab.UNK] = -float('inf') # Get top k tokens and the corresponding logprob. topk_probs, topk_idx = torch.topk(log_probs, k) # Extend the current hypo with top k tokens, resulting k new hypos. best_k = [ beam.extend(x, log_probs[x], decoder_states, coverage_vector) for x in topk_idx.tolist() ] return best_k