コード例 #1
0
    def forward_word_ins(self, encoder_outs, output_tokens, output_scores,
                         attn, can_ins_word):
        word_ins_score_avg = []
        word_ins_attn_avg = []
        for model, encoder_out in zip(self.models, encoder_outs):
            word_ins_out, word_ins_attn = model.decoder.forward_word_ins(
                _skip(output_tokens, can_ins_word),
                _skip(encoder_out, can_ins_word),
            )
            word_ins_score = F.log_softmax(word_ins_out, 2)
            word_ins_score_avg.append(word_ins_score)
            word_ins_attn_avg.append(word_ins_attn)
        word_ins_score_avg = torch.logsumexp(
            torch.stack(word_ins_score_avg, dim=0), dim=0) - math.log(
                len(self.models))
        if word_ins_attn_avg[0] is not None:
            word_ins_attn_avg = torch.stack(word_ins_attn_avg, dim=0) / len(
                self.models)
        else:
            word_ins_attn_avg = None
        word_ins_score_max, word_ins_pred = word_ins_score_avg.max(-1)

        _tokens, _scores = _apply_ins_words(
            output_tokens[can_ins_word],
            output_scores[can_ins_word],
            word_ins_pred,
            word_ins_score_max,
            self.unk,
        )

        output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad)
        output_scores = _fill(output_scores, can_ins_word, _scores, 0)
        attn = _fill(attn, can_ins_word, word_ins_attn, 0.)
        return output_tokens, output_scores, attn
コード例 #2
0
        def ins_words(
            output_tokens,
            output_scores,
            attn: Tensor,
            word_ins_attn,
            word_ins_pred,
            word_ins_scores,
            can_ins_word,
            pad_idx: int,
            unk_idx: int,
        ):
            # insert words
            if can_ins_word.sum() != 0:
                in_tokens = output_tokens[can_ins_word]
                in_scores = output_scores[can_ins_word]
                word_ins_masks = in_tokens.eq(unk_idx)
                out_tokens = in_tokens.masked_scatter(
                    word_ins_masks, word_ins_pred[word_ins_masks].float()
                )

                if in_scores is not None:
                    out_scores = in_scores.masked_scatter(
                        word_ins_masks, word_ins_scores[word_ins_masks]
                    )
                else:
                    out_scores = None
                output_tokens = _fill(output_tokens, can_ins_word, out_tokens, pad_idx)
                output_scores = _fill(output_scores, can_ins_word, out_scores, 0)
                attn = _fill(attn, can_ins_word, word_ins_attn, 0)
            return output_tokens, output_scores, attn
コード例 #3
0
 def forward_mask_ins(self, encoder_outs, output_tokens, output_scores,
                      can_ins_mask, eos_penalty, max_lens):
     mask_ins_score_avg = []
     for model, encoder_out in zip(self.models, encoder_outs):
         mask_ins_out, _ = model.decoder.forward_mask_ins(
             _skip(output_tokens, can_ins_mask),
             _skip(encoder_out, can_ins_mask),
         )
         mask_ins_score = F.log_softmax(mask_ins_out, 2)
         if eos_penalty > 0.0:
             mask_ins_score[:, :, 0] -= eos_penalty
         mask_ins_score_avg.append(mask_ins_score)
     mask_ins_score_avg = torch.logsumexp(
         torch.stack(mask_ins_score_avg, dim=0), dim=0) - math.log(
             len(self.models))
     mask_ins_pred = mask_ins_score_avg.max(-1)[1]
     mask_ins_pred = torch.min(
         mask_ins_pred, max_lens[can_ins_mask,
                                 None].expand_as(mask_ins_pred))
     _tokens, _scores = _apply_ins_masks(
         output_tokens[can_ins_mask],
         output_scores[can_ins_mask],
         mask_ins_pred,
         self.pad,
         self.unk,
         self.eos,
     )
     output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad)
     output_scores = _fill(output_scores, can_ins_mask, _scores, 0)
     return output_tokens, output_scores
コード例 #4
0
        def ins_placeholders(
            output_tokens,
            output_scores,
            mask_ins_out,
            can_ins_mask,
            pad_idx: int,
            unk_idx: int,
            eos_idx: int,
            max_ratio: float,
            max_lengths,
        ):
            # insert placeholders
            if can_ins_mask.sum() != 0:
                mask_ins_score = F.log_softmax(mask_ins_out, 2)
                if eos_penalty > 0.0:
                    mask_ins_score[:, :, 0] -= eos_penalty
                mask_ins_pred = mask_ins_score.max(-1)[1]
                if max_ratio is not None and encoder_out[1] is not None:
                    mask_ins_pred = torch.min(
                        mask_ins_pred,
                        max_lengths[can_ins_mask,
                                    None].expand_as(mask_ins_pred))
                in_tokens = output_tokens[can_ins_mask]
                in_scores = output_scores[can_ins_mask]
                in_masks = in_tokens.ne(pad_idx)
                in_lengths = in_masks.sum(1)

                # HACK: hacky way to shift all the paddings to eos first.
                in_tokens.masked_fill_(~in_masks, eos_idx)
                mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)

                out_lengths = in_lengths + mask_ins_pred.sum(1)
                out_max_len = out_lengths.max()
                out_masks = (torch.arange(out_max_len)[None, :].long() <
                             out_lengths[:, None])

                reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
                out_tokens = (torch.zeros(
                    in_tokens.size()[0],
                    out_max_len).fill_(pad_idx).masked_fill_(
                        out_masks, unk_idx))
                out_tokens = torch.cat([in_tokens[:, :1], out_tokens[:, 1:]],
                                       1)
                out_tokens.scatter_(1, reordering, in_tokens[:, 1:].float())

                if in_scores is not None:
                    in_scores.masked_fill_(~in_masks, 0)
                    out_scores = torch.zeros_like(out_tokens).to(in_scores)
                    out_tokens = torch.cat(
                        [in_tokens[:, :1], out_tokens[:, 1:]], 1)
                    out_scores.scatter_(1, reordering, in_scores[:, 1:])
                else:
                    out_scores = None
                output_tokens = _fill(output_tokens, can_ins_mask, out_tokens,
                                      pad_idx)
                output_scores = _fill(output_scores, can_ins_mask, out_scores,
                                      0)
            return output_tokens, output_scores
コード例 #5
0
        def del_word(
            output_tokens,
            output_scores,
            attn: Tensor,
            word_del_attn: Optional[Tensor],
            word_del_out,
            can_del_word,
            pad_idx: int,
            bos_idx: int,
            eos_idx: int,
        ):
            # delete words
            # do not delete tokens if it is <s> </s>
            if can_del_word.sum() != 0:  # we cannot delete, skip
                word_del_score = F.log_softmax(word_del_out, 2)
                word_del_pred = torch.jit.Attribute(
                    word_del_score.max(-1)[1], bool)
                in_tokens = output_tokens[can_del_word]
                in_scores = output_scores[can_del_word]
                # apply deletion to a tensor
                in_masks = in_tokens.ne(pad_idx)
                bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)

                max_len = in_tokens.size(1)
                word_del_pred.masked_fill_(~in_masks, 1)
                word_del_pred.masked_fill_(bos_eos_masks, 0)

                reordering = (torch.arange(max_len)[None, :].expand_as(
                    in_tokens).contiguous().masked_fill(
                        word_del_pred, max_len).sort(1)[1])

                _tokens = in_tokens.masked_fill(word_del_pred,
                                                pad_idx).gather(1, reordering)

                _scores = in_scores.masked_fill(word_del_pred,
                                                0).gather(1, reordering)
                if word_del_attn is not None:
                    _mask = word_del_pred[:, :, None].expand_as(word_del_attn)
                    _reordering = reordering[:, :,
                                             None].expand_as(word_del_attn)
                    _attn = word_del_attn.masked_fill(_mask, 0.0).gather(
                        1, _reordering)
                    attn = _fill(attn, can_del_word, _attn, 0)

                output_tokens = _fill(output_tokens, can_del_word, _tokens,
                                      pad_idx)
                output_scores = _fill(output_scores, can_del_word, _scores, 0)
            return output_tokens, output_scores, attn
コード例 #6
0
        def ins_placeholders(
            output_tokens,
            output_scores,
            mask_ins_pred,
            can_ins_mask,
            pad_idx: int,
            unk_idx: int,
            eos_idx: int,
        ):
            # insert placeholders
            if can_ins_mask.sum() != 0:
                in_tokens = output_tokens[can_ins_mask]
                in_scores = output_scores[can_ins_mask]
                in_masks = in_tokens.ne(pad_idx)
                in_lengths = in_masks.sum(1)

                # HACK: hacky way to shift all the paddings to eos first.
                in_tokens.masked_fill_(~in_masks, eos_idx)
                mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)

                out_lengths = in_lengths + mask_ins_pred.sum(1)
                out_max_len = out_lengths.max()
                out_masks = (torch.arange(out_max_len)[None, :].long() <
                             out_lengths[:, None])

                reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
                out_tokens = (torch.zeros(
                    in_tokens.size()[0],
                    out_max_len).fill_(pad_idx).masked_fill_(
                        out_masks, unk_idx))
                out_tokens = torch.cat([in_tokens[:, :1], out_tokens[:, 1:]],
                                       1)
                out_tokens.scatter_(1, reordering, in_tokens[:, 1:].float())

                if in_scores is not None:
                    in_scores.masked_fill_(~in_masks, 0)
                    out_scores = torch.zeros_like(out_tokens).to(in_scores)
                    out_tokens = torch.cat(
                        [in_tokens[:, :1], out_tokens[:, 1:]], 1)
                    out_scores.scatter_(1, reordering, in_scores[:, 1:])
                else:
                    out_scores = None
                output_tokens = _fill(output_tokens, can_ins_mask, out_tokens,
                                      pad_idx)
                output_scores = _fill(output_scores, can_ins_mask, out_scores,
                                      0)
            return output_tokens, output_scores
