def _forward_rnn(self, input, input_length, h_0=None): """ :param input: Input sequence. FloatTensor with shape [batch_size, input_len, dim] :param input_length: Mask of sequence. LongTensor with shape [batch_size, ] """ total_length = input.size(1) # 1. Packed with pad oidx, sidx, slens = sort_batch(input_length) input_sorted = torch.index_select(input, index=sidx, dim=0) if h_0 is not None: h_0_sorted = nest.map_structure( lambda t: torch.index_select(t, 1, sidx), h_0) else: h_0_sorted = None # 2. RNN compute input_packed = pack_padded_sequence(input_sorted, slens, batch_first=True) out_packed, h_n_sorted = self.rnn(input_packed, h_0_sorted) # 3. Restore out_sorted = pad_packed_sequence(out_packed, batch_first=True, total_length=total_length)[0] out = torch.index_select(out_sorted, dim=0, index=oidx) h_n_sorted = nest.map_structure( lambda t: torch.index_select(t, 1, oidx), h_n_sorted) return out.contiguous(), h_n_sorted
def reorder_dec_states(self, dec_states, new_beam_indices, beam_size): slf_attn_caches = dec_states['slf_attn_caches'] batch_size = slf_attn_caches[0][0].size(0) // beam_size n_head = self.decoder.n_head dim_per_head = self.decoder.dim_per_head slf_attn_caches = nest.map_structure( lambda t: tensor_gather_helper( gather_indices=new_beam_indices, gather_from=t, batch_size=batch_size, beam_size=beam_size, gather_shape= [batch_size * beam_size, n_head, -1, dim_per_head]), slf_attn_caches) dec_states['slf_attn_caches'] = slf_attn_caches return dec_states
def batch_beam_search(self, src_seq, beam_size=5, max_steps=150): batch_size = src_seq.size(0) enc_output, enc_mask = self.encoder( src_seq) # [batch_size, seq_len, dim] # dec_caches = self.decoder.compute_caches(enc_output) # Tile beam_size times enc_mask = tile_batch(enc_mask, multiplier=beam_size, batch_dim=0) enc_output = tile_batch(enc_output, multiplier=beam_size, batch_dim=0) final_word_indices = src_seq.new(batch_size, beam_size, 1).fill_( Vocab.BOS) # Word indices in the beam final_lengths = enc_output.new(batch_size, beam_size).fill_( 0.0) # length of the sentence beam_mask = enc_output.new(batch_size, beam_size).fill_(1.0) # Mask of beams beam_scores = enc_output.new(batch_size, beam_size).fill_( 0.0) # Accumulated scores of the beam self_attn_caches = None # Every element has shape [batch_size * beam_size, num_heads, seq_len, dim_head] enc_attn_caches = None for t in range(max_steps): inp_t = final_word_indices.view(-1, final_word_indices.size(-1)) dec_output, self_attn_caches, enc_attn_caches \ = self.decoder(tgt_seq=inp_t, enc_output=enc_output, enc_mask=enc_mask, enc_attn_caches=enc_attn_caches, self_attn_caches=self_attn_caches) # [batch_size * beam_size, seq_len, dim] next_scores = -self.generator(dec_output[:, -1].contiguous( )) # [batch_size * beam_size, n_words] next_scores = next_scores.view(batch_size, beam_size, -1) next_scores = mask_scores(next_scores, beam_mask=beam_mask) beam_scores = next_scores + beam_scores.unsqueeze( 2) # [B, Bm, N] + [B, Bm, 1] ==> [B, Bm, N] vocab_size = beam_scores.size(-1) if t == 0: beam_scores = beam_scores[:, 0, :].contiguous() beam_scores = beam_scores.view(batch_size, -1) # Get topK with beams【 beam_scores, indices = torch.topk(beam_scores, k=beam_size, dim=-1, largest=False, sorted=False) next_beam_ids = torch.div(indices, vocab_size) next_word_ids = indices % vocab_size # Re-arrange by new beam indices beam_mask = tensor_gather_helper(gather_indices=next_beam_ids, gather_from=beam_mask, batch_size=batch_size, beam_size=beam_size, gather_shape=[-1]) final_word_indices = tensor_gather_helper( gather_indices=next_beam_ids, gather_from=final_word_indices, batch_size=batch_size, beam_size=beam_size, gather_shape=[batch_size * beam_size, -1]) final_lengths = tensor_gather_helper(gather_indices=next_beam_ids, gather_from=final_lengths, batch_size=batch_size, beam_size=beam_size, gather_shape=[-1]) self_attn_caches = nest.map_structure( lambda t: tensor_gather_helper( gather_indices=next_beam_ids, gather_from=t, batch_size=batch_size, beam_size=beam_size, gather_shape=[ batch_size * beam_size, self.decoder.n_head, -1, self. decoder.dim_per_head ]), self_attn_caches) # If next_word_ids is EOS, beam_mask_ should be 0.0 beam_mask_ = 1.0 - next_word_ids.eq(Vocab.EOS).float() next_word_ids.masked_fill_( (beam_mask_ + beam_mask).eq(0.0), Vocab.PAD ) # If last step a EOS is already generated, we replace the last token as PAD beam_mask = beam_mask * beam_mask_ # # If an EOS or PAD is encountered, set the beam mask to 0.0 # beam_mask_ = next_word_ids.gt(Vocab.EOS).float() # beam_mask = beam_mask * beam_mask_ final_lengths += beam_mask final_word_indices = torch.cat( (final_word_indices, next_word_ids.unsqueeze(2)), dim=2) if beam_mask.eq(0.0).all(): break scores = beam_scores / (final_lengths + 1e-2) _, reranked_ids = torch.sort(scores, dim=-1, descending=False) return tensor_gather_helper( gather_indices=reranked_ids, gather_from=final_word_indices[:, :, 1:].contiguous(), batch_size=batch_size, beam_size=beam_size, gather_shape=[batch_size * beam_size, -1])