def forward(self, xs_pad, ilens, ys_pad): """E2E forward. Args: xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 1. encoder if self.etype == 'transformer': xs_pad = xs_pad[:, :max(ilens)] src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) else: hs_pad, hlens = xs_pad, ilens hs_pad, hlens, _ = self.encoder(hs_pad, hlens) hs_mask = hlens self.hs_pad = hs_pad # 1.5. transducer preparation related ys_in_pad, target, pred_len, target_len = prepare_loss_inputs( ys_pad, hs_mask) # 2. decoder if self.dtype == 'transformer': ys_mask = target_mask(ys_in_pad, self.blank_id) pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad) else: if self.rnnt_mode == 'rnnt': pred_pad = self.decoder(hs_pad, ys_in_pad) else: pred_pad = self.decoder(hs_pad, ys_in_pad, pred_len) self.pred_pad = pred_pad # 3. loss computation loss = self.criterion(pred_pad, target, pred_len, target_len) self.loss = loss loss_data = float(self.loss) # 4. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: cer, wer = self.error_calculator(hs_pad, ys_pad) if not math.isnan(loss_data): self.reporter.report(loss_data, cer, wer) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. Args: xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 1. encoder xs_pad = xs_pad[:, :max(ilens)] if "custom" in self.etype: src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) else: hs_pad, hs_mask, _ = self.enc(xs_pad, ilens) # 1.5. transducer preparation related ys_in_pad, target, pred_len, target_len = prepare_loss_inputs( ys_pad, hs_mask) # 2. decoder if "custom" in self.dtype: ys_mask = target_mask(ys_in_pad, self.blank_id) pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad) else: pred_pad = self.dec(hs_pad, ys_in_pad) z = self.joint_network(hs_pad.unsqueeze(2), pred_pad.unsqueeze(1)) # 3. loss computation loss = self.criterion(z, target, pred_len, target_len) self.loss = loss loss_data = float(loss) # 4. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: cer, wer = self.error_calculator(hs_pad, ys_pad) if not math.isnan(loss_data): self.reporter.report(loss_data, cer, wer) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def test_sa_transducer_mask(module): from espnet.nets.pytorch_backend.nets_utils import make_pad_mask from espnet.nets.pytorch_backend.transducer.utils import prepare_loss_inputs from espnet.nets.pytorch_backend.transformer.mask import target_mask train_args = make_train_args() model, x, ilens, y, data = prepare(module, train_args) # dummy mask x_mask = (~make_pad_mask(ilens.tolist())).to(x.device).unsqueeze(-2) _, target, _, _ = prepare_loss_inputs(y, x_mask) y_mask = target_mask(target, model.blank_id) y = model.decoder.embed(target.type(torch.long)) y[0, 3:] = float("nan") a = model.decoder.decoders[0].self_attn a(y, y, y, y_mask) assert not numpy.isnan(a.attn[0, :, :3, :3].detach().numpy()).any()
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. Args: xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 1. encoder # if etpye is transformer, deal the padding # xs_pad:[8, 393, 83] # ilens:[393 * 8] # src_mask:[8, 1, 393] # hs_mask:[8, 1, 65] if self.etype == "transformer": xs_pad = xs_pad[:, :max(ilens)] src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) else: logging.info("enc!!!") hs_pad, hs_mask, _ = self.encoder(xs_pad, ilens) self.hs_pad = hs_pad # 1.5. transducer preparation related # ys_in_pad: sos,1,2,...,0 [8, 14] # target: 1,2,... [8, 13] # pred_len: [8] # target_len: [8] # ys_out_pad:1,2,...,eos,-1 ys = [y[y != self.ignore_id] for y in ys_pad] eos = ys[0].new([self.eos]) sos = ys[0].new([self.sos]) ys_in = [torch.cat([sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, eos], dim=0) for y in ys] ys_out_pad = pad_list(ys_out, self.ignore_id) ys_in_pad, target, pred_len, target_len = prepare_loss_inputs( ys_pad, hs_mask) # 2. decoder # ys_mask:[8, 16, 16] if self.dtype == "transformer": ys_mask = target_mask(ys_in_pad, self.blank_id) pred_pad, pred_att, _ = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) else: if self.rnnt_mode == "rnnt": pred_pad = self.dec(hs_pad, ys_in_pad) else: pred_pad = self.dec(hs_pad, ys_in_pad, pred_len) self.pred_pad = pred_pad # 3. loss computation loss_att = F.cross_entropy( pred_att, ys_out_pad.view(-1), # batch x olength ignore_index=self.ignore_id, ) # compute perplexity # ppl = math.exp(loss_att.item()) # -1: eos, which is removed in the loss computation loss_att *= np.mean([len(x) for x in ys_in]) - 1 loss_rnnt = self.criterion(pred_pad, target, pred_len, target_len) # loss_ctc = self.ctc(hs_pad, pred_len, ys_pad) alpha = self.mtlalpha beta = self.mtlbeta gamma = self.mtlgamma self.loss_rnnt = loss_rnnt self.loss_att = loss_att # self.loss_ctc = loss_ctc # self.loss = alpha * self.loss_ctc + beta * self.loss_rnnt + gamma * self.loss_att self.loss = beta * self.loss_rnnt + gamma * self.loss_att # self.loss = alpha * self.loss_ctc loss_data = float(self.loss) # loss_ctc_data = float(self.loss_ctc) loss_att_data = float(self.loss_att) loss_rnnt_data = float(self.loss_rnnt) # loss_att_data = None # loss_rnnt_data = None # 4. compute cer/wer if self.training or self.error_calculator is None: logging.info("ALL none!!!!!") cer, wer = None, None else: cer, wer = self.error_calculator(hs_pad, ys_pad) # with open('/home/oshindo/espnet/egs/aishell/asr1/exp/train_sp_pytorch_e2e_asr_transducer/blstmp_ctc.txt', "a+") as fid: # fid.write("loss:" + str(loss_ctc_data) + '\n') if not math.isnan(loss_data): self.reporter.report(loss_data, loss_rnnt_data, loss_att_data, cer, wer) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. Args: xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 1. encoder xs_pad = xs_pad[:, :max(ilens)] if "custom" in self.etype: src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) _hs_pad, hs_mask = self.encoder(xs_pad, src_mask) else: _hs_pad, hs_mask, _ = self.enc(xs_pad, ilens) if self.use_aux_task: hs_pad, aux_hs_pad = _hs_pad[0], _hs_pad[1] else: hs_pad, aux_hs_pad = _hs_pad, None # 1.5. transducer preparation related ys_in_pad, ys_out_pad, target, pred_len, target_len = prepare_loss_inputs( ys_pad, hs_mask) # 2. decoder if "custom" in self.dtype: ys_mask = target_mask(ys_in_pad, self.blank_id) pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad) else: pred_pad = self.dec(hs_pad, ys_in_pad) z = self.joint_network(hs_pad.unsqueeze(2), pred_pad.unsqueeze(1)) # 3. loss computation loss_trans = self.criterion(z, target, pred_len, target_len) if self.use_aux_task and aux_hs_pad is not None: loss_trans += self.auxiliary_task(aux_hs_pad, pred_pad, z, target, pred_len, target_len) if self.use_aux_ctc: if "custom" in self.etype: hs_mask = torch.IntTensor([h.size(1) for h in hs_mask], ).to( hs_mask.device) loss_ctc = self.aux_ctc(hs_pad, hs_mask, ys_pad) else: loss_ctc = 0 if self.use_aux_cross_entropy: loss_ce = self.aux_cross_entropy(self.aux_decoder_output(pred_pad), ys_out_pad) else: loss_ce = 0 loss = (self.transducer_weight * loss_trans + self.aux_ctc_weight * loss_ctc + self.aux_cross_entropy_weight * loss_ce) self.loss = loss loss_data = float(loss) # 4. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: cer, wer = self.error_calculator(hs_pad, ys_pad) if not math.isnan(loss_data): self.reporter.report(loss_data, cer, wer) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. Args: xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 1. encoder xs_pad = xs_pad[:, :max(ilens)] if "transformer" in self.etype: src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) batchsize = xs_pad.size(0) inputs = xs_pad.unsqueeze(1) logging.info("inputs:{}".format(inputs.shape)) logging.info("src_mask:{}".format(src_mask.shape)) inputs_length = [] if src_mask is not None: for mask in src_mask.tolist(): inputs_length.append(mask[0].count(True)) for i in range(batchsize): inputs_s = inputs[i].unsqueeze(0)[:, :, 0:inputs_length[i], :] core_out = self.conv(inputs_s) inputs_length[i] = core_out.size(2) inputs_length = torch.as_tensor(inputs_length) else: core_out = self.conv(inputs) inputs_length = core_out.size(2) inputs_length = torch.as_tensor(inputs_length) logging.info("inputs_length:{}".format(inputs_length)) # block 1 # the inputs shape of Conv2d is 4-dim of (bsz * c * l * w) # the inputs shape of Conv1d is 3-dim of (bsz * c * l) # the inputs shape of transformer is 3-dim of (l * bsz * c) # conv output format: (bsz * c * t * d) inputs = self.conv(inputs) # we can get a batch of 16 channels feature maps in all time steps # merge 16 channels of one timestep to create one self-attention input (batch, 16, dim) inputs = inputs.permute(2, 0, 1, 3) logging.info("inputs:{}".format(inputs.shape)) merge = torch.zeros(inputs.size(0), batchsize, 512) for t in range(inputs.size(0)): # max_length merge[t] = self.clayers(inputs[t], None)[0].reshape(batchsize, 512) xs = merge.permute(1, 0, 2) if inputs_length.dim() == 0: masks = make_non_pad_mask([inputs_length]).unsqueeze(-2) else: masks = make_non_pad_mask(inputs_length.tolist()).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs, masks) else: hs_pad, hs_mask, _ = self.enc(xs_pad, ilens) self.hs_pad = hs_pad # 1.5. transducer preparation related ys_in_pad, target, pred_len, target_len = prepare_loss_inputs( ys_pad, hs_mask) # 2. decoder if "transformer" in self.dtype: ys_mask = target_mask(ys_in_pad, self.blank_id) pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad) else: if self.rnnt_mode == "rnnt": pred_pad = self.dec(hs_pad, ys_in_pad) else: pred_pad = self.dec(hs_pad, ys_in_pad, pred_len) self.pred_pad = pred_pad # 3. loss computation loss = self.criterion(pred_pad, target, pred_len, target_len) self.loss = loss loss_data = float(self.loss) # 4. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: cer, wer = self.error_calculator(hs_pad, ys_pad) if not math.isnan(loss_data): self.reporter.report(loss_data, cer, wer) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss