예제 #1
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        # lid output layer
        pred_pad = self.lid_lo(hs_pad)

        # compute lid loss
        self.loss = self.criterion(pred_pad, ys_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_pad,
                               ignore_label=self.ignore_id)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(self.acc, loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
예제 #2
0
    def _calc_mt_att_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_in_lens = ys_pad_lens + 1

        # 1. Forward decoder
        decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad,
                                      ys_in_lens)

        # 2. Compute attention loss
        loss_att = self.criterion_mt(decoder_out, ys_out_pad)
        acc_att = th_accuracy(
            decoder_out.view(-1, self.vocab_size),
            ys_out_pad,
            ignore_label=self.ignore_id,
        )

        # Compute cer/wer using attention-decoder
        if self.training or self.mt_error_calculator is None:
            bleu_att = None
        else:
            ys_hat = decoder_out.argmax(dim=-1)
            bleu_att = self.mt_error_calculator(ys_hat.cpu(), ys_pad.cpu())

        return loss_att, acc_att, bleu_att
예제 #3
0
    def _calc_mlm_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        # 1. Apply masks
        ys_in_pad, ys_out_pad = mask_uniform(
            ys_pad, self.mask_token, self.eos, self.ignore_id
        )

        # 2. Forward decoder
        decoder_out, _ = self.decoder(
            encoder_out, encoder_out_lens, ys_in_pad, ys_pad_lens
        )

        # 3. Compute mlm loss
        loss_mlm = self.criterion_mlm(decoder_out, ys_out_pad)
        acc_mlm = th_accuracy(
            decoder_out.view(-1, self.vocab_size),
            ys_out_pad,
            ignore_label=self.ignore_id,
        )

        return loss_mlm, acc_mlm
예제 #4
0
    def forward_asr(self, hs_pad, hs_mask, ys_pad):
        """Forward pass in the auxiliary ASR task.

        :param torch.Tensor hs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor hs_mask: batch of input token mask (B, Lmax)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ASR attention loss value
        :rtype: torch.Tensor
        :return: accuracy in ASR attention decoder
        :rtype: float
        :return: ASR CTC loss value
        :rtype: torch.Tensor
        :return: character error rate from CTC prediction
        :rtype: float
        :return: character error rate from attetion decoder prediction
        :rtype: float
        :return: word error rate from attetion decoder prediction
        :rtype: float
        """
        loss_att, loss_ctc = 0.0, 0.0
        acc = None
        cer, wer = None, None
        cer_ctc = None
        if self.asr_weight == 0:
            return loss_att, acc, loss_ctc, cer_ctc, cer, wer

        # attention
        if self.mtlalpha < 1:
            ys_in_pad_asr, ys_out_pad_asr = add_sos_eos(
                ys_pad, self.sos, self.eos, self.ignore_id)
            ys_mask_asr = target_mask(ys_in_pad_asr, self.ignore_id)
            pred_pad, _ = self.decoder_asr(ys_in_pad_asr, ys_mask_asr, hs_pad,
                                           hs_mask)
            loss_att = self.criterion(pred_pad, ys_out_pad_asr)

            acc = th_accuracy(
                pred_pad.view(-1, self.odim),
                ys_out_pad_asr,
                ignore_label=self.ignore_id,
            )
            if not self.training:
                ys_hat_asr = pred_pad.argmax(dim=-1)
                cer, wer = self.error_calculator_asr(ys_hat_asr.cpu(),
                                                     ys_pad.cpu())

        # CTC
        if self.mtlalpha > 0:
            batch_size = hs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)
            if not self.training:
                ys_hat_ctc = self.ctc.argmax(
                    hs_pad.view(batch_size, -1, self.adim)).data
                cer_ctc = self.error_calculator_asr(ys_hat_ctc.cpu(),
                                                    ys_pad.cpu(),
                                                    is_ctc=True)
                # for visualization
                self.ctc.softmax(hs_pad)
        return loss_att, acc, loss_ctc, cer_ctc, cer, wer
예제 #5
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
        :return: loss value
        :rtype: torch.Tensor
        """
        # 1. Encoder
        xs_pad, ilens = self.conv(xs_pad, ilens)
        xs_pad, ilens = self.lstm(xs_pad, ilens)
        pred_pad = self.lid_lo(xs_pad)

        # compute lid loss
        self.loss = self.criterion(pred_pad, ys_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_pad,
                               ignore_label=self.ignore_id)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(self.acc, loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
예제 #6
0
def test_train_acc():
    n_out = 7
    _eos = n_out - 1
    n_batch = 3
    label_length = numpy.array([4, 2, 3], dtype=numpy.int32)
    np_pred = numpy.random.rand(n_batch,
                                max(label_length) + 1,
                                n_out).astype(numpy.float32)
    # NOTE: 0 is only used for CTC, never appeared in attn target
    np_target = [
        numpy.random.randint(1, n_out - 1, size=ol, dtype=numpy.int32)
        for ol in label_length
    ]

    eos = numpy.array([_eos], 'i')
    ys_out = [F.concat([y, eos], axis=0) for y in np_target]

    # padding for ys with -1
    # pys: utt x olen
    # NOTE: -1 is default ignore index for chainer
    pad_ys_out = F.pad_sequence(ys_out, padding=-1)
    y_all = F.reshape(np_pred, (n_batch * (max(label_length) + 1), n_out))
    ch_acc = F.accuracy(y_all, F.concat(pad_ys_out, axis=0), ignore_label=-1)

    # NOTE: this index 0 is only for CTC not attn. so it can be ignored
    # unfortunately, torch cross_entropy does not accept out-of-bound ids
    th_ignore = 0
    th_pred = torch.from_numpy(y_all.data)
    th_ys = [torch.from_numpy(numpy.append(t, eos)).long() for t in np_target]
    th_target = pad_list(th_ys, th_ignore)
    th_acc = th_accuracy(th_pred, th_target, th_ignore)

    numpy.testing.assert_allclose(ch_acc.data, th_acc)
예제 #7
0
    def forward_mt(self, xs_pad, ys_in_pad, ys_out_pad, ys_mask):
        """Forward pass in the auxiliary MT task.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ys_in_pad: batch of padded target sequences (B, Lmax)
        :param torch.Tensor ys_out_pad: batch of padded target sequences (B, Lmax)
        :param torch.Tensor ys_mask: batch of input token mask (B, Lmax)
        :return: MT loss value
        :rtype: torch.Tensor
        :return: accuracy in MT decoder
        :rtype: float
        """
        loss, acc = 0.0, None
        if self.mt_weight == 0:
            return loss, acc

        ilens = torch.sum(xs_pad != self.ignore_id, dim=1).cpu().numpy()
        # NOTE: xs_pad is padded with -1
        xs = [x[x != self.ignore_id] for x in xs_pad]  # parse padded xs
        xs_zero_pad = pad_list(xs, self.pad)  # re-pad with zero
        xs_zero_pad = xs_zero_pad[:, : max(ilens)]  # for data parallel
        src_mask = (
            make_non_pad_mask(ilens.tolist()).to(xs_zero_pad.device).unsqueeze(-2)
        )
        hs_pad, hs_mask = self.encoder_mt(xs_zero_pad, src_mask)
        pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        loss = self.criterion(pred_pad, ys_out_pad)
        acc = th_accuracy(
            pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id
        )
        return loss, acc
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats

        # 5. compute bleu
        if self.training or self.error_calculator is None:
            bleu = 0.0
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_mt
        self.loss = loss

        loss_data = float(self.loss)
        if self.normalize_length:
            self.ppl = np.exp(loss_data)
        else:
            ys_out_pad = ys_out_pad.view(-1)
            ignore = ys_out_pad == self.ignore_id  # (B,)
            total = len(ys_out_pad) - ignore.sum().item()
            self.ppl = np.exp(loss_data * ys_out_pad.size(0) / total)
        if not math.isnan(loss_data):
            self.reporter.report(loss_data, self.acc, self.ppl, bleu)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
예제 #9
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)

        # 3. compute attention loss
        self.loss = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # 4. compute corpus-level bleu in a mini-batch
        if self.training:
            self.bleu = None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        loss_data = float(self.loss)
        if self.normalize_length:
            self.ppl = np.exp(loss_data)
        else:
            batch_size = ys_out_pad.size(0)
            ys_out_pad = ys_out_pad.view(-1)
            ignore = ys_out_pad == self.ignore_id  # (B*T,)
            total_n_tokens = len(ys_out_pad) - ignore.sum().item()
            self.ppl = np.exp(loss_data * batch_size / total_n_tokens)
        if not math.isnan(loss_data):
            self.reporter.report(loss_data, self.acc, self.ppl, self.bleu)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
예제 #10
0
    def forward_ilm(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
            speech: (Batch, Length, ...) not nessesary it is only used to get device of tensor
            speech_lengths: (Batch, )   not nessesary it is only used to get device of tensor
            text: (Batch, Length)
            text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (text.shape[0] == text_lengths.shape[0]), (text.shape,
                                                          text_lengths.shape)
        batch_size = text.shape[0]

        # for data-parallel
        text = text[:, :text_lengths.max()]

        ys_in_pad, ys_out_pad = add_sos_eos(text, self.sos, self.eos,
                                            self.ignore_id)
        ys_in_lens = text_lengths + 1

        fake_encoder_out = speech.new_zeros(batch_size, 1,
                                            self.encoder._output_size)
        # 1. Forward decoder
        decoder_out, _ = self.decoder.forward_ilm(fake_encoder_out, -1,
                                                  ys_in_pad, ys_in_lens)

        # 2. Compute ilm loss
        loss_ilm = self.criterion_att(decoder_out, ys_out_pad)
        ilm_acc = th_accuracy(
            decoder_out.view(-1, self.vocab_size),
            ys_out_pad,
            ignore_label=self.ignore_id,
        )

        ilm_ppl = torch.exp(loss_ilm)
        stats = dict(ilm_loss=loss_ilm.detach(),
                     ilm_acc=ilm_acc,
                     ilm_ppl=ilm_ppl.detach())

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss_ilm, stats, batch_size),
                                               loss_ilm.device)
        return loss_ilm, stats, weight
예제 #11
0
    def decoder_and_attention(self, hs_pad, hs_mask, ys_pad, batch_size):
        """Forward decoder and attention loss."""
        # forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)

        # compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        acc = th_accuracy(pred_pad.view(-1, self.odim),
                          ys_out_pad,
                          ignore_label=self.ignore_id)
        return pred_pad, pred_mask, loss_att, acc
예제 #12
0
    def forward(self, xs_pad, ilens, ys_pad):
        """Compute scalar loss for backprop"""
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        ys_in_pad, ys_out_pad = self.add_sos_eos(ys_pad)
        ys_mask = self.target_mask(ys_in_pad)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)

        loss = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, pred_pad.size(-1)),
                               ys_out_pad,
                               ignore_label=self.ignore_id)
        self.reporter.report(loss=loss, acc=self.acc)
        return loss
예제 #13
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loss value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, : max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        hs_mask = hs_mask.transpose(1,2)
        hs_mask = hs_mask.repeat(1,1,256).type(torch.FloatTensor).to(hs_pad.device)
        hs_pad_masked = hs_pad * hs_mask
        logging.warning("hs_pad_masked.size()==>" + str(hs_pad_masked.size()))
        att_vec = self.att(hs_pad_masked)

        # att_vec = self.att(hs_pad)

        pred_pad = self.output(att_vec).unsqueeze(1)
        logging.warning("att_vec.size()==>" + str(att_vec.size()))
        logging.warning("pred_pad.size()==>" + str(pred_pad.size()))
        # compute loss
        self.loss = self.criterion(pred_pad, ys_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_pad,
                               ignore_label=self.ignore_id)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
             self.reporter.report(self.acc, loss_data)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
예제 #14
0
    def forward(self, xs_pad, ilens, ys_pad, ys_pad_src):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax)
        :return: ctc loss value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 0. Extract target language ID
        tgt_lang_ids = None
        if self.multilingual:
            tgt_lang_ids = ys_pad[:, 0:1]
            ys_pad = ys_pad[:, 1:]  # remove target language ID in the beggining

        # 1. forward encoder
        xs_pad = xs_pad[:, : max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        # replace <sos> with target language ID
        if self.replace_sos:
            ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)

        # 3. compute ST loss
        loss_att = self.criterion(pred_pad, ys_out_pad)

        self.acc = th_accuracy(
            pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id
        )

        # 4. compute corpus-level bleu in a mini-batch
        if self.training:
            self.bleu = None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # 5. compute auxiliary ASR loss
        loss_asr_att, acc_asr, loss_asr_ctc, cer_ctc, cer, wer = self.forward_asr(
            hs_pad, hs_mask, ys_pad_src
        )

        # 6. compute auxiliary MT loss
        loss_mt, acc_mt = 0.0, None
        if self.mt_weight > 0:
            loss_mt, acc_mt = self.forward_mt(
                ys_pad_src, ys_in_pad, ys_out_pad, ys_mask
            )

        asr_ctc_weight = self.mtlalpha
        self.loss = (
            (1 - self.asr_weight - self.mt_weight) * loss_att
            + self.asr_weight
            * (asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att)
            + self.mt_weight * loss_mt
        )
        loss_asr_data = float(
            asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att
        )
        loss_mt_data = None if self.mt_weight == 0 else float(loss_mt)
        loss_st_data = float(loss_att)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(
                loss_asr_data,
                loss_mt_data,
                loss_st_data,
                acc_asr,
                acc_mt,
                self.acc,
                cer_ctc,
                cer,
                wer,
                self.bleu,
                loss_data,
            )
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        # mlp moe forward
        cn_hs_pad, cn_hs_mask = self.cn_encoder(xs_pad, src_mask)
        en_hs_pad, en_hs_mask = self.en_encoder(xs_pad, src_mask)
        hs_mask = cn_hs_mask  # cn_hs_mask & en_hs_mask are identical
        # gated add module
        """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b)  #(B, T, 1)
            xs = lambda * cn_xs + (1-lambda) * en_xs 
        """
        hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1)
        lambda_ = self.aggregation_module(
            hs_pad)  # (B,T,1)/(B,T,D), range from (0, 1)
        hs_pad = lambda_ * cn_hs_pad + (1 - lambda_) * en_hs_pad
        self.hs_pad = hs_pad

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)

            # divide ys_pad into cn_ys & en_ys;
            # note that this target can directly pass to ctc module
            cn_ys, en_ys = partial_target(ys_pad, self.language_divider)
            cn_loss_ctc = self.cn_ctc(
                cn_hs_pad.view(batch_size, -1, self.adim), hs_len, cn_ys)
            en_loss_ctc = self.en_ctc(
                en_hs_pad.view(batch_size, -1, self.adim), hs_len, en_ys)
            loss_ctc = 0.5 * (cn_loss_ctc + en_loss_ctc)

        if self.mtlalpha == 1:
            self.loss_att, acc = None, None
        else:
            # 2. forward decoder
            ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                                self.ignore_id)
            ys_mask = target_mask(ys_in_pad, self.ignore_id)
            pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad,
                                               hs_mask)
            self.pred_pad = pred_pad

            # 3. compute attention loss
            loss_att = self.criterion(pred_pad, ys_out_pad)
            acc = th_accuracy(pred_pad.view(-1, self.odim),
                              ys_out_pad,
                              ignore_label=self.ignore_id)
        self.acc = acc

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, None, None, loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
예제 #16
0
    def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, tgt_lang_ids=None):
        """Decoder forward

        :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
        :param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
        :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
        :param int strm_idx: stream index indicates the index of decoding stream.
        :param torch.Tensor tgt_lang_ids: batch of target language id tensor (B, 1)
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy
        :rtype: float
        """
        # TODO(kan-bayashi): need to make more smart way
        ys = [y[y != self.ignore_id] for y in ys_pad]  # parse padded ys
        # attention index for the attention module
        # in SPA (speaker parallel attention), att_idx is used to select attention module. In other cases, it is 0.
        att_idx = min(strm_idx, len(self.att) - 1)

        # hlen should be list of integer
        hlens = list(map(int, hlens))

        self.loss = None
        # prepare input and output word sequences with sos/eos IDs
        eos = ys[0].new([self.eos])
        sos = ys[0].new([self.sos])
        if self.replace_sos:
            ys_in = [
                torch.cat([idx, y], dim=0) for idx, y in zip(tgt_lang_ids, ys)
            ]
        else:
            ys_in = [torch.cat([sos, y], dim=0) for y in ys]
        ys_out = [torch.cat([y, eos], dim=0) for y in ys]

        # padding for ys with -1
        # pys: utt x olen
        ys_in_pad = pad_list(ys_in, self.eos)
        ys_out_pad = pad_list(ys_out, self.ignore_id)

        # get dim, length info
        batch = ys_out_pad.size(0)
        olength = ys_out_pad.size(1)
        logging.info(self.__class__.__name__ + ' input lengths:  ' +
                     str(hlens))
        logging.info(self.__class__.__name__ + ' output lengths: ' +
                     str([y.size(0) for y in ys_out]))

        # initialization
        c_list = [self.zero_state(hs_pad)]
        z_list = [self.zero_state(hs_pad)]
        for _ in six.moves.range(1, self.dlayers):
            c_list.append(self.zero_state(hs_pad))
            z_list.append(self.zero_state(hs_pad))
        att_w = None
        z_all = []
        self.att[att_idx].reset()  # reset pre-computation of h

        # pre-computation of embedding
        eys = self.dropout_emb(self.embed(ys_in_pad))  # utt x olen x zdim

        # loop for an output sequence
        for i in six.moves.range(olength):
            att_c, att_w = self.att[att_idx](hs_pad, hlens,
                                             self.dropout_dec[0](z_list[0]),
                                             att_w)
            if i > 0 and random.random() < self.sampling_probability:
                logging.info(' scheduled sampling ')
                z_out = self.output(z_all[-1])
                z_out = np.argmax(z_out.detach().cpu(), axis=1)
                z_out = self.dropout_emb(self.embed(to_device(self, z_out)))
                ey = torch.cat((z_out, att_c), dim=1)  # utt x (zdim + hdim)
            else:
                ey = torch.cat((eys[:, i, :], att_c),
                               dim=1)  # utt x (zdim + hdim)
            z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list,
                                              c_list)
            if self.context_residual:
                z_all.append(
                    torch.cat((self.dropout_dec[-1](z_list[-1]), att_c),
                              dim=-1))  # utt x (zdim + hdim)
            else:
                z_all.append(self.dropout_dec[-1](z_list[-1]))  # utt x (zdim)

        z_all = torch.stack(z_all, dim=1).view(batch * olength, -1)
        # compute loss
        y_all = self.output(z_all)
        if LooseVersion(torch.__version__) < LooseVersion('1.0'):
            reduction_str = 'elementwise_mean'
        else:
            reduction_str = 'mean'
        self.loss = F.cross_entropy(y_all,
                                    ys_out_pad.view(-1),
                                    ignore_index=self.ignore_id,
                                    reduction=reduction_str)
        # -1: eos, which is removed in the loss computation
        self.loss *= (np.mean([len(x) for x in ys_in]) - 1)
        acc = th_accuracy(y_all, ys_out_pad, ignore_label=self.ignore_id)
        logging.info('att loss:' + ''.join(str(self.loss.item()).split('\n')))

        # compute perplexity
        ppl = np.exp(self.loss.item() * np.mean([len(x) for x in ys_in]) /
                     np.sum([len(x) for x in ys_in]))

        # show predicted character sequence for debug
        if self.verbose > 0 and self.char_list is not None:
            ys_hat = y_all.view(batch, olength, -1)
            ys_true = ys_out_pad
            for (i, y_hat), y_true in zip(
                    enumerate(ys_hat.detach().cpu().numpy()),
                    ys_true.detach().cpu().numpy()):
                if i == MAX_DECODER_OUTPUT:
                    break
                idx_hat = np.argmax(y_hat[y_true != self.ignore_id], axis=1)
                idx_true = y_true[y_true != self.ignore_id]
                seq_hat = [self.char_list[int(idx)] for idx in idx_hat]
                seq_true = [self.char_list[int(idx)] for idx in idx_true]
                seq_hat = "".join(seq_hat)
                seq_true = "".join(seq_true)
                logging.info("groundtruth[%d]: " % i + seq_true)
                logging.info("prediction [%d]: " % i + seq_hat)

        if self.labeldist is not None:
            if self.vlabeldist is None:
                self.vlabeldist = to_device(self,
                                            torch.from_numpy(self.labeldist))
            loss_reg = -torch.sum(
                (F.log_softmax(y_all, dim=1) * self.vlabeldist).view(-1),
                dim=0) / len(ys_in)
            self.loss = (
                1. - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg

        return self.loss, acc, ppl
예제 #17
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask, penultimate_state = self.decoder(
            ys_in_pad, ys_mask, hs_pad, hs_mask, return_penultimate_state=True)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)
        # 4. compute lid multitask loss
        src_att = self.lid_src_att(penultimate_state, hs_pad, hs_pad, hs_mask)
        pred_lid_pad = self.lid_output_layer(src_att)
        loss_lid, lid_ys_out_pad = self.lid_criterion(pred_lid_pad, ys_out_pad)
        lid_acc = th_accuracy(
            pred_lid_pad.view(-1, self.lid_odim),
            lid_ys_out_pad,
            ignore_label=self.ignore_id) if self.log_lid_mtl_acc else None

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1,
                                                     self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(),
                                                ys_pad.cpu(),
                                                is_ctc=True)

        # 5. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        lid_alpha = self.lid_mtl_alpha
        if alpha == 0:
            self.loss = loss_att + lid_alpha * loss_lid
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            raise Exception("LID MTL not supports pure ctc mode")
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (
                1 - alpha) * loss_att + lid_alpha * loss_lid
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data, lid_acc)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
예제 #18
0
    def forward(self, xs_pad, ilens, ys_pad, ys_pad_mono=None):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
        :return: loss value
        :rtype: torch.Tensor
        """
        # 0. Frontend
        if self.frontend is not None:
            hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
            hs_pad, hlens = self.feature_transform(hs_pad, hlens)
        else:
            hs_pad, hlens = xs_pad, ilens

        # 1. RNN Encoder
        hs_pad, hlens, _ = self.enc(hs_pad, hlens)

        # 2. post-processing layer for target dimension
        pred_pad = self.poster(hs_pad)
        pred_pad = pred_pad.view(pred_pad.size(0), -1, self.odim)
        self.pred_pad = pred_pad
        if pred_pad.size(1) != ys_pad.size(1):
            if pred_pad.size(1) < ys_pad.size(1):
                ys_pad = ys_pad[:, :pred_pad.size(1)].contiguous()
            else:
                raise ValueError(
                    "target size {} and pred size {} is mismatch".format(
                        ys_pad.size(1), pred_pad.size(1)))

        if ys_pad_mono is not None:
            pred_pad_mono = self.poster_mono(hs_pad)
            pred_pad_mono = pred_pad_mono.view(pred_pad_mono.size(0), -1,
                                               self.mono_odim)
            self.pred_pad_mono = pred_pad_mono
            if pred_pad_mono.size(1) != ys_pad_mono.size(1):
                if pred_pad_mono.size(1) < ys_pad_mono.size(1):
                    ys_pad_mono = ys_pad_mono[:, :pred_pad_mono.
                                              size(1)].contiguous()
                else:
                    raise ValueError(
                        "target size {} and pred size {} is mismatch".format(
                            ys_pad_mono.size(1), pred_pad_mono.size(1)))

        # 3. CTC loss
        if self.mtlalpha == 0:
            self.loss_ctc = None
        else:
            self.loss_ctc = self.ctc(pred_pad, hlens, ys_pad)

        # 3. CE loss
        if LooseVersion(torch.__version__) < LooseVersion("1.0"):
            reduction_str = "elementwise_mean"
        else:
            reduction_str = "mean"
        self.loss_ce_tri = F.cross_entropy(
            pred_pad.view(-1, self.odim),
            ys_pad.view(-1),
            ignore_index=self.ignore_id,
            reduction=reduction_str,
        )
        if ys_pad_mono is not None:
            self.loss_ce_mono = F.cross_entropy(
                pred_pad_mono.view(-1, self.odim),
                ys_pad_mono.view(-1),
                ignore_index=self.ignore_id,
                reduction=reduction_str,
            )
        else:
            self.loss_ce_mono = 0
        self.loss_ce = 0.6 * self.loss_ce_tri + 0.4 * self.loss_ce_mono
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_pad,
                               ignore_label=self.ignore_id)

        # 4. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer, cer_ctc = None, None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
            cer_ctc = self.error_calculator(ys_hat.cpu(),
                                            ys_pad.cpu(),
                                            is_ctc=True)

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = self.loss_ce
            loss_ce_data = float(self.loss_ce)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = self.loss_ctc
            loss_ce_data = None
            loss_ctc_data = float(self.loss_ctc)
        else:
            self.loss = alpha * self.loss_ctc + (1 - alpha) * self.loss_ce
            loss_ce_data = float(self.loss_ce)
            loss_ctc_data = float(self.loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_ce_data, self.acc,
                                 cer_ctc, cer, wer, loss_data)
        else:
            pass
        return self.loss
예제 #19
0
    def forward(self, xs_pad, ilens, ys_pad, enc_mask=None, dec_mask=None):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        batch_size = xs_pad.shape[0]
        src_mask = make_non_pad_mask(ilens.tolist()).to(
            xs_pad.device).unsqueeze(-2)
        if isinstance(self.encoder.embed, EncoderConv2d):
            xs, hs_mask = self.encoder.embed(xs_pad,
                                             torch.sum(src_mask, 2).squeeze())
            hs_mask = hs_mask.unsqueeze(1)
        else:
            xs, hs_mask = self.encoder.embed(xs_pad, src_mask)

        if enc_mask is not None:
            enc_mask = enc_mask[:, :hs_mask.shape[2], :hs_mask.shape[2]]
        enc_mask = enc_mask & hs_mask if enc_mask is not None else hs_mask
        hs_pad, _ = self.encoder.encoders(xs, enc_mask)
        if self.encoder.normalize_before:
            hs_pad = self.encoder.after_norm(hs_pad)

        # CTC forward
        ys = [y[y != self.ignore_id] for y in ys_pad]
        y_len = max([len(y) for y in ys])
        ys_pad = ys_pad[:, :y_len]
        if dec_mask is not None:
            dec_mask = dec_mask[:, :y_len + 1, :hs_pad.shape[1]]
        self.hs_pad = hs_pad
        batch_size = xs_pad.size(0)
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)

        # trigger mask
        hs_mask = hs_mask & dec_mask if dec_mask is not None else hs_mask
        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        return self.loss, loss_ctc_data, loss_att_data, self.acc
예제 #20
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward Transformer encoder
        xs_pad = xs_pad[:, : max(ilens)]  # for data parallel
        if xs_pad.size(1) != ys_pad.size(1):
            if xs_pad.size(1) < ys_pad.size(1):
                ys_pad = ys_pad[:, :xs_pad.size(1)].contiguous()
            else:
                raise ValueError("target size {} is smaller than input size {}".format(ys_pad.size(1), xs_pad.size(1)))
        src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        # 2. post-processing layer for target dimension
        if self.outer:
            post_pad = self.poster(hs_pad)
            post_pad = post_pad.view(post_pad.size(0), -1, self.odim)
            if post_pad.size(1) != xs_pad.size(1):
                if post_pad.size(1) < xs_pad.size(1):
                    xs_pad = xs_pad[:, :post_pad.size(1)].contiguous()
                else:
                    raise ValueError("target size {} and pred size {} is mismatch".format(xs_pad.size(1), post_pad.size(1)))
            if self.residual:
                post_pad = post_pad + self.matcher_res(xs_pad)
            else:
                post_pad = torch.cat([post_pad, xs_pad], dim=-1)
            pred_pad = self.matcher(post_pad)
        else:
            pred_pad = self.poster(hs_pad)
            pred_pad = pred_pad.view(pred_pad.size(0), -1, self.odim)
        self.pred_pad = pred_pad
        if pred_pad.size(1) != ys_pad.size(1):
            if pred_pad.size(1) < ys_pad.size(1):
                ys_pad = ys_pad[:, :pred_pad.size(1)].contiguous()
            else:
                raise ValueError("target size {} and pred size {} is mismatch".format(ys_pad.size(1), pred_pad.size(1)))

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_pad)
        self.acc = th_accuracy(
            pred_pad.view(-1, self.odim), ys_pad, ignore_label=self.ignore_id
        )

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(pred_pad.view(batch_size, -1, self.adim), hs_len, ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(pred_pad.view(batch_size, -1, self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)

        # 3. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(
                loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data
            )
        else:
            pass
            # logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
예제 #21
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loss value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        # 2. forward decoder
        ys_in_pad, ys_out_pad = mask_uniform(ys_pad, self.mask_token, self.eos,
                                             self.ignore_id)
        ys_mask = square_mask(ys_in_pad, self.eos)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # 4. compute ctc loss
        loss_ctc, cer_ctc = None, None
        loss_intermediate_ctc = 0.0
        if self.mtlalpha > 0:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1,
                                                     self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(),
                                                ys_pad.cpu(),
                                                is_ctc=True)
            # for visualization
            if not self.training:
                self.ctc.softmax(hs_pad)
            if self.intermediate_ctc_weight > 0 and self.intermediate_ctc_layers:
                for hs_intermediate in hs_intermediates:
                    # assuming hs_intermediates and hs_pad has same length / padding
                    loss_inter = self.ctc(
                        hs_intermediate.view(batch_size, -1, self.adim),
                        hs_len, ys_pad)
                    loss_intermediate_ctc += loss_inter

                loss_intermediate_ctc /= len(self.intermediate_ctc_layers)

        # 5. compute cer/wer
        if self.training or self.error_calculator is None or self.decoder is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        else:
            self.loss = (alpha * loss_ctc +
                         self.intermediate_ctc_weight * loss_intermediate_ctc +
                         (1 - alpha - self.intermediate_ctc_weight) * loss_att)
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
    def forward(self, xs_pad, ilens, ys_pad, ys_pad_src):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 0. Extract target language ID
        # src_lang_ids = None
        tgt_lang_ids, tgt_lang_ids_src = None, None
        if self.multilingual:
            tgt_lang_ids = ys_pad[:, 0:1]
            ys_pad = ys_pad[:, 1:]  # remove target language ID in the beggining

        if self.one_to_many:
            tgt_lang_ids = ys_pad[:, 0:1]
            ys_pad = ys_pad[:, 1:]  # remove target language ID in the beggining
            
            if self.do_asr:
                tgt_lang_ids_src = ys_pad_src[:, 0:1]
                ys_pad_src = ys_pad_src[:, 1:]  # remove target language ID in the beggining

        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel # bs x max_ilens x idim

        if self.lang_tok == "encoder-pre-sum":
            lang_embed = self.language_embeddings(tgt_lang_ids) # bs x 1 x idim
            xs_pad = xs_pad + lang_embed

        src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2) # bs x 1 x max_ilens
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # hs_pad: bs x (max_ilens/4) x adim; hs_mask: bs x 1 x (max_ilens/4)
        self.hs_pad = hs_pad

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # bs x max_lens

        if self.do_asr:
            ys_in_pad_src, ys_out_pad_src = add_sos_eos(ys_pad_src, self.sos, self.eos, self.ignore_id) # bs x max_lens_src

        # replace <sos> with target language ID
        if self.replace_sos:
            ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1)

        if self.lang_tok == "decoder-pre":
            ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1)
            if self.do_asr:
                ys_in_pad_src = torch.cat([tgt_lang_ids_src, ys_in_pad_src[:, 1:]], dim=1)

        ys_mask = target_mask(ys_in_pad, self.ignore_id) # bs x max_lens x max_lens

        if self.do_asr:
            ys_mask_src = target_mask(ys_in_pad_src, self.ignore_id) # bs x max_lens x max_lens_src

        if self.wait_k_asr > 0:
            cross_mask = create_cross_mask(ys_in_pad, ys_in_pad_src, self.ignore_id, wait_k_cross=self.wait_k_asr)
            cross_mask_asr = create_cross_mask(ys_in_pad_src, ys_in_pad, self.ignore_id, wait_k_cross=-self.wait_k_asr)
        elif self.wait_k_st > 0:
            cross_mask = create_cross_mask(ys_in_pad, ys_in_pad_src, self.ignore_id, wait_k_cross=-self.wait_k_st)
            cross_mask_asr = create_cross_mask(ys_in_pad_src, ys_in_pad, self.ignore_id, wait_k_cross=self.wait_k_st)
        else:
            cross_mask = create_cross_mask(ys_in_pad, ys_in_pad_src, self.ignore_id, wait_k_cross=0)
            cross_mask_asr = create_cross_mask(ys_in_pad_src, ys_in_pad, self.ignore_id, wait_k_cross=0)

        pred_pad, pred_mask, pred_pad_asr, pred_mask_asr = self.dual_decoder(ys_in_pad, ys_mask, ys_in_pad_src, ys_mask_src,
                                                                                hs_pad, hs_mask, cross_mask, cross_mask_asr,
                                                                                cross_self=self.cross_self, cross_src=self.cross_src,
                                                                                cross_self_from=self.cross_self_from,
                                                                                cross_src_from=self.cross_src_from)

        self.pred_pad = pred_pad
        self.pred_pad_asr = pred_pad_asr
        pred_pad_mt = None

        # 3. compute attention loss
        loss_asr, loss_mt = 0.0, 0.0
        loss_att = self.criterion(pred_pad, ys_out_pad)

        # compute loss
        loss_asr = self.criterion(pred_pad_asr, ys_out_pad_src)
        # Multi-task w/ MT
        if self.mt_weight > 0:
            # forward MT encoder
            ilens_mt = torch.sum(ys_pad_src != self.ignore_id, dim=1).cpu().numpy()
            # NOTE: ys_pad_src is padded with -1
            ys_src = [y[y != self.ignore_id] for y in ys_pad_src]  # parse padded ys_src
            ys_zero_pad_src = pad_list(ys_src, self.pad)  # re-pad with zero
            ys_zero_pad_src = ys_zero_pad_src[:, :max(ilens_mt)]  # for data parallel
            src_mask_mt = (~make_pad_mask(ilens_mt.tolist())).to(ys_zero_pad_src.device).unsqueeze(-2)
            # ys_zero_pad_src, ys_pad = self.target_forcing(ys_zero_pad_src, ys_pad)
            hs_pad_mt, hs_mask_mt = self.encoder_mt(ys_zero_pad_src, src_mask_mt)
            # forward MT decoder
            pred_pad_mt, _ = self.decoder(ys_in_pad, ys_mask, hs_pad_mt, hs_mask_mt)
            # compute loss
            loss_mt = self.criterion(pred_pad_mt, ys_out_pad)

        self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad,
                               ignore_label=self.ignore_id)
        if pred_pad_asr is not None:
            self.acc_asr = th_accuracy(pred_pad_asr.view(-1, self.odim), ys_out_pad_src,
                                       ignore_label=self.ignore_id)
        else:
            self.acc_asr = 0.0
        if pred_pad_mt is not None:
            self.acc_mt = th_accuracy(pred_pad_mt.view(-1, self.odim), ys_out_pad,
                                      ignore_label=self.ignore_id)
        else:
            self.acc_mt = 0.0

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0 or self.asr_weight == 0:
            loss_ctc = 0.0
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad_src)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad_src.cpu(), is_ctc=True)

        # 5. compute cer/wer
        cer, wer = None, None  # TODO(hirofumi0810): fix later
        # if self.training or (self.asr_weight == 0 or self.mtlalpha == 1 or not (self.report_cer or self.report_wer)):
        #     cer, wer = None, None
        # else:
        #     ys_hat = pred_pad.argmax(dim=-1)
        #     cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        self.loss = (1 - self.asr_weight - self.mt_weight) * loss_att + self.asr_weight * \
            (alpha * loss_ctc + (1 - alpha) * loss_asr) + self.mt_weight * loss_mt
        loss_asr_data = float(alpha * loss_ctc + (1 - alpha) * loss_asr)
        loss_mt_data = None if self.mt_weight == 0 else float(loss_mt)
        loss_st_data = float(loss_att)

        # logging.info(f'loss_st_data={loss_st_data}')

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_asr_data, loss_mt_data, loss_st_data,
                                 self.acc_asr, self.acc_mt, self.acc,
                                 cer_ctc, cer, wer, 0.0,  # TODO(hirofumi0810): bleu
                                 loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
예제 #23
0
def recog(args):
    """Decode with the given args.

    Args:
        args (namespace): The program arguments.
    """
    set_deterministic_pytorch(args)
    model, train_args = load_trained_model(args.model)
    assert isinstance(model, ASRInterface)
    model.recog_args = args

    # gpu
    if args.ngpu == 1:
        gpu_id = list(range(args.ngpu))
        logging.info('gpu id: ' + str(gpu_id))
        model.cuda()
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)


    # read json data
    with open(args.recog_json, 'rb') as f:
        js = json.load(f)['utts']

    load_inputs_and_targets = LoadInputsAndTargets(
        mode='asr', load_output=True, sort_in_input_length=False,
        preprocess_conf=train_args.preprocess_conf
        if args.preprocess_conf is None else args.preprocess_conf,
        preprocess_args={'train': False})
    import kaldiio
    import time
    with torch.no_grad(), \
            kaldiio.WriteHelper('ark,scp:{o}.ark,{o}.scp'.format(o=args.out)) as f:
        ys = []
        xs = []
        for idx, utt_id in enumerate(js.keys()):
            logging.info('(%d/%d) decoding ' + utt_id, idx, len(js.keys()))
            batch = [(utt_id, js[utt_id])]
            data = load_inputs_and_targets(batch)
            feat = data[0][0]
            ys.append(data[1][0])
            # x = torch.LongTensor(x).to(device)

            # decode and write
            start_time = time.time()
            # include the inference here
            # have the layer specification here
            # skeleton model.inference(x, args, layer)
            scores, outs = model.inference(feat, ys, args, train_args.char_list)
            xs.append(scores)
            logging.info("inference speed = %s msec / frame." % (
                (time.time() - start_time) / (int(outs.size(0)) * 1000)))
            logging.warning("output length reaches maximum length (%s)." % utt_id)
            logging.info('(%d/%d) %s (size:%d->%d)' % (
                idx + 1, len(js.keys()), utt_id, len(feat), outs.size(0)))
            f[utt_id] = outs.cpu().numpy()
        from espnet.nets.pytorch_backend.nets_utils import th_accuracy
        preds = torch.stack(xs).view(len(xs), -1)
        labels = torch.LongTensor(ys).view(len(xs), 1)
        acc = th_accuracy(preds, labels, -1)
        logging.warn("Final acc is (%.2f)" % (acc*100))
