def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # lid output layer pred_pad = self.lid_lo(hs_pad) # compute lid loss self.loss = self.criterion(pred_pad, ys_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_pad, ignore_label=self.ignore_id) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(self.acc, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def _calc_mt_att_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ): ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_in_lens = ys_pad_lens + 1 # 1. Forward decoder decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens) # 2. Compute attention loss loss_att = self.criterion_mt(decoder_out, ys_out_pad) acc_att = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) # Compute cer/wer using attention-decoder if self.training or self.mt_error_calculator is None: bleu_att = None else: ys_hat = decoder_out.argmax(dim=-1) bleu_att = self.mt_error_calculator(ys_hat.cpu(), ys_pad.cpu()) return loss_att, acc_att, bleu_att
def _calc_mlm_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ): # 1. Apply masks ys_in_pad, ys_out_pad = mask_uniform( ys_pad, self.mask_token, self.eos, self.ignore_id ) # 2. Forward decoder decoder_out, _ = self.decoder( encoder_out, encoder_out_lens, ys_in_pad, ys_pad_lens ) # 3. Compute mlm loss loss_mlm = self.criterion_mlm(decoder_out, ys_out_pad) acc_mlm = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) return loss_mlm, acc_mlm
def forward_asr(self, hs_pad, hs_mask, ys_pad): """Forward pass in the auxiliary ASR task. :param torch.Tensor hs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor hs_mask: batch of input token mask (B, Lmax) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ASR attention loss value :rtype: torch.Tensor :return: accuracy in ASR attention decoder :rtype: float :return: ASR CTC loss value :rtype: torch.Tensor :return: character error rate from CTC prediction :rtype: float :return: character error rate from attetion decoder prediction :rtype: float :return: word error rate from attetion decoder prediction :rtype: float """ loss_att, loss_ctc = 0.0, 0.0 acc = None cer, wer = None, None cer_ctc = None if self.asr_weight == 0: return loss_att, acc, loss_ctc, cer_ctc, cer, wer # attention if self.mtlalpha < 1: ys_in_pad_asr, ys_out_pad_asr = add_sos_eos( ys_pad, self.sos, self.eos, self.ignore_id) ys_mask_asr = target_mask(ys_in_pad_asr, self.ignore_id) pred_pad, _ = self.decoder_asr(ys_in_pad_asr, ys_mask_asr, hs_pad, hs_mask) loss_att = self.criterion(pred_pad, ys_out_pad_asr) acc = th_accuracy( pred_pad.view(-1, self.odim), ys_out_pad_asr, ignore_label=self.ignore_id, ) if not self.training: ys_hat_asr = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator_asr(ys_hat_asr.cpu(), ys_pad.cpu()) # CTC if self.mtlalpha > 0: batch_size = hs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if not self.training: ys_hat_ctc = self.ctc.argmax( hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator_asr(ys_hat_ctc.cpu(), ys_pad.cpu(), is_ctc=True) # for visualization self.ctc.softmax(hs_pad) return loss_att, acc, loss_ctc, cer_ctc, cer, wer
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: loss value :rtype: torch.Tensor """ # 1. Encoder xs_pad, ilens = self.conv(xs_pad, ilens) xs_pad, ilens = self.lstm(xs_pad, ilens) pred_pad = self.lid_lo(xs_pad) # compute lid loss self.loss = self.criterion(pred_pad, ys_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_pad, ignore_label=self.ignore_id) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(self.acc, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def test_train_acc(): n_out = 7 _eos = n_out - 1 n_batch = 3 label_length = numpy.array([4, 2, 3], dtype=numpy.int32) np_pred = numpy.random.rand(n_batch, max(label_length) + 1, n_out).astype(numpy.float32) # NOTE: 0 is only used for CTC, never appeared in attn target np_target = [ numpy.random.randint(1, n_out - 1, size=ol, dtype=numpy.int32) for ol in label_length ] eos = numpy.array([_eos], 'i') ys_out = [F.concat([y, eos], axis=0) for y in np_target] # padding for ys with -1 # pys: utt x olen # NOTE: -1 is default ignore index for chainer pad_ys_out = F.pad_sequence(ys_out, padding=-1) y_all = F.reshape(np_pred, (n_batch * (max(label_length) + 1), n_out)) ch_acc = F.accuracy(y_all, F.concat(pad_ys_out, axis=0), ignore_label=-1) # NOTE: this index 0 is only for CTC not attn. so it can be ignored # unfortunately, torch cross_entropy does not accept out-of-bound ids th_ignore = 0 th_pred = torch.from_numpy(y_all.data) th_ys = [torch.from_numpy(numpy.append(t, eos)).long() for t in np_target] th_target = pad_list(th_ys, th_ignore) th_acc = th_accuracy(th_pred, th_target, th_ignore) numpy.testing.assert_allclose(ch_acc.data, th_acc)
def forward_mt(self, xs_pad, ys_in_pad, ys_out_pad, ys_mask): """Forward pass in the auxiliary MT task. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ys_in_pad: batch of padded target sequences (B, Lmax) :param torch.Tensor ys_out_pad: batch of padded target sequences (B, Lmax) :param torch.Tensor ys_mask: batch of input token mask (B, Lmax) :return: MT loss value :rtype: torch.Tensor :return: accuracy in MT decoder :rtype: float """ loss, acc = 0.0, None if self.mt_weight == 0: return loss, acc ilens = torch.sum(xs_pad != self.ignore_id, dim=1).cpu().numpy() # NOTE: xs_pad is padded with -1 xs = [x[x != self.ignore_id] for x in xs_pad] # parse padded xs xs_zero_pad = pad_list(xs, self.pad) # re-pad with zero xs_zero_pad = xs_zero_pad[:, : max(ilens)] # for data parallel src_mask = ( make_non_pad_mask(ilens.tolist()).to(xs_zero_pad.device).unsqueeze(-2) ) hs_pad, hs_mask = self.encoder_mt(xs_zero_pad, src_mask) pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) loss = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy( pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id ) return loss, acc
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # TODO(karita) show predicted text # TODO(karita) calculate these stats # 5. compute bleu if self.training or self.error_calculator is None: bleu = 0.0 else: ys_hat = pred_pad.argmax(dim=-1) bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_mt self.loss = loss loss_data = float(self.loss) if self.normalize_length: self.ppl = np.exp(loss_data) else: ys_out_pad = ys_out_pad.view(-1) ignore = ys_out_pad == self.ignore_id # (B,) total = len(ys_out_pad) - ignore.sum().item() self.ppl = np.exp(loss_data * ys_out_pad.size(0) / total) if not math.isnan(loss_data): self.reporter.report(loss_data, self.acc, self.ppl, bleu) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) # 3. compute attention loss self.loss = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # 4. compute corpus-level bleu in a mini-batch if self.training: self.bleu = None else: ys_hat = pred_pad.argmax(dim=-1) self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) loss_data = float(self.loss) if self.normalize_length: self.ppl = np.exp(loss_data) else: batch_size = ys_out_pad.size(0) ys_out_pad = ys_out_pad.view(-1) ignore = ys_out_pad == self.ignore_id # (B*T,) total_n_tokens = len(ys_out_pad) - ignore.sum().item() self.ppl = np.exp(loss_data * batch_size / total_n_tokens) if not math.isnan(loss_data): self.reporter.report(loss_data, self.acc, self.ppl, self.bleu) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward_ilm( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) not nessesary it is only used to get device of tensor speech_lengths: (Batch, ) not nessesary it is only used to get device of tensor text: (Batch, Length) text_lengths: (Batch,) """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert (text.shape[0] == text_lengths.shape[0]), (text.shape, text_lengths.shape) batch_size = text.shape[0] # for data-parallel text = text[:, :text_lengths.max()] ys_in_pad, ys_out_pad = add_sos_eos(text, self.sos, self.eos, self.ignore_id) ys_in_lens = text_lengths + 1 fake_encoder_out = speech.new_zeros(batch_size, 1, self.encoder._output_size) # 1. Forward decoder decoder_out, _ = self.decoder.forward_ilm(fake_encoder_out, -1, ys_in_pad, ys_in_lens) # 2. Compute ilm loss loss_ilm = self.criterion_att(decoder_out, ys_out_pad) ilm_acc = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) ilm_ppl = torch.exp(loss_ilm) stats = dict(ilm_loss=loss_ilm.detach(), ilm_acc=ilm_acc, ilm_ppl=ilm_ppl.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss_ilm, stats, batch_size), loss_ilm.device) return loss_ilm, stats, weight
def decoder_and_attention(self, hs_pad, hs_mask, ys_pad, batch_size): """Forward decoder and attention loss.""" # forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) # compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) return pred_pad, pred_mask, loss_att, acc
def forward(self, xs_pad, ilens, ys_pad): """Compute scalar loss for backprop""" src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) ys_in_pad, ys_out_pad = self.add_sos_eos(ys_pad) ys_mask = self.target_mask(ys_in_pad) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) loss = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, pred_pad.size(-1)), ys_out_pad, ignore_label=self.ignore_id) self.reporter.report(loss=loss, acc=self.acc) return loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad hs_mask = hs_mask.transpose(1,2) hs_mask = hs_mask.repeat(1,1,256).type(torch.FloatTensor).to(hs_pad.device) hs_pad_masked = hs_pad * hs_mask logging.warning("hs_pad_masked.size()==>" + str(hs_pad_masked.size())) att_vec = self.att(hs_pad_masked) # att_vec = self.att(hs_pad) pred_pad = self.output(att_vec).unsqueeze(1) logging.warning("att_vec.size()==>" + str(att_vec.size())) logging.warning("pred_pad.size()==>" + str(pred_pad.size())) # compute loss self.loss = self.criterion(pred_pad, ys_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_pad, ignore_label=self.ignore_id) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(self.acc, loss_data) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 0. Extract target language ID tgt_lang_ids = None if self.multilingual: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining # 1. forward encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # replace <sos> with target language ID if self.replace_sos: ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) # 3. compute ST loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy( pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id ) # 4. compute corpus-level bleu in a mini-batch if self.training: self.bleu = None else: ys_hat = pred_pad.argmax(dim=-1) self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # 5. compute auxiliary ASR loss loss_asr_att, acc_asr, loss_asr_ctc, cer_ctc, cer, wer = self.forward_asr( hs_pad, hs_mask, ys_pad_src ) # 6. compute auxiliary MT loss loss_mt, acc_mt = 0.0, None if self.mt_weight > 0: loss_mt, acc_mt = self.forward_mt( ys_pad_src, ys_in_pad, ys_out_pad, ys_mask ) asr_ctc_weight = self.mtlalpha self.loss = ( (1 - self.asr_weight - self.mt_weight) * loss_att + self.asr_weight * (asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att) + self.mt_weight * loss_mt ) loss_asr_data = float( asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att ) loss_mt_data = None if self.mt_weight == 0 else float(loss_mt) loss_st_data = float(loss_att) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_asr_data, loss_mt_data, loss_st_data, acc_asr, acc_mt, self.acc, cer_ctc, cer, wer, self.bleu, loss_data, ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # mlp moe forward cn_hs_pad, cn_hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, en_hs_mask = self.en_encoder(xs_pad, src_mask) hs_mask = cn_hs_mask # cn_hs_mask & en_hs_mask are identical # gated add module """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b) #(B, T, 1) xs = lambda * cn_xs + (1-lambda) * en_xs """ hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1) lambda_ = self.aggregation_module( hs_pad) # (B,T,1)/(B,T,D), range from (0, 1) hs_pad = lambda_ * cn_hs_pad + (1 - lambda_) * en_hs_pad self.hs_pad = hs_pad # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) # divide ys_pad into cn_ys & en_ys; # note that this target can directly pass to ctc module cn_ys, en_ys = partial_target(ys_pad, self.language_divider) cn_loss_ctc = self.cn_ctc( cn_hs_pad.view(batch_size, -1, self.adim), hs_len, cn_ys) en_loss_ctc = self.en_ctc( en_hs_pad.view(batch_size, -1, self.adim), hs_len, en_ys) loss_ctc = 0.5 * (cn_loss_ctc + en_loss_ctc) if self.mtlalpha == 1: self.loss_att, acc = None, None else: # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) self.acc = acc # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, None, None, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, tgt_lang_ids=None): """Decoder forward :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D) :param torch.Tensor hlens: batch of lengths of hidden state sequences (B) :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :param int strm_idx: stream index indicates the index of decoding stream. :param torch.Tensor tgt_lang_ids: batch of target language id tensor (B, 1) :return: attention loss value :rtype: torch.Tensor :return: accuracy :rtype: float """ # TODO(kan-bayashi): need to make more smart way ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys # attention index for the attention module # in SPA (speaker parallel attention), att_idx is used to select attention module. In other cases, it is 0. att_idx = min(strm_idx, len(self.att) - 1) # hlen should be list of integer hlens = list(map(int, hlens)) self.loss = None # prepare input and output word sequences with sos/eos IDs eos = ys[0].new([self.eos]) sos = ys[0].new([self.sos]) if self.replace_sos: ys_in = [ torch.cat([idx, y], dim=0) for idx, y in zip(tgt_lang_ids, ys) ] else: ys_in = [torch.cat([sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, eos], dim=0) for y in ys] # padding for ys with -1 # pys: utt x olen ys_in_pad = pad_list(ys_in, self.eos) ys_out_pad = pad_list(ys_out, self.ignore_id) # get dim, length info batch = ys_out_pad.size(0) olength = ys_out_pad.size(1) logging.info(self.__class__.__name__ + ' input lengths: ' + str(hlens)) logging.info(self.__class__.__name__ + ' output lengths: ' + str([y.size(0) for y in ys_out])) # initialization c_list = [self.zero_state(hs_pad)] z_list = [self.zero_state(hs_pad)] for _ in six.moves.range(1, self.dlayers): c_list.append(self.zero_state(hs_pad)) z_list.append(self.zero_state(hs_pad)) att_w = None z_all = [] self.att[att_idx].reset() # reset pre-computation of h # pre-computation of embedding eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim # loop for an output sequence for i in six.moves.range(olength): att_c, att_w = self.att[att_idx](hs_pad, hlens, self.dropout_dec[0](z_list[0]), att_w) if i > 0 and random.random() < self.sampling_probability: logging.info(' scheduled sampling ') z_out = self.output(z_all[-1]) z_out = np.argmax(z_out.detach().cpu(), axis=1) z_out = self.dropout_emb(self.embed(to_device(self, z_out))) ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim) else: ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim) z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) if self.context_residual: z_all.append( torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)) # utt x (zdim + hdim) else: z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim) z_all = torch.stack(z_all, dim=1).view(batch * olength, -1) # compute loss y_all = self.output(z_all) if LooseVersion(torch.__version__) < LooseVersion('1.0'): reduction_str = 'elementwise_mean' else: reduction_str = 'mean' self.loss = F.cross_entropy(y_all, ys_out_pad.view(-1), ignore_index=self.ignore_id, reduction=reduction_str) # -1: eos, which is removed in the loss computation self.loss *= (np.mean([len(x) for x in ys_in]) - 1) acc = th_accuracy(y_all, ys_out_pad, ignore_label=self.ignore_id) logging.info('att loss:' + ''.join(str(self.loss.item()).split('\n'))) # compute perplexity ppl = np.exp(self.loss.item() * np.mean([len(x) for x in ys_in]) / np.sum([len(x) for x in ys_in])) # show predicted character sequence for debug if self.verbose > 0 and self.char_list is not None: ys_hat = y_all.view(batch, olength, -1) ys_true = ys_out_pad for (i, y_hat), y_true in zip( enumerate(ys_hat.detach().cpu().numpy()), ys_true.detach().cpu().numpy()): if i == MAX_DECODER_OUTPUT: break idx_hat = np.argmax(y_hat[y_true != self.ignore_id], axis=1) idx_true = y_true[y_true != self.ignore_id] seq_hat = [self.char_list[int(idx)] for idx in idx_hat] seq_true = [self.char_list[int(idx)] for idx in idx_true] seq_hat = "".join(seq_hat) seq_true = "".join(seq_true) logging.info("groundtruth[%d]: " % i + seq_true) logging.info("prediction [%d]: " % i + seq_hat) if self.labeldist is not None: if self.vlabeldist is None: self.vlabeldist = to_device(self, torch.from_numpy(self.labeldist)) loss_reg = -torch.sum( (F.log_softmax(y_all, dim=1) * self.vlabeldist).view(-1), dim=0) / len(ys_in) self.loss = ( 1. - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg return self.loss, acc, ppl
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask, penultimate_state = self.decoder( ys_in_pad, ys_mask, hs_pad, hs_mask, return_penultimate_state=True) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # 4. compute lid multitask loss src_att = self.lid_src_att(penultimate_state, hs_pad, hs_pad, hs_mask) pred_lid_pad = self.lid_output_layer(src_att) loss_lid, lid_ys_out_pad = self.lid_criterion(pred_lid_pad, ys_out_pad) lid_acc = th_accuracy( pred_lid_pad.view(-1, self.lid_odim), lid_ys_out_pad, ignore_label=self.ignore_id) if self.log_lid_mtl_acc else None # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha lid_alpha = self.lid_mtl_alpha if alpha == 0: self.loss = loss_att + lid_alpha * loss_lid loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: raise Exception("LID MTL not supports pure ctc mode") self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + ( 1 - alpha) * loss_att + lid_alpha * loss_lid loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data, lid_acc) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad, ys_pad_mono=None): """E2E forward. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: loss value :rtype: torch.Tensor """ # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = xs_pad, ilens # 1. RNN Encoder hs_pad, hlens, _ = self.enc(hs_pad, hlens) # 2. post-processing layer for target dimension pred_pad = self.poster(hs_pad) pred_pad = pred_pad.view(pred_pad.size(0), -1, self.odim) self.pred_pad = pred_pad if pred_pad.size(1) != ys_pad.size(1): if pred_pad.size(1) < ys_pad.size(1): ys_pad = ys_pad[:, :pred_pad.size(1)].contiguous() else: raise ValueError( "target size {} and pred size {} is mismatch".format( ys_pad.size(1), pred_pad.size(1))) if ys_pad_mono is not None: pred_pad_mono = self.poster_mono(hs_pad) pred_pad_mono = pred_pad_mono.view(pred_pad_mono.size(0), -1, self.mono_odim) self.pred_pad_mono = pred_pad_mono if pred_pad_mono.size(1) != ys_pad_mono.size(1): if pred_pad_mono.size(1) < ys_pad_mono.size(1): ys_pad_mono = ys_pad_mono[:, :pred_pad_mono. size(1)].contiguous() else: raise ValueError( "target size {} and pred size {} is mismatch".format( ys_pad_mono.size(1), pred_pad_mono.size(1))) # 3. CTC loss if self.mtlalpha == 0: self.loss_ctc = None else: self.loss_ctc = self.ctc(pred_pad, hlens, ys_pad) # 3. CE loss if LooseVersion(torch.__version__) < LooseVersion("1.0"): reduction_str = "elementwise_mean" else: reduction_str = "mean" self.loss_ce_tri = F.cross_entropy( pred_pad.view(-1, self.odim), ys_pad.view(-1), ignore_index=self.ignore_id, reduction=reduction_str, ) if ys_pad_mono is not None: self.loss_ce_mono = F.cross_entropy( pred_pad_mono.view(-1, self.odim), ys_pad_mono.view(-1), ignore_index=self.ignore_id, reduction=reduction_str, ) else: self.loss_ce_mono = 0 self.loss_ce = 0.6 * self.loss_ce_tri + 0.4 * self.loss_ce_mono self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_pad, ignore_label=self.ignore_id) # 4. compute cer/wer if self.training or self.error_calculator is None: cer, wer, cer_ctc = None, None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = self.loss_ce loss_ce_data = float(self.loss_ce) loss_ctc_data = None elif alpha == 1: self.loss = self.loss_ctc loss_ce_data = None loss_ctc_data = float(self.loss_ctc) else: self.loss = alpha * self.loss_ctc + (1 - alpha) * self.loss_ce loss_ce_data = float(self.loss_ce) loss_ctc_data = float(self.loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_ce_data, self.acc, cer_ctc, cer, wer, loss_data) else: pass return self.loss
def forward(self, xs_pad, ilens, ys_pad, enc_mask=None, dec_mask=None): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel batch_size = xs_pad.shape[0] src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) if isinstance(self.encoder.embed, EncoderConv2d): xs, hs_mask = self.encoder.embed(xs_pad, torch.sum(src_mask, 2).squeeze()) hs_mask = hs_mask.unsqueeze(1) else: xs, hs_mask = self.encoder.embed(xs_pad, src_mask) if enc_mask is not None: enc_mask = enc_mask[:, :hs_mask.shape[2], :hs_mask.shape[2]] enc_mask = enc_mask & hs_mask if enc_mask is not None else hs_mask hs_pad, _ = self.encoder.encoders(xs, enc_mask) if self.encoder.normalize_before: hs_pad = self.encoder.after_norm(hs_pad) # CTC forward ys = [y[y != self.ignore_id] for y in ys_pad] y_len = max([len(y) for y in ys]) ys_pad = ys_pad[:, :y_len] if dec_mask is not None: dec_mask = dec_mask[:, :y_len + 1, :hs_pad.shape[1]] self.hs_pad = hs_pad batch_size = xs_pad.size(0) if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # trigger mask hs_mask = hs_mask & dec_mask if dec_mask is not None else hs_mask # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) return self.loss, loss_ctc_data, loss_att_data, self.acc
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward Transformer encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel if xs_pad.size(1) != ys_pad.size(1): if xs_pad.size(1) < ys_pad.size(1): ys_pad = ys_pad[:, :xs_pad.size(1)].contiguous() else: raise ValueError("target size {} is smaller than input size {}".format(ys_pad.size(1), xs_pad.size(1))) src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # 2. post-processing layer for target dimension if self.outer: post_pad = self.poster(hs_pad) post_pad = post_pad.view(post_pad.size(0), -1, self.odim) if post_pad.size(1) != xs_pad.size(1): if post_pad.size(1) < xs_pad.size(1): xs_pad = xs_pad[:, :post_pad.size(1)].contiguous() else: raise ValueError("target size {} and pred size {} is mismatch".format(xs_pad.size(1), post_pad.size(1))) if self.residual: post_pad = post_pad + self.matcher_res(xs_pad) else: post_pad = torch.cat([post_pad, xs_pad], dim=-1) pred_pad = self.matcher(post_pad) else: pred_pad = self.poster(hs_pad) pred_pad = pred_pad.view(pred_pad.size(0), -1, self.odim) self.pred_pad = pred_pad if pred_pad.size(1) != ys_pad.size(1): if pred_pad.size(1) < ys_pad.size(1): ys_pad = ys_pad[:, :pred_pad.size(1)].contiguous() else: raise ValueError("target size {} and pred size {} is mismatch".format(ys_pad.size(1), pred_pad.size(1))) # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_pad) self.acc = th_accuracy( pred_pad.view(-1, self.odim), ys_pad, ignore_label=self.ignore_id ) # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(pred_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(pred_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # 3. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data ) else: pass # logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = mask_uniform(ys_pad, self.mask_token, self.eos, self.ignore_id) ys_mask = square_mask(ys_in_pad, self.eos) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # 4. compute ctc loss loss_ctc, cer_ctc = None, None loss_intermediate_ctc = 0.0 if self.mtlalpha > 0: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # for visualization if not self.training: self.ctc.softmax(hs_pad) if self.intermediate_ctc_weight > 0 and self.intermediate_ctc_layers: for hs_intermediate in hs_intermediates: # assuming hs_intermediates and hs_pad has same length / padding loss_inter = self.ctc( hs_intermediate.view(batch_size, -1, self.adim), hs_len, ys_pad) loss_intermediate_ctc += loss_inter loss_intermediate_ctc /= len(self.intermediate_ctc_layers) # 5. compute cer/wer if self.training or self.error_calculator is None or self.decoder is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None else: self.loss = (alpha * loss_ctc + self.intermediate_ctc_weight * loss_intermediate_ctc + (1 - alpha - self.intermediate_ctc_weight) * loss_att) loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 0. Extract target language ID # src_lang_ids = None tgt_lang_ids, tgt_lang_ids_src = None, None if self.multilingual: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining if self.one_to_many: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining if self.do_asr: tgt_lang_ids_src = ys_pad_src[:, 0:1] ys_pad_src = ys_pad_src[:, 1:] # remove target language ID in the beggining # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel # bs x max_ilens x idim if self.lang_tok == "encoder-pre-sum": lang_embed = self.language_embeddings(tgt_lang_ids) # bs x 1 x idim xs_pad = xs_pad + lang_embed src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2) # bs x 1 x max_ilens hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # hs_pad: bs x (max_ilens/4) x adim; hs_mask: bs x 1 x (max_ilens/4) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # bs x max_lens if self.do_asr: ys_in_pad_src, ys_out_pad_src = add_sos_eos(ys_pad_src, self.sos, self.eos, self.ignore_id) # bs x max_lens_src # replace <sos> with target language ID if self.replace_sos: ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1) if self.lang_tok == "decoder-pre": ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1) if self.do_asr: ys_in_pad_src = torch.cat([tgt_lang_ids_src, ys_in_pad_src[:, 1:]], dim=1) ys_mask = target_mask(ys_in_pad, self.ignore_id) # bs x max_lens x max_lens if self.do_asr: ys_mask_src = target_mask(ys_in_pad_src, self.ignore_id) # bs x max_lens x max_lens_src if self.wait_k_asr > 0: cross_mask = create_cross_mask(ys_in_pad, ys_in_pad_src, self.ignore_id, wait_k_cross=self.wait_k_asr) cross_mask_asr = create_cross_mask(ys_in_pad_src, ys_in_pad, self.ignore_id, wait_k_cross=-self.wait_k_asr) elif self.wait_k_st > 0: cross_mask = create_cross_mask(ys_in_pad, ys_in_pad_src, self.ignore_id, wait_k_cross=-self.wait_k_st) cross_mask_asr = create_cross_mask(ys_in_pad_src, ys_in_pad, self.ignore_id, wait_k_cross=self.wait_k_st) else: cross_mask = create_cross_mask(ys_in_pad, ys_in_pad_src, self.ignore_id, wait_k_cross=0) cross_mask_asr = create_cross_mask(ys_in_pad_src, ys_in_pad, self.ignore_id, wait_k_cross=0) pred_pad, pred_mask, pred_pad_asr, pred_mask_asr = self.dual_decoder(ys_in_pad, ys_mask, ys_in_pad_src, ys_mask_src, hs_pad, hs_mask, cross_mask, cross_mask_asr, cross_self=self.cross_self, cross_src=self.cross_src, cross_self_from=self.cross_self_from, cross_src_from=self.cross_src_from) self.pred_pad = pred_pad self.pred_pad_asr = pred_pad_asr pred_pad_mt = None # 3. compute attention loss loss_asr, loss_mt = 0.0, 0.0 loss_att = self.criterion(pred_pad, ys_out_pad) # compute loss loss_asr = self.criterion(pred_pad_asr, ys_out_pad_src) # Multi-task w/ MT if self.mt_weight > 0: # forward MT encoder ilens_mt = torch.sum(ys_pad_src != self.ignore_id, dim=1).cpu().numpy() # NOTE: ys_pad_src is padded with -1 ys_src = [y[y != self.ignore_id] for y in ys_pad_src] # parse padded ys_src ys_zero_pad_src = pad_list(ys_src, self.pad) # re-pad with zero ys_zero_pad_src = ys_zero_pad_src[:, :max(ilens_mt)] # for data parallel src_mask_mt = (~make_pad_mask(ilens_mt.tolist())).to(ys_zero_pad_src.device).unsqueeze(-2) # ys_zero_pad_src, ys_pad = self.target_forcing(ys_zero_pad_src, ys_pad) hs_pad_mt, hs_mask_mt = self.encoder_mt(ys_zero_pad_src, src_mask_mt) # forward MT decoder pred_pad_mt, _ = self.decoder(ys_in_pad, ys_mask, hs_pad_mt, hs_mask_mt) # compute loss loss_mt = self.criterion(pred_pad_mt, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) if pred_pad_asr is not None: self.acc_asr = th_accuracy(pred_pad_asr.view(-1, self.odim), ys_out_pad_src, ignore_label=self.ignore_id) else: self.acc_asr = 0.0 if pred_pad_mt is not None: self.acc_mt = th_accuracy(pred_pad_mt.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) else: self.acc_mt = 0.0 # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0 or self.asr_weight == 0: loss_ctc = 0.0 else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad_src) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad_src.cpu(), is_ctc=True) # 5. compute cer/wer cer, wer = None, None # TODO(hirofumi0810): fix later # if self.training or (self.asr_weight == 0 or self.mtlalpha == 1 or not (self.report_cer or self.report_wer)): # cer, wer = None, None # else: # ys_hat = pred_pad.argmax(dim=-1) # cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha self.loss = (1 - self.asr_weight - self.mt_weight) * loss_att + self.asr_weight * \ (alpha * loss_ctc + (1 - alpha) * loss_asr) + self.mt_weight * loss_mt loss_asr_data = float(alpha * loss_ctc + (1 - alpha) * loss_asr) loss_mt_data = None if self.mt_weight == 0 else float(loss_mt) loss_st_data = float(loss_att) # logging.info(f'loss_st_data={loss_st_data}') loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_asr_data, loss_mt_data, loss_st_data, self.acc_asr, self.acc_mt, self.acc, cer_ctc, cer, wer, 0.0, # TODO(hirofumi0810): bleu loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def recog(args): """Decode with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) model, train_args = load_trained_model(args.model) assert isinstance(model, ASRInterface) model.recog_args = args # gpu if args.ngpu == 1: gpu_id = list(range(args.ngpu)) logging.info('gpu id: ' + str(gpu_id)) model.cuda() device = torch.device("cuda" if args.ngpu > 0 else "cpu") model = model.to(device) # read json data with open(args.recog_json, 'rb') as f: js = json.load(f)['utts'] load_inputs_and_targets = LoadInputsAndTargets( mode='asr', load_output=True, sort_in_input_length=False, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={'train': False}) import kaldiio import time with torch.no_grad(), \ kaldiio.WriteHelper('ark,scp:{o}.ark,{o}.scp'.format(o=args.out)) as f: ys = [] xs = [] for idx, utt_id in enumerate(js.keys()): logging.info('(%d/%d) decoding ' + utt_id, idx, len(js.keys())) batch = [(utt_id, js[utt_id])] data = load_inputs_and_targets(batch) feat = data[0][0] ys.append(data[1][0]) # x = torch.LongTensor(x).to(device) # decode and write start_time = time.time() # include the inference here # have the layer specification here # skeleton model.inference(x, args, layer) scores, outs = model.inference(feat, ys, args, train_args.char_list) xs.append(scores) logging.info("inference speed = %s msec / frame." % ( (time.time() - start_time) / (int(outs.size(0)) * 1000))) logging.warning("output length reaches maximum length (%s)." % utt_id) logging.info('(%d/%d) %s (size:%d->%d)' % ( idx + 1, len(js.keys()), utt_id, len(feat), outs.size(0))) f[utt_id] = outs.cpu().numpy() from espnet.nets.pytorch_backend.nets_utils import th_accuracy preds = torch.stack(xs).view(len(xs), -1) labels = torch.LongTensor(ys).view(len(xs), 1) acc = th_accuracy(preds, labels, -1) logging.warn("Final acc is (%.2f)" % (acc*100))
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # mlp moe forward cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask) # gated add module """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b) #(B, T, 1) xs = lambda * cn_xs + (1-lambda) * en_xs """ hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1) lambda_ = self.enc_lambda hs_pad = lambda_ * cn_hs_pad + (1 - lambda_) * en_hs_pad self.hs_pad = hs_pad # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) if self.mtlalpha == 1: self.loss_att, acc = None, None else: # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) self.acc = acc # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 0. Extract target language ID tgt_lang_ids = None if self.multilingual: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # replace <sos> with target language ID if self.replace_sos: ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) # 3. compute ST loss loss_asr_att, loss_asr_ctc, loss_mt = 0.0, 0.0, 0.0 acc_asr, acc_mt = 0.0, 0.0 loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # 4. compute corpus-level bleu in a mini-batch if self.training: self.bleu = None else: ys_hat = pred_pad.argmax(dim=-1) self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # 5. compute auxiliary ASR loss cer, wer = None, None cer_ctc = None if self.asr_weight > 0: # attention if self.mtlalpha < 1: ys_in_pad_asr, ys_out_pad_asr = add_sos_eos( ys_pad_src, self.sos, self.eos, self.ignore_id) ys_mask_asr = target_mask(ys_in_pad_asr, self.ignore_id) pred_pad_asr, _ = self.decoder_asr(ys_in_pad_asr, ys_mask_asr, hs_pad, hs_mask) loss_asr_att = self.criterion(pred_pad_asr, ys_out_pad_asr) acc_asr = th_accuracy( pred_pad_asr.view(-1, self.odim), ys_out_pad_asr, ignore_label=self.ignore_id, ) if not self.training: ys_hat_asr = pred_pad_asr.argmax(dim=-1) cer, wer = self.error_calculator_asr( ys_hat_asr.cpu(), ys_pad_src.cpu()) # CTC if self.mtlalpha > 0: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_asr_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad_src) ys_hat_ctc = self.ctc.argmax( hs_pad.view(batch_size, -1, self.adim)).data if not self.training: cer_ctc = self.error_calculator_asr(ys_hat_ctc.cpu(), ys_pad_src.cpu(), is_ctc=True) # 6. compute auxiliary MT loss if self.mt_weight > 0: ilens_mt = torch.sum(ys_pad_src != self.ignore_id, dim=1).cpu().numpy() # NOTE: ys_pad_src is padded with -1 ys_src = [y[y != self.ignore_id] for y in ys_pad_src] # parse padded ys_src ys_zero_pad_src = pad_list(ys_src, self.pad) # re-pad with zero ys_zero_pad_src = ys_zero_pad_src[:, :max( ilens_mt)] # for data parallel src_mask_mt = ((~make_pad_mask(ilens_mt.tolist())).to( ys_zero_pad_src.device).unsqueeze(-2)) hs_pad_mt, hs_mask_mt = self.encoder_mt(ys_zero_pad_src, src_mask_mt) pred_pad_mt, _ = self.decoder(ys_in_pad, ys_mask, hs_pad_mt, hs_mask_mt) loss_mt = self.criterion(pred_pad_mt, ys_out_pad) acc_mt = th_accuracy(pred_pad_mt.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) alpha = self.mtlalpha self.loss = ((1 - self.asr_weight - self.mt_weight) * loss_att + self.asr_weight * (alpha * loss_asr_ctc + (1 - alpha) * loss_asr_att) + self.mt_weight * loss_mt) loss_asr_data = float(alpha * loss_asr_ctc + (1 - alpha) * loss_asr_att) loss_mt_data = None if self.mt_weight == 0 else float(loss_mt) loss_st_data = float(loss_att) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_asr_data, loss_mt_data, loss_st_data, acc_asr, acc_mt, self.acc, cer_ctc, cer, wer, self.bleu, loss_data, ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder moe_coes = moe_coes[:, :max(moe_coe_lens)].long() # for data parallel xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # mlp moe forward cn_hs_pad, cn_hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, en_hs_mask = self.en_encoder(xs_pad, src_mask) # gated add module """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b) #(B, T, 1) xs = lambda * cn_xs + (1-lambda) * en_xs """ hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1) lambda_ = F.softmax( self.aggre_scaling * self.aggregation_module(hs_pad), -1).unsqueeze(-1) ctc_hs_pad = lambda_[:, :, 0] * cn_hs_pad + lambda_[:, :, 1] * en_hs_pad ctc_hs_mask = cn_hs_mask # plat attention mode, (B,T,D)*2 --> (B,2T,D) s2s_hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=1) # mask: (B,1,T) --> (B,1,2T) s2s_hs_mask = torch.cat((cn_hs_mask, en_hs_mask), dim=-1) # self.hs_pad = hs_pad # compute lid loss here, using lambda_ # moe_coes (B, T, 2) ==> (B,T) moe_coes = moe_coes[:, :, 0] # 0 for cn, 1 for en lambda_ = lambda_.squeeze(-1) if self.lid_mtl_alpha == 0.0: loss_lid = 0.0 else: loss_lid = self.lid_criterion(lambda_, moe_coes) lid_acc = th_accuracy( lambda_.view(-1, 2), moe_coes, ignore_label=self.ignore_id) if self.log_lid_mtl_acc else None # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = ctc_hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(ctc_hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax( ctc_hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) if self.mtlalpha == 1: self.loss_att, acc = None, None else: # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, s2s_hs_pad, s2s_hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) self.acc = acc # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha lid_alpha = self.lid_mtl_alpha if alpha == 0: self.loss = loss_att + lid_alpha * loss_lid loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc + lid_alpha * loss_lid loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + ( 1 - alpha) * loss_att + lid_alpha * loss_lid loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data, lid_acc) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None): """Decoder forward :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D) [in multi-encoder case, list of torch.Tensor, [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ] :param torch.Tensor hlens: batch of lengths of hidden state sequences (B) [in multi-encoder case, list of torch.Tensor, [(B), (B), ..., ] :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :param int strm_idx: stream index indicates the index of decoding stream. :param torch.Tensor lang_ids: batch of target language id tensor (B, 1) :return: attention loss value :rtype: torch.Tensor :return: accuracy :rtype: float """ # to support mutiple encoder asr mode, in single encoder mode, # convert torch.Tensor to List of torch.Tensor if self.num_encs == 1: hs_pad = [hs_pad] hlens = [hlens] # TODO(kan-bayashi): need to make more smart way ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys # attention index for the attention module # in SPA (speaker parallel attention), # att_idx is used to select attention module. In other cases, it is 0. att_idx = min(strm_idx, len(self.att) - 1) # hlens should be list of list of integer hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)] self.loss = None # prepare input and output word sequences with sos/eos IDs eos = ys[0].new([self.eos]) sos = ys[0].new([self.sos]) if self.replace_sos: ys_in = [ torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys) ] else: ys_in = [torch.cat([sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, eos], dim=0) for y in ys] # padding for ys with -1 # pys: utt x olen ys_in_pad = pad_list(ys_in, self.eos) ys_out_pad = pad_list(ys_out, self.ignore_id) # get dim, length info batch = ys_out_pad.size(0) olength = ys_out_pad.size(1) for idx in range(self.num_encs): logging.info(self.__class__.__name__ + "Number of Encoder:{}; enc{}: input lengths: {}.". format(self.num_encs, idx + 1, hlens[idx])) logging.info(self.__class__.__name__ + " output lengths: " + str([y.size(0) for y in ys_out])) # initialization c_list = [self.zero_state(hs_pad[0])] z_list = [self.zero_state(hs_pad[0])] for _ in six.moves.range(1, self.dlayers): c_list.append(self.zero_state(hs_pad[0])) z_list.append(self.zero_state(hs_pad[0])) z_all = [] if self.num_encs == 1: att_w = None self.att[att_idx].reset() # reset pre-computation of h else: att_w_list = [None] * (self.num_encs + 1) # atts + han att_c_list = [None] * (self.num_encs) # atts for idx in range(self.num_encs + 1): self.att[idx].reset( ) # reset pre-computation of h in atts and han # pre-computation of embedding eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim # loop for an output sequence for i in six.moves.range(olength): if self.num_encs == 1: att_c, att_w = self.att[att_idx]( hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w) else: for idx in range(self.num_encs): att_c_list[idx], att_w_list[idx] = self.att[idx]( hs_pad[idx], hlens[idx], self.dropout_dec[0](z_list[0]), att_w_list[idx], ) hs_pad_han = torch.stack(att_c_list, dim=1) hlens_han = [self.num_encs] * len(ys_in) att_c, att_w_list[self.num_encs] = self.att[self.num_encs]( hs_pad_han, hlens_han, self.dropout_dec[0](z_list[0]), att_w_list[self.num_encs], ) if i > 0 and random.random() < self.sampling_probability: logging.info(" scheduled sampling ") z_out = self.output(z_all[-1]) z_out = np.argmax(z_out.detach().cpu(), axis=1) z_out = self.dropout_emb(self.embed(to_device(self, z_out))) ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim) else: ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim) z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) if self.context_residual: z_all.append( torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)) # utt x (zdim + hdim) else: z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim) z_all = torch.stack(z_all, dim=1).view(batch * olength, -1) # compute loss y_all = self.output(z_all) logging.info("max:{}".format(torch.max(y_all, dim=1))) if LooseVersion(torch.__version__) < LooseVersion("1.0"): reduction_str = "elementwise_mean" else: reduction_str = "mean" self.loss = F.cross_entropy( y_all, ys_out_pad.view(-1), ignore_index=self.ignore_id, reduction=reduction_str, ) logging.info("loss_att:{}".format(self.loss)) # compute perplexity ppl = math.exp(self.loss.item()) # -1: eos, which is removed in the loss computation self.loss *= np.mean([len(x) for x in ys_in]) - 1 acc = th_accuracy(y_all, ys_out_pad, ignore_label=self.ignore_id) logging.info("att loss:" + "".join(str(self.loss.item()).split("\n"))) # show predicted character sequence for debug if self.verbose > 0 and self.char_list is not None: ys_hat = y_all.view(batch, olength, -1) ys_true = ys_out_pad for (i, y_hat), y_true in zip( enumerate(ys_hat.detach().cpu().numpy()), ys_true.detach().cpu().numpy()): if i == MAX_DECODER_OUTPUT: break idx_hat = np.argmax(y_hat[y_true != self.ignore_id], axis=1) idx_true = y_true[y_true != self.ignore_id] seq_hat = [self.char_list[int(idx)] for idx in idx_hat] seq_true = [self.char_list[int(idx)] for idx in idx_true] seq_hat = "".join(seq_hat) seq_true = "".join(seq_true) logging.info("groundtruth[%d]: " % i + seq_true) logging.info("prediction [%d]: " % i + seq_hat) if self.labeldist is not None: if self.vlabeldist is None: self.vlabeldist = to_device(self, torch.from_numpy(self.labeldist)) loss_reg = -torch.sum( (F.log_softmax(y_all, dim=1) * self.vlabeldist).view(-1), dim=0) / len(ys_in) self.loss = ( 1.0 - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg return self.loss, acc, ppl
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # CTC forward ys = [y[y != self.ignore_id] for y in ys_pad] y_len = max([len(y) for y in ys]) ys_pad = ys_pad[:, :y_len] self.hs_pad = hs_pad cer_ctc = None batch_size = xs_pad.size(0) if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # trigger mask start_time = time.time() # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) return self.loss, loss_ctc_data, loss_att_data, self.acc
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder if self.decoder is not None: if self.decoder_mode == "maskctc": ys_in_pad, ys_out_pad = mask_uniform( ys_pad, self.mask_token, self.eos, self.ignore_id ) ys_mask = (ys_in_pad != self.ignore_id).unsqueeze(-2) else: ys_in_pad, ys_out_pad = add_sos_eos( ys_pad, self.sos, self.eos, self.ignore_id ) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy( pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id ) else: loss_att = None self.acc = None # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if not self.training and self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # for visualization if not self.training: self.ctc.softmax(hs_pad) # 5. compute cer/wer if self.training or self.error_calculator is None or self.decoder is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder moe_coes = moe_coes[:, :max(moe_coe_lens)] # for data parallel # here we use interpolation_coe to 'fix' initial moe_coes interp_factor = self.interp_factor # 0.1 for example, similar to lsm moe_coes = ( 1 - interp_factor) * moe_coes + interp_factor / moe_coes.shape[2] xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # multi-encoder forward cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask) moe_coes = moe_coes.unsqueeze(-1) hs_pad = cn_hs_pad * moe_coes[:, :, 1] + en_hs_pad * moe_coes[:, :, 0] self.hs_pad = hs_pad # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) if self.mtlalpha == 1: self.loss_att, acc = None, None else: # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) self.acc = acc # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss