Beispiel #1
0
    def reorder_dec_states(self, dec_states, new_beam_indices, beam_size):

        dec_hiddens = dec_states["dec_hiddens"]

        batch_size = dec_hiddens.size(0) // beam_size

        dec_hiddens = tensor_gather_helper(
            gather_indices=new_beam_indices,
            gather_from=dec_hiddens,
            batch_size=batch_size,
            beam_size=beam_size,
            gather_shape=[batch_size * beam_size, -1])

        dec_states['dec_hiddens'] = dec_hiddens

        # [batch, len, num_in_caps, num_out_caps]
        routing_weights = dec_states["routing_weights"]
        routing_weights = tensor_gather_helper(gather_indices=new_beam_indices,
                                               gather_from=routing_weights,
                                               batch_size=batch_size,
                                               beam_size=beam_size,
                                               gather_shape=[
                                                   batch_size * beam_size, -1,
                                                   routing_weights.size(1),
                                                   routing_weights.size(2)
                                               ])

        dec_states["routing_weights"] = routing_weights

        return dec_states
Beispiel #2
0
    def reorder_dec_states(self, dec_states, new_beam_indices, batch_size, beam_size):

        dec_hiddens = dec_states["dec_hiddens"]

        dec_hiddens = tensor_gather_helper(gather_indices=new_beam_indices,
                                           gather_from=dec_hiddens,
                                           batch_size=batch_size,
                                           beam_size=beam_size,
                                           gather_shape=[batch_size * beam_size, -1])

        dec_states['dec_hiddens'] = dec_hiddens

        return dec_states
Beispiel #3
0
    def reorder_dec_states(self, dec_states, new_beam_indices, beam_size):

        slf_attn_caches = dec_states['slf_attn_caches']

        batch_size = slf_attn_caches[0][0].size(0) // beam_size

        n_head = self.decoder.n_head
        dim_per_head = self.decoder.dim_per_head

        slf_attn_caches = nest.map_structure(
            lambda t: tensor_gather_helper(gather_indices=new_beam_indices,
                                           gather_from=t,
                                           batch_size=batch_size,
                                           beam_size=beam_size,
                                           gather_shape=[batch_size * beam_size, n_head, -1, dim_per_head]),
            slf_attn_caches)

        dec_states['slf_attn_caches'] = slf_attn_caches

        return dec_states
Beispiel #4
0
    def beam_search(self, src_seqs, beam_size=4, alpha=1.0, max_steps=200):

        batch_size = src_seqs.size(0)

        enc_outputs = self.encode(src_seqs)
        init_dec_states = self.init_decoder(enc_outputs, expand_size=beam_size)

        # Prepare for beam searching
        beam_mask = src_seqs.new(batch_size, beam_size).fill_(1).float()
        final_lengths = src_seqs.new(batch_size, beam_size).zero_().float()
        beam_scores = src_seqs.new(batch_size, beam_size).zero_().float()
        final_word_indices = src_seqs.new(batch_size, beam_size, 1).fill_(BOS)

        dec_states = init_dec_states

        for t in range(max_steps):

            next_scores, dec_states = self.decode(
                final_word_indices.view(batch_size * beam_size, -1),
                dec_states)

            next_scores = -next_scores  # convert to negative log_probs
            next_scores = next_scores.view(batch_size, beam_size, -1)
            next_scores = mask_scores(scores=next_scores, beam_mask=beam_mask)

            beam_scores = next_scores + beam_scores.unsqueeze(
                2)  # [B, Bm, N] + [B, Bm, 1] ==> [B, Bm, N]

            vocab_size = beam_scores.size(-1)

            if t == 0 and beam_size > 1:
                # Force to select first beam at step 0
                beam_scores[:, 1:, :] = float('inf')

            # Length penalty
            if alpha > 0.0:
                normed_scores = beam_scores * (5.0 + 1.0)**alpha / (
                    5.0 + beam_mask + final_lengths).unsqueeze(2)**alpha
            else:
                normed_scores = beam_scores.detach().clone()

            normed_scores = normed_scores.view(batch_size, -1)

            # Get topK with beams
            # indices: [batch_size, ]
            _, indices = torch.topk(normed_scores,
                                    k=beam_size,
                                    dim=-1,
                                    largest=False,
                                    sorted=False)
            next_beam_ids = torch.div(indices, vocab_size)  # [batch_size, ]
            next_word_ids = indices % vocab_size  # [batch_size, ]

            # Re-arrange by new beam indices
            beam_scores = beam_scores.view(batch_size, -1)
            beam_scores = torch.gather(beam_scores, 1, indices)

            beam_mask = tensor_gather_helper(gather_indices=next_beam_ids,
                                             gather_from=beam_mask,
                                             batch_size=batch_size,
                                             beam_size=beam_size,
                                             gather_shape=[-1])

            final_word_indices = tensor_gather_helper(
                gather_indices=next_beam_ids,
                gather_from=final_word_indices,
                batch_size=batch_size,
                beam_size=beam_size,
                gather_shape=[batch_size * beam_size, -1])

            final_lengths = tensor_gather_helper(gather_indices=next_beam_ids,
                                                 gather_from=final_lengths,
                                                 batch_size=batch_size,
                                                 beam_size=beam_size,
                                                 gather_shape=[-1])

            dec_states = self.reorder_dec_states(
                dec_states,
                new_beam_indices=next_beam_ids,
                beam_size=beam_size)

            # If next_word_ids is EOS, beam_mask_ should be 0.0
            beam_mask_ = 1.0 - next_word_ids.eq(EOS).float()
            next_word_ids.masked_fill_(
                (beam_mask_ + beam_mask).eq(0.0), PAD
            )  # If last step a EOS is already generated, we replace the last token as PAD
            beam_mask = beam_mask * beam_mask_

            # # If an EOS or PAD is encountered, set the beam mask to 0.0
            final_lengths += beam_mask

            final_word_indices = torch.cat(
                (final_word_indices, next_word_ids.unsqueeze(2)), dim=2)

            if beam_mask.eq(0.0).all():
                break

        # Length penalty
        if alpha > 0.0:
            scores = beam_scores * (5.0 + 1.0)**alpha / (5.0 +
                                                         final_lengths)**alpha
        else:
            scores = beam_scores / final_lengths

        _, reranked_ids = torch.sort(scores, dim=-1, descending=False)

        return tensor_gather_helper(
            gather_indices=reranked_ids,
            gather_from=final_word_indices[:, :, 1:].contiguous(),
            batch_size=batch_size,
            beam_size=beam_size,
            gather_shape=[batch_size * beam_size, -1])