def reorder_dec_states(self, dec_states, new_beam_indices, beam_size): dec_hiddens = dec_states["dec_hiddens"] batch_size = dec_hiddens.size(0) // beam_size dec_hiddens = tensor_gather_helper( gather_indices=new_beam_indices, gather_from=dec_hiddens, batch_size=batch_size, beam_size=beam_size, gather_shape=[batch_size * beam_size, -1]) dec_states['dec_hiddens'] = dec_hiddens # [batch, len, num_in_caps, num_out_caps] routing_weights = dec_states["routing_weights"] routing_weights = tensor_gather_helper(gather_indices=new_beam_indices, gather_from=routing_weights, batch_size=batch_size, beam_size=beam_size, gather_shape=[ batch_size * beam_size, -1, routing_weights.size(1), routing_weights.size(2) ]) dec_states["routing_weights"] = routing_weights return dec_states
def reorder_dec_states(self, dec_states, new_beam_indices, batch_size, beam_size): dec_hiddens = dec_states["dec_hiddens"] dec_hiddens = tensor_gather_helper(gather_indices=new_beam_indices, gather_from=dec_hiddens, batch_size=batch_size, beam_size=beam_size, gather_shape=[batch_size * beam_size, -1]) dec_states['dec_hiddens'] = dec_hiddens return dec_states
def reorder_dec_states(self, dec_states, new_beam_indices, beam_size): slf_attn_caches = dec_states['slf_attn_caches'] batch_size = slf_attn_caches[0][0].size(0) // beam_size n_head = self.decoder.n_head dim_per_head = self.decoder.dim_per_head slf_attn_caches = nest.map_structure( lambda t: tensor_gather_helper(gather_indices=new_beam_indices, gather_from=t, batch_size=batch_size, beam_size=beam_size, gather_shape=[batch_size * beam_size, n_head, -1, dim_per_head]), slf_attn_caches) dec_states['slf_attn_caches'] = slf_attn_caches return dec_states
def beam_search(self, src_seqs, beam_size=4, alpha=1.0, max_steps=200): batch_size = src_seqs.size(0) enc_outputs = self.encode(src_seqs) init_dec_states = self.init_decoder(enc_outputs, expand_size=beam_size) # Prepare for beam searching beam_mask = src_seqs.new(batch_size, beam_size).fill_(1).float() final_lengths = src_seqs.new(batch_size, beam_size).zero_().float() beam_scores = src_seqs.new(batch_size, beam_size).zero_().float() final_word_indices = src_seqs.new(batch_size, beam_size, 1).fill_(BOS) dec_states = init_dec_states for t in range(max_steps): next_scores, dec_states = self.decode( final_word_indices.view(batch_size * beam_size, -1), dec_states) next_scores = -next_scores # convert to negative log_probs next_scores = next_scores.view(batch_size, beam_size, -1) next_scores = mask_scores(scores=next_scores, beam_mask=beam_mask) beam_scores = next_scores + beam_scores.unsqueeze( 2) # [B, Bm, N] + [B, Bm, 1] ==> [B, Bm, N] vocab_size = beam_scores.size(-1) if t == 0 and beam_size > 1: # Force to select first beam at step 0 beam_scores[:, 1:, :] = float('inf') # Length penalty if alpha > 0.0: normed_scores = beam_scores * (5.0 + 1.0)**alpha / ( 5.0 + beam_mask + final_lengths).unsqueeze(2)**alpha else: normed_scores = beam_scores.detach().clone() normed_scores = normed_scores.view(batch_size, -1) # Get topK with beams # indices: [batch_size, ] _, indices = torch.topk(normed_scores, k=beam_size, dim=-1, largest=False, sorted=False) next_beam_ids = torch.div(indices, vocab_size) # [batch_size, ] next_word_ids = indices % vocab_size # [batch_size, ] # Re-arrange by new beam indices beam_scores = beam_scores.view(batch_size, -1) beam_scores = torch.gather(beam_scores, 1, indices) beam_mask = tensor_gather_helper(gather_indices=next_beam_ids, gather_from=beam_mask, batch_size=batch_size, beam_size=beam_size, gather_shape=[-1]) final_word_indices = tensor_gather_helper( gather_indices=next_beam_ids, gather_from=final_word_indices, batch_size=batch_size, beam_size=beam_size, gather_shape=[batch_size * beam_size, -1]) final_lengths = tensor_gather_helper(gather_indices=next_beam_ids, gather_from=final_lengths, batch_size=batch_size, beam_size=beam_size, gather_shape=[-1]) dec_states = self.reorder_dec_states( dec_states, new_beam_indices=next_beam_ids, beam_size=beam_size) # If next_word_ids is EOS, beam_mask_ should be 0.0 beam_mask_ = 1.0 - next_word_ids.eq(EOS).float() next_word_ids.masked_fill_( (beam_mask_ + beam_mask).eq(0.0), PAD ) # If last step a EOS is already generated, we replace the last token as PAD beam_mask = beam_mask * beam_mask_ # # If an EOS or PAD is encountered, set the beam mask to 0.0 final_lengths += beam_mask final_word_indices = torch.cat( (final_word_indices, next_word_ids.unsqueeze(2)), dim=2) if beam_mask.eq(0.0).all(): break # Length penalty if alpha > 0.0: scores = beam_scores * (5.0 + 1.0)**alpha / (5.0 + final_lengths)**alpha else: scores = beam_scores / final_lengths _, reranked_ids = torch.sort(scores, dim=-1, descending=False) return tensor_gather_helper( gather_indices=reranked_ids, gather_from=final_word_indices[:, :, 1:].contiguous(), batch_size=batch_size, beam_size=beam_size, gather_shape=[batch_size * beam_size, -1])