コード例 #7
0
    def forward_decoder(self,
                        decoder_out,
                        encoder_out,
                        eos_penalty=0.0,
                        max_ratio=None,
                        **kwargs):

        output_tokens = decoder_out["output_tokens"]
        output_scores = decoder_out["output_scores"]
        attn = decoder_out["attn"]

        if max_ratio is None:
            max_lens = output_tokens.new(output_tokens.size(0)).fill_(255)
        else:
            max_lens = ((~encoder_out["encoder_padding_mask"]).sum(1) *
                        max_ratio).clamp(min=10)

        # delete words
        # do not delete tokens if it is <s> </s>
        can_del_word = output_tokens.ne(self.pad).sum(1) > 2
        if can_del_word.sum() != 0:  # we cannot delete, skip
            word_del_out, word_del_attn = self.decoder.forward_word_del(
                _skip(output_tokens, can_del_word),
                _skip(encoder_out, can_del_word))
            word_del_score = F.log_softmax(word_del_out, 2)
            word_del_pred = word_del_score.max(-1)[1].bool()

            _tokens, _scores, _attn = _apply_del_words(
                output_tokens[can_del_word],
                output_scores[can_del_word],
                word_del_attn,
                word_del_pred,
                self.pad,
                self.bos,
                self.eos,
            )
            output_tokens = _fill(output_tokens, can_del_word, _tokens,
                                  self.pad)
            output_scores = _fill(output_scores, can_del_word, _scores, 0)
            attn = _fill(attn, can_del_word, _attn, 0.)

        # insert placeholders
        can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens
        if can_ins_mask.sum() != 0:
            mask_ins_out, _ = self.decoder.forward_mask_ins(
                _skip(output_tokens, can_ins_mask),
                _skip(encoder_out, can_ins_mask))
            mask_ins_score = F.log_softmax(mask_ins_out, 2)
            if eos_penalty > 0.0:
                mask_ins_score[:, :, 0] -= eos_penalty
            mask_ins_pred = mask_ins_score.max(-1)[1]
            mask_ins_pred = torch.min(
                mask_ins_pred, max_lens[:, None].expand_as(mask_ins_pred))

            _tokens, _scores = _apply_ins_masks(
                output_tokens[can_ins_mask],
                output_scores[can_ins_mask],
                mask_ins_pred,
                self.pad,
                self.unk,
                self.eos,
            )
            output_tokens = _fill(output_tokens, can_ins_mask, _tokens,
                                  self.pad)
            output_scores = _fill(output_scores, can_ins_mask, _scores, 0)

        # insert words
        can_ins_word = output_tokens.eq(self.unk).sum(1) > 0
        if can_ins_word.sum() != 0:
            word_ins_out, word_ins_attn = self.decoder.forward_word_ins(
                _skip(output_tokens, can_ins_word),
                _skip(encoder_out, can_ins_word))
            word_ins_score = F.log_softmax(word_ins_out, 2)
            word_ins_pred = word_ins_score.max(-1)[1]

            _tokens, _scores = _apply_ins_words(
                output_tokens[can_ins_word],
                output_scores[can_ins_word],
                word_ins_pred,
                word_ins_score,
                self.unk,
            )

            output_tokens = _fill(output_tokens, can_ins_word, _tokens,
                                  self.pad)
            output_scores = _fill(output_scores, can_ins_word, _scores, 0)
            attn = _fill(attn, can_ins_word, word_ins_attn, 0.)

        # delete some unnecessary paddings
        cut_off = output_tokens.ne(self.pad).sum(1).max()
        output_tokens = output_tokens[:, :cut_off]
        output_scores = output_scores[:, :cut_off]
        attn = None if attn is None else attn[:, :cut_off, :]
        return {
            "output_tokens": output_tokens,
            "output_scores": output_scores,
            "attn": attn,
        }