def forward(self, eouts, elens, ys, forced_align=False): """Compute CTC loss. Args: eouts (FloatTensor): `[B, T, enc_n_units]` elens (List): length `B` ys (List): length `B`, each of which contains a list of size `[L]` Returns: loss (FloatTensor): `[1]` trigger_points (IntTensor): `[B, L]` """ # Concatenate all elements in ys for warpctc_pytorch ylens = np2tensor(np.fromiter([len(y) for y in ys], dtype=np.int32)) ys_ctc = torch.cat([np2tensor(np.fromiter(y[::-1] if self.bwd else y, dtype=np.int32)) for y in ys], dim=0) # NOTE: do not copy to GPUs here # Compute CTC loss logits = self.output(eouts) loss = self.loss_fn(logits.transpose(1, 0), ys_ctc, elens, ylens) # Label smoothing for CTC if self.lsm_prob > 0: loss = loss * (1 - self.lsm_prob) + kldiv_lsm_ctc(logits, elens) * self.lsm_prob trigger_points = self.forced_align(logits, elens, ys, ylens) if forced_align else None if not self.training: self.data_dict['elens'] = tensor2np(elens) self.prob_dict['probs'] = tensor2np(torch.softmax(logits, dim=-1)) return loss, trigger_points
def forward(self, eouts, elens, ys): """Compute CTC loss. Args: eouts (FloatTensor): `[B, T, dec_n_units]` elens (list): A list of length `[B]` ys (list): A list of length `[B]`, which contains a list of size `[L]` Returns: loss (FloatTensor): `[B, L, vocab]` """ # Concatenate all elements in ys for warpctc_pytorch ylens = np2tensor(np.fromiter([len(y) for y in ys], dtype=np.int32)) ys_ctc = torch.cat( [np2tensor(np.fromiter(y, dtype=np.int32)) for y in ys], dim=0) # NOTE: do not copy to GPUs here # Compute CTC loss logits = self.output(eouts) loss = self.warpctc_loss( logits.transpose(1, 0).cpu(), # time-major ys_ctc, elens.cpu(), ylens) # NOTE: ctc loss has already been normalized by bs # NOTE: index 0 is reserved for blank in warpctc_pytorch if self.device_id >= 0: loss = loss.cuda(self.device_id) # Label smoothing for CTC if self.lsm_prob > 0: loss = loss * (1 - self.lsm_prob) + kldiv_lsm_ctc( logits, elens) * self.lsm_prob return loss
def forward(self, eouts, elens, ys, forced_align=False): """Compute CTC loss. Args: eouts (FloatTensor): `[B, T, dec_n_units]` elens (list): A list of length B ys (list): A list of length B, which contains a list of size `[L]` Returns: loss (FloatTensor): `[B, L, vocab]` """ # Concatenate all elements in ys for warpctc_pytorch ylens = np2tensor(np.fromiter([len(y) for y in ys], dtype=np.int32)) ys_ctc = torch.cat([ np2tensor(np.fromiter(y[::-1] if self.bwd else y, dtype=np.int32)) for y in ys ], dim=0) # NOTE: do not copy to GPUs here # Compute CTC loss logits = self.output(eouts) loss = self.warpctc_loss( logits.transpose(1, 0), # time-major ys_ctc, elens.cpu(), ylens) # NOTE: ctc loss has already been normalized by bs # NOTE: index 0 is reserved for blank in warpctc_pytorch if self.device_id >= 0: loss = loss.cuda(self.device_id) # Label smoothing for CTC if self.lsm_prob > 0: loss = loss * (1 - self.lsm_prob) + kldiv_lsm_ctc( logits, elens) * self.lsm_prob trigger_points = None if forced_align: ys = [ np2tensor(np.fromiter(y, dtype=np.int64), self.device_id) for y in ys ] ys_in_pad = pad_list(ys, 0) # pad by zero trigger_points = self.forced_aligner.align(logits.clone(), elens, ys_in_pad, ylens) return loss, trigger_points
def forward_ctc(self, eouts, elens, ys): """Compute CTC loss. Args: eouts (FloatTensor): `[B, T, d_model]` elens (list): A list of length `[B]` ys (list): A list of length `[B]`, which contains a list of size `[L]` Returns: loss (FloatTensor): `[1]` """ logits = self.output_ctc(eouts) # Compute the auxiliary CTC loss elens_ctc = np2tensor(np.fromiter(elens, dtype=np.int32), -1).int() ys_ctc = [ np2tensor(np.fromiter(y, dtype=np.int64)).long() for y in ys ] # always fwd ylens = np2tensor( np.fromiter([y.size(0) for y in ys_ctc], dtype=np.int32), -1).int() ys_ctc = torch.cat(ys_ctc, dim=0).int() # NOTE: Concatenate all elements in ys for warpctc_pytorch # NOTE: do not copy to GPUs here # Compute CTC loss loss = self.warpctc_loss( logits.transpose(1, 0).cpu(), # time-major ys_ctc, elens_ctc, ylens) # NOTE: ctc loss has already been normalized by bs # NOTE: index 0 is reserved for blank in warpctc_pytorch if self.device_id >= 0: loss = loss.cuda(self.device_id) # Label smoothing for CTC if self.lsm_prob > 0 and self.ctc_weight == 1: loss = loss * (1 - self.lsm_prob) + kldiv_lsm_ctc( logits, ylens=elens, lsm_prob=self.lsm_prob, size_average=True) * self.lsm_prob return loss