def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E forward. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: loss value :rtype: torch.Tensor """ # 0. Extract target language ID if self.multilingual: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining else: tgt_lang_ids = None # 1. Encoder hs_pad, hlens, _ = self.enc(xs_pad, ilens) # 2. ST attention loss self.loss_st, self.acc, _ = self.dec(hs_pad, hlens, ys_pad, lang_ids=tgt_lang_ids) # 2. ASR CTC loss if self.asr_weight == 0 or self.mtlalpha == 0: self.loss_ctc = 0.0 else: self.loss_ctc = self.ctc(hs_pad, hlens, ys_pad_src) # 3. ASR attention loss if self.asr_weight == 0 or self.mtlalpha == 1: self.loss_asr = 0.0 acc_asr = 0.0 else: self.loss_asr, acc_asr, _ = self.dec_asr(hs_pad, hlens, ys_pad_src) acc_asr = acc_asr # 3. MT attention loss if self.mt_weight == 0: self.loss_mt = 0.0 acc_mt = 0.0 else: # ys_pad_src, ys_pad = self.target_forcing(ys_pad_src, ys_pad) ilens_mt = torch.sum(ys_pad_src != -1, dim=1).cpu().numpy() # NOTE: ys_pad_src is padded with -1 ys_src = [y[y != -1] for y in ys_pad_src] # parse padded ys_src ys_zero_pad_src = pad_list(ys_src, self.pad) # re-pad with zero hs_pad_mt, hlens_mt, _ = self.enc_mt( self.dropout_mt(self.embed_mt(ys_zero_pad_src)), ilens_mt) self.loss_mt, acc_mt, _ = self.dec(hs_pad_mt, hlens_mt, ys_pad) acc_mt = acc_mt # 4. compute cer without beam search if (self.asr_weight == 0 or self.mtlalpha == 0) or self.char_list is None: cer_ctc = None else: cers = [] y_hats = self.ctc.argmax(hs_pad).data for i, y in enumerate(y_hats): y_hat = [x[0] for x in groupby(y)] y_true = ys_pad_src[i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace(self.space, " ") seq_hat_text = seq_hat_text.replace(self.blank, "") seq_true_text = "".join(seq_true).replace(self.space, " ") hyp_chars = seq_hat_text.replace(" ", "") ref_chars = seq_true_text.replace(" ", "") if len(ref_chars) > 0: cers.append( editdistance.eval(hyp_chars, ref_chars) / len(ref_chars)) cer_ctc = sum(cers) / len(cers) if cers else None # 5. compute cer/wer if self.training or (self.asr_weight == 0 or self.mtlalpha == 1 or not (self.report_cer or self.report_wer)): cer, wer = 0.0, 0.0 else: if (self.asr_weight > 0 and self.mtlalpha > 0) and self.recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(hs_pad).data else: lpz = None word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], [] nbest_hyps_asr = self.dec_asr.recognize_beam_batch( hs_pad, torch.tensor(hlens), lpz, self.recog_args, self.char_list, self.rnnlm, ) # remove <sos> and <eos> y_hats = [ nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps_asr ] for i, y_hat in enumerate(y_hats): y_true = ys_pad[i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace( self.recog_args.space, " ") seq_hat_text = seq_hat_text.replace(self.recog_args.blank, "") seq_true_text = "".join(seq_true).replace( self.recog_args.space, " ") hyp_words = seq_hat_text.split() ref_words = seq_true_text.split() word_eds.append(editdistance.eval(hyp_words, ref_words)) word_ref_lens.append(len(ref_words)) hyp_chars = seq_hat_text.replace(" ", "") ref_chars = seq_true_text.replace(" ", "") char_eds.append(editdistance.eval(hyp_chars, ref_chars)) char_ref_lens.append(len(ref_chars)) wer = (0.0 if not self.report_wer else float(sum(word_eds)) / sum(word_ref_lens)) cer = (0.0 if not self.report_cer else float(sum(char_eds)) / sum(char_ref_lens)) # 6. compute bleu if self.training or not self.report_bleu: self.bleu = 0.0 else: lpz = None nbest_hyps = self.dec.recognize_beam_batch( hs_pad, torch.tensor(hlens), lpz, self.trans_args, self.char_list, self.rnnlm, lang_ids=tgt_lang_ids.squeeze(1).tolist() if self.multilingual else None, ) # remove <sos> and <eos> list_of_refs = [] hyps = [] y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps] for i, y_hat in enumerate(y_hats): y_true = ys_pad[i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace( self.trans_args.space, " ") seq_hat_text = seq_hat_text.replace(self.trans_args.blank, "") seq_true_text = "".join(seq_true).replace( self.trans_args.space, " ") hyps += [seq_hat_text.split(" ")] list_of_refs += [[seq_true_text.split(" ")]] self.bleu = nltk.corpus_bleu(list_of_refs, hyps) * 100 alpha = self.mtlalpha self.loss = ((1 - self.asr_weight - self.mt_weight) * self.loss_st + self.asr_weight * (alpha * self.loss_ctc + (1 - alpha) * self.loss_asr) + self.mt_weight * self.loss_mt) loss_st_data = float(self.loss_st) loss_asr_data = float(alpha * self.loss_ctc + (1 - alpha) * self.loss_asr) loss_mt_data = None if self.mt_weight == 0 else float(self.loss_mt) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_asr_data, loss_mt_data, loss_st_data, acc_asr, acc_mt, self.acc, cer_ctc, cer, wer, self.bleu, loss_data, ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: loss value :rtype: torch.Tensor """ # 1. Encoder xs_pad, ys_pad = self.target_language_biasing(xs_pad, ilens, ys_pad) hs_pad, hlens, _ = self.enc(self.dropout(self.embed(xs_pad)), ilens) # 3. attention loss self.loss, self.acc, self.ppl = self.dec(hs_pad, hlens, ys_pad) # 4. compute bleu if self.training or not self.report_bleu: self.bleu = 0.0 else: lpz = None nbest_hyps = self.dec.recognize_beam_batch( hs_pad, torch.tensor(hlens), lpz, self.trans_args, self.char_list, self.rnnlm, ) # remove <sos> and <eos> list_of_refs = [] hyps = [] y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps] for i, y_hat in enumerate(y_hats): y_true = ys_pad[i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace( self.trans_args.space, " ") seq_hat_text = seq_hat_text.replace(self.trans_args.blank, "") seq_true_text = "".join(seq_true).replace( self.trans_args.space, " ") hyps += [seq_hat_text.split(" ")] list_of_refs += [[seq_true_text.split(" ")]] self.bleu = nltk.corpus_bleu(list_of_refs, hyps) * 100 loss_data = float(self.loss) if not math.isnan(loss_data): self.reporter.report(loss_data, self.acc, self.ppl, self.bleu) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss