Exemplo n.º 1
0
    def forward(self, xs_pad, ilens, ys_pad, ys_pad_src):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
        :return: loss value
        :rtype: torch.Tensor
        """
        # 0. Extract target language ID
        if self.multilingual:
            tgt_lang_ids = ys_pad[:, 0:1]
            ys_pad = ys_pad[:,
                            1:]  # remove target language ID in the beggining
        else:
            tgt_lang_ids = None

        # 1. Encoder
        hs_pad, hlens, _ = self.enc(xs_pad, ilens)

        # 2. ST attention loss
        self.loss_st, self.acc, _ = self.dec(hs_pad,
                                             hlens,
                                             ys_pad,
                                             lang_ids=tgt_lang_ids)

        # 2. ASR CTC loss
        if self.asr_weight == 0 or self.mtlalpha == 0:
            self.loss_ctc = 0.0
        else:
            self.loss_ctc = self.ctc(hs_pad, hlens, ys_pad_src)

        # 3. ASR attention loss
        if self.asr_weight == 0 or self.mtlalpha == 1:
            self.loss_asr = 0.0
            acc_asr = 0.0
        else:
            self.loss_asr, acc_asr, _ = self.dec_asr(hs_pad, hlens, ys_pad_src)
            acc_asr = acc_asr

        # 3. MT attention loss
        if self.mt_weight == 0:
            self.loss_mt = 0.0
            acc_mt = 0.0
        else:
            # ys_pad_src, ys_pad = self.target_forcing(ys_pad_src, ys_pad)
            ilens_mt = torch.sum(ys_pad_src != -1, dim=1).cpu().numpy()
            # NOTE: ys_pad_src is padded with -1
            ys_src = [y[y != -1] for y in ys_pad_src]  # parse padded ys_src
            ys_zero_pad_src = pad_list(ys_src, self.pad)  # re-pad with zero
            hs_pad_mt, hlens_mt, _ = self.enc_mt(
                self.dropout_mt(self.embed_mt(ys_zero_pad_src)), ilens_mt)
            self.loss_mt, acc_mt, _ = self.dec(hs_pad_mt, hlens_mt, ys_pad)
            acc_mt = acc_mt

        # 4. compute cer without beam search
        if (self.asr_weight == 0
                or self.mtlalpha == 0) or self.char_list is None:
            cer_ctc = None
        else:
            cers = []

            y_hats = self.ctc.argmax(hs_pad).data
            for i, y in enumerate(y_hats):
                y_hat = [x[0] for x in groupby(y)]
                y_true = ys_pad_src[i]

                seq_hat = [
                    self.char_list[int(idx)] for idx in y_hat if int(idx) != -1
                ]
                seq_true = [
                    self.char_list[int(idx)] for idx in y_true
                    if int(idx) != -1
                ]
                seq_hat_text = "".join(seq_hat).replace(self.space, " ")
                seq_hat_text = seq_hat_text.replace(self.blank, "")
                seq_true_text = "".join(seq_true).replace(self.space, " ")

                hyp_chars = seq_hat_text.replace(" ", "")
                ref_chars = seq_true_text.replace(" ", "")
                if len(ref_chars) > 0:
                    cers.append(
                        editdistance.eval(hyp_chars, ref_chars) /
                        len(ref_chars))

            cer_ctc = sum(cers) / len(cers) if cers else None

        # 5. compute cer/wer
        if self.training or (self.asr_weight == 0 or self.mtlalpha == 1
                             or not (self.report_cer or self.report_wer)):
            cer, wer = 0.0, 0.0
        else:
            if (self.asr_weight > 0 and
                    self.mtlalpha > 0) and self.recog_args.ctc_weight > 0.0:
                lpz = self.ctc.log_softmax(hs_pad).data
            else:
                lpz = None

            word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], []
            nbest_hyps_asr = self.dec_asr.recognize_beam_batch(
                hs_pad,
                torch.tensor(hlens),
                lpz,
                self.recog_args,
                self.char_list,
                self.rnnlm,
            )
            # remove <sos> and <eos>
            y_hats = [
                nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps_asr
            ]
            for i, y_hat in enumerate(y_hats):
                y_true = ys_pad[i]

                seq_hat = [
                    self.char_list[int(idx)] for idx in y_hat if int(idx) != -1
                ]
                seq_true = [
                    self.char_list[int(idx)] for idx in y_true
                    if int(idx) != -1
                ]
                seq_hat_text = "".join(seq_hat).replace(
                    self.recog_args.space, " ")
                seq_hat_text = seq_hat_text.replace(self.recog_args.blank, "")
                seq_true_text = "".join(seq_true).replace(
                    self.recog_args.space, " ")

                hyp_words = seq_hat_text.split()
                ref_words = seq_true_text.split()
                word_eds.append(editdistance.eval(hyp_words, ref_words))
                word_ref_lens.append(len(ref_words))
                hyp_chars = seq_hat_text.replace(" ", "")
                ref_chars = seq_true_text.replace(" ", "")
                char_eds.append(editdistance.eval(hyp_chars, ref_chars))
                char_ref_lens.append(len(ref_chars))

            wer = (0.0 if not self.report_wer else float(sum(word_eds)) /
                   sum(word_ref_lens))
            cer = (0.0 if not self.report_cer else float(sum(char_eds)) /
                   sum(char_ref_lens))

        # 6. compute bleu
        if self.training or not self.report_bleu:
            self.bleu = 0.0
        else:
            lpz = None

            nbest_hyps = self.dec.recognize_beam_batch(
                hs_pad,
                torch.tensor(hlens),
                lpz,
                self.trans_args,
                self.char_list,
                self.rnnlm,
                lang_ids=tgt_lang_ids.squeeze(1).tolist()
                if self.multilingual else None,
            )
            # remove <sos> and <eos>
            list_of_refs = []
            hyps = []
            y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps]
            for i, y_hat in enumerate(y_hats):
                y_true = ys_pad[i]

                seq_hat = [
                    self.char_list[int(idx)] for idx in y_hat if int(idx) != -1
                ]
                seq_true = [
                    self.char_list[int(idx)] for idx in y_true
                    if int(idx) != -1
                ]
                seq_hat_text = "".join(seq_hat).replace(
                    self.trans_args.space, " ")
                seq_hat_text = seq_hat_text.replace(self.trans_args.blank, "")
                seq_true_text = "".join(seq_true).replace(
                    self.trans_args.space, " ")

                hyps += [seq_hat_text.split(" ")]
                list_of_refs += [[seq_true_text.split(" ")]]

            self.bleu = nltk.corpus_bleu(list_of_refs, hyps) * 100

        alpha = self.mtlalpha
        self.loss = ((1 - self.asr_weight - self.mt_weight) * self.loss_st +
                     self.asr_weight * (alpha * self.loss_ctc +
                                        (1 - alpha) * self.loss_asr) +
                     self.mt_weight * self.loss_mt)
        loss_st_data = float(self.loss_st)
        loss_asr_data = float(alpha * self.loss_ctc +
                              (1 - alpha) * self.loss_asr)
        loss_mt_data = None if self.mt_weight == 0 else float(self.loss_mt)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(
                loss_asr_data,
                loss_mt_data,
                loss_st_data,
                acc_asr,
                acc_mt,
                self.acc,
                cer_ctc,
                cer,
                wer,
                self.bleu,
                loss_data,
            )
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
Exemplo n.º 2
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
        :return: loss value
        :rtype: torch.Tensor
        """
        # 1. Encoder
        xs_pad, ys_pad = self.target_language_biasing(xs_pad, ilens, ys_pad)
        hs_pad, hlens, _ = self.enc(self.dropout(self.embed(xs_pad)), ilens)

        # 3. attention loss
        self.loss, self.acc, self.ppl = self.dec(hs_pad, hlens, ys_pad)

        # 4. compute bleu
        if self.training or not self.report_bleu:
            self.bleu = 0.0
        else:
            lpz = None

            nbest_hyps = self.dec.recognize_beam_batch(
                hs_pad,
                torch.tensor(hlens),
                lpz,
                self.trans_args,
                self.char_list,
                self.rnnlm,
            )
            # remove <sos> and <eos>
            list_of_refs = []
            hyps = []
            y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps]
            for i, y_hat in enumerate(y_hats):
                y_true = ys_pad[i]

                seq_hat = [
                    self.char_list[int(idx)] for idx in y_hat if int(idx) != -1
                ]
                seq_true = [
                    self.char_list[int(idx)] for idx in y_true
                    if int(idx) != -1
                ]
                seq_hat_text = "".join(seq_hat).replace(
                    self.trans_args.space, " ")
                seq_hat_text = seq_hat_text.replace(self.trans_args.blank, "")
                seq_true_text = "".join(seq_true).replace(
                    self.trans_args.space, " ")

                hyps += [seq_hat_text.split(" ")]
                list_of_refs += [[seq_true_text.split(" ")]]

            self.bleu = nltk.corpus_bleu(list_of_refs, hyps) * 100

        loss_data = float(self.loss)
        if not math.isnan(loss_data):
            self.reporter.report(loss_data, self.acc, self.ppl, self.bleu)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss