示例#1
0
    def _forward_rnn(self, input, input_length, h_0=None):
        """
        :param input: Input sequence.
            FloatTensor with shape [batch_size, input_len, dim]

        :param input_length: Mask of sequence.
            LongTensor with shape [batch_size, ]
        """
        total_length = input.size(1)

        # 1. Packed with pad
        oidx, sidx, slens = sort_batch(input_length)

        input_sorted = torch.index_select(input, index=sidx, dim=0)

        if h_0 is not None:
            h_0_sorted = nest.map_structure(
                lambda t: torch.index_select(t, 1, sidx), h_0)
        else:
            h_0_sorted = None

        # 2. RNN compute
        input_packed = pack_padded_sequence(input_sorted,
                                            slens,
                                            batch_first=True)

        out_packed, h_n_sorted = self.rnn(input_packed, h_0_sorted)

        # 3. Restore
        out_sorted = pad_packed_sequence(out_packed,
                                         batch_first=True,
                                         total_length=total_length)[0]
        out = torch.index_select(out_sorted, dim=0, index=oidx)

        h_n_sorted = nest.map_structure(
            lambda t: torch.index_select(t, 1, oidx), h_n_sorted)

        return out.contiguous(), h_n_sorted
    def reorder_dec_states(self, dec_states, new_beam_indices, beam_size):
        slf_attn_caches = dec_states['slf_attn_caches']
        batch_size = slf_attn_caches[0][0].size(0) // beam_size
        n_head = self.decoder.n_head
        dim_per_head = self.decoder.dim_per_head
        slf_attn_caches = nest.map_structure(
            lambda t: tensor_gather_helper(
                gather_indices=new_beam_indices,
                gather_from=t,
                batch_size=batch_size,
                beam_size=beam_size,
                gather_shape=
                [batch_size * beam_size, n_head, -1, dim_per_head]),
            slf_attn_caches)

        dec_states['slf_attn_caches'] = slf_attn_caches
        return dec_states
    def batch_beam_search(self, src_seq, beam_size=5, max_steps=150):

        batch_size = src_seq.size(0)

        enc_output, enc_mask = self.encoder(
            src_seq)  # [batch_size, seq_len, dim]

        # dec_caches = self.decoder.compute_caches(enc_output)

        # Tile beam_size times
        enc_mask = tile_batch(enc_mask, multiplier=beam_size, batch_dim=0)
        enc_output = tile_batch(enc_output, multiplier=beam_size, batch_dim=0)

        final_word_indices = src_seq.new(batch_size, beam_size, 1).fill_(
            Vocab.BOS)  # Word indices in the beam
        final_lengths = enc_output.new(batch_size, beam_size).fill_(
            0.0)  # length of the sentence
        beam_mask = enc_output.new(batch_size,
                                   beam_size).fill_(1.0)  # Mask of beams
        beam_scores = enc_output.new(batch_size, beam_size).fill_(
            0.0)  # Accumulated scores of the beam

        self_attn_caches = None  # Every element has shape [batch_size * beam_size, num_heads, seq_len, dim_head]
        enc_attn_caches = None

        for t in range(max_steps):

            inp_t = final_word_indices.view(-1, final_word_indices.size(-1))

            dec_output, self_attn_caches, enc_attn_caches \
                = self.decoder(tgt_seq=inp_t,
                               enc_output=enc_output,
                               enc_mask=enc_mask,
                               enc_attn_caches=enc_attn_caches,
                               self_attn_caches=self_attn_caches) # [batch_size * beam_size, seq_len, dim]

            next_scores = -self.generator(dec_output[:, -1].contiguous(
            ))  # [batch_size * beam_size, n_words]
            next_scores = next_scores.view(batch_size, beam_size, -1)
            next_scores = mask_scores(next_scores, beam_mask=beam_mask)

            beam_scores = next_scores + beam_scores.unsqueeze(
                2)  # [B, Bm, N] + [B, Bm, 1] ==> [B, Bm, N]

            vocab_size = beam_scores.size(-1)
            if t == 0:
                beam_scores = beam_scores[:, 0, :].contiguous()

            beam_scores = beam_scores.view(batch_size, -1)

            # Get topK with beams【
            beam_scores, indices = torch.topk(beam_scores,
                                              k=beam_size,
                                              dim=-1,
                                              largest=False,
                                              sorted=False)
            next_beam_ids = torch.div(indices, vocab_size)
            next_word_ids = indices % vocab_size

            # Re-arrange by new beam indices
            beam_mask = tensor_gather_helper(gather_indices=next_beam_ids,
                                             gather_from=beam_mask,
                                             batch_size=batch_size,
                                             beam_size=beam_size,
                                             gather_shape=[-1])

            final_word_indices = tensor_gather_helper(
                gather_indices=next_beam_ids,
                gather_from=final_word_indices,
                batch_size=batch_size,
                beam_size=beam_size,
                gather_shape=[batch_size * beam_size, -1])

            final_lengths = tensor_gather_helper(gather_indices=next_beam_ids,
                                                 gather_from=final_lengths,
                                                 batch_size=batch_size,
                                                 beam_size=beam_size,
                                                 gather_shape=[-1])

            self_attn_caches = nest.map_structure(
                lambda t: tensor_gather_helper(
                    gather_indices=next_beam_ids,
                    gather_from=t,
                    batch_size=batch_size,
                    beam_size=beam_size,
                    gather_shape=[
                        batch_size * beam_size, self.decoder.n_head, -1, self.
                        decoder.dim_per_head
                    ]), self_attn_caches)

            # If next_word_ids is EOS, beam_mask_ should be 0.0
            beam_mask_ = 1.0 - next_word_ids.eq(Vocab.EOS).float()
            next_word_ids.masked_fill_(
                (beam_mask_ + beam_mask).eq(0.0), Vocab.PAD
            )  # If last step a EOS is already generated, we replace the last token as PAD
            beam_mask = beam_mask * beam_mask_

            # # If an EOS or PAD is encountered, set the beam mask to 0.0
            # beam_mask_ = next_word_ids.gt(Vocab.EOS).float()
            # beam_mask = beam_mask * beam_mask_

            final_lengths += beam_mask

            final_word_indices = torch.cat(
                (final_word_indices, next_word_ids.unsqueeze(2)), dim=2)

            if beam_mask.eq(0.0).all():
                break

        scores = beam_scores / (final_lengths + 1e-2)

        _, reranked_ids = torch.sort(scores, dim=-1, descending=False)

        return tensor_gather_helper(
            gather_indices=reranked_ids,
            gather_from=final_word_indices[:, :, 1:].contiguous(),
            batch_size=batch_size,
            beam_size=beam_size,
            gather_shape=[batch_size * beam_size, -1])