def forward_asr(self, hs_pad, hs_mask, ys_pad): """Forward pass in the auxiliary ASR task. :param torch.Tensor hs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor hs_mask: batch of input token mask (B, Lmax) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ASR attention loss value :rtype: torch.Tensor :return: accuracy in ASR attention decoder :rtype: float :return: ASR CTC loss value :rtype: torch.Tensor :return: character error rate from CTC prediction :rtype: float :return: character error rate from attetion decoder prediction :rtype: float :return: word error rate from attetion decoder prediction :rtype: float """ loss_att, loss_ctc = 0.0, 0.0 acc = None cer, wer = None, None cer_ctc = None if self.asr_weight == 0: return loss_att, acc, loss_ctc, cer_ctc, cer, wer # attention if self.mtlalpha < 1: ys_in_pad_asr, ys_out_pad_asr = add_sos_eos( ys_pad, self.sos, self.eos, self.ignore_id) ys_mask_asr = target_mask(ys_in_pad_asr, self.ignore_id) pred_pad, _ = self.decoder_asr(ys_in_pad_asr, ys_mask_asr, hs_pad, hs_mask) loss_att = self.criterion(pred_pad, ys_out_pad_asr) acc = th_accuracy( pred_pad.view(-1, self.odim), ys_out_pad_asr, ignore_label=self.ignore_id, ) if not self.training: ys_hat_asr = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator_asr(ys_hat_asr.cpu(), ys_pad.cpu()) # CTC if self.mtlalpha > 0: batch_size = hs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if not self.training: ys_hat_ctc = self.ctc.argmax( hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator_asr(ys_hat_ctc.cpu(), ys_pad.cpu(), is_ctc=True) # for visualization self.ctc.softmax(hs_pad) return loss_att, acc, loss_ctc, cer_ctc, cer, wer
def store_penultimate_state(self, xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens): moe_coes = moe_coes[:, :max(moe_coe_lens)] # for data parallel xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # multi-encoder forward cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask) moe_coes = moe_coes.unsqueeze(-1) hs_pad = cn_hs_pad * moe_coes[:, :, 1] + en_hs_pad * moe_coes[:, :, 0] self.hs_pad = hs_pad # forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask, penultimate_state = self.decoder( ys_in_pad, ys_mask, hs_pad, hs_mask, moe_coes, return_penultimate_state=True) # plot penultimate_state, (B,T,att_dim) return penultimate_state.squeeze(0).detach().cpu().numpy()
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # TODO(karita) show predicted text # TODO(karita) calculate these stats # 5. compute bleu if self.training or self.error_calculator is None: bleu = 0.0 else: ys_hat = pred_pad.argmax(dim=-1) bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_mt self.loss = loss loss_data = float(self.loss) if self.normalize_length: self.ppl = np.exp(loss_data) else: ys_out_pad = ys_out_pad.view(-1) ignore = ys_out_pad == self.ignore_id # (B,) total = len(ys_out_pad) - ignore.sum().item() self.ppl = np.exp(loss_data * ys_out_pad.size(0) / total) if not math.isnan(loss_data): self.reporter.report(loss_data, self.acc, self.ppl, bleu) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. Args: xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 1. encoder if self.etype == 'transformer': xs_pad = xs_pad[:, :max(ilens)] src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) else: hs_pad, hlens = xs_pad, ilens hs_pad, hlens, _ = self.encoder(hs_pad, hlens) hs_mask = hlens self.hs_pad = hs_pad # 1.5. transducer preparation related ys_in_pad, target, pred_len, target_len = prepare_loss_inputs( ys_pad, hs_mask) # 2. decoder if self.dtype == 'transformer': ys_mask = target_mask(ys_in_pad, self.blank_id) pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad) else: if self.rnnt_mode == 'rnnt': pred_pad = self.decoder(hs_pad, ys_in_pad) else: pred_pad = self.decoder(hs_pad, ys_in_pad, pred_len) self.pred_pad = pred_pad # 3. loss computation loss = self.criterion(pred_pad, target, pred_len, target_len) self.loss = loss loss_data = float(self.loss) # 4. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: cer, wer = self.error_calculator(hs_pad, ys_pad) if not math.isnan(loss_data): self.reporter.report(loss_data, cer, wer) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def test_transformer_mask(): args = make_arg() model, x, ilens, y, data, uttid_list = prepare("pytorch", args) yi, yo = add_sos_eos(y, model.sos, model.eos, model.ignore_id) y_mask = target_mask(yi, model.ignore_id) y = model.decoder.embed(yi) y[0, 3:] = float("nan") a = model.decoder.decoders[0].self_attn a(y, y, y, y_mask) assert not numpy.isnan(a.attn[0, :, :3, :3].detach().numpy()).any()
def test_transformer_mask(module): model, x, ilens, y, data = prepare(module) from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos from espnet.nets.pytorch_backend.transformer.mask import target_mask yi, yo = add_sos_eos(y, model.sos, model.eos, model.ignore_id) y_mask = target_mask(yi, model.ignore_id) y = model.decoder.embed(yi) y[0, 3:] = float("nan") a = model.decoder.decoders[0].self_attn a(y, y, y, y_mask) assert not numpy.isnan(a.attn[0, :, :3, :3].detach().numpy()).any()
def forward(self, feats, feats_lengths): src_mask = make_non_pad_mask(feats_lengths.tolist()).to(feats.device).unsqueeze(-2) src_mask = src_mask.to(feats.device) h, hs_mask = self.encoder.encoder(feats, src_mask) ys_in_pad = self.encoder.init_decoder(h, .9) ys_mask = target_mask(ys_in_pad, ignore_id=self.encoder.eos).to(ys_in_pad.device) _, mask, encodings = self.encoder.decoder(ys_in_pad, ys_mask, h, hs_mask, return_hidden=True) encoding_lengths = mask.sum(-1)[:, -1] logits = self.classifier(encodings, encoding_lengths) return logits
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. Args: xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 1. encoder xs_pad = xs_pad[:, :max(ilens)] if "custom" in self.etype: src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) else: hs_pad, hs_mask, _ = self.enc(xs_pad, ilens) # 1.5. transducer preparation related ys_in_pad, target, pred_len, target_len = prepare_loss_inputs( ys_pad, hs_mask) # 2. decoder if "custom" in self.dtype: ys_mask = target_mask(ys_in_pad, self.blank_id) pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad) else: pred_pad = self.dec(hs_pad, ys_in_pad) z = self.joint_network(hs_pad.unsqueeze(2), pred_pad.unsqueeze(1)) # 3. loss computation loss = self.criterion(z, target, pred_len, target_len) self.loss = loss loss_data = float(loss) # 4. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: cer, wer = self.error_calculator(hs_pad, ys_pad) if not math.isnan(loss_data): self.reporter.report(loss_data, cer, wer) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) # 3. compute attention loss self.loss = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # 4. compute corpus-level bleu in a mini-batch if self.training: self.bleu = None else: ys_hat = pred_pad.argmax(dim=-1) self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) loss_data = float(self.loss) if self.normalize_length: self.ppl = np.exp(loss_data) else: batch_size = ys_out_pad.size(0) ys_out_pad = ys_out_pad.view(-1) ignore = ys_out_pad == self.ignore_id # (B*T,) total_n_tokens = len(ys_out_pad) - ignore.sum().item() self.ppl = np.exp(loss_data * batch_size / total_n_tokens) if not math.isnan(loss_data): self.reporter.report(loss_data, self.acc, self.ppl, self.bleu) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def store_penultimate_state(self, xs_pad, ilens, ys_pad, bnf_feats, bnf_feats_lens): bnf_feats = bnf_feats[:, :max(bnf_feats_lens)] # for data parallel xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask, bnf_feats) self.hs_pad = hs_pad # forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask, penultimate_state = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask, return_penultimate_state=True) # plot penultimate_state, (B,T,att_dim) return penultimate_state.squeeze(0).detach().cpu().numpy()
def decoder_and_attention(self, hs_pad, hs_mask, ys_pad, batch_size): """Forward decoder and attention loss.""" # forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) # compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) return pred_pad, pred_mask, loss_att, acc
def test_sa_transducer_mask(module): from espnet.nets.pytorch_backend.nets_utils import make_pad_mask from espnet.nets.pytorch_backend.transducer.utils import prepare_loss_inputs from espnet.nets.pytorch_backend.transformer.mask import target_mask train_args = make_train_args() model, x, ilens, y, data = prepare(module, train_args) # dummy mask x_mask = (~make_pad_mask(ilens.tolist())).to(x.device).unsqueeze(-2) _, target, _, _ = prepare_loss_inputs(y, x_mask) y_mask = target_mask(target, model.blank_id) y = model.decoder.embed(target.type(torch.long)) y[0, 3:] = float("nan") a = model.decoder.decoders[0].self_attn a(y, y, y, y_mask) assert not numpy.isnan(a.attn[0, :, :3, :3].detach().numpy()).any()
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. Args: xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 1. encoder xs_pad = xs_pad[:, :max(ilens)] if "transformer" in self.etype: src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) batchsize = xs_pad.size(0) inputs = xs_pad.unsqueeze(1) logging.info("inputs:{}".format(inputs.shape)) logging.info("src_mask:{}".format(src_mask.shape)) inputs_length = [] if src_mask is not None: for mask in src_mask.tolist(): inputs_length.append(mask[0].count(True)) for i in range(batchsize): inputs_s = inputs[i].unsqueeze(0)[:, :, 0:inputs_length[i], :] core_out = self.conv(inputs_s) inputs_length[i] = core_out.size(2) inputs_length = torch.as_tensor(inputs_length) else: core_out = self.conv(inputs) inputs_length = core_out.size(2) inputs_length = torch.as_tensor(inputs_length) logging.info("inputs_length:{}".format(inputs_length)) # block 1 # the inputs shape of Conv2d is 4-dim of (bsz * c * l * w) # the inputs shape of Conv1d is 3-dim of (bsz * c * l) # the inputs shape of transformer is 3-dim of (l * bsz * c) # conv output format: (bsz * c * t * d) inputs = self.conv(inputs) # we can get a batch of 16 channels feature maps in all time steps # merge 16 channels of one timestep to create one self-attention input (batch, 16, dim) inputs = inputs.permute(2, 0, 1, 3) logging.info("inputs:{}".format(inputs.shape)) merge = torch.zeros(inputs.size(0), batchsize, 512) for t in range(inputs.size(0)): # max_length merge[t] = self.clayers(inputs[t], None)[0].reshape(batchsize, 512) xs = merge.permute(1, 0, 2) if inputs_length.dim() == 0: masks = make_non_pad_mask([inputs_length]).unsqueeze(-2) else: masks = make_non_pad_mask(inputs_length.tolist()).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs, masks) else: hs_pad, hs_mask, _ = self.enc(xs_pad, ilens) self.hs_pad = hs_pad # 1.5. transducer preparation related ys_in_pad, target, pred_len, target_len = prepare_loss_inputs( ys_pad, hs_mask) # 2. decoder if "transformer" in self.dtype: ys_mask = target_mask(ys_in_pad, self.blank_id) pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad) else: if self.rnnt_mode == "rnnt": pred_pad = self.dec(hs_pad, ys_in_pad) else: pred_pad = self.dec(hs_pad, ys_in_pad, pred_len) self.pred_pad = pred_pad # 3. loss computation loss = self.criterion(pred_pad, target, pred_len, target_len) self.loss = loss loss_data = float(self.loss) # 4. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: cer, wer = self.error_calculator(hs_pad, ys_pad) if not math.isnan(loss_data): self.reporter.report(loss_data, cer, wer) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # mlp moe forward cn_hs_pad, cn_hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, en_hs_mask = self.en_encoder(xs_pad, src_mask) hs_mask = cn_hs_mask # cn_hs_mask & en_hs_mask are identical # gated add module """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b) #(B, T, 1) xs = lambda * cn_xs + (1-lambda) * en_xs """ hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1) lambda_ = self.aggregation_module( hs_pad) # (B,T,1)/(B,T,D), range from (0, 1) hs_pad = lambda_ * cn_hs_pad + (1 - lambda_) * en_hs_pad self.hs_pad = hs_pad # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) # divide ys_pad into cn_ys & en_ys; # note that this target can directly pass to ctc module cn_ys, en_ys = partial_target(ys_pad, self.language_divider) cn_loss_ctc = self.cn_ctc( cn_hs_pad.view(batch_size, -1, self.adim), hs_len, cn_ys) en_loss_ctc = self.en_ctc( en_hs_pad.view(batch_size, -1, self.adim), hs_len, en_ys) loss_ctc = 0.5 * (cn_loss_ctc + en_loss_ctc) if self.mtlalpha == 1: self.loss_att, acc = None, None else: # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) self.acc = acc # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, None, None, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, feats: torch.Tensor, feats_len: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """E2E forward. Args: feats: Feature sequences. (B, F, D_feats) feats_len: Feature sequences lengths. (B,) labels: Label ID sequences. (B, L) Returns: loss: Transducer loss value """ # 1. encoder feats = feats[:, :max(feats_len)] if self.etype == "custom": feats_mask = (make_non_pad_mask(feats_len.tolist()).to( feats.device).unsqueeze(-2)) _enc_out, _enc_out_len = self.encoder(feats, feats_mask) else: _enc_out, _enc_out_len, _ = self.enc(feats, feats_len) if self.use_auxiliary_enc_outputs: enc_out, aux_enc_out = _enc_out[0], _enc_out[1] enc_out_len, aux_enc_out_len = _enc_out_len[0], _enc_out_len[1] else: enc_out, aux_enc_out = _enc_out, None enc_out_len, aux_enc_out_len = _enc_out_len, None # 2. decoder dec_in = get_decoder_input(labels, self.blank_id, self.ignore_id) if self.dtype == "custom": self.decoder.set_device(enc_out.device) dec_in_mask = target_mask(dec_in, self.blank_id) dec_out, _ = self.decoder(dec_in, dec_in_mask) else: self.dec.set_device(enc_out.device) dec_out = self.dec(dec_in) # 3. transducer tasks computation losses = self.transducer_tasks( enc_out, aux_enc_out, dec_out, labels, enc_out_len, aux_enc_out_len, ) if self.training or self.error_calculator is None: cer, wer = None, None else: cer, wer = self.error_calculator( enc_out, self.transducer_tasks.get_target()) self.loss = sum(losses) loss_data = float(self.loss) if not math.isnan(loss_data): self.reporter.report( loss_data, *[float(loss) for loss in losses], cer, wer, ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float ''' """ if self.attention_enc_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]: for layer in self.encoder.encoders: layer.self_attn.clamp_param() if self.attention_dec_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]: for layer in self.decoder.decoders: layer.self_attn.clamp_param() # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) # xkc09 Span attention loss computation # xkc09 Span attention size loss computation loss_span = 0 if self.attention_enc_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_dynamic_span2' ]: loss_span += sum([ layer.self_attn.get_mean_span() for layer in self.encoder.encoders ]) if self.attention_dec_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_dynamic_span2' ]: loss_span += sum([ layer.self_attn.get_mean_span() for layer in self.decoder.decoders ]) # xkc09 Span attention ratio loss computation loss_ratio = 0 if self.ratio_adaptive: # target_ratio = 0.5 if self.attention_enc_type in [ 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]: loss_ratio += sum([ 1 - layer.self_attn.get_mean_ratio() for layer in self.encoder.encoders ]) if self.attention_dec_type in [ 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]: loss_ratio += sum([ 1 - layer.self_attn.get_mean_ratio() for layer in self.decoder.decoders ]) if (self.attention_enc_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ] or self.attention_dec_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]): if getattr(self, 'span_loss_coef', None): self.loss += (loss_span + loss_ratio) * self.span_loss_coef loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. Args: xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 1. encoder # if etpye is transformer, deal the padding # xs_pad:[8, 393, 83] # ilens:[393 * 8] # src_mask:[8, 1, 393] # hs_mask:[8, 1, 65] if self.etype == "transformer": xs_pad = xs_pad[:, :max(ilens)] src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) else: logging.info("enc!!!") hs_pad, hs_mask, _ = self.encoder(xs_pad, ilens) self.hs_pad = hs_pad # 1.5. transducer preparation related # ys_in_pad: sos,1,2,...,0 [8, 14] # target: 1,2,... [8, 13] # pred_len: [8] # target_len: [8] # ys_out_pad:1,2,...,eos,-1 ys = [y[y != self.ignore_id] for y in ys_pad] eos = ys[0].new([self.eos]) sos = ys[0].new([self.sos]) ys_in = [torch.cat([sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, eos], dim=0) for y in ys] ys_out_pad = pad_list(ys_out, self.ignore_id) ys_in_pad, target, pred_len, target_len = prepare_loss_inputs( ys_pad, hs_mask) # 2. decoder # ys_mask:[8, 16, 16] if self.dtype == "transformer": ys_mask = target_mask(ys_in_pad, self.blank_id) pred_pad, pred_att, _ = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) else: if self.rnnt_mode == "rnnt": pred_pad = self.dec(hs_pad, ys_in_pad) else: pred_pad = self.dec(hs_pad, ys_in_pad, pred_len) self.pred_pad = pred_pad # 3. loss computation loss_att = F.cross_entropy( pred_att, ys_out_pad.view(-1), # batch x olength ignore_index=self.ignore_id, ) # compute perplexity # ppl = math.exp(loss_att.item()) # -1: eos, which is removed in the loss computation loss_att *= np.mean([len(x) for x in ys_in]) - 1 loss_rnnt = self.criterion(pred_pad, target, pred_len, target_len) # loss_ctc = self.ctc(hs_pad, pred_len, ys_pad) alpha = self.mtlalpha beta = self.mtlbeta gamma = self.mtlgamma self.loss_rnnt = loss_rnnt self.loss_att = loss_att # self.loss_ctc = loss_ctc # self.loss = alpha * self.loss_ctc + beta * self.loss_rnnt + gamma * self.loss_att self.loss = beta * self.loss_rnnt + gamma * self.loss_att # self.loss = alpha * self.loss_ctc loss_data = float(self.loss) # loss_ctc_data = float(self.loss_ctc) loss_att_data = float(self.loss_att) loss_rnnt_data = float(self.loss_rnnt) # loss_att_data = None # loss_rnnt_data = None # 4. compute cer/wer if self.training or self.error_calculator is None: logging.info("ALL none!!!!!") cer, wer = None, None else: cer, wer = self.error_calculator(hs_pad, ys_pad) # with open('/home/oshindo/espnet/egs/aishell/asr1/exp/train_sp_pytorch_e2e_asr_transducer/blstmp_ctc.txt', "a+") as fid: # fid.write("loss:" + str(loss_ctc_data) + '\n') if not math.isnan(loss_data): self.reporter.report(loss_data, loss_rnnt_data, loss_att_data, cer, wer) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # mlp moe forward cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask) # gated add module """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b) #(B, T, 1) xs = lambda * cn_xs + (1-lambda) * en_xs """ hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1) lambda_ = self.enc_lambda hs_pad = lambda_ * cn_hs_pad + (1 - lambda_) * en_hs_pad self.hs_pad = hs_pad # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) if self.mtlalpha == 1: self.loss_att, acc = None, None else: # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) self.acc = acc # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder moe_coes = moe_coes[:, :max(moe_coe_lens)] # for data parallel # here we use interpolation_coe to 'fix' initial moe_coes interp_factor = self.interp_factor # 0.1 for example, similar to lsm moe_coes = ( 1 - interp_factor) * moe_coes + interp_factor / moe_coes.shape[2] xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # multi-encoder forward cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask) moe_coes = moe_coes.unsqueeze(-1) hs_pad = cn_hs_pad * moe_coes[:, :, 1] + en_hs_pad * moe_coes[:, :, 0] self.hs_pad = hs_pad # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) if self.mtlalpha == 1: self.loss_att, acc = None, None else: # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) self.acc = acc # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # CTC forward ys = [y[y != self.ignore_id] for y in ys_pad] y_len = max([len(y) for y in ys]) ys_pad = ys_pad[:, :y_len] self.hs_pad = hs_pad cer_ctc = None batch_size = xs_pad.size(0) if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # trigger mask start_time = time.time() # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) return self.loss, loss_ctc_data, loss_att_data, self.acc
def forward(self, xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder moe_coes = moe_coes[:, :max(moe_coe_lens)].long() # for data parallel xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # mlp moe forward cn_hs_pad, cn_hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, en_hs_mask = self.en_encoder(xs_pad, src_mask) # gated add module """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b) #(B, T, 1) xs = lambda * cn_xs + (1-lambda) * en_xs """ hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1) lambda_ = F.softmax( self.aggre_scaling * self.aggregation_module(hs_pad), -1).unsqueeze(-1) ctc_hs_pad = lambda_[:, :, 0] * cn_hs_pad + lambda_[:, :, 1] * en_hs_pad ctc_hs_mask = cn_hs_mask # plat attention mode, (B,T,D)*2 --> (B,2T,D) s2s_hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=1) # mask: (B,1,T) --> (B,1,2T) s2s_hs_mask = torch.cat((cn_hs_mask, en_hs_mask), dim=-1) # self.hs_pad = hs_pad # compute lid loss here, using lambda_ # moe_coes (B, T, 2) ==> (B,T) moe_coes = moe_coes[:, :, 0] # 0 for cn, 1 for en lambda_ = lambda_.squeeze(-1) if self.lid_mtl_alpha == 0.0: loss_lid = 0.0 else: loss_lid = self.lid_criterion(lambda_, moe_coes) lid_acc = th_accuracy( lambda_.view(-1, 2), moe_coes, ignore_label=self.ignore_id) if self.log_lid_mtl_acc else None # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = ctc_hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(ctc_hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax( ctc_hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) if self.mtlalpha == 1: self.loss_att, acc = None, None else: # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, s2s_hs_pad, s2s_hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) self.acc = acc # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha lid_alpha = self.lid_mtl_alpha if alpha == 0: self.loss = loss_att + lid_alpha * loss_lid loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc + lid_alpha * loss_lid loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + ( 1 - alpha) * loss_att + lid_alpha * loss_lid loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data, lid_acc) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad, enc_mask=None, dec_mask=None): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel batch_size = xs_pad.shape[0] src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) if isinstance(self.encoder.embed, EncoderConv2d): xs, hs_mask = self.encoder.embed(xs_pad, torch.sum(src_mask, 2).squeeze()) hs_mask = hs_mask.unsqueeze(1) else: xs, hs_mask = self.encoder.embed(xs_pad, src_mask) if enc_mask is not None: enc_mask = enc_mask[:, :hs_mask.shape[2], :hs_mask.shape[2]] enc_mask = enc_mask & hs_mask if enc_mask is not None else hs_mask hs_pad, _ = self.encoder.encoders(xs, enc_mask) if self.encoder.normalize_before: hs_pad = self.encoder.after_norm(hs_pad) # CTC forward ys = [y[y != self.ignore_id] for y in ys_pad] y_len = max([len(y) for y in ys]) ys_pad = ys_pad[:, :y_len] if dec_mask is not None: dec_mask = dec_mask[:, :y_len + 1, :hs_pad.shape[1]] self.hs_pad = hs_pad batch_size = xs_pad.size(0) if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # trigger mask hs_mask = hs_mask & dec_mask if dec_mask is not None else hs_mask # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) return self.loss, loss_ctc_data, loss_att_data, self.acc
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask, penultimate_state = self.decoder( ys_in_pad, ys_mask, hs_pad, hs_mask, return_penultimate_state=True) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # 4. compute lid multitask loss src_att = self.lid_src_att(penultimate_state, hs_pad, hs_pad, hs_mask) pred_lid_pad = self.lid_output_layer(src_att) loss_lid, lid_ys_out_pad = self.lid_criterion(pred_lid_pad, ys_out_pad) lid_acc = th_accuracy( pred_lid_pad.view(-1, self.lid_odim), lid_ys_out_pad, ignore_label=self.ignore_id) if self.log_lid_mtl_acc else None # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha lid_alpha = self.lid_mtl_alpha if alpha == 0: self.loss = loss_att + lid_alpha * loss_lid loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: raise Exception("LID MTL not supports pure ctc mode") self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + ( 1 - alpha) * loss_att + lid_alpha * loss_lid loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data, lid_acc) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 0. Extract target language ID tgt_lang_ids = None if self.multilingual: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining # 1. forward encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # replace <sos> with target language ID if self.replace_sos: ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) # 3. compute ST loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy( pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id ) # 4. compute corpus-level bleu in a mini-batch if self.training: self.bleu = None else: ys_hat = pred_pad.argmax(dim=-1) self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # 5. compute auxiliary ASR loss loss_asr_att, acc_asr, loss_asr_ctc, cer_ctc, cer, wer = self.forward_asr( hs_pad, hs_mask, ys_pad_src ) # 6. compute auxiliary MT loss loss_mt, acc_mt = 0.0, None if self.mt_weight > 0: loss_mt, acc_mt = self.forward_mt( ys_pad_src, ys_in_pad, ys_out_pad, ys_mask ) asr_ctc_weight = self.mtlalpha self.loss = ( (1 - self.asr_weight - self.mt_weight) * loss_att + self.asr_weight * (asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att) + self.mt_weight * loss_mt ) loss_asr_data = float( asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att ) loss_mt_data = None if self.mt_weight == 0 else float(loss_mt) loss_st_data = float(loss_att) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_asr_data, loss_mt_data, loss_st_data, acc_asr, acc_mt, self.acc, cer_ctc, cer, wer, self.bleu, loss_data, ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 0. Extract target language ID tgt_lang_ids = None if self.multilingual: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # replace <sos> with target language ID if self.replace_sos: ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) # 3. compute ST loss loss_asr_att, loss_asr_ctc, loss_mt = 0.0, 0.0, 0.0 acc_asr, acc_mt = 0.0, 0.0 loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # 4. compute corpus-level bleu in a mini-batch if self.training: self.bleu = None else: ys_hat = pred_pad.argmax(dim=-1) self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # 5. compute auxiliary ASR loss cer, wer = None, None cer_ctc = None if self.asr_weight > 0: # attention if self.mtlalpha < 1: ys_in_pad_asr, ys_out_pad_asr = add_sos_eos( ys_pad_src, self.sos, self.eos, self.ignore_id) ys_mask_asr = target_mask(ys_in_pad_asr, self.ignore_id) pred_pad_asr, _ = self.decoder_asr(ys_in_pad_asr, ys_mask_asr, hs_pad, hs_mask) loss_asr_att = self.criterion(pred_pad_asr, ys_out_pad_asr) acc_asr = th_accuracy( pred_pad_asr.view(-1, self.odim), ys_out_pad_asr, ignore_label=self.ignore_id, ) if not self.training: ys_hat_asr = pred_pad_asr.argmax(dim=-1) cer, wer = self.error_calculator_asr( ys_hat_asr.cpu(), ys_pad_src.cpu()) # CTC if self.mtlalpha > 0: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_asr_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad_src) ys_hat_ctc = self.ctc.argmax( hs_pad.view(batch_size, -1, self.adim)).data if not self.training: cer_ctc = self.error_calculator_asr(ys_hat_ctc.cpu(), ys_pad_src.cpu(), is_ctc=True) # 6. compute auxiliary MT loss if self.mt_weight > 0: ilens_mt = torch.sum(ys_pad_src != self.ignore_id, dim=1).cpu().numpy() # NOTE: ys_pad_src is padded with -1 ys_src = [y[y != self.ignore_id] for y in ys_pad_src] # parse padded ys_src ys_zero_pad_src = pad_list(ys_src, self.pad) # re-pad with zero ys_zero_pad_src = ys_zero_pad_src[:, :max( ilens_mt)] # for data parallel src_mask_mt = ((~make_pad_mask(ilens_mt.tolist())).to( ys_zero_pad_src.device).unsqueeze(-2)) hs_pad_mt, hs_mask_mt = self.encoder_mt(ys_zero_pad_src, src_mask_mt) pred_pad_mt, _ = self.decoder(ys_in_pad, ys_mask, hs_pad_mt, hs_mask_mt) loss_mt = self.criterion(pred_pad_mt, ys_out_pad) acc_mt = th_accuracy(pred_pad_mt.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) alpha = self.mtlalpha self.loss = ((1 - self.asr_weight - self.mt_weight) * loss_att + self.asr_weight * (alpha * loss_asr_ctc + (1 - alpha) * loss_asr_att) + self.mt_weight * loss_mt) loss_asr_data = float(alpha * loss_asr_ctc + (1 - alpha) * loss_asr_att) loss_mt_data = None if self.mt_weight == 0 else float(loss_mt) loss_st_data = float(loss_att) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_asr_data, loss_mt_data, loss_st_data, acc_asr, acc_mt, self.acc, cer_ctc, cer, wer, self.bleu, loss_data, ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 0. Extract target language ID # src_lang_ids = None tgt_lang_ids, tgt_lang_ids_src = None, None if self.multilingual: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining if self.one_to_many: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining if self.do_asr: tgt_lang_ids_src = ys_pad_src[:, 0:1] ys_pad_src = ys_pad_src[:, 1:] # remove target language ID in the beggining # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel # bs x max_ilens x idim if self.lang_tok == "encoder-pre-sum": lang_embed = self.language_embeddings(tgt_lang_ids) # bs x 1 x idim xs_pad = xs_pad + lang_embed src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2) # bs x 1 x max_ilens hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # hs_pad: bs x (max_ilens/4) x adim; hs_mask: bs x 1 x (max_ilens/4) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # bs x max_lens if self.do_asr: ys_in_pad_src, ys_out_pad_src = add_sos_eos(ys_pad_src, self.sos, self.eos, self.ignore_id) # bs x max_lens_src # replace <sos> with target language ID if self.replace_sos: ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1) if self.lang_tok == "decoder-pre": ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1) if self.do_asr: ys_in_pad_src = torch.cat([tgt_lang_ids_src, ys_in_pad_src[:, 1:]], dim=1) ys_mask = target_mask(ys_in_pad, self.ignore_id) # bs x max_lens x max_lens if self.do_asr: ys_mask_src = target_mask(ys_in_pad_src, self.ignore_id) # bs x max_lens x max_lens_src if self.wait_k_asr > 0: cross_mask = create_cross_mask(ys_in_pad, ys_in_pad_src, self.ignore_id, wait_k_cross=self.wait_k_asr) cross_mask_asr = create_cross_mask(ys_in_pad_src, ys_in_pad, self.ignore_id, wait_k_cross=-self.wait_k_asr) elif self.wait_k_st > 0: cross_mask = create_cross_mask(ys_in_pad, ys_in_pad_src, self.ignore_id, wait_k_cross=-self.wait_k_st) cross_mask_asr = create_cross_mask(ys_in_pad_src, ys_in_pad, self.ignore_id, wait_k_cross=self.wait_k_st) else: cross_mask = create_cross_mask(ys_in_pad, ys_in_pad_src, self.ignore_id, wait_k_cross=0) cross_mask_asr = create_cross_mask(ys_in_pad_src, ys_in_pad, self.ignore_id, wait_k_cross=0) pred_pad, pred_mask, pred_pad_asr, pred_mask_asr = self.dual_decoder(ys_in_pad, ys_mask, ys_in_pad_src, ys_mask_src, hs_pad, hs_mask, cross_mask, cross_mask_asr, cross_self=self.cross_self, cross_src=self.cross_src, cross_self_from=self.cross_self_from, cross_src_from=self.cross_src_from) self.pred_pad = pred_pad self.pred_pad_asr = pred_pad_asr pred_pad_mt = None # 3. compute attention loss loss_asr, loss_mt = 0.0, 0.0 loss_att = self.criterion(pred_pad, ys_out_pad) # compute loss loss_asr = self.criterion(pred_pad_asr, ys_out_pad_src) # Multi-task w/ MT if self.mt_weight > 0: # forward MT encoder ilens_mt = torch.sum(ys_pad_src != self.ignore_id, dim=1).cpu().numpy() # NOTE: ys_pad_src is padded with -1 ys_src = [y[y != self.ignore_id] for y in ys_pad_src] # parse padded ys_src ys_zero_pad_src = pad_list(ys_src, self.pad) # re-pad with zero ys_zero_pad_src = ys_zero_pad_src[:, :max(ilens_mt)] # for data parallel src_mask_mt = (~make_pad_mask(ilens_mt.tolist())).to(ys_zero_pad_src.device).unsqueeze(-2) # ys_zero_pad_src, ys_pad = self.target_forcing(ys_zero_pad_src, ys_pad) hs_pad_mt, hs_mask_mt = self.encoder_mt(ys_zero_pad_src, src_mask_mt) # forward MT decoder pred_pad_mt, _ = self.decoder(ys_in_pad, ys_mask, hs_pad_mt, hs_mask_mt) # compute loss loss_mt = self.criterion(pred_pad_mt, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) if pred_pad_asr is not None: self.acc_asr = th_accuracy(pred_pad_asr.view(-1, self.odim), ys_out_pad_src, ignore_label=self.ignore_id) else: self.acc_asr = 0.0 if pred_pad_mt is not None: self.acc_mt = th_accuracy(pred_pad_mt.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) else: self.acc_mt = 0.0 # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0 or self.asr_weight == 0: loss_ctc = 0.0 else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad_src) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad_src.cpu(), is_ctc=True) # 5. compute cer/wer cer, wer = None, None # TODO(hirofumi0810): fix later # if self.training or (self.asr_weight == 0 or self.mtlalpha == 1 or not (self.report_cer or self.report_wer)): # cer, wer = None, None # else: # ys_hat = pred_pad.argmax(dim=-1) # cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha self.loss = (1 - self.asr_weight - self.mt_weight) * loss_att + self.asr_weight * \ (alpha * loss_ctc + (1 - alpha) * loss_asr) + self.mt_weight * loss_mt loss_asr_data = float(alpha * loss_ctc + (1 - alpha) * loss_asr) loss_mt_data = None if self.mt_weight == 0 else float(loss_mt) loss_st_data = float(loss_att) # logging.info(f'loss_st_data={loss_st_data}') loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_asr_data, loss_mt_data, loss_st_data, self.acc_asr, self.acc_mt, self.acc, cer_ctc, cer, wer, 0.0, # TODO(hirofumi0810): bleu loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder if self.decoder is not None: if self.decoder_mode == "maskctc": ys_in_pad, ys_out_pad = mask_uniform( ys_pad, self.mask_token, self.eos, self.ignore_id ) ys_mask = (ys_in_pad != self.ignore_id).unsqueeze(-2) else: ys_in_pad, ys_out_pad = add_sos_eos( ys_pad, self.sos, self.eos, self.ignore_id ) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy( pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id ) else: loss_att = None self.acc = None # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if not self.training and self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # for visualization if not self.training: self.ctc.softmax(hs_pad) # 5. compute cer/wer if self.training or self.error_calculator is None or self.decoder is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. Args: xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 1. encoder xs_pad = xs_pad[:, :max(ilens)] if "custom" in self.etype: src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) _hs_pad, hs_mask = self.encoder(xs_pad, src_mask) else: _hs_pad, hs_mask, _ = self.enc(xs_pad, ilens) if self.use_aux_task: hs_pad, aux_hs_pad = _hs_pad[0], _hs_pad[1] else: hs_pad, aux_hs_pad = _hs_pad, None # 1.5. transducer preparation related ys_in_pad, ys_out_pad, target, pred_len, target_len = prepare_loss_inputs( ys_pad, hs_mask) # 2. decoder if "custom" in self.dtype: ys_mask = target_mask(ys_in_pad, self.blank_id) pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad) else: pred_pad = self.dec(hs_pad, ys_in_pad) z = self.joint_network(hs_pad.unsqueeze(2), pred_pad.unsqueeze(1)) # 3. loss computation loss_trans = self.criterion(z, target, pred_len, target_len) if self.use_aux_task and aux_hs_pad is not None: loss_trans += self.auxiliary_task(aux_hs_pad, pred_pad, z, target, pred_len, target_len) if self.use_aux_ctc: if "custom" in self.etype: hs_mask = torch.IntTensor([h.size(1) for h in hs_mask], ).to( hs_mask.device) loss_ctc = self.aux_ctc(hs_pad, hs_mask, ys_pad) else: loss_ctc = 0 if self.use_aux_cross_entropy: loss_ce = self.aux_cross_entropy(self.aux_decoder_output(pred_pad), ys_out_pad) else: loss_ce = 0 loss = (self.transducer_weight * loss_trans + self.aux_ctc_weight * loss_ctc + self.aux_cross_entropy_weight * loss_ce) self.loss = loss loss_data = float(loss) # 4. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: cer, wer = self.error_calculator(hs_pad, ys_pad) if not math.isnan(loss_data): self.reporter.report(loss_data, cer, wer) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss