def store_penultimate_state(self, xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens): moe_coes = moe_coes[:, :max(moe_coe_lens)] # for data parallel xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # multi-encoder forward cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask) moe_coes = moe_coes.unsqueeze(-1) hs_pad = cn_hs_pad * moe_coes[:, :, 1] + en_hs_pad * moe_coes[:, :, 0] self.hs_pad = hs_pad # forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask, penultimate_state = self.decoder( ys_in_pad, ys_mask, hs_pad, hs_mask, moe_coes, return_penultimate_state=True) # plot penultimate_state, (B,T,att_dim) return penultimate_state.squeeze(0).detach().cpu().numpy()
def _calc_mt_att_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ): ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_in_lens = ys_pad_lens + 1 # 1. Forward decoder decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens) # 2. Compute attention loss loss_att = self.criterion_mt(decoder_out, ys_out_pad) acc_att = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) # Compute cer/wer using attention-decoder if self.training or self.mt_error_calculator is None: bleu_att = None else: ys_hat = decoder_out.argmax(dim=-1) bleu_att = self.mt_error_calculator(ys_hat.cpu(), ys_pad.cpu()) return loss_att, acc_att, bleu_att
def forward_asr(self, hs_pad, hs_mask, ys_pad): """Forward pass in the auxiliary ASR task. :param torch.Tensor hs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor hs_mask: batch of input token mask (B, Lmax) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ASR attention loss value :rtype: torch.Tensor :return: accuracy in ASR attention decoder :rtype: float :return: ASR CTC loss value :rtype: torch.Tensor :return: character error rate from CTC prediction :rtype: float :return: character error rate from attetion decoder prediction :rtype: float :return: word error rate from attetion decoder prediction :rtype: float """ loss_att, loss_ctc = 0.0, 0.0 acc = None cer, wer = None, None cer_ctc = None if self.asr_weight == 0: return loss_att, acc, loss_ctc, cer_ctc, cer, wer # attention if self.mtlalpha < 1: ys_in_pad_asr, ys_out_pad_asr = add_sos_eos( ys_pad, self.sos, self.eos, self.ignore_id) ys_mask_asr = target_mask(ys_in_pad_asr, self.ignore_id) pred_pad, _ = self.decoder_asr(ys_in_pad_asr, ys_mask_asr, hs_pad, hs_mask) loss_att = self.criterion(pred_pad, ys_out_pad_asr) acc = th_accuracy( pred_pad.view(-1, self.odim), ys_out_pad_asr, ignore_label=self.ignore_id, ) if not self.training: ys_hat_asr = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator_asr(ys_hat_asr.cpu(), ys_pad.cpu()) # CTC if self.mtlalpha > 0: batch_size = hs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if not self.training: ys_hat_ctc = self.ctc.argmax( hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator_asr(ys_hat_ctc.cpu(), ys_pad.cpu(), is_ctc=True) # for visualization self.ctc.softmax(hs_pad) return loss_att, acc, loss_ctc, cer_ctc, cer, wer
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # TODO(karita) show predicted text # TODO(karita) calculate these stats # 5. compute bleu if self.training or self.error_calculator is None: bleu = 0.0 else: ys_hat = pred_pad.argmax(dim=-1) bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_mt self.loss = loss loss_data = float(self.loss) if self.normalize_length: self.ppl = np.exp(loss_data) else: ys_out_pad = ys_out_pad.view(-1) ignore = ys_out_pad == self.ignore_id # (B,) total = len(ys_out_pad) - ignore.sum().item() self.ppl = np.exp(loss_data * ys_out_pad.size(0) / total) if not math.isnan(loss_data): self.reporter.report(loss_data, self.acc, self.ppl, bleu) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def test_transformer_mask(): args = make_arg() model, x, ilens, y, data, uttid_list = prepare("pytorch", args) yi, yo = add_sos_eos(y, model.sos, model.eos, model.ignore_id) y_mask = target_mask(yi, model.ignore_id) y = model.decoder.embed(yi) y[0, 3:] = float("nan") a = model.decoder.decoders[0].self_attn a(y, y, y, y_mask) assert not numpy.isnan(a.attn[0, :, :3, :3].detach().numpy()).any()
def test_transformer_mask(module): model, x, ilens, y, data = prepare(module) from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos from espnet.nets.pytorch_backend.transformer.mask import target_mask yi, yo = add_sos_eos(y, model.sos, model.eos, model.ignore_id) y_mask = target_mask(yi, model.ignore_id) y = model.decoder.embed(yi) y[0, 3:] = float("nan") a = model.decoder.decoders[0].self_attn a(y, y, y, y_mask) assert not numpy.isnan(a.attn[0, :, :3, :3].detach().numpy()).any()
def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, ): """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) # 1. Encoder encoder_out, encoder_out_lens, _ = self.encoder( speech, speech_lengths, left_mask=self.encoder_left_mask, right_mask=self.encoder_right_mask) # return xs_pad, olens, None # 2. Decoder # todo: train right shift text_in, text_out = add_sos_eos(text, self.sos, self.eos, self.ignore_id) text_in_lens = text_lengths + 1 decoder_out, decoder_out_lens, _ = self.decoder( text_in, text_in_lens, left_mask=self.decoder_left_mask, right_mask=0) # return xs_pad, olens, None # 3.Joint # h_enc: Batch of expanded hidden state (B, T, 1, D_enc) # h_dec: Batch of expanded hidden state (B, 1, U, D_dec) encoder_out = encoder_out.unsqueeze(2) decoder_out = decoder_out.unsqueeze(1) joint_out = self.joint(h_enc=encoder_out, h_dec=decoder_out) # 4.loss # pred_pad (torch.Tensor): Batch of predicted sequences loss = self.loss( pred_pad=joint_out, # (batch, maxlen_in, maxlen_out+1, odim) target=text.int(), # (batch, maxlen_out) pred_len=speech_lengths.int(), # (batch) target_len=text_lengths.int()) # (batch) return loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) # 3. compute attention loss self.loss = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # 4. compute corpus-level bleu in a mini-batch if self.training: self.bleu = None else: ys_hat = pred_pad.argmax(dim=-1) self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) loss_data = float(self.loss) if self.normalize_length: self.ppl = np.exp(loss_data) else: batch_size = ys_out_pad.size(0) ys_out_pad = ys_out_pad.view(-1) ignore = ys_out_pad == self.ignore_id # (B*T,) total_n_tokens = len(ys_out_pad) - ignore.sum().item() self.ppl = np.exp(loss_data * batch_size / total_n_tokens) if not math.isnan(loss_data): self.reporter.report(loss_data, self.acc, self.ppl, self.bleu) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward_ilm( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) not nessesary it is only used to get device of tensor speech_lengths: (Batch, ) not nessesary it is only used to get device of tensor text: (Batch, Length) text_lengths: (Batch,) """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert (text.shape[0] == text_lengths.shape[0]), (text.shape, text_lengths.shape) batch_size = text.shape[0] # for data-parallel text = text[:, :text_lengths.max()] ys_in_pad, ys_out_pad = add_sos_eos(text, self.sos, self.eos, self.ignore_id) ys_in_lens = text_lengths + 1 fake_encoder_out = speech.new_zeros(batch_size, 1, self.encoder._output_size) # 1. Forward decoder decoder_out, _ = self.decoder.forward_ilm(fake_encoder_out, -1, ys_in_pad, ys_in_lens) # 2. Compute ilm loss loss_ilm = self.criterion_att(decoder_out, ys_out_pad) ilm_acc = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) ilm_ppl = torch.exp(loss_ilm) stats = dict(ilm_loss=loss_ilm.detach(), ilm_acc=ilm_acc, ilm_ppl=ilm_ppl.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss_ilm, stats, batch_size), loss_ilm.device) return loss_ilm, stats, weight
def store_penultimate_state(self, xs_pad, ilens, ys_pad, bnf_feats, bnf_feats_lens): bnf_feats = bnf_feats[:, :max(bnf_feats_lens)] # for data parallel xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask, bnf_feats) self.hs_pad = hs_pad # forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask, penultimate_state = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask, return_penultimate_state=True) # plot penultimate_state, (B,T,att_dim) return penultimate_state.squeeze(0).detach().cpu().numpy()
def decoder_and_attention(self, hs_pad, hs_mask, ys_pad, batch_size): """Forward decoder and attention loss.""" # forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) # compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) return pred_pad, pred_mask, loss_att, acc
def _meta_collect_stats( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, ): ys_in_pad, ys_out_pad = add_sos_eos(text, self.sos, self.eos, self.ignore_id) ys_in_lens = text_lengths + 1 encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens) # Prior statistic caculator self.stat(decoder_out, ys_out_pad, ys_in_lens - 1) # eliminate <eos> label by reducing the ys_pad_lens
def _extract_feats( self, src_text: torch.Tensor, src_text_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: assert src_text_lengths.dim() == 1, src_text_lengths.shape # for data-parallel src_text = src_text[:, :src_text_lengths.max()] src_text, _ = add_sos_eos(src_text, self.sos, self.eos, self.ignore_id) src_text_lengths = src_text_lengths + 1 if self.frontend is not None: # Frontend # e.g. Embedding Lookup # src_text (Batch, NSamples) -> feats: (Batch, NSamples, Dim) feats, feats_lengths = self.frontend(src_text, src_text_lengths) else: # No frontend and no feature extract feats, feats_lengths = src_text, src_text_lengths return feats, feats_lengths
def nll( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ) -> torch.Tensor: """Compute negative log likelihood(nll) from transformer-decoder Normally, this function is called in batchify_nll. Args: encoder_out: (Batch, Length, Dim) encoder_out_lens: (Batch,) ys_pad: (Batch, Length) ys_pad_lens: (Batch,) """ ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_in_lens = ys_pad_lens + 1 # 1. Forward decoder decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens) # [batch, seqlen, dim] batch_size = decoder_out.size(0) decoder_num_class = decoder_out.size(2) # nll: negative log-likelihood nll = torch.nn.functional.cross_entropy( decoder_out.view(-1, decoder_num_class), ys_out_pad.view(-1), ignore_index=self.ignore_id, reduction="none", ) nll = nll.view(batch_size, -1) nll = nll.sum(dim=1) assert nll.size(0) == batch_size return nll
def _meta_forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, lam=0.6, ): assert self.meta_encoder is not None assert self.meta_decoder is not None ys_in_pad, _ = add_sos_eos(text, self.sos, self.eos, self.ignore_id) ys_in_lens = text_lengths + 1 encoder_meta_out, encoder_meta_out_lens = self._meta_encode(speech, speech_lengths) decoder_meta_out, _ = self.meta_decoder(encoder_meta_out, encoder_meta_out_lens, ys_in_pad, ys_in_lens) decoder_out_prob = torch.softmax(decoder_meta_out, dim=-1) if self.lm is not None: lm_out, _ = self.lm(ys_in_pad, None) lm_out_prob = torch.softmax(lm_out, dim=-1) decoder_out_prob = lam * lm_out_prob + decoder_out_prob return decoder_out_prob
def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 0. Extract target language ID tgt_lang_ids = None if self.multilingual: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining # 1. forward encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # replace <sos> with target language ID if self.replace_sos: ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) # 3. compute ST loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy( pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id ) # 4. compute corpus-level bleu in a mini-batch if self.training: self.bleu = None else: ys_hat = pred_pad.argmax(dim=-1) self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # 5. compute auxiliary ASR loss loss_asr_att, acc_asr, loss_asr_ctc, cer_ctc, cer, wer = self.forward_asr( hs_pad, hs_mask, ys_pad_src ) # 6. compute auxiliary MT loss loss_mt, acc_mt = 0.0, None if self.mt_weight > 0: loss_mt, acc_mt = self.forward_mt( ys_pad_src, ys_in_pad, ys_out_pad, ys_mask ) asr_ctc_weight = self.mtlalpha self.loss = ( (1 - self.asr_weight - self.mt_weight) * loss_att + self.asr_weight * (asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att) + self.mt_weight * loss_mt ) loss_asr_data = float( asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att ) loss_mt_data = None if self.mt_weight == 0 else float(loss_mt) loss_st_data = float(loss_att) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_asr_data, loss_mt_data, loss_st_data, acc_asr, acc_mt, self.acc, cer_ctc, cer, wer, self.bleu, loss_data, ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # mlp moe forward cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask) # gated add module """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b) #(B, T, 1) xs = lambda * cn_xs + (1-lambda) * en_xs """ hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1) lambda_ = self.enc_lambda hs_pad = lambda_ * cn_hs_pad + (1 - lambda_) * en_hs_pad self.hs_pad = hs_pad # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) if self.mtlalpha == 1: self.loss_att, acc = None, None else: # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) self.acc = acc # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def _calc_att_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, replace_label_flag: bool=False, decoder_out_prob=None, ): ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_in_lens = ys_pad_lens + 1 # Replace the labels if replace_label_flag: assert decoder_out_prob is not None from espnet.nets.pytorch_backend.nets_utils import pad_list confid = calc_confidence(decoder_out_prob, ys_out_pad) # Eliminate the <eos> token and find the position what we replace repl_mask = [prob[:l] < self.th for prob, l in zip(confid, ys_pad_lens)] with torch.no_grad(): decoder_out, _ = self.decoder( encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens ) decoder_out_prob = torch.softmax(decoder_out, dim=-1).detach() ys_in = [y[y != self.ignore_id] for y in ys_pad.clone().detach()] for i, (rm, y) in enumerate(zip(repl_mask, ys_in)): weight = decoder_out_prob[i] weight = weight[:len(y)] samples = torch.multinomial(weight, 1).squeeze(-1) y[rm] = samples[rm] _sos = ys_pad.new([self.sos]) ys_in = [torch.cat([_sos, y], dim=0) for y in ys_in] ys_in_pad = pad_list(ys_in, self.eos).detach() ys_out = [y[y != self.ignore_id] for y in ys_pad.clone().detach()] for i, (rm, y) in enumerate(zip(repl_mask, ys_out)): weight = decoder_out_prob[i] weight = weight[:len(y)] samples = torch.multinomial(weight, 1).squeeze(-1) y[rm] = samples[rm] # _ignore = ys_pad.new([self.ignore_id]) _ignore = ys_pad.new([self.eos]) ys_out = [torch.cat([y, _ignore], dim=0) for y in ys_out] ys_out_pad = pad_list(ys_out, self.ignore_id).detach() # 1. Forward decoder decoder_out, _ = self.decoder( encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens ) # 2. Compute attention loss loss_att = self.criterion_att(decoder_out, ys_out_pad) acc_att = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) # if replace_label_flag: # num_repl = 0.0 # num_total = 0.0 # for m in repl_mask: # num_repl += m.sum() # num_total += len(m) # pred_err_att = float(num_repl) / float(num_total) # else: # pred_err_att = 0.0 # Compute cer/wer using attention-decoder if self.training or self.error_calculator is None: cer_att, wer_att = None, None else: ys_hat = decoder_out.argmax(dim=-1) cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) return loss_att, acc_att, cer_att, wer_att
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # mlp moe forward cn_hs_pad, cn_hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, en_hs_mask = self.en_encoder(xs_pad, src_mask) hs_mask = cn_hs_mask # cn_hs_mask & en_hs_mask are identical # gated add module """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b) #(B, T, 1) xs = lambda * cn_xs + (1-lambda) * en_xs """ hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1) lambda_ = self.aggregation_module( hs_pad) # (B,T,1)/(B,T,D), range from (0, 1) hs_pad = lambda_ * cn_hs_pad + (1 - lambda_) * en_hs_pad self.hs_pad = hs_pad # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) # divide ys_pad into cn_ys & en_ys; # note that this target can directly pass to ctc module cn_ys, en_ys = partial_target(ys_pad, self.language_divider) cn_loss_ctc = self.cn_ctc( cn_hs_pad.view(batch_size, -1, self.adim), hs_len, cn_ys) en_loss_ctc = self.en_ctc( en_hs_pad.view(batch_size, -1, self.adim), hs_len, en_ys) loss_ctc = 0.5 * (cn_loss_ctc + en_loss_ctc) if self.mtlalpha == 1: self.loss_att, acc = None, None else: # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) self.acc = acc # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, None, None, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder moe_coes = moe_coes[:, :max(moe_coe_lens)].long() # for data parallel xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # mlp moe forward cn_hs_pad, cn_hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, en_hs_mask = self.en_encoder(xs_pad, src_mask) # gated add module """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b) #(B, T, 1) xs = lambda * cn_xs + (1-lambda) * en_xs """ hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1) lambda_ = F.softmax( self.aggre_scaling * self.aggregation_module(hs_pad), -1).unsqueeze(-1) ctc_hs_pad = lambda_[:, :, 0] * cn_hs_pad + lambda_[:, :, 1] * en_hs_pad ctc_hs_mask = cn_hs_mask # plat attention mode, (B,T,D)*2 --> (B,2T,D) s2s_hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=1) # mask: (B,1,T) --> (B,1,2T) s2s_hs_mask = torch.cat((cn_hs_mask, en_hs_mask), dim=-1) # self.hs_pad = hs_pad # compute lid loss here, using lambda_ # moe_coes (B, T, 2) ==> (B,T) moe_coes = moe_coes[:, :, 0] # 0 for cn, 1 for en lambda_ = lambda_.squeeze(-1) if self.lid_mtl_alpha == 0.0: loss_lid = 0.0 else: loss_lid = self.lid_criterion(lambda_, moe_coes) lid_acc = th_accuracy( lambda_.view(-1, 2), moe_coes, ignore_label=self.ignore_id) if self.log_lid_mtl_acc else None # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = ctc_hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(ctc_hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax( ctc_hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) if self.mtlalpha == 1: self.loss_att, acc = None, None else: # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, s2s_hs_pad, s2s_hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) self.acc = acc # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha lid_alpha = self.lid_mtl_alpha if alpha == 0: self.loss = loss_att + lid_alpha * loss_lid loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc + lid_alpha * loss_lid loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + ( 1 - alpha) * loss_att + lid_alpha * loss_lid loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data, lid_acc) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad, enc_mask=None, dec_mask=None): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel batch_size = xs_pad.shape[0] src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) if isinstance(self.encoder.embed, EncoderConv2d): xs, hs_mask = self.encoder.embed(xs_pad, torch.sum(src_mask, 2).squeeze()) hs_mask = hs_mask.unsqueeze(1) else: xs, hs_mask = self.encoder.embed(xs_pad, src_mask) if enc_mask is not None: enc_mask = enc_mask[:, :hs_mask.shape[2], :hs_mask.shape[2]] enc_mask = enc_mask & hs_mask if enc_mask is not None else hs_mask hs_pad, _ = self.encoder.encoders(xs, enc_mask) if self.encoder.normalize_before: hs_pad = self.encoder.after_norm(hs_pad) # CTC forward ys = [y[y != self.ignore_id] for y in ys_pad] y_len = max([len(y) for y in ys]) ys_pad = ys_pad[:, :y_len] if dec_mask is not None: dec_mask = dec_mask[:, :y_len + 1, :hs_pad.shape[1]] self.hs_pad = hs_pad batch_size = xs_pad.size(0) if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # trigger mask hs_mask = hs_mask & dec_mask if dec_mask is not None else hs_mask # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) return self.loss, loss_ctc_data, loss_att_data, self.acc
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask, penultimate_state = self.decoder( ys_in_pad, ys_mask, hs_pad, hs_mask, return_penultimate_state=True) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # 4. compute lid multitask loss src_att = self.lid_src_att(penultimate_state, hs_pad, hs_pad, hs_mask) pred_lid_pad = self.lid_output_layer(src_att) loss_lid, lid_ys_out_pad = self.lid_criterion(pred_lid_pad, ys_out_pad) lid_acc = th_accuracy( pred_lid_pad.view(-1, self.lid_odim), lid_ys_out_pad, ignore_label=self.ignore_id) if self.log_lid_mtl_acc else None # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha lid_alpha = self.lid_mtl_alpha if alpha == 0: self.loss = loss_att + lid_alpha * loss_lid loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: raise Exception("LID MTL not supports pure ctc mode") self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + ( 1 - alpha) * loss_att + lid_alpha * loss_lid loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data, lid_acc) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # CTC forward ys = [y[y != self.ignore_id] for y in ys_pad] y_len = max([len(y) for y in ys]) ys_pad = ys_pad[:, :y_len] self.hs_pad = hs_pad cer_ctc = None batch_size = xs_pad.size(0) if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # trigger mask start_time = time.time() # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) return self.loss, loss_ctc_data, loss_att_data, self.acc
def forward(self, xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder moe_coes = moe_coes[:, :max(moe_coe_lens)] # for data parallel # here we use interpolation_coe to 'fix' initial moe_coes interp_factor = self.interp_factor # 0.1 for example, similar to lsm moe_coes = ( 1 - interp_factor) * moe_coes + interp_factor / moe_coes.shape[2] xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # multi-encoder forward cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask) moe_coes = moe_coes.unsqueeze(-1) hs_pad = cn_hs_pad * moe_coes[:, :, 1] + en_hs_pad * moe_coes[:, :, 0] self.hs_pad = hs_pad # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) if self.mtlalpha == 1: self.loss_att, acc = None, None else: # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) self.acc = acc # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 0. Extract target language ID tgt_lang_ids = None if self.multilingual: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # replace <sos> with target language ID if self.replace_sos: ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) # 3. compute ST loss loss_asr_att, loss_asr_ctc, loss_mt = 0.0, 0.0, 0.0 acc_asr, acc_mt = 0.0, 0.0 loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # 4. compute corpus-level bleu in a mini-batch if self.training: self.bleu = None else: ys_hat = pred_pad.argmax(dim=-1) self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # 5. compute auxiliary ASR loss cer, wer = None, None cer_ctc = None if self.asr_weight > 0: # attention if self.mtlalpha < 1: ys_in_pad_asr, ys_out_pad_asr = add_sos_eos( ys_pad_src, self.sos, self.eos, self.ignore_id) ys_mask_asr = target_mask(ys_in_pad_asr, self.ignore_id) pred_pad_asr, _ = self.decoder_asr(ys_in_pad_asr, ys_mask_asr, hs_pad, hs_mask) loss_asr_att = self.criterion(pred_pad_asr, ys_out_pad_asr) acc_asr = th_accuracy( pred_pad_asr.view(-1, self.odim), ys_out_pad_asr, ignore_label=self.ignore_id, ) if not self.training: ys_hat_asr = pred_pad_asr.argmax(dim=-1) cer, wer = self.error_calculator_asr( ys_hat_asr.cpu(), ys_pad_src.cpu()) # CTC if self.mtlalpha > 0: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_asr_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad_src) ys_hat_ctc = self.ctc.argmax( hs_pad.view(batch_size, -1, self.adim)).data if not self.training: cer_ctc = self.error_calculator_asr(ys_hat_ctc.cpu(), ys_pad_src.cpu(), is_ctc=True) # 6. compute auxiliary MT loss if self.mt_weight > 0: ilens_mt = torch.sum(ys_pad_src != self.ignore_id, dim=1).cpu().numpy() # NOTE: ys_pad_src is padded with -1 ys_src = [y[y != self.ignore_id] for y in ys_pad_src] # parse padded ys_src ys_zero_pad_src = pad_list(ys_src, self.pad) # re-pad with zero ys_zero_pad_src = ys_zero_pad_src[:, :max( ilens_mt)] # for data parallel src_mask_mt = ((~make_pad_mask(ilens_mt.tolist())).to( ys_zero_pad_src.device).unsqueeze(-2)) hs_pad_mt, hs_mask_mt = self.encoder_mt(ys_zero_pad_src, src_mask_mt) pred_pad_mt, _ = self.decoder(ys_in_pad, ys_mask, hs_pad_mt, hs_mask_mt) loss_mt = self.criterion(pred_pad_mt, ys_out_pad) acc_mt = th_accuracy(pred_pad_mt.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) alpha = self.mtlalpha self.loss = ((1 - self.asr_weight - self.mt_weight) * loss_att + self.asr_weight * (alpha * loss_asr_ctc + (1 - alpha) * loss_asr_att) + self.mt_weight * loss_mt) loss_asr_data = float(alpha * loss_asr_ctc + (1 - alpha) * loss_asr_att) loss_mt_data = None if self.mt_weight == 0 else float(loss_mt) loss_st_data = float(loss_att) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_asr_data, loss_mt_data, loss_st_data, acc_asr, acc_mt, self.acc, cer_ctc, cer, wer, self.bleu, loss_data, ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 0. Extract target language ID # src_lang_ids = None tgt_lang_ids, tgt_lang_ids_src = None, None if self.multilingual: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining if self.one_to_many: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining if self.do_asr: tgt_lang_ids_src = ys_pad_src[:, 0:1] ys_pad_src = ys_pad_src[:, 1:] # remove target language ID in the beggining # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel # bs x max_ilens x idim if self.lang_tok == "encoder-pre-sum": lang_embed = self.language_embeddings(tgt_lang_ids) # bs x 1 x idim xs_pad = xs_pad + lang_embed src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2) # bs x 1 x max_ilens hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # hs_pad: bs x (max_ilens/4) x adim; hs_mask: bs x 1 x (max_ilens/4) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # bs x max_lens if self.do_asr: ys_in_pad_src, ys_out_pad_src = add_sos_eos(ys_pad_src, self.sos, self.eos, self.ignore_id) # bs x max_lens_src # replace <sos> with target language ID if self.replace_sos: ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1) if self.lang_tok == "decoder-pre": ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1) if self.do_asr: ys_in_pad_src = torch.cat([tgt_lang_ids_src, ys_in_pad_src[:, 1:]], dim=1) ys_mask = target_mask(ys_in_pad, self.ignore_id) # bs x max_lens x max_lens if self.do_asr: ys_mask_src = target_mask(ys_in_pad_src, self.ignore_id) # bs x max_lens x max_lens_src if self.wait_k_asr > 0: cross_mask = create_cross_mask(ys_in_pad, ys_in_pad_src, self.ignore_id, wait_k_cross=self.wait_k_asr) cross_mask_asr = create_cross_mask(ys_in_pad_src, ys_in_pad, self.ignore_id, wait_k_cross=-self.wait_k_asr) elif self.wait_k_st > 0: cross_mask = create_cross_mask(ys_in_pad, ys_in_pad_src, self.ignore_id, wait_k_cross=-self.wait_k_st) cross_mask_asr = create_cross_mask(ys_in_pad_src, ys_in_pad, self.ignore_id, wait_k_cross=self.wait_k_st) else: cross_mask = create_cross_mask(ys_in_pad, ys_in_pad_src, self.ignore_id, wait_k_cross=0) cross_mask_asr = create_cross_mask(ys_in_pad_src, ys_in_pad, self.ignore_id, wait_k_cross=0) pred_pad, pred_mask, pred_pad_asr, pred_mask_asr = self.dual_decoder(ys_in_pad, ys_mask, ys_in_pad_src, ys_mask_src, hs_pad, hs_mask, cross_mask, cross_mask_asr, cross_self=self.cross_self, cross_src=self.cross_src, cross_self_from=self.cross_self_from, cross_src_from=self.cross_src_from) self.pred_pad = pred_pad self.pred_pad_asr = pred_pad_asr pred_pad_mt = None # 3. compute attention loss loss_asr, loss_mt = 0.0, 0.0 loss_att = self.criterion(pred_pad, ys_out_pad) # compute loss loss_asr = self.criterion(pred_pad_asr, ys_out_pad_src) # Multi-task w/ MT if self.mt_weight > 0: # forward MT encoder ilens_mt = torch.sum(ys_pad_src != self.ignore_id, dim=1).cpu().numpy() # NOTE: ys_pad_src is padded with -1 ys_src = [y[y != self.ignore_id] for y in ys_pad_src] # parse padded ys_src ys_zero_pad_src = pad_list(ys_src, self.pad) # re-pad with zero ys_zero_pad_src = ys_zero_pad_src[:, :max(ilens_mt)] # for data parallel src_mask_mt = (~make_pad_mask(ilens_mt.tolist())).to(ys_zero_pad_src.device).unsqueeze(-2) # ys_zero_pad_src, ys_pad = self.target_forcing(ys_zero_pad_src, ys_pad) hs_pad_mt, hs_mask_mt = self.encoder_mt(ys_zero_pad_src, src_mask_mt) # forward MT decoder pred_pad_mt, _ = self.decoder(ys_in_pad, ys_mask, hs_pad_mt, hs_mask_mt) # compute loss loss_mt = self.criterion(pred_pad_mt, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) if pred_pad_asr is not None: self.acc_asr = th_accuracy(pred_pad_asr.view(-1, self.odim), ys_out_pad_src, ignore_label=self.ignore_id) else: self.acc_asr = 0.0 if pred_pad_mt is not None: self.acc_mt = th_accuracy(pred_pad_mt.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) else: self.acc_mt = 0.0 # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0 or self.asr_weight == 0: loss_ctc = 0.0 else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad_src) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad_src.cpu(), is_ctc=True) # 5. compute cer/wer cer, wer = None, None # TODO(hirofumi0810): fix later # if self.training or (self.asr_weight == 0 or self.mtlalpha == 1 or not (self.report_cer or self.report_wer)): # cer, wer = None, None # else: # ys_hat = pred_pad.argmax(dim=-1) # cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha self.loss = (1 - self.asr_weight - self.mt_weight) * loss_att + self.asr_weight * \ (alpha * loss_ctc + (1 - alpha) * loss_asr) + self.mt_weight * loss_mt loss_asr_data = float(alpha * loss_ctc + (1 - alpha) * loss_asr) loss_mt_data = None if self.mt_weight == 0 else float(loss_mt) loss_st_data = float(loss_att) # logging.info(f'loss_st_data={loss_st_data}') loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_asr_data, loss_mt_data, loss_st_data, self.acc_asr, self.acc_mt, self.acc, cer_ctc, cer, wer, 0.0, # TODO(hirofumi0810): bleu loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder if self.decoder is not None: if self.decoder_mode == "maskctc": ys_in_pad, ys_out_pad = mask_uniform( ys_pad, self.mask_token, self.eos, self.ignore_id ) ys_mask = (ys_in_pad != self.ignore_id).unsqueeze(-2) else: ys_in_pad, ys_out_pad = add_sos_eos( ys_pad, self.sos, self.eos, self.ignore_id ) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy( pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id ) else: loss_att = None self.acc = None # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if not self.training and self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # for visualization if not self.training: self.ctc.softmax(hs_pad) # 5. compute cer/wer if self.training or self.error_calculator is None or self.decoder is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float ''' """ if self.attention_enc_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]: for layer in self.encoder.encoders: layer.self_attn.clamp_param() if self.attention_dec_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]: for layer in self.decoder.decoders: layer.self_attn.clamp_param() # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) # xkc09 Span attention loss computation # xkc09 Span attention size loss computation loss_span = 0 if self.attention_enc_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_dynamic_span2' ]: loss_span += sum([ layer.self_attn.get_mean_span() for layer in self.encoder.encoders ]) if self.attention_dec_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_dynamic_span2' ]: loss_span += sum([ layer.self_attn.get_mean_span() for layer in self.decoder.decoders ]) # xkc09 Span attention ratio loss computation loss_ratio = 0 if self.ratio_adaptive: # target_ratio = 0.5 if self.attention_enc_type in [ 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]: loss_ratio += sum([ 1 - layer.self_attn.get_mean_ratio() for layer in self.encoder.encoders ]) if self.attention_dec_type in [ 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]: loss_ratio += sum([ 1 - layer.self_attn.get_mean_ratio() for layer in self.decoder.decoders ]) if (self.attention_enc_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ] or self.attention_dec_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]): if getattr(self, 'span_loss_coef', None): self.loss += (loss_span + loss_ratio) * self.span_loss_coef loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss