示例#1
0
    def forward(self, encoded_attentioned, target):
        """
        Args:
            padded_input: N x T x D
            input_lengths: N

        Returns:
            enc_output: N x T x H
        """
        # Prepare masks
        ys_in = self.preprocess(target)
        non_pad_mask = (target > 0).unsqueeze(-1)

        slf_attn_mask_subseq = get_subsequent_mask(ys_in)
        slf_attn_mask_keypad = get_attn_key_pad_mask(
            seq_k=ys_in, seq_q=ys_in, pad_idx=0)
        slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)

        ys_in_emb = self.dropout(self.tgt_word_emb(ys_in) + self.positional_encoding(ys_in))

        dec_output = self.input_affine(torch.cat([encoded_attentioned, ys_in_emb], -1))

        for dec_layer in self.layer_stack:
            dec_output = dec_layer(
                dec_output,
                non_pad_mask=non_pad_mask,
                slf_attn_mask=slf_attn_mask)

        dec_output = torch.cat([encoded_attentioned, dec_output], -1)

        logits = self.tgt_word_prj(dec_output)

        return logits
示例#2
0
    def forward(self,
                padded_input,
                encoder_padded_outputs,
                encoder_input_lengths,
                return_attns=False):
        """
        Args:
            padded_input: N x To
            encoder_padded_outputs: N x Ti x H
        Returns:
        """
        dec_slf_attn_list, dec_enc_attn_list = [], []

        # Get Deocder Input and Output
        ys_in_pad, ys_out_pad = self.preprocess(padded_input)

        # Prepare masks
        non_pad_mask = get_non_pad_mask(ys_in_pad, pad_idx=self.eos_id)

        slf_attn_mask_subseq = get_subsequent_mask(ys_in_pad)
        slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=ys_in_pad,
                                                     seq_q=ys_in_pad,
                                                     pad_idx=self.eos_id)
        slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)

        output_length = ys_in_pad.size(1)
        dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs,
                                              encoder_input_lengths,
                                              output_length)

        # Forward
        dec_output = self.dropout(
            self.tgt_word_emb(ys_in_pad) * self.x_logit_scale +
            self.positional_encoding(ys_in_pad))

        for dec_layer in self.layer_stack:
            dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
                dec_output,
                encoder_padded_outputs,
                non_pad_mask=non_pad_mask,
                slf_attn_mask=slf_attn_mask,
                dec_enc_attn_mask=dec_enc_attn_mask)

            if return_attns:
                dec_slf_attn_list += [dec_slf_attn]
                dec_enc_attn_list += [dec_enc_attn]

        # before softmax
        seq_logit = self.tgt_word_prj(dec_output)

        # Return
        pred, gold = seq_logit, ys_out_pad

        if return_attns:
            return pred, gold, dec_slf_attn_list, dec_enc_attn_list
        return pred, gold
示例#3
0
    def step_forward(self, ys, encoded_attentioned, t):
        # -- Prepare masks
        non_pad_mask = torch.ones_like(ys).float().unsqueeze(-1) # 1xix1
        slf_attn_mask = get_subsequent_mask(ys)

        # -- Forward
        target_emb = self.tgt_word_emb(ys) + self.positional_encoding(ys)
        dec_output = self.input_affine(torch.cat([encoded_attentioned[:, :t+1, :], target_emb], -1))

        for dec_layer in self.layer_stack:
            dec_output = dec_layer(
                dec_output,
                non_pad_mask=non_pad_mask,
                slf_attn_mask=slf_attn_mask)

        dec_output = torch.cat([encoded_attentioned[:, :t+1, :], dec_output], -1)

        seq_logit = self.tgt_word_prj(dec_output[:, -1])

        local_scores = F.log_softmax(seq_logit, dim=1)

        return local_scores
示例#4
0
    def step(self, prefixs, encoded, len_encoded):

        non_pad_mask = torch.ones_like(prefixs).float().unsqueeze(-1) # Nxix1
        slf_attn_mask = get_subsequent_mask(prefixs)

        output_length = prefixs.size(1)
        dec_enc_attn_mask = get_attn_pad_mask(len_encoded, output_length)

        # Forward
        dec_output = self.tgt_word_emb(prefixs) + self.positional_encoding(prefixs)

        for dec_layer in self.layer_stack:
            dec_output = dec_layer(
                dec_output, encoded,
                non_pad_mask=non_pad_mask,
                slf_attn_mask=slf_attn_mask,
                dec_enc_attn_mask=dec_enc_attn_mask)

        # before softmax
        logits = self.tgt_word_prj(dec_output[:, -1, :])
        scores = F.log_softmax(logits, -1) # [batch*beam, size_output]

        return scores
示例#5
0
    def forward(self, targets, encoder_padded_outputs, encoder_input_lengths):
        """
        Args:
            padded_input: N x To
            encoder_padded_outputs: N x Ti x H

        Returns:
        """
        # Get Deocder Input and Output
        targets_sos, targets_eos = self.preprocess(targets)

        # Prepare masks
        non_pad_mask = (targets_sos > 0).unsqueeze(-1)

        slf_attn_mask_subseq = get_subsequent_mask(targets_sos)
        slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=targets_sos,
                                                     seq_q=targets_sos,
                                                     pad_idx=0)
        slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)
        output_length = targets_sos.size(1)
        dec_enc_attn_mask = get_attn_pad_mask(encoder_input_lengths, output_length)

        # Forward
        dec_output = self.dropout(self.tgt_word_emb(targets_sos) +
                                  self.positional_encoding(targets_sos))

        for dec_layer in self.layer_stack:
            dec_output = dec_layer(
                dec_output, encoder_padded_outputs,
                non_pad_mask=non_pad_mask,
                slf_attn_mask=slf_attn_mask,
                dec_enc_attn_mask=dec_enc_attn_mask)

        # before softmax
        logits = self.tgt_word_prj(dec_output)

        return logits, targets_eos
示例#6
0
    def beam_decode(self, encoded, beam=5, nbest=1, maxlen=100):
        """Beam search, decode one utterence now.
        Args:
            encoder_outputs: T x H
            char_list: list of character
            args: args.beam
        Returns:
            nbest_hyps:
        """
        encoded = encoded

        # prepare sos
        ys = torch.ones(1, 1).fill_(self.sos_id).type_as(encoded).long()

        # yseq: 1xT
        hyp = {'score': 0.0, 'yseq': ys}
        hyps = [hyp]
        ended_hyps = []

        for i in range(maxlen):
            hyps_best_kept = []
            for hyp in hyps:
                ys = hyp['yseq']  # 1 x i

                # -- Prepare masks
                non_pad_mask = torch.ones_like(ys).float().unsqueeze(-1) # 1xix1
                slf_attn_mask = get_subsequent_mask(ys)

                # -- Forward
                dec_output = self.dropout(
                    self.tgt_word_emb(ys) + self.positional_encoding(ys))

                for dec_layer in self.layer_stack:
                    dec_output, *_ = dec_layer(
                        dec_output, encoded,
                        non_pad_mask=non_pad_mask,
                        slf_attn_mask=slf_attn_mask,
                        dec_enc_attn_mask=None)

                seq_logit = self.tgt_word_prj(dec_output[:, -1])

                local_scores = F.log_softmax(seq_logit, dim=1)
                # topk scores
                local_best_scores, local_best_ids = torch.topk(
                    local_scores, beam, dim=1)

                for j in range(beam):
                    new_hyp = {}
                    new_hyp['score'] = hyp['score'] + local_best_scores[0, j]
                    new_hyp['yseq'] = torch.ones(1, (1+ys.size(1))).type_as(encoded).long()
                    new_hyp['yseq'][:, :ys.size(1)] = hyp['yseq']
                    new_hyp['yseq'][:, ys.size(1)] = int(local_best_ids[0, j])
                    # will be (2 x beam) hyps at most
                    hyps_best_kept.append(new_hyp)

                hyps_best_kept = sorted(hyps_best_kept,
                                        key=lambda x: x['score'],
                                        reverse=True)[:beam]
            # end for hyp in hyps
            hyps = hyps_best_kept

            # add eos in the final loop to avoid that there are no ended hyps
            if i == maxlen - 1:
                for hyp in hyps:
                    hyp['yseq'] = torch.cat([hyp['yseq'],
                                             torch.ones(1, 1).fill_(self.eos_id).type_as(encoded).long()], dim=1)

            # add ended hypothes to a final list, and removed them from current hypothes
            # (this will be a probmlem, number of hyps < beam)
            remained_hyps = []
            for hyp in hyps:
                if hyp['yseq'][0, -1] == self.eos_id:
                    ended_hyps.append(hyp)
                else:
                    remained_hyps.append(hyp)

            hyps = remained_hyps
            if len(hyps) > 0:
                print('remeined hypothes: ' + str(len(hyps)))
            else:
                print('no hypothesis. Finish decoding.')
                break
        # end for i in range(maxlen)
        nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[
            :min(len(ended_hyps), nbest)]
        # compitable with LAS implementation
        for hyp in nbest_hyps:
            hyp['yseq'] = hyp['yseq'][0].cpu().numpy().tolist()
        return nbest_hyps