예제 #24
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        # mlp moe forward
        cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask)
        en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask)
        # gated add module
        """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b)  #(B, T, 1)
            xs = lambda * cn_xs + (1-lambda) * en_xs 
        """
        hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1)
        lambda_ = self.enc_lambda
        hs_pad = lambda_ * cn_hs_pad + (1 - lambda_) * en_hs_pad
        self.hs_pad = hs_pad

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1,
                                                     self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(),
                                                ys_pad.cpu(),
                                                is_ctc=True)

        if self.mtlalpha == 1:
            self.loss_att, acc = None, None
        else:
            # 2. forward decoder
            ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                                self.ignore_id)
            ys_mask = target_mask(ys_in_pad, self.ignore_id)
            pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad,
                                               hs_mask)
            self.pred_pad = pred_pad

            # 3. compute attention loss
            loss_att = self.criterion(pred_pad, ys_out_pad)
            acc = th_accuracy(pred_pad.view(-1, self.odim),
                              ys_out_pad,
                              ignore_label=self.ignore_id)
        self.acc = acc

        # 5. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
예제 #25
0
    def forward(self, xs_pad, ilens, ys_pad, ys_pad_src):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 0. Extract target language ID
        tgt_lang_ids = None
        if self.multilingual:
            tgt_lang_ids = ys_pad[:, 0:1]
            ys_pad = ys_pad[:,
                            1:]  # remove target language ID in the beggining

        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        # replace <sos> with target language ID
        if self.replace_sos:
            ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)

        # 3. compute ST loss
        loss_asr_att, loss_asr_ctc, loss_mt = 0.0, 0.0, 0.0
        acc_asr, acc_mt = 0.0, 0.0
        loss_att = self.criterion(pred_pad, ys_out_pad)

        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # 4. compute corpus-level bleu in a mini-batch
        if self.training:
            self.bleu = None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # 5. compute auxiliary ASR loss
        cer, wer = None, None
        cer_ctc = None
        if self.asr_weight > 0:
            # attention
            if self.mtlalpha < 1:
                ys_in_pad_asr, ys_out_pad_asr = add_sos_eos(
                    ys_pad_src, self.sos, self.eos, self.ignore_id)
                ys_mask_asr = target_mask(ys_in_pad_asr, self.ignore_id)
                pred_pad_asr, _ = self.decoder_asr(ys_in_pad_asr, ys_mask_asr,
                                                   hs_pad, hs_mask)
                loss_asr_att = self.criterion(pred_pad_asr, ys_out_pad_asr)

                acc_asr = th_accuracy(
                    pred_pad_asr.view(-1, self.odim),
                    ys_out_pad_asr,
                    ignore_label=self.ignore_id,
                )
                if not self.training:
                    ys_hat_asr = pred_pad_asr.argmax(dim=-1)
                    cer, wer = self.error_calculator_asr(
                        ys_hat_asr.cpu(), ys_pad_src.cpu())

            # CTC
            if self.mtlalpha > 0:
                batch_size = xs_pad.size(0)
                hs_len = hs_mask.view(batch_size, -1).sum(1)
                loss_asr_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim),
                                        hs_len, ys_pad_src)
                ys_hat_ctc = self.ctc.argmax(
                    hs_pad.view(batch_size, -1, self.adim)).data
                if not self.training:
                    cer_ctc = self.error_calculator_asr(ys_hat_ctc.cpu(),
                                                        ys_pad_src.cpu(),
                                                        is_ctc=True)

        # 6. compute auxiliary MT loss
        if self.mt_weight > 0:
            ilens_mt = torch.sum(ys_pad_src != self.ignore_id,
                                 dim=1).cpu().numpy()
            # NOTE: ys_pad_src is padded with -1
            ys_src = [y[y != self.ignore_id]
                      for y in ys_pad_src]  # parse padded ys_src
            ys_zero_pad_src = pad_list(ys_src, self.pad)  # re-pad with zero
            ys_zero_pad_src = ys_zero_pad_src[:, :max(
                ilens_mt)]  # for data parallel
            src_mask_mt = ((~make_pad_mask(ilens_mt.tolist())).to(
                ys_zero_pad_src.device).unsqueeze(-2))
            hs_pad_mt, hs_mask_mt = self.encoder_mt(ys_zero_pad_src,
                                                    src_mask_mt)
            pred_pad_mt, _ = self.decoder(ys_in_pad, ys_mask, hs_pad_mt,
                                          hs_mask_mt)
            loss_mt = self.criterion(pred_pad_mt, ys_out_pad)

            acc_mt = th_accuracy(pred_pad_mt.view(-1, self.odim),
                                 ys_out_pad,
                                 ignore_label=self.ignore_id)

        alpha = self.mtlalpha
        self.loss = ((1 - self.asr_weight - self.mt_weight) * loss_att +
                     self.asr_weight * (alpha * loss_asr_ctc +
                                        (1 - alpha) * loss_asr_att) +
                     self.mt_weight * loss_mt)
        loss_asr_data = float(alpha * loss_asr_ctc +
                              (1 - alpha) * loss_asr_att)
        loss_mt_data = None if self.mt_weight == 0 else float(loss_mt)
        loss_st_data = float(loss_att)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(
                loss_asr_data,
                loss_mt_data,
                loss_st_data,
                acc_asr,
                acc_mt,
                self.acc,
                cer_ctc,
                cer,
                wer,
                self.bleu,
                loss_data,
            )
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
예제 #26
0
    def forward(self, xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        moe_coes = moe_coes[:, :max(moe_coe_lens)].long()  # for data parallel
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        # mlp moe forward
        cn_hs_pad, cn_hs_mask = self.cn_encoder(xs_pad, src_mask)
        en_hs_pad, en_hs_mask = self.en_encoder(xs_pad, src_mask)
        # gated add module
        """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b)  #(B, T, 1)
            xs = lambda * cn_xs + (1-lambda) * en_xs 
        """
        hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1)
        lambda_ = F.softmax(
            self.aggre_scaling * self.aggregation_module(hs_pad),
            -1).unsqueeze(-1)
        ctc_hs_pad = lambda_[:, :, 0] * cn_hs_pad + lambda_[:, :,
                                                            1] * en_hs_pad
        ctc_hs_mask = cn_hs_mask

        # plat attention mode, (B,T,D)*2 --> (B,2T,D)
        s2s_hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=1)
        # mask: (B,1,T) --> (B,1,2T)
        s2s_hs_mask = torch.cat((cn_hs_mask, en_hs_mask), dim=-1)
        # self.hs_pad = hs_pad

        # compute lid loss here, using lambda_
        # moe_coes (B, T, 2) ==> (B,T)
        moe_coes = moe_coes[:, :, 0]  # 0 for cn, 1 for en
        lambda_ = lambda_.squeeze(-1)
        if self.lid_mtl_alpha == 0.0:
            loss_lid = 0.0
        else:
            loss_lid = self.lid_criterion(lambda_, moe_coes)
        lid_acc = th_accuracy(
            lambda_.view(-1, 2), moe_coes,
            ignore_label=self.ignore_id) if self.log_lid_mtl_acc else None
        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = ctc_hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(ctc_hs_pad.view(batch_size, -1, self.adim),
                                hs_len, ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(
                    ctc_hs_pad.view(batch_size, -1, self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(),
                                                ys_pad.cpu(),
                                                is_ctc=True)

        if self.mtlalpha == 1:
            self.loss_att, acc = None, None
        else:
            # 2. forward decoder
            ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                                self.ignore_id)
            ys_mask = target_mask(ys_in_pad, self.ignore_id)
            pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, s2s_hs_pad,
                                               s2s_hs_mask)
            self.pred_pad = pred_pad

            # 3. compute attention loss
            loss_att = self.criterion(pred_pad, ys_out_pad)
            acc = th_accuracy(pred_pad.view(-1, self.odim),
                              ys_out_pad,
                              ignore_label=self.ignore_id)
        self.acc = acc

        # 5. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        lid_alpha = self.lid_mtl_alpha
        if alpha == 0:
            self.loss = loss_att + lid_alpha * loss_lid
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc + lid_alpha * loss_lid
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (
                1 - alpha) * loss_att + lid_alpha * loss_lid
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data, lid_acc)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
예제 #27
0
    def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None):
        """Decoder forward

        :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
                                    [in multi-encoder case,
                                    list of torch.Tensor,
                                    [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ]
        :param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
                                   [in multi-encoder case, list of torch.Tensor,
                                   [(B), (B), ..., ]
        :param torch.Tensor ys_pad: batch of padded character id sequence tensor
                                    (B, Lmax)
        :param int strm_idx: stream index indicates the index of decoding stream.
        :param torch.Tensor lang_ids: batch of target language id tensor (B, 1)
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy
        :rtype: float
        """
        # to support mutiple encoder asr mode, in single encoder mode,
        # convert torch.Tensor to List of torch.Tensor
        if self.num_encs == 1:
            hs_pad = [hs_pad]
            hlens = [hlens]

        # TODO(kan-bayashi): need to make more smart way
        ys = [y[y != self.ignore_id] for y in ys_pad]  # parse padded ys
        # attention index for the attention module
        # in SPA (speaker parallel attention),
        # att_idx is used to select attention module. In other cases, it is 0.
        att_idx = min(strm_idx, len(self.att) - 1)

        # hlens should be list of list of integer
        hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]

        self.loss = None
        # prepare input and output word sequences with sos/eos IDs
        eos = ys[0].new([self.eos])
        sos = ys[0].new([self.sos])
        if self.replace_sos:
            ys_in = [
                torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)
            ]
        else:
            ys_in = [torch.cat([sos, y], dim=0) for y in ys]
        ys_out = [torch.cat([y, eos], dim=0) for y in ys]

        # padding for ys with -1
        # pys: utt x olen
        ys_in_pad = pad_list(ys_in, self.eos)
        ys_out_pad = pad_list(ys_out, self.ignore_id)
        # get dim, length info
        batch = ys_out_pad.size(0)
        olength = ys_out_pad.size(1)
        for idx in range(self.num_encs):
            logging.info(self.__class__.__name__ +
                         "Number of Encoder:{}; enc{}: input lengths: {}.".
                         format(self.num_encs, idx + 1, hlens[idx]))
        logging.info(self.__class__.__name__ + " output lengths: " +
                     str([y.size(0) for y in ys_out]))

        # initialization
        c_list = [self.zero_state(hs_pad[0])]
        z_list = [self.zero_state(hs_pad[0])]
        for _ in six.moves.range(1, self.dlayers):
            c_list.append(self.zero_state(hs_pad[0]))
            z_list.append(self.zero_state(hs_pad[0]))
        z_all = []
        if self.num_encs == 1:
            att_w = None
            self.att[att_idx].reset()  # reset pre-computation of h
        else:
            att_w_list = [None] * (self.num_encs + 1)  # atts + han
            att_c_list = [None] * (self.num_encs)  # atts
            for idx in range(self.num_encs + 1):
                self.att[idx].reset(
                )  # reset pre-computation of h in atts and han

        # pre-computation of embedding
        eys = self.dropout_emb(self.embed(ys_in_pad))  # utt x olen x zdim

        # loop for an output sequence
        for i in six.moves.range(olength):
            if self.num_encs == 1:
                att_c, att_w = self.att[att_idx](
                    hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w)
            else:
                for idx in range(self.num_encs):
                    att_c_list[idx], att_w_list[idx] = self.att[idx](
                        hs_pad[idx],
                        hlens[idx],
                        self.dropout_dec[0](z_list[0]),
                        att_w_list[idx],
                    )
                hs_pad_han = torch.stack(att_c_list, dim=1)
                hlens_han = [self.num_encs] * len(ys_in)
                att_c, att_w_list[self.num_encs] = self.att[self.num_encs](
                    hs_pad_han,
                    hlens_han,
                    self.dropout_dec[0](z_list[0]),
                    att_w_list[self.num_encs],
                )
            if i > 0 and random.random() < self.sampling_probability:
                logging.info(" scheduled sampling ")
                z_out = self.output(z_all[-1])
                z_out = np.argmax(z_out.detach().cpu(), axis=1)
                z_out = self.dropout_emb(self.embed(to_device(self, z_out)))
                ey = torch.cat((z_out, att_c), dim=1)  # utt x (zdim + hdim)
            else:
                ey = torch.cat((eys[:, i, :], att_c),
                               dim=1)  # utt x (zdim + hdim)
            z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list,
                                              c_list)
            if self.context_residual:
                z_all.append(
                    torch.cat((self.dropout_dec[-1](z_list[-1]), att_c),
                              dim=-1))  # utt x (zdim + hdim)
            else:
                z_all.append(self.dropout_dec[-1](z_list[-1]))  # utt x (zdim)

        z_all = torch.stack(z_all, dim=1).view(batch * olength, -1)
        # compute loss
        y_all = self.output(z_all)
        logging.info("max:{}".format(torch.max(y_all, dim=1)))
        if LooseVersion(torch.__version__) < LooseVersion("1.0"):
            reduction_str = "elementwise_mean"
        else:
            reduction_str = "mean"
        self.loss = F.cross_entropy(
            y_all,
            ys_out_pad.view(-1),
            ignore_index=self.ignore_id,
            reduction=reduction_str,
        )
        logging.info("loss_att:{}".format(self.loss))
        # compute perplexity
        ppl = math.exp(self.loss.item())
        # -1: eos, which is removed in the loss computation
        self.loss *= np.mean([len(x) for x in ys_in]) - 1
        acc = th_accuracy(y_all, ys_out_pad, ignore_label=self.ignore_id)
        logging.info("att loss:" + "".join(str(self.loss.item()).split("\n")))

        # show predicted character sequence for debug
        if self.verbose > 0 and self.char_list is not None:
            ys_hat = y_all.view(batch, olength, -1)
            ys_true = ys_out_pad
            for (i, y_hat), y_true in zip(
                    enumerate(ys_hat.detach().cpu().numpy()),
                    ys_true.detach().cpu().numpy()):
                if i == MAX_DECODER_OUTPUT:
                    break
                idx_hat = np.argmax(y_hat[y_true != self.ignore_id], axis=1)
                idx_true = y_true[y_true != self.ignore_id]
                seq_hat = [self.char_list[int(idx)] for idx in idx_hat]
                seq_true = [self.char_list[int(idx)] for idx in idx_true]
                seq_hat = "".join(seq_hat)
                seq_true = "".join(seq_true)
                logging.info("groundtruth[%d]: " % i + seq_true)
                logging.info("prediction [%d]: " % i + seq_hat)

        if self.labeldist is not None:
            if self.vlabeldist is None:
                self.vlabeldist = to_device(self,
                                            torch.from_numpy(self.labeldist))
            loss_reg = -torch.sum(
                (F.log_softmax(y_all, dim=1) * self.vlabeldist).view(-1),
                dim=0) / len(ys_in)
            self.loss = (
                1.0 - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg

        return self.loss, acc, ppl
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        # CTC forward
        ys = [y[y != self.ignore_id] for y in ys_pad]
        y_len = max([len(y) for y in ys])
        ys_pad = ys_pad[:, :y_len]
        self.hs_pad = hs_pad
        cer_ctc = None
        batch_size = xs_pad.size(0)
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)

        # trigger mask
        start_time = time.time()
        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)

        return self.loss, loss_ctc_data, loss_att_data, self.acc
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loss value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, : max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        # 2. forward decoder
        if self.decoder is not None:
            if self.decoder_mode == "maskctc":
                ys_in_pad, ys_out_pad = mask_uniform(
                    ys_pad, self.mask_token, self.eos, self.ignore_id
                )
                ys_mask = (ys_in_pad != self.ignore_id).unsqueeze(-2)
            else:
                ys_in_pad, ys_out_pad = add_sos_eos(
                    ys_pad, self.sos, self.eos, self.ignore_id
                )
                ys_mask = target_mask(ys_in_pad, self.ignore_id)
            pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
            self.pred_pad = pred_pad

            # 3. compute attention loss
            loss_att = self.criterion(pred_pad, ys_out_pad)
            self.acc = th_accuracy(
                pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id
            )
        else:
            loss_att = None
            self.acc = None

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad)
            if not self.training and self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
            # for visualization
            if not self.training:
                self.ctc.softmax(hs_pad)

        # 5. compute cer/wer
        if self.training or self.error_calculator is None or self.decoder is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(
                loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data
            )
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
    def forward(self, xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        moe_coes = moe_coes[:, :max(moe_coe_lens)]  # for data parallel
        # here we use interpolation_coe to 'fix' initial moe_coes
        interp_factor = self.interp_factor  # 0.1 for example, similar to lsm
        moe_coes = (
            1 - interp_factor) * moe_coes + interp_factor / moe_coes.shape[2]

        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)

        # multi-encoder forward
        cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask)
        en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask)
        moe_coes = moe_coes.unsqueeze(-1)
        hs_pad = cn_hs_pad * moe_coes[:, :, 1] + en_hs_pad * moe_coes[:, :, 0]
        self.hs_pad = hs_pad

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1,
                                                     self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(),
                                                ys_pad.cpu(),
                                                is_ctc=True)

        if self.mtlalpha == 1:
            self.loss_att, acc = None, None
        else:
            # 2. forward decoder
            ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                                self.ignore_id)
            ys_mask = target_mask(ys_in_pad, self.ignore_id)
            pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad,
                                               hs_mask)
            self.pred_pad = pred_pad

            # 3. compute attention loss
            loss_att = self.criterion(pred_pad, ys_out_pad)
            acc = th_accuracy(pred_pad.view(-1, self.odim),
                              ys_out_pad,
                              ignore_label=self.ignore_id)
        self.acc = acc

        # 5. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss