def decode_training(self, x, y, gru_dec, enc_hidden, enc_annotations): batch_size = x.size(0) # decoder first hidden is last encoder hidden dec_hidden = enc_hidden # decoder first output (needed for attention) init to zeros dec_output = self.decoder.init_first_output(batch_size).float() # init first decoder input <sos> last_char_index = decoder.init_first_input_index(batch_size, self.sos_token).double() p_gen = decoder.init_first_p_gen(batch_size) first_prob = torch.zeros((batch_size, 1, self.vocabulary_size)).to(device) first_prob[:, 0, self.sos_token] = 1 probabilities = [first_prob] max_len = y.size(1) for pos in range(1, max_len): # DecoderAndPointer att, dec_output, dec_hidden, prob, _, p_gen = self.decoder(last_char_index, dec_output[:, :, :gru_dec.hidden_size], # only fw dec_hidden, enc_annotations, x, p_gen, self.embedding, gru_dec) probabilities.append(prob) # find next char [Bx1x1] target = [ref[pos] for ref in y] last_char_index = torch.tensor(target).double().view(batch_size, 1, 1).to(device) return torch.stack(probabilities, dim=1).squeeze()
def decode_eval(self, x, beam_size, alpha, gru_dec, enc_hidden, enc_annotations, max_len): # beam element: <likelihood, probabilities, last_index, attentions, dec_output, dec_hidden, p_gens> beam = [(torch.tensor([0.]).to(device), [torch.tensor([0. if i != self.sos_token else 1. for i in range(self.vocabulary_size)]).to(device)], self.sos_token, [torch.zeros((1, x.size(1))).to(device)], self.decoder.init_first_output(1).float(), # decoder first hidden is last encoder hidden enc_hidden, [decoder.init_first_p_gen(1)])] for pos in range(1, max_len): new_elements = [] for beam_elem in beam: # Check if this beam element is complete if beam_elem[2] == self.eos_token: new_elements.append(beam_elem) continue likelihood, probs, last_index, attentions, dec_output, dec_hidden, p_gens = beam_elem # DecoderAndPointer att, dec_output, dec_hidden, prob, _, p_gen = self.decoder( torch.tensor([[[last_index]]], dtype=torch.float64).to(device), dec_output[:, :, :gru_dec.hidden_size], # only fw dec_hidden, enc_annotations, x, p_gens[-1], self.embedding, gru_dec) probs = probs + [prob.squeeze()] attentions = attentions + [att] p_gens = p_gens + [p_gen] # Expand beam best_probs, top_indices = prob.topk(beam_size, 2) top_indices = top_indices.squeeze().tolist() if beam_size > 1 else [top_indices.item()] for i in range(beam_size): # Update list of ended sentences next_char = top_indices[i] new_elements.append((likelihood + best_probs[:, :, i].squeeze(), probs, next_char, attentions, dec_output, dec_hidden, p_gens)) new_elements.sort(key=lambda elem: elem[0][0] / length_penalty(len(elem[1]), alpha), reverse=True) beam = new_elements[:beam_size] _, probabilities, _, attentions, _, _, p_gens = beam[0] probabilities = torch.stack(probabilities) attentions = torch.stack(attentions) p_gens = torch.stack(p_gens) return probabilities, attentions, p_gens
def forward(self, x, y=None): assert y is not None if self.training else y is None # Unpack input x, lengths = x batch_size = x.size(0) # encoder pass enc_annotations, annotations_len, enc_hidden = self.encoder(x, lengths) # decoder first hidden is last encoder hidden (using both forward and backward pass) fw_to_bw = enc_hidden.size(0) // 2 dec_hidden = 0.5 * (enc_hidden[:fw_to_bw] + enc_hidden[fw_to_bw:] ) # [num_dir x B x H] # dec_hidden = decoder.init_first_hidden(batch_size).float() # decoder first output (needed for attention) init to zeros dec_output = self.decoder.init_first_output(batch_size).float() # init first decoder input <sos> last_char_index = decoder.init_first_input_index( batch_size, self.sos_token).double() p_gen = decoder.init_first_p_gen(batch_size) first_prob = torch.zeros( (batch_size, 1, self.vocabulary_size)).to(device) first_prob[:, 0, self.sos_token] = 1 probabilities = [first_prob] if self.training: max_len = y.size(1) else: max_len = self.max_string_length sentence_end = [False for _ in range(batch_size)] attentions = [torch.zeros((batch_size, x.size(1))).to(device)] p_gens = [p_gen[0]] for pos in range(1, max_len): # DecoderAndPointer att, dec_output, dec_hidden, prob, _, p_gen = self.decoder( last_char_index, dec_output, dec_hidden, enc_annotations, x, p_gen) probabilities.append(prob) # find next char [Bx1x1] if self.training: target = [ref[pos] for ref in y] last_char_index = torch.tensor(target).double().view( batch_size, 1, 1).to(device) else: attentions.append(att) p_gens.append(p_gen.squeeze(2)) last_char_index = prob.data.argmax(2).unsqueeze(2).double() # Update list of ended sentences last_char_list = last_char_index.view(-1).tolist() sentence_end_iter = [ i for i, ch in enumerate(last_char_list) if ch == self.eos_token ] for i in sentence_end_iter: sentence_end[i] = True if all(sentence_end): break if self.training: return torch.stack(probabilities, dim=1).squeeze() else: return torch.stack( probabilities, dim=1).squeeze(), torch.stack(attentions), torch.stack(p_gens)