def _calc_mlm_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ): # 1. Apply masks ys_in_pad, ys_out_pad = mask_uniform( ys_pad, self.mask_token, self.eos, self.ignore_id ) # 2. Forward decoder decoder_out, _ = self.decoder( encoder_out, encoder_out_lens, ys_in_pad, ys_pad_lens ) # 3. Compute mlm loss loss_mlm = self.criterion_mlm(decoder_out, ys_out_pad) acc_mlm = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) return loss_mlm, acc_mlm
def test_mask(): args = make_arg() model, x, ilens, y, data, uttid_list = prepare(args) # check <sos>/<eos>, <mask> position n_char = len(args.char_list) + 1 assert model.sos == n_char - 2 assert model.eos == n_char - 2 assert model.mask_token == n_char - 1 yi, yo = mask_uniform(y, model.mask_token, model.eos, model.ignore_id) assert ((yi == model.mask_token).detach().numpy() == ( yo != model.ignore_id).detach().numpy()).all()
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = mask_uniform(ys_pad, self.mask_token, self.eos, self.ignore_id) ys_mask = square_mask(ys_in_pad, self.eos) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # 4. compute ctc loss loss_ctc, cer_ctc = None, None loss_intermediate_ctc = 0.0 if self.mtlalpha > 0: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # for visualization if not self.training: self.ctc.softmax(hs_pad) if self.intermediate_ctc_weight > 0 and self.intermediate_ctc_layers: for hs_intermediate in hs_intermediates: # assuming hs_intermediates and hs_pad has same length / padding loss_inter = self.ctc( hs_intermediate.view(batch_size, -1, self.adim), hs_len, ys_pad) loss_intermediate_ctc += loss_inter loss_intermediate_ctc /= len(self.intermediate_ctc_layers) # 5. compute cer/wer if self.training or self.error_calculator is None or self.decoder is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None else: self.loss = (alpha * loss_ctc + self.intermediate_ctc_weight * loss_intermediate_ctc + (1 - alpha - self.intermediate_ctc_weight) * loss_att) loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss