def forward(self, inputs, targets, inputs_len, targets_len): ''' Args: inputs(acoustic feature): [N,T,D] targets(phoneme sequence): [N,T] inputs_len: [N] targets_len: [N] Return: outputs(predicted logits): [N,T,E] ''' enc_state, inputs_len = self.encoder(inputs, inputs_len) dec_state, _ = self.decoder(F.pad(targets, pad=[1,0,0,0], value=self.blank_idx)) dec_state = dec_state.unsqueeze(1) enc_state = enc_state.unsqueeze(2) t = enc_state.size(1) u = dec_state.size(2) dec_state = dec_state.repeat([1,t,1,1]) enc_state = enc_state.repeat([1,1,u,1]) concat_state = torch.cat([enc_state, dec_state], dim=-1) logits = self.out(self.tanh(self.joint(concat_state))) logits = F.log_softmax(logits, dim=-1) loss = rnnt_loss(logits, targets.int(), inputs_len.int(), targets_len.int(), blank=self.blank_idx) return loss.mean()
def forward_transducer(self, eouts, elens, ys): """Compute Transducer loss. Args: eouts (FloatTensor): `[B, T, enc_n_units]` elens (IntTensor): `[B]` ys (list): length `B`, each of which contains a list of size `[L]` Returns: loss (FloatTensor): `[1]` """ # Append <sos> and <eos> _ys = [ np2tensor(np.fromiter(y, dtype=np.int64), eouts.device) for y in ys ] ylens = np2tensor(np.fromiter([y.size(0) for y in _ys], dtype=np.int32)) eos = eouts.new_zeros((1, ), dtype=torch.int64).fill_(self.eos) ys_in = pad_list([torch.cat([eos, y], dim=0) for y in _ys], self.pad) # `[B, L+1]` ys_out = pad_list(_ys, self.blank) # `[B, L]` # Update prediction network ys_emb = self.dropout_emb(self.embed(ys_in)) dout, _ = self.recurrency(ys_emb, None) # Compute output distribution logits = self.joint(eouts, dout) # `[B, T, L+1, vocab]` # Compute Transducer loss log_probs = torch.log_softmax(logits, dim=-1) assert log_probs.size(2) == ys_out.size(1) + 1 if self.device_id >= 0: ys_out = ys_out.to(eouts.device) elens = elens.to(eouts.device) ylens = ylens.to(eouts.device) import warp_rnnt loss = warp_rnnt.rnnt_loss(log_probs, ys_out.int(), elens, ylens, average_frames=False, reduction='mean', gather=False) else: import warprnnt_pytorch self.warprnnt_loss = warprnnt_pytorch.RNNTLoss() loss = self.warprnnt_loss(log_probs, ys_out.int(), elens, ylens) # NOTE: Transducer loss has already been normalized by bs # NOTE: index 0 is reserved for blank in warprnnt_pytorch return loss
def cal_transducer_loss(self, model_output, target, frame_length, target_length, type='rnnt'): log_prob = t.nn.functional.log_softmax(model_output, -1) rnn_t_loss = rnnt_loss(log_probs=log_prob, labels=target.int(), frames_lengths=frame_length.int(), labels_lengths=target_length.int(), reduction='mean') return rnn_t_loss
for xs, ys, xn, yn in progress: optimizer.zero_grad() xs = xs.cuda(non_blocking=True) ys = ys.cuda(non_blocking=True) xn = xn.cuda(non_blocking=True) yn = yn.cuda(non_blocking=True) zs, xs, xn = model(xs, ys, xn, yn) ys = ys.t().contiguous() loss = rnnt_loss(zs, ys, xn, yn, average_frames=False, reduction="mean") loss.backward() grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 100) optimizer.step() err.update(loss.item()) grd.update(grad_norm) progress.set_description('epoch %d %s %s' % (epoch + 1, err, grd)) model.eval()
rewards = relu(SymAcc).reshape(K, -1).cuda() rewards_mean = rewards.mean().item() rewards -= rewards.mean(dim=0) elu(rewards, alpha=gamma, inplace=True) hs_k = hs_k.reshape(K, len(xs), -1) hn_k = hn_k.reshape(K, len(xs)) model.train() zs, xs, xn = model(xs, ys.t(), xn, yn) loss1 = rnnt_loss(zs, ys, xn, yn).mean() loss2 = -(zs.exp() * zs).sum(dim=-1).mean() for k in range(K): ys = hs_k[k] yn = hn_k[k] ys = ys[:, :yn.max()].contiguous() zs = model.forward_language(ys.t(), yn) zs = model.forward_joint(xs, zs) nll = rnnt_loss(zs, ys, xn, yn)
def forward( self, eouts, elens, eouts_inter=None, ys=None, ylens=None, ys_in=None, ys_out=None, soft_labels=None, ps=None, plens=None, ): loss = 0 loss_dict = {} # Prediction network douts, _ = self.recurrency(ys_in, dstate=None) # Joint network logits = self.joint(eouts, douts) # (B, T, L + 1, vocab) log_probs = torch.log_softmax(logits, dim=-1) assert log_probs.size(2) == ys.size(1) + 1 # NOTE: rnnt_loss only accepts ys, elens, ylens with torch.int loss_rnnt = warp_rnnt.rnnt_loss( log_probs, ys.int(), elens.int(), ylens.int(), average_frames=False, reduction="mean", blank=self.blank_id, gather=False, ) loss += loss_rnnt # main loss loss_dict["loss_rnnt"] = loss_rnnt if self.mtl_ctc_weight > 0: # NOTE: KD is not applied to auxiliary CTC loss_ctc, _, _ = self.ctc(eouts=eouts, elens=elens, ys=ys, ylens=ylens, soft_labels=None) loss += self.mtl_ctc_weight * loss_ctc # auxiliary loss loss_dict["loss_ctc"] = loss_ctc if self.kd_weight > 0 and soft_labels is not None: if self.kd_type == "word": loss_kd = self.transducer_kd_loss(logits, soft_labels, elens, ylens) elif self.kd_type == "align": aligns = self.forced_aligner(log_probs, elens, ys, ylens) loss_kd = self.transducer_kd_loss(logits, ys, soft_labels, aligns, elens, ylens) loss_dict["loss_kd"] = loss_kd if self.reduce_main_loss_kd: loss = (1 - self.kd_weight) * loss + self.kd_weight * loss_kd else: loss += self.kd_weight * loss_kd loss_dict["loss_total"] = loss return loss, loss_dict, logits
def forward_rnnt(self, eouts, elens, ys): """Compute XE loss for the attention-based sequence-to-sequence model. Args: eouts (FloatTensor): `[B, T, dec_n_units]` elens (IntTensor): `[B]` ys (list): A list of length `[B]`, which contains a list of size `[L]` Returns: loss (FloatTensor): `[1]` """ # Append <sos> and <eos> eos = eouts.new_zeros(1).fill_(self.eos).long() if self.end_pointing: _ys = [ np2tensor(np.fromiter(y + [self.eos], dtype=np.int64), self.device_id) for y in ys ] else: _ys = [ np2tensor(np.fromiter(y, dtype=np.int64), self.device_id) for y in ys ] ylens = np2tensor(np.fromiter([y.size(0) for y in _ys], dtype=np.int32)) ys_in = pad_list([torch.cat([eos, y], dim=0) for y in _ys], self.pad) ys_out = pad_list(_ys, self.blank) # Update prediction network ys_emb = self.dropout_emb(self.embed(ys_in)) dout, _ = self.recurrency(ys_emb, None) # Compute output distribution logits = self.joint(eouts, dout) # Compute Transducer loss log_probs = torch.log_softmax(logits, dim=-1) if self.device_id >= 0: ys_out = ys_out.cuda(self.device_id) elens = elens.cuda(self.device_id) ylens = ylens.cuda(self.device_id) assert log_probs.size(2) == ys_out.size(1) + 1 # loss = self.warprnnt_loss(log_probs, ys_out.int(), elens, ylens) # NOTE: Transducer loss has already been normalized by bs # NOTE: index 0 is reserved for blank in warprnnt_pytorch import warp_rnnt loss = warp_rnnt.rnnt_loss(log_probs, ys_out.int(), elens, ylens, average_frames=False, reduction='mean', gather=False) # Label smoothing for Transducer # if self.lsm_prob > 0: # loss = loss * (1 - self.lsm_prob) + kldiv_lsm_ctc(logits, # ylens=elens, # size_average=True) * self.lsm_prob # TODO(hirofumi): this leads to out of memory return loss