def store_penultimate_state(self, xs_pad, ilens, ys_pad, moe_coes,
                                moe_coe_lens):
        moe_coes = moe_coes[:, :max(moe_coe_lens)]  # for data parallel
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        # multi-encoder forward
        cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask)
        en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask)
        moe_coes = moe_coes.unsqueeze(-1)
        hs_pad = cn_hs_pad * moe_coes[:, :, 1] + en_hs_pad * moe_coes[:, :, 0]
        self.hs_pad = hs_pad

        # forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask, penultimate_state = self.decoder(
            ys_in_pad,
            ys_mask,
            hs_pad,
            hs_mask,
            moe_coes,
            return_penultimate_state=True)

        # plot penultimate_state, (B,T,att_dim)
        return penultimate_state.squeeze(0).detach().cpu().numpy()
Beispiel #2
0
    def _calc_mt_att_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_in_lens = ys_pad_lens + 1

        # 1. Forward decoder
        decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad,
                                      ys_in_lens)

        # 2. Compute attention loss
        loss_att = self.criterion_mt(decoder_out, ys_out_pad)
        acc_att = th_accuracy(
            decoder_out.view(-1, self.vocab_size),
            ys_out_pad,
            ignore_label=self.ignore_id,
        )

        # Compute cer/wer using attention-decoder
        if self.training or self.mt_error_calculator is None:
            bleu_att = None
        else:
            ys_hat = decoder_out.argmax(dim=-1)
            bleu_att = self.mt_error_calculator(ys_hat.cpu(), ys_pad.cpu())

        return loss_att, acc_att, bleu_att
Beispiel #3
0
    def forward_asr(self, hs_pad, hs_mask, ys_pad):
        """Forward pass in the auxiliary ASR task.

        :param torch.Tensor hs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor hs_mask: batch of input token mask (B, Lmax)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ASR attention loss value
        :rtype: torch.Tensor
        :return: accuracy in ASR attention decoder
        :rtype: float
        :return: ASR CTC loss value
        :rtype: torch.Tensor
        :return: character error rate from CTC prediction
        :rtype: float
        :return: character error rate from attetion decoder prediction
        :rtype: float
        :return: word error rate from attetion decoder prediction
        :rtype: float
        """
        loss_att, loss_ctc = 0.0, 0.0
        acc = None
        cer, wer = None, None
        cer_ctc = None
        if self.asr_weight == 0:
            return loss_att, acc, loss_ctc, cer_ctc, cer, wer

        # attention
        if self.mtlalpha < 1:
            ys_in_pad_asr, ys_out_pad_asr = add_sos_eos(
                ys_pad, self.sos, self.eos, self.ignore_id)
            ys_mask_asr = target_mask(ys_in_pad_asr, self.ignore_id)
            pred_pad, _ = self.decoder_asr(ys_in_pad_asr, ys_mask_asr, hs_pad,
                                           hs_mask)
            loss_att = self.criterion(pred_pad, ys_out_pad_asr)

            acc = th_accuracy(
                pred_pad.view(-1, self.odim),
                ys_out_pad_asr,
                ignore_label=self.ignore_id,
            )
            if not self.training:
                ys_hat_asr = pred_pad.argmax(dim=-1)
                cer, wer = self.error_calculator_asr(ys_hat_asr.cpu(),
                                                     ys_pad.cpu())

        # CTC
        if self.mtlalpha > 0:
            batch_size = hs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)
            if not self.training:
                ys_hat_ctc = self.ctc.argmax(
                    hs_pad.view(batch_size, -1, self.adim)).data
                cer_ctc = self.error_calculator_asr(ys_hat_ctc.cpu(),
                                                    ys_pad.cpu(),
                                                    is_ctc=True)
                # for visualization
                self.ctc.softmax(hs_pad)
        return loss_att, acc, loss_ctc, cer_ctc, cer, wer
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats

        # 5. compute bleu
        if self.training or self.error_calculator is None:
            bleu = 0.0
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_mt
        self.loss = loss

        loss_data = float(self.loss)
        if self.normalize_length:
            self.ppl = np.exp(loss_data)
        else:
            ys_out_pad = ys_out_pad.view(-1)
            ignore = ys_out_pad == self.ignore_id  # (B,)
            total = len(ys_out_pad) - ignore.sum().item()
            self.ppl = np.exp(loss_data * ys_out_pad.size(0) / total)
        if not math.isnan(loss_data):
            self.reporter.report(loss_data, self.acc, self.ppl, bleu)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
Beispiel #5
0
def test_transformer_mask():
    args = make_arg()
    model, x, ilens, y, data, uttid_list = prepare("pytorch", args)
    yi, yo = add_sos_eos(y, model.sos, model.eos, model.ignore_id)
    y_mask = target_mask(yi, model.ignore_id)
    y = model.decoder.embed(yi)
    y[0, 3:] = float("nan")
    a = model.decoder.decoders[0].self_attn
    a(y, y, y, y_mask)
    assert not numpy.isnan(a.attn[0, :, :3, :3].detach().numpy()).any()
Beispiel #6
0
def test_transformer_mask(module):
    model, x, ilens, y, data = prepare(module)
    from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
    from espnet.nets.pytorch_backend.transformer.mask import target_mask
    yi, yo = add_sos_eos(y, model.sos, model.eos, model.ignore_id)
    y_mask = target_mask(yi, model.ignore_id)
    y = model.decoder.embed(yi)
    y[0, 3:] = float("nan")
    a = model.decoder.decoders[0].self_attn
    a(y, y, y, y_mask)
    assert not numpy.isnan(a.attn[0, :, :3, :3].detach().numpy()).any()
Beispiel #7
0
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
    ):
        """Frontend + Encoder + Decoder + Calc loss

        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
            text: (Batch, Length)
            text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
                text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
                                         text.shape, text_lengths.shape)
        # 1. Encoder
        encoder_out, encoder_out_lens, _ = self.encoder(
            speech,
            speech_lengths,
            left_mask=self.encoder_left_mask,
            right_mask=self.encoder_right_mask)  # return xs_pad, olens, None

        # 2. Decoder
        # todo: train right shift
        text_in, text_out = add_sos_eos(text, self.sos, self.eos,
                                        self.ignore_id)
        text_in_lens = text_lengths + 1
        decoder_out, decoder_out_lens, _ = self.decoder(
            text_in,
            text_in_lens,
            left_mask=self.decoder_left_mask,
            right_mask=0)  # return xs_pad, olens, None

        # 3.Joint
        # h_enc: Batch of expanded hidden state (B, T, 1, D_enc)
        # h_dec: Batch of expanded hidden state (B, 1, U, D_dec)
        encoder_out = encoder_out.unsqueeze(2)
        decoder_out = decoder_out.unsqueeze(1)
        joint_out = self.joint(h_enc=encoder_out, h_dec=decoder_out)

        # 4.loss
        # pred_pad (torch.Tensor): Batch of predicted sequences
        loss = self.loss(
            pred_pad=joint_out,  # (batch, maxlen_in, maxlen_out+1, odim)
            target=text.int(),  # (batch, maxlen_out)
            pred_len=speech_lengths.int(),  # (batch)
            target_len=text_lengths.int())  # (batch)
        return loss
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)

        # 3. compute attention loss
        self.loss = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # 4. compute corpus-level bleu in a mini-batch
        if self.training:
            self.bleu = None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        loss_data = float(self.loss)
        if self.normalize_length:
            self.ppl = np.exp(loss_data)
        else:
            batch_size = ys_out_pad.size(0)
            ys_out_pad = ys_out_pad.view(-1)
            ignore = ys_out_pad == self.ignore_id  # (B*T,)
            total_n_tokens = len(ys_out_pad) - ignore.sum().item()
            self.ppl = np.exp(loss_data * batch_size / total_n_tokens)
        if not math.isnan(loss_data):
            self.reporter.report(loss_data, self.acc, self.ppl, self.bleu)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
Beispiel #9
0
    def forward_ilm(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
            speech: (Batch, Length, ...) not nessesary it is only used to get device of tensor
            speech_lengths: (Batch, )   not nessesary it is only used to get device of tensor
            text: (Batch, Length)
            text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (text.shape[0] == text_lengths.shape[0]), (text.shape,
                                                          text_lengths.shape)
        batch_size = text.shape[0]

        # for data-parallel
        text = text[:, :text_lengths.max()]

        ys_in_pad, ys_out_pad = add_sos_eos(text, self.sos, self.eos,
                                            self.ignore_id)
        ys_in_lens = text_lengths + 1

        fake_encoder_out = speech.new_zeros(batch_size, 1,
                                            self.encoder._output_size)
        # 1. Forward decoder
        decoder_out, _ = self.decoder.forward_ilm(fake_encoder_out, -1,
                                                  ys_in_pad, ys_in_lens)

        # 2. Compute ilm loss
        loss_ilm = self.criterion_att(decoder_out, ys_out_pad)
        ilm_acc = th_accuracy(
            decoder_out.view(-1, self.vocab_size),
            ys_out_pad,
            ignore_label=self.ignore_id,
        )

        ilm_ppl = torch.exp(loss_ilm)
        stats = dict(ilm_loss=loss_ilm.detach(),
                     ilm_acc=ilm_acc,
                     ilm_ppl=ilm_ppl.detach())

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss_ilm, stats, batch_size),
                                               loss_ilm.device)
        return loss_ilm, stats, weight
Beispiel #10
0
    def store_penultimate_state(self, xs_pad, ilens, ys_pad, bnf_feats, bnf_feats_lens):
        bnf_feats = bnf_feats[:, :max(bnf_feats_lens)] # for data parallel
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask, bnf_feats)
        self.hs_pad = hs_pad

        # forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask, penultimate_state = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask, return_penultimate_state=True)

        # plot penultimate_state, (B,T,att_dim)
        return penultimate_state.squeeze(0).detach().cpu().numpy()
Beispiel #11
0
    def decoder_and_attention(self, hs_pad, hs_mask, ys_pad, batch_size):
        """Forward decoder and attention loss."""
        # forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)

        # compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        acc = th_accuracy(pred_pad.view(-1, self.odim),
                          ys_out_pad,
                          ignore_label=self.ignore_id)
        return pred_pad, pred_mask, loss_att, acc
Beispiel #12
0
    def _meta_collect_stats(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
    ):
        ys_in_pad, ys_out_pad = add_sos_eos(text, self.sos, self.eos, self.ignore_id)
        ys_in_lens = text_lengths + 1

        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens)

        # Prior statistic caculator
        self.stat(decoder_out, ys_out_pad, ys_in_lens - 1) # eliminate <eos> label by reducing the ys_pad_lens
Beispiel #13
0
    def _extract_feats(
            self, src_text: torch.Tensor, src_text_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        assert src_text_lengths.dim() == 1, src_text_lengths.shape

        # for data-parallel
        src_text = src_text[:, :src_text_lengths.max()]
        src_text, _ = add_sos_eos(src_text, self.sos, self.eos, self.ignore_id)
        src_text_lengths = src_text_lengths + 1

        if self.frontend is not None:
            # Frontend
            #  e.g. Embedding Lookup
            # src_text (Batch, NSamples) -> feats: (Batch, NSamples, Dim)
            feats, feats_lengths = self.frontend(src_text, src_text_lengths)
        else:
            # No frontend and no feature extract
            feats, feats_lengths = src_text, src_text_lengths
        return feats, feats_lengths
Beispiel #14
0
    def nll(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ) -> torch.Tensor:
        """Compute negative log likelihood(nll) from transformer-decoder

        Normally, this function is called in batchify_nll.

        Args:
            encoder_out: (Batch, Length, Dim)
            encoder_out_lens: (Batch,)
            ys_pad: (Batch, Length)
            ys_pad_lens: (Batch,)
        """
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_in_lens = ys_pad_lens + 1

        # 1. Forward decoder
        decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad,
                                      ys_in_lens)  # [batch, seqlen, dim]
        batch_size = decoder_out.size(0)
        decoder_num_class = decoder_out.size(2)
        # nll: negative log-likelihood
        nll = torch.nn.functional.cross_entropy(
            decoder_out.view(-1, decoder_num_class),
            ys_out_pad.view(-1),
            ignore_index=self.ignore_id,
            reduction="none",
        )
        nll = nll.view(batch_size, -1)
        nll = nll.sum(dim=1)
        assert nll.size(0) == batch_size
        return nll
Beispiel #15
0
    def _meta_forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        lam=0.6,
    ):
        assert self.meta_encoder is not None
        assert self.meta_decoder is not None
        ys_in_pad, _ = add_sos_eos(text, self.sos, self.eos, self.ignore_id)
        ys_in_lens = text_lengths + 1

        encoder_meta_out, encoder_meta_out_lens = self._meta_encode(speech, speech_lengths)
        decoder_meta_out, _ = self.meta_decoder(encoder_meta_out, encoder_meta_out_lens, ys_in_pad, ys_in_lens)

        decoder_out_prob = torch.softmax(decoder_meta_out, dim=-1)
        if self.lm is not None:
            lm_out, _ = self.lm(ys_in_pad, None)
            lm_out_prob = torch.softmax(lm_out, dim=-1)

            decoder_out_prob = lam * lm_out_prob + decoder_out_prob
        
        return decoder_out_prob
Beispiel #16
0
    def forward(self, xs_pad, ilens, ys_pad, ys_pad_src):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax)
        :return: ctc loss value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 0. Extract target language ID
        tgt_lang_ids = None
        if self.multilingual:
            tgt_lang_ids = ys_pad[:, 0:1]
            ys_pad = ys_pad[:, 1:]  # remove target language ID in the beggining

        # 1. forward encoder
        xs_pad = xs_pad[:, : max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        # replace <sos> with target language ID
        if self.replace_sos:
            ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)

        # 3. compute ST loss
        loss_att = self.criterion(pred_pad, ys_out_pad)

        self.acc = th_accuracy(
            pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id
        )

        # 4. compute corpus-level bleu in a mini-batch
        if self.training:
            self.bleu = None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # 5. compute auxiliary ASR loss
        loss_asr_att, acc_asr, loss_asr_ctc, cer_ctc, cer, wer = self.forward_asr(
            hs_pad, hs_mask, ys_pad_src
        )

        # 6. compute auxiliary MT loss
        loss_mt, acc_mt = 0.0, None
        if self.mt_weight > 0:
            loss_mt, acc_mt = self.forward_mt(
                ys_pad_src, ys_in_pad, ys_out_pad, ys_mask
            )

        asr_ctc_weight = self.mtlalpha
        self.loss = (
            (1 - self.asr_weight - self.mt_weight) * loss_att
            + self.asr_weight
            * (asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att)
            + self.mt_weight * loss_mt
        )
        loss_asr_data = float(
            asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att
        )
        loss_mt_data = None if self.mt_weight == 0 else float(loss_mt)
        loss_st_data = float(loss_att)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(
                loss_asr_data,
                loss_mt_data,
                loss_st_data,
                acc_asr,
                acc_mt,
                self.acc,
                cer_ctc,
                cer,
                wer,
                self.bleu,
                loss_data,
            )
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
Beispiel #17
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        # mlp moe forward
        cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask)
        en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask)
        # gated add module
        """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b)  #(B, T, 1)
            xs = lambda * cn_xs + (1-lambda) * en_xs 
        """
        hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1)
        lambda_ = self.enc_lambda
        hs_pad = lambda_ * cn_hs_pad + (1 - lambda_) * en_hs_pad
        self.hs_pad = hs_pad

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1,
                                                     self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(),
                                                ys_pad.cpu(),
                                                is_ctc=True)

        if self.mtlalpha == 1:
            self.loss_att, acc = None, None
        else:
            # 2. forward decoder
            ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                                self.ignore_id)
            ys_mask = target_mask(ys_in_pad, self.ignore_id)
            pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad,
                                               hs_mask)
            self.pred_pad = pred_pad

            # 3. compute attention loss
            loss_att = self.criterion(pred_pad, ys_out_pad)
            acc = th_accuracy(pred_pad.view(-1, self.odim),
                              ys_out_pad,
                              ignore_label=self.ignore_id)
        self.acc = acc

        # 5. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
Beispiel #18
0
    def _calc_att_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
        replace_label_flag: bool=False,
        decoder_out_prob=None,
    ):
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)    
        ys_in_lens = ys_pad_lens + 1

        # Replace the labels
        if replace_label_flag:
            assert decoder_out_prob is not None
            from espnet.nets.pytorch_backend.nets_utils import pad_list    
            confid = calc_confidence(decoder_out_prob, ys_out_pad)

            # Eliminate the <eos> token and find the position what we replace  
            repl_mask = [prob[:l] < self.th for prob, l in zip(confid, ys_pad_lens)]

            with torch.no_grad():
                decoder_out, _ = self.decoder(
                    encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
                )
                decoder_out_prob = torch.softmax(decoder_out, dim=-1).detach()

            ys_in = [y[y != self.ignore_id] for y in ys_pad.clone().detach()]
            for i, (rm, y) in enumerate(zip(repl_mask, ys_in)):
                weight = decoder_out_prob[i]
                weight = weight[:len(y)]
                samples = torch.multinomial(weight, 1).squeeze(-1)
                y[rm] = samples[rm]
            _sos = ys_pad.new([self.sos])
            ys_in = [torch.cat([_sos, y], dim=0) for y in ys_in]
            ys_in_pad = pad_list(ys_in, self.eos).detach()
            
            ys_out = [y[y != self.ignore_id] for y in ys_pad.clone().detach()]
            for i, (rm, y) in enumerate(zip(repl_mask, ys_out)):
                weight = decoder_out_prob[i]
                weight = weight[:len(y)]
                samples = torch.multinomial(weight, 1).squeeze(-1)
                y[rm] = samples[rm]
            # _ignore = ys_pad.new([self.ignore_id])
            _ignore = ys_pad.new([self.eos])
            ys_out = [torch.cat([y, _ignore], dim=0) for y in ys_out]
            ys_out_pad = pad_list(ys_out, self.ignore_id).detach()

        # 1. Forward decoder
        decoder_out, _ = self.decoder(
                encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
        )
        
        # 2. Compute attention loss
        loss_att = self.criterion_att(decoder_out, ys_out_pad)
        acc_att = th_accuracy(
            decoder_out.view(-1, self.vocab_size),
            ys_out_pad,
            ignore_label=self.ignore_id,
        )
        # if replace_label_flag:
        #     num_repl = 0.0
        #     num_total = 0.0
        #     for m in repl_mask:
        #         num_repl += m.sum()
        #         num_total += len(m)
        #     pred_err_att = float(num_repl) / float(num_total)
        # else:
        #     pred_err_att = 0.0

        # Compute cer/wer using attention-decoder
        if self.training or self.error_calculator is None:
            cer_att, wer_att = None, None
        else:
            ys_hat = decoder_out.argmax(dim=-1)
            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        return loss_att, acc_att, cer_att, wer_att
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        # mlp moe forward
        cn_hs_pad, cn_hs_mask = self.cn_encoder(xs_pad, src_mask)
        en_hs_pad, en_hs_mask = self.en_encoder(xs_pad, src_mask)
        hs_mask = cn_hs_mask  # cn_hs_mask & en_hs_mask are identical
        # gated add module
        """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b)  #(B, T, 1)
            xs = lambda * cn_xs + (1-lambda) * en_xs 
        """
        hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1)
        lambda_ = self.aggregation_module(
            hs_pad)  # (B,T,1)/(B,T,D), range from (0, 1)
        hs_pad = lambda_ * cn_hs_pad + (1 - lambda_) * en_hs_pad
        self.hs_pad = hs_pad

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)

            # divide ys_pad into cn_ys & en_ys;
            # note that this target can directly pass to ctc module
            cn_ys, en_ys = partial_target(ys_pad, self.language_divider)
            cn_loss_ctc = self.cn_ctc(
                cn_hs_pad.view(batch_size, -1, self.adim), hs_len, cn_ys)
            en_loss_ctc = self.en_ctc(
                en_hs_pad.view(batch_size, -1, self.adim), hs_len, en_ys)
            loss_ctc = 0.5 * (cn_loss_ctc + en_loss_ctc)

        if self.mtlalpha == 1:
            self.loss_att, acc = None, None
        else:
            # 2. forward decoder
            ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                                self.ignore_id)
            ys_mask = target_mask(ys_in_pad, self.ignore_id)
            pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad,
                                               hs_mask)
            self.pred_pad = pred_pad

            # 3. compute attention loss
            loss_att = self.criterion(pred_pad, ys_out_pad)
            acc = th_accuracy(pred_pad.view(-1, self.odim),
                              ys_out_pad,
                              ignore_label=self.ignore_id)
        self.acc = acc

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, None, None, loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
Beispiel #20
0
    def forward(self, xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        moe_coes = moe_coes[:, :max(moe_coe_lens)].long()  # for data parallel
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        # mlp moe forward
        cn_hs_pad, cn_hs_mask = self.cn_encoder(xs_pad, src_mask)
        en_hs_pad, en_hs_mask = self.en_encoder(xs_pad, src_mask)
        # gated add module
        """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b)  #(B, T, 1)
            xs = lambda * cn_xs + (1-lambda) * en_xs 
        """
        hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1)
        lambda_ = F.softmax(
            self.aggre_scaling * self.aggregation_module(hs_pad),
            -1).unsqueeze(-1)
        ctc_hs_pad = lambda_[:, :, 0] * cn_hs_pad + lambda_[:, :,
                                                            1] * en_hs_pad
        ctc_hs_mask = cn_hs_mask

        # plat attention mode, (B,T,D)*2 --> (B,2T,D)
        s2s_hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=1)
        # mask: (B,1,T) --> (B,1,2T)
        s2s_hs_mask = torch.cat((cn_hs_mask, en_hs_mask), dim=-1)
        # self.hs_pad = hs_pad

        # compute lid loss here, using lambda_
        # moe_coes (B, T, 2) ==> (B,T)
        moe_coes = moe_coes[:, :, 0]  # 0 for cn, 1 for en
        lambda_ = lambda_.squeeze(-1)
        if self.lid_mtl_alpha == 0.0:
            loss_lid = 0.0
        else:
            loss_lid = self.lid_criterion(lambda_, moe_coes)
        lid_acc = th_accuracy(
            lambda_.view(-1, 2), moe_coes,
            ignore_label=self.ignore_id) if self.log_lid_mtl_acc else None
        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = ctc_hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(ctc_hs_pad.view(batch_size, -1, self.adim),
                                hs_len, ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(
                    ctc_hs_pad.view(batch_size, -1, self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(),
                                                ys_pad.cpu(),
                                                is_ctc=True)

        if self.mtlalpha == 1:
            self.loss_att, acc = None, None
        else:
            # 2. forward decoder
            ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                                self.ignore_id)
            ys_mask = target_mask(ys_in_pad, self.ignore_id)
            pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, s2s_hs_pad,
                                               s2s_hs_mask)
            self.pred_pad = pred_pad

            # 3. compute attention loss
            loss_att = self.criterion(pred_pad, ys_out_pad)
            acc = th_accuracy(pred_pad.view(-1, self.odim),
                              ys_out_pad,
                              ignore_label=self.ignore_id)
        self.acc = acc

        # 5. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        lid_alpha = self.lid_mtl_alpha
        if alpha == 0:
            self.loss = loss_att + lid_alpha * loss_lid
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc + lid_alpha * loss_lid
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (
                1 - alpha) * loss_att + lid_alpha * loss_lid
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data, lid_acc)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
Beispiel #21
0
    def forward(self, xs_pad, ilens, ys_pad, enc_mask=None, dec_mask=None):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        batch_size = xs_pad.shape[0]
        src_mask = make_non_pad_mask(ilens.tolist()).to(
            xs_pad.device).unsqueeze(-2)
        if isinstance(self.encoder.embed, EncoderConv2d):
            xs, hs_mask = self.encoder.embed(xs_pad,
                                             torch.sum(src_mask, 2).squeeze())
            hs_mask = hs_mask.unsqueeze(1)
        else:
            xs, hs_mask = self.encoder.embed(xs_pad, src_mask)

        if enc_mask is not None:
            enc_mask = enc_mask[:, :hs_mask.shape[2], :hs_mask.shape[2]]
        enc_mask = enc_mask & hs_mask if enc_mask is not None else hs_mask
        hs_pad, _ = self.encoder.encoders(xs, enc_mask)
        if self.encoder.normalize_before:
            hs_pad = self.encoder.after_norm(hs_pad)

        # CTC forward
        ys = [y[y != self.ignore_id] for y in ys_pad]
        y_len = max([len(y) for y in ys])
        ys_pad = ys_pad[:, :y_len]
        if dec_mask is not None:
            dec_mask = dec_mask[:, :y_len + 1, :hs_pad.shape[1]]
        self.hs_pad = hs_pad
        batch_size = xs_pad.size(0)
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)

        # trigger mask
        hs_mask = hs_mask & dec_mask if dec_mask is not None else hs_mask
        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        return self.loss, loss_ctc_data, loss_att_data, self.acc
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask, penultimate_state = self.decoder(
            ys_in_pad, ys_mask, hs_pad, hs_mask, return_penultimate_state=True)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)
        # 4. compute lid multitask loss
        src_att = self.lid_src_att(penultimate_state, hs_pad, hs_pad, hs_mask)
        pred_lid_pad = self.lid_output_layer(src_att)
        loss_lid, lid_ys_out_pad = self.lid_criterion(pred_lid_pad, ys_out_pad)
        lid_acc = th_accuracy(
            pred_lid_pad.view(-1, self.lid_odim),
            lid_ys_out_pad,
            ignore_label=self.ignore_id) if self.log_lid_mtl_acc else None

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1,
                                                     self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(),
                                                ys_pad.cpu(),
                                                is_ctc=True)

        # 5. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        lid_alpha = self.lid_mtl_alpha
        if alpha == 0:
            self.loss = loss_att + lid_alpha * loss_lid
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            raise Exception("LID MTL not supports pure ctc mode")
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (
                1 - alpha) * loss_att + lid_alpha * loss_lid
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data, lid_acc)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        # CTC forward
        ys = [y[y != self.ignore_id] for y in ys_pad]
        y_len = max([len(y) for y in ys])
        ys_pad = ys_pad[:, :y_len]
        self.hs_pad = hs_pad
        cer_ctc = None
        batch_size = xs_pad.size(0)
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)

        # trigger mask
        start_time = time.time()
        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)

        return self.loss, loss_ctc_data, loss_att_data, self.acc
    def forward(self, xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        moe_coes = moe_coes[:, :max(moe_coe_lens)]  # for data parallel
        # here we use interpolation_coe to 'fix' initial moe_coes
        interp_factor = self.interp_factor  # 0.1 for example, similar to lsm
        moe_coes = (
            1 - interp_factor) * moe_coes + interp_factor / moe_coes.shape[2]

        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)

        # multi-encoder forward
        cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask)
        en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask)
        moe_coes = moe_coes.unsqueeze(-1)
        hs_pad = cn_hs_pad * moe_coes[:, :, 1] + en_hs_pad * moe_coes[:, :, 0]
        self.hs_pad = hs_pad

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1,
                                                     self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(),
                                                ys_pad.cpu(),
                                                is_ctc=True)

        if self.mtlalpha == 1:
            self.loss_att, acc = None, None
        else:
            # 2. forward decoder
            ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                                self.ignore_id)
            ys_mask = target_mask(ys_in_pad, self.ignore_id)
            pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad,
                                               hs_mask)
            self.pred_pad = pred_pad

            # 3. compute attention loss
            loss_att = self.criterion(pred_pad, ys_out_pad)
            acc = th_accuracy(pred_pad.view(-1, self.odim),
                              ys_out_pad,
                              ignore_label=self.ignore_id)
        self.acc = acc

        # 5. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
    def forward(self, xs_pad, ilens, ys_pad, ys_pad_src):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 0. Extract target language ID
        tgt_lang_ids = None
        if self.multilingual:
            tgt_lang_ids = ys_pad[:, 0:1]
            ys_pad = ys_pad[:,
                            1:]  # remove target language ID in the beggining

        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        # replace <sos> with target language ID
        if self.replace_sos:
            ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)

        # 3. compute ST loss
        loss_asr_att, loss_asr_ctc, loss_mt = 0.0, 0.0, 0.0
        acc_asr, acc_mt = 0.0, 0.0
        loss_att = self.criterion(pred_pad, ys_out_pad)

        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # 4. compute corpus-level bleu in a mini-batch
        if self.training:
            self.bleu = None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # 5. compute auxiliary ASR loss
        cer, wer = None, None
        cer_ctc = None
        if self.asr_weight > 0:
            # attention
            if self.mtlalpha < 1:
                ys_in_pad_asr, ys_out_pad_asr = add_sos_eos(
                    ys_pad_src, self.sos, self.eos, self.ignore_id)
                ys_mask_asr = target_mask(ys_in_pad_asr, self.ignore_id)
                pred_pad_asr, _ = self.decoder_asr(ys_in_pad_asr, ys_mask_asr,
                                                   hs_pad, hs_mask)
                loss_asr_att = self.criterion(pred_pad_asr, ys_out_pad_asr)

                acc_asr = th_accuracy(
                    pred_pad_asr.view(-1, self.odim),
                    ys_out_pad_asr,
                    ignore_label=self.ignore_id,
                )
                if not self.training:
                    ys_hat_asr = pred_pad_asr.argmax(dim=-1)
                    cer, wer = self.error_calculator_asr(
                        ys_hat_asr.cpu(), ys_pad_src.cpu())

            # CTC
            if self.mtlalpha > 0:
                batch_size = xs_pad.size(0)
                hs_len = hs_mask.view(batch_size, -1).sum(1)
                loss_asr_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim),
                                        hs_len, ys_pad_src)
                ys_hat_ctc = self.ctc.argmax(
                    hs_pad.view(batch_size, -1, self.adim)).data
                if not self.training:
                    cer_ctc = self.error_calculator_asr(ys_hat_ctc.cpu(),
                                                        ys_pad_src.cpu(),
                                                        is_ctc=True)

        # 6. compute auxiliary MT loss
        if self.mt_weight > 0:
            ilens_mt = torch.sum(ys_pad_src != self.ignore_id,
                                 dim=1).cpu().numpy()
            # NOTE: ys_pad_src is padded with -1
            ys_src = [y[y != self.ignore_id]
                      for y in ys_pad_src]  # parse padded ys_src
            ys_zero_pad_src = pad_list(ys_src, self.pad)  # re-pad with zero
            ys_zero_pad_src = ys_zero_pad_src[:, :max(
                ilens_mt)]  # for data parallel
            src_mask_mt = ((~make_pad_mask(ilens_mt.tolist())).to(
                ys_zero_pad_src.device).unsqueeze(-2))
            hs_pad_mt, hs_mask_mt = self.encoder_mt(ys_zero_pad_src,
                                                    src_mask_mt)
            pred_pad_mt, _ = self.decoder(ys_in_pad, ys_mask, hs_pad_mt,
                                          hs_mask_mt)
            loss_mt = self.criterion(pred_pad_mt, ys_out_pad)

            acc_mt = th_accuracy(pred_pad_mt.view(-1, self.odim),
                                 ys_out_pad,
                                 ignore_label=self.ignore_id)

        alpha = self.mtlalpha
        self.loss = ((1 - self.asr_weight - self.mt_weight) * loss_att +
                     self.asr_weight * (alpha * loss_asr_ctc +
                                        (1 - alpha) * loss_asr_att) +
                     self.mt_weight * loss_mt)
        loss_asr_data = float(alpha * loss_asr_ctc +
                              (1 - alpha) * loss_asr_att)
        loss_mt_data = None if self.mt_weight == 0 else float(loss_mt)
        loss_st_data = float(loss_att)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(
                loss_asr_data,
                loss_mt_data,
                loss_st_data,
                acc_asr,
                acc_mt,
                self.acc,
                cer_ctc,
                cer,
                wer,
                self.bleu,
                loss_data,
            )
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
    def forward(self, xs_pad, ilens, ys_pad, ys_pad_src):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 0. Extract target language ID
        # src_lang_ids = None
        tgt_lang_ids, tgt_lang_ids_src = None, None
        if self.multilingual:
            tgt_lang_ids = ys_pad[:, 0:1]
            ys_pad = ys_pad[:, 1:]  # remove target language ID in the beggining

        if self.one_to_many:
            tgt_lang_ids = ys_pad[:, 0:1]
            ys_pad = ys_pad[:, 1:]  # remove target language ID in the beggining
            
            if self.do_asr:
                tgt_lang_ids_src = ys_pad_src[:, 0:1]
                ys_pad_src = ys_pad_src[:, 1:]  # remove target language ID in the beggining

        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel # bs x max_ilens x idim

        if self.lang_tok == "encoder-pre-sum":
            lang_embed = self.language_embeddings(tgt_lang_ids) # bs x 1 x idim
            xs_pad = xs_pad + lang_embed

        src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2) # bs x 1 x max_ilens
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # hs_pad: bs x (max_ilens/4) x adim; hs_mask: bs x 1 x (max_ilens/4)
        self.hs_pad = hs_pad

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # bs x max_lens

        if self.do_asr:
            ys_in_pad_src, ys_out_pad_src = add_sos_eos(ys_pad_src, self.sos, self.eos, self.ignore_id) # bs x max_lens_src

        # replace <sos> with target language ID
        if self.replace_sos:
            ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1)

        if self.lang_tok == "decoder-pre":
            ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1)
            if self.do_asr:
                ys_in_pad_src = torch.cat([tgt_lang_ids_src, ys_in_pad_src[:, 1:]], dim=1)

        ys_mask = target_mask(ys_in_pad, self.ignore_id) # bs x max_lens x max_lens

        if self.do_asr:
            ys_mask_src = target_mask(ys_in_pad_src, self.ignore_id) # bs x max_lens x max_lens_src

        if self.wait_k_asr > 0:
            cross_mask = create_cross_mask(ys_in_pad, ys_in_pad_src, self.ignore_id, wait_k_cross=self.wait_k_asr)
            cross_mask_asr = create_cross_mask(ys_in_pad_src, ys_in_pad, self.ignore_id, wait_k_cross=-self.wait_k_asr)
        elif self.wait_k_st > 0:
            cross_mask = create_cross_mask(ys_in_pad, ys_in_pad_src, self.ignore_id, wait_k_cross=-self.wait_k_st)
            cross_mask_asr = create_cross_mask(ys_in_pad_src, ys_in_pad, self.ignore_id, wait_k_cross=self.wait_k_st)
        else:
            cross_mask = create_cross_mask(ys_in_pad, ys_in_pad_src, self.ignore_id, wait_k_cross=0)
            cross_mask_asr = create_cross_mask(ys_in_pad_src, ys_in_pad, self.ignore_id, wait_k_cross=0)

        pred_pad, pred_mask, pred_pad_asr, pred_mask_asr = self.dual_decoder(ys_in_pad, ys_mask, ys_in_pad_src, ys_mask_src,
                                                                                hs_pad, hs_mask, cross_mask, cross_mask_asr,
                                                                                cross_self=self.cross_self, cross_src=self.cross_src,
                                                                                cross_self_from=self.cross_self_from,
                                                                                cross_src_from=self.cross_src_from)

        self.pred_pad = pred_pad
        self.pred_pad_asr = pred_pad_asr
        pred_pad_mt = None

        # 3. compute attention loss
        loss_asr, loss_mt = 0.0, 0.0
        loss_att = self.criterion(pred_pad, ys_out_pad)

        # compute loss
        loss_asr = self.criterion(pred_pad_asr, ys_out_pad_src)
        # Multi-task w/ MT
        if self.mt_weight > 0:
            # forward MT encoder
            ilens_mt = torch.sum(ys_pad_src != self.ignore_id, dim=1).cpu().numpy()
            # NOTE: ys_pad_src is padded with -1
            ys_src = [y[y != self.ignore_id] for y in ys_pad_src]  # parse padded ys_src
            ys_zero_pad_src = pad_list(ys_src, self.pad)  # re-pad with zero
            ys_zero_pad_src = ys_zero_pad_src[:, :max(ilens_mt)]  # for data parallel
            src_mask_mt = (~make_pad_mask(ilens_mt.tolist())).to(ys_zero_pad_src.device).unsqueeze(-2)
            # ys_zero_pad_src, ys_pad = self.target_forcing(ys_zero_pad_src, ys_pad)
            hs_pad_mt, hs_mask_mt = self.encoder_mt(ys_zero_pad_src, src_mask_mt)
            # forward MT decoder
            pred_pad_mt, _ = self.decoder(ys_in_pad, ys_mask, hs_pad_mt, hs_mask_mt)
            # compute loss
            loss_mt = self.criterion(pred_pad_mt, ys_out_pad)

        self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad,
                               ignore_label=self.ignore_id)
        if pred_pad_asr is not None:
            self.acc_asr = th_accuracy(pred_pad_asr.view(-1, self.odim), ys_out_pad_src,
                                       ignore_label=self.ignore_id)
        else:
            self.acc_asr = 0.0
        if pred_pad_mt is not None:
            self.acc_mt = th_accuracy(pred_pad_mt.view(-1, self.odim), ys_out_pad,
                                      ignore_label=self.ignore_id)
        else:
            self.acc_mt = 0.0

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0 or self.asr_weight == 0:
            loss_ctc = 0.0
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad_src)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad_src.cpu(), is_ctc=True)

        # 5. compute cer/wer
        cer, wer = None, None  # TODO(hirofumi0810): fix later
        # if self.training or (self.asr_weight == 0 or self.mtlalpha == 1 or not (self.report_cer or self.report_wer)):
        #     cer, wer = None, None
        # else:
        #     ys_hat = pred_pad.argmax(dim=-1)
        #     cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        self.loss = (1 - self.asr_weight - self.mt_weight) * loss_att + self.asr_weight * \
            (alpha * loss_ctc + (1 - alpha) * loss_asr) + self.mt_weight * loss_mt
        loss_asr_data = float(alpha * loss_ctc + (1 - alpha) * loss_asr)
        loss_mt_data = None if self.mt_weight == 0 else float(loss_mt)
        loss_st_data = float(loss_att)

        # logging.info(f'loss_st_data={loss_st_data}')

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_asr_data, loss_mt_data, loss_st_data,
                                 self.acc_asr, self.acc_mt, self.acc,
                                 cer_ctc, cer, wer, 0.0,  # TODO(hirofumi0810): bleu
                                 loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loss value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, : max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        # 2. forward decoder
        if self.decoder is not None:
            if self.decoder_mode == "maskctc":
                ys_in_pad, ys_out_pad = mask_uniform(
                    ys_pad, self.mask_token, self.eos, self.ignore_id
                )
                ys_mask = (ys_in_pad != self.ignore_id).unsqueeze(-2)
            else:
                ys_in_pad, ys_out_pad = add_sos_eos(
                    ys_pad, self.sos, self.eos, self.ignore_id
                )
                ys_mask = target_mask(ys_in_pad, self.ignore_id)
            pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
            self.pred_pad = pred_pad

            # 3. compute attention loss
            loss_att = self.criterion(pred_pad, ys_out_pad)
            self.acc = th_accuracy(
                pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id
            )
        else:
            loss_att = None
            self.acc = None

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad)
            if not self.training and self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
            # for visualization
            if not self.training:
                self.ctc.softmax(hs_pad)

        # 5. compute cer/wer
        if self.training or self.error_calculator is None or self.decoder is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(
                loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data
            )
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
Beispiel #28
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        '''
        """
        if self.attention_enc_type in [
                'self_attn_dynamic_span', 'self_attn_adaptive_span',
                'self_attn_adaptive_span2', 'self_attn_fixed_span2',
                'self_attn_dynamic_span2'
        ]:
            for layer in self.encoder.encoders:
                layer.self_attn.clamp_param()
        if self.attention_dec_type in [
                'self_attn_dynamic_span', 'self_attn_adaptive_span',
                'self_attn_adaptive_span2', 'self_attn_fixed_span2',
                'self_attn_dynamic_span2'
        ]:
            for layer in self.decoder.decoders:
                layer.self_attn.clamp_param()

        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1,
                                                     self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(),
                                                ys_pad.cpu(),
                                                is_ctc=True)

        # 5. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)
        # xkc09 Span attention loss computation
        # xkc09 Span attention size loss computation
        loss_span = 0
        if self.attention_enc_type in [
                'self_attn_dynamic_span', 'self_attn_adaptive_span',
                'self_attn_adaptive_span2', 'self_attn_dynamic_span2'
        ]:
            loss_span += sum([
                layer.self_attn.get_mean_span()
                for layer in self.encoder.encoders
            ])
        if self.attention_dec_type in [
                'self_attn_dynamic_span', 'self_attn_adaptive_span',
                'self_attn_adaptive_span2', 'self_attn_dynamic_span2'
        ]:
            loss_span += sum([
                layer.self_attn.get_mean_span()
                for layer in self.decoder.decoders
            ])
        # xkc09 Span attention ratio loss computation
        loss_ratio = 0
        if self.ratio_adaptive:
            # target_ratio = 0.5
            if self.attention_enc_type in [
                    'self_attn_adaptive_span2', 'self_attn_fixed_span2',
                    'self_attn_dynamic_span2'
            ]:
                loss_ratio += sum([
                    1 - layer.self_attn.get_mean_ratio()
                    for layer in self.encoder.encoders
                ])
            if self.attention_dec_type in [
                    'self_attn_adaptive_span2', 'self_attn_fixed_span2',
                    'self_attn_dynamic_span2'
            ]:
                loss_ratio += sum([
                    1 - layer.self_attn.get_mean_ratio()
                    for layer in self.decoder.decoders
                ])
        if (self.attention_enc_type in [
                'self_attn_dynamic_span', 'self_attn_adaptive_span',
                'self_attn_adaptive_span2', 'self_attn_fixed_span2',
                'self_attn_dynamic_span2'
        ] or self.attention_dec_type in [
                'self_attn_dynamic_span', 'self_attn_adaptive_span',
                'self_attn_adaptive_span2', 'self_attn_fixed_span2',
                'self_attn_dynamic_span2'
        ]):
            if getattr(self, 'span_loss_coef', None):
                self.loss += (loss_span + loss_ratio) * self.span_loss_coef

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss