def forward(self, inputs, inputs_length, targets, targets_length, rles, rles_length): # inputs: N x L x c # targets: N x L enc_state, _ = self.encoder(inputs, inputs_length) # [N, L, o] forward_out = self.tanh(self.forward_layer(enc_state)) # [N,L,256] base_logits = self.base_project_layer(forward_out).log_softmax( dim=-1) # [N,L,6] rle_logits = self.rle_project_layer(forward_out).log_softmax( dim=-1) # [N,L,10] # print("base logits:", base_logits.transpose(0, 1).shape) # print("base target:", targets.shape) # print(inputs_length) # print(targets_length) # print("rle logits:", rle_logits.transpose(0, 1).shape) # print("rles target:", rles.shape) # print(inputs_length) # print(targets_length) base_loss = F.ctc_loss(base_logits.transpose(0, 1), targets.int(), inputs_length.int(), targets_length.int()) rle_loss = F.ctc_loss(rle_logits.transpose(0, 1), rles.int(), inputs_length.int(), rles_length.int()) # print(targets,base_logits.shape) # print(base_loss, rle_loss) return base_loss, rle_loss
def valid(model, optimizer, epoch, dataloader, characters): model.eval() with tqdm(dataloader) as pbar, torch.no_grad(): loss_sum = 0 acc_sum = 0 for batch_index, (data, target, input_lengths, target_lengths) in enumerate(pbar): data, target = data.cuda(), target.cuda() output = model(data) output_log_softmax = F.log_softmax(output, dim=-1) loss = F.ctc_loss(output_log_softmax, target, input_lengths, target_lengths) loss = loss.item() acc = calc_acc(target, output, characters) loss_sum += loss acc_sum += acc loss_mean = loss_sum / (batch_index + 1) acc_mean = acc_sum / (batch_index + 1) pbar.set_description( f'Test : {epoch} Loss: {loss_mean:.4f} Acc: {acc_mean:.4f} ')
def ctc_loss(acts, labels, act_lens, label_lens, num_symbols=None, context_order=1, normalize_by_dim=None, allow_nonblank_selfloops=True, loop_using_symbol_repetitions=False, eval_repeats_in_context=False, other_data_in_batch=None): for condition in [ context_order == 1, normalize_by_dim is None, not eval_repeats_in_context, allow_nonblank_selfloops, not loop_using_symbol_repetitions, not other_data_in_batch ]: assert condition, "Option not supported in this loss" # F.ctc_loss doesn't check these conditions and may segfault later on assert acts.size(0) == act_lens[0] assert labels.max() < acts.size(2) return F.ctc_loss(F.log_softmax(acts, dim=-1).contiguous(), labels.cuda(), act_lens, label_lens, reduction='mean', zero_infinity=False)
def sequence_ctc_loss_with_logits( logits: torch.FloatTensor, logit_mask: Union[torch.FloatTensor, torch.BoolTensor], targets: torch.LongTensor, target_mask: Union[torch.FloatTensor, torch.BoolTensor], blank_index: torch.LongTensor ) -> torch.FloatTensor: # lengths : (batch_size, ) # calculated by counting number of mask logit_lengths = (logit_mask.bool()).long().sum(1) target_lengths = (target_mask.bool()).long().sum(1) # log_logits : (T, batch_size, n_class), this kind of shape is required for ctc_loss #log_logits = logits + (logit_mask.unsqueeze(-1) + 1e-45).log() log_logits = logits.log_softmax(-1).transpose(0, 1) targets = targets.long() loss = F.ctc_loss(log_logits, targets, logit_lengths, target_lengths, blank=blank_index, reduction='mean') if (logit_lengths < target_lengths).sum() > 0: print("The length of predicted alignment is shoter than target length, increase upsample factor.") raise Exception return loss
def compute_loss( self, model_output: torch.Tensor, target: List[str], ) -> torch.Tensor: """Compute CTC loss for the model. Args: gt: the encoded tensor with gt labels model_output: predicted logits of the model seq_len: lengths of each gt word inside the batch Returns: The loss of the model on the batch """ gt, seq_len = self.build_target(target) batch_len = model_output.shape[0] input_length = model_output.shape[1] * torch.ones(size=(batch_len, ), dtype=torch.int32) # N x T x C -> T x N x C logits = model_output.permute(1, 0, 2) probs = F.log_softmax(logits, dim=-1) ctc_loss = F.ctc_loss( probs, torch.from_numpy(gt), input_length, torch.tensor(seq_len, dtype=torch.int), len(self.vocab), zero_infinity=True, ) return ctc_loss
def forward(self): a = torch.randn(3, 2) b = torch.rand(3, 2) c = torch.rand(3) log_probs = torch.randn(50, 16, 20).log_softmax(2).detach() targets = torch.randint(1, 20, (16, 30), dtype=torch.long) input_lengths = torch.full((16, ), 50, dtype=torch.long) target_lengths = torch.randint(10, 30, (16, ), dtype=torch.long) return len( F.binary_cross_entropy(torch.sigmoid(a), b), F.binary_cross_entropy_with_logits(torch.sigmoid(a), b), F.poisson_nll_loss(a, b), F.cosine_embedding_loss(a, b, c), F.cross_entropy(a, b), F.ctc_loss(log_probs, targets, input_lengths, target_lengths), # F.gaussian_nll_loss(a, b, torch.ones(5, 1)), # ENTER is not supported in mobile module F.hinge_embedding_loss(a, b), F.kl_div(a, b), F.l1_loss(a, b), F.mse_loss(a, b), F.margin_ranking_loss(c, c, c), F.multilabel_margin_loss(self.x, self.y), F.multilabel_soft_margin_loss(self.x, self.y), F.multi_margin_loss(self.x, torch.tensor([3])), F.nll_loss(a, torch.tensor([1, 0, 1])), F.huber_loss(a, b), F.smooth_l1_loss(a, b), F.soft_margin_loss(a, b), F.triplet_margin_loss(a, b, -b), # F.triplet_margin_with_distance_loss(a, b, -b), # can't take variable number of arguments )
def ctc_label_smoothing_loss(self, log_probs, targets, lengths, weights=None): T, N, C = log_probs.shape weights = weights or torch.cat([torch.tensor([0.4]), (0.1 / (C - 1)) * torch.ones(C - 1)]) log_probs_lengths = torch.full(size=(N, ), fill_value=T, dtype=torch.int64) loss = ctc_loss(log_probs.to(torch.float32), targets, log_probs_lengths, lengths, reduction='mean') label_smoothing_loss = -((log_probs * weights.to(log_probs.device)).mean()) return {'total_loss': loss + label_smoothing_loss, 'loss': loss, 'label_smooth_loss': label_smoothing_loss}
def forward(self, x, nx, y=None, ny=None): "(B, H, W) batches of mfccs, (B, W) batches of graphemes." with amp.autocast(enabled=self.cfg.mixed_precision): B, H, W = x.shape # first convolution, collapses H x = F.pad(x, (self.cfg.padding(), self.cfg.padding(), 0, 0)) x = F.relu(self.conv(x.unsqueeze(1))) # add empty channel dim x = torch.squeeze(x, 2).permute(0, 2, 1) # B, W, C # dense, gru x = F.relu(self.dense_a(x)) x = F.relu(self.dense_b(x)) x = F.relu(self.gru(x)[0]) # sum over last dimension, fwd and bwd x = torch.split(x, self.cfg.n_hidden, dim=-1) x = torch.sum(torch.stack(x, dim=-1), dim=-1) # head x = F.relu(self.dense_end(x)) x = F.log_softmax(x, dim=2) # loss loss = None if y is not None and ny is not None: nx = self.cfg.frame_lengths(nx) xctc = x.permute(1, 0, 2) # W, B, C loss = F.ctc_loss(xctc, y, nx, ny) return x, loss
def train(model, optimizer, epoch, dataloader, characters): model.train() loss_mean = 0 acc_mean = 0 with tqdm(dataloader) as pbar: for batch_index, (data, target, input_lengths, target_lengths) in enumerate(pbar): data, target = data.cuda(), target.cuda() optimizer.zero_grad() output = model(data) output_log_softmax = F.log_softmax(output, dim=-1) loss = F.ctc_loss(output_log_softmax, target, input_lengths, target_lengths) loss.backward() optimizer.step() loss = loss.item() acc = calc_acc(target, output, characters) if batch_index == 0: loss_mean = loss acc_mean = acc loss_mean = 0.1 * loss + 0.9 * loss_mean acc_mean = 0.1 * acc + 0.9 * acc_mean pbar.set_description( f'Epoch: {epoch} Loss: {loss_mean:.4f} Acc: {acc_mean:.4f} ')
def loss_score(self, batch, y_pred): assert len(y_pred.shape) == 3 x, y_target = batch y_true, x_t_batch, y_t_batch = y_target y_pred = F.log_softmax(y_pred, dim=2) err = F.ctc_loss(y_pred, y_true, x_t_batch, y_t_batch) return err
def _alignment_cost(log_probs, allowed_skips_beg, allowed_skips_end, force_forbid_blank): # log_probs is BS x WIN_LEN x NUM_PREDS bs, win_len, num_preds = log_probs.size() assert win_len >= num_preds padded_log_probs = F.pad(log_probs, (0, 0, allowed_skips_beg, allowed_skips_end), "constant", 0) padded_win_len = win_len + allowed_skips_beg + allowed_skips_end fake_ctc_labels = torch.arange(1, num_preds + 1, dtype=torch.int).expand(bs, num_preds) # append impossible BLANK probabilities ctc_log_probs = padded_log_probs.permute(1, 0, 2).contiguous() if force_forbid_blank: ctc_log_probs = torch.cat((torch.empty( padded_win_len, bs, 1, device=log_probs.device).fill_(-1000), ctc_log_probs), 2) # Now ctc_log_probs is win_size x BS x (num_preds + 1) assert ctc_log_probs.is_contiguous() # normalize the log-probs over num_preds # This is required, because ctc returns a bad gradient when given # unnormalized log probs log_sum_exps = torch.logsumexp(ctc_log_probs, 2, keepdim=True) ctc_log_probs = ctc_log_probs - log_sum_exps losses = F.ctc_loss(ctc_log_probs, fake_ctc_labels, torch.empty(bs, dtype=torch.int).fill_(padded_win_len), torch.empty(bs, dtype=torch.int).fill_(num_preds), reduction='none') losses = losses - log_sum_exps.squeeze(2).sum(0) return losses
def test_ctc_loss(self): torch.manual_seed(0) N = 16 # Batch size T = 50 # Input sequence length C = 20 # Number of classes (including blank) S = 30 # Target sequence length of longest target in batch (padding length) S_min = 10 # Minimum target length (only for testing) logits = torch.randn(N, T, C) targets = torch.randint(1, C, (N, S), dtype=torch.long) input_lengths = torch.full((N, ), T, dtype=torch.long) target_lengths = torch.randint(S_min, S, (N, ), dtype=torch.long) config = CTCLoss.Config() config.blank = 0 # Needs to be set to 0 for CuDNN support. ctc_loss_fn = CTCLoss(config=config) ctc_loss_val = ctc_loss_fn( logits, targets, input_lengths, target_lengths, ) # PyTorch CTC loss log_probs = logits.permute(1, 0, 2).log_softmax( 2 ) # permute to conform to CTC loss input tensor (T,N,C) in PyTorch. lib_ctc_loss_val = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) self.assertAlmostEqual(ctc_loss_val.item(), lib_ctc_loss_val.item())
def train(epoch): net.train() closs = [] for iter_idx, (img, transcr) in enumerate(train_loader): optimizer.zero_grad() img = Variable(img.to(device)) # geometrical and morphological deformations ''' rids = torch.BoolTensor(torch.bernoulli(.1 * torch.ones(img.size(0))).bool()) if sum(rids) > 1: u = 4 # 2 ** np.random.randint(2, 5) img[rids] = local_deform(img[rids], rm=None, dscale=u, a=np.random.uniform(2.0, 4.0)) rids = torch.BoolTensor(torch.bernoulli(.1 * torch.ones(img.size(0))).bool()) if sum(rids) > 1: u = 4 # 2 ** np.random.randint(2, 5) img[rids] = morphological(img[rids], rm=None, dscale=u) ''' output = net(img.detach()) #.permute(2, 0, 1) act_lens = torch.IntTensor(img.size(0)*[output.size(0)]) labels = torch.IntTensor([cdict[c] for c in ''.join(transcr)]) label_lens = torch.IntTensor([len(t) for t in transcr]) ls_output = F.log_softmax(output, dim=2).cpu() loss_val = F.ctc_loss(ls_output, labels, act_lens, label_lens, zero_infinity=True, reduction='sum') / img.size(0) closs += [loss_val.data] loss_val.backward() optimizer.step() # mean runing errors?? if iter_idx % args.display == args.display-1: logger.info('Epoch %d, Iteration %d: %f', epoch, iter_idx+1, sum(closs)/len(closs)) closs = [] net.eval() tst_img, tst_transcr = test_set.__getitem__(np.random.randint(test_set.__len__())) print('orig:: ' + tst_transcr) with torch.no_grad(): tst_o = net(Variable(tst_img.to(device)).unsqueeze(0)) #.permute(2, 0, 1) tdec = tst_o.argmax(2).permute(1, 0).cpu().numpy().squeeze() tt = [v for j, v in enumerate(tdec) if j == 0 or v != tdec[j - 1]] print('gdec:: ' + ''.join([icdict[t] for t in tt]).replace('_', '')) net.train()
def cal_ctc(pred, gold, pre_len, gold_len): pred_log_prob = F.log_softmax(pred, -1) loss = ctc_loss(pred_log_prob.transpose(0, 1), gold, pre_len, gold_len, blank=1, zero_infinity=True) return loss
def train_loop(model, optimizer, writer, step, args, hp): dataset_train = datasets.get_dataset(hp.train_script, hp, use_spec_aug=hp.use_spec_aug) if hp.encoder_type == 'Conformer': train_sampler = datasets.LengthsBatchSampler(dataset_train, hp.batch_size * 1500, hp.lengths_file) dataloader = DataLoader(dataset_train, batch_sampler=train_sampler, num_workers=2, collate_fn=datasets.collate_fn) else: train_sampler = DistributedSampler(dataset_train) if args.n_gpus > 1 else None dataloader = DataLoader(dataset_train, batch_size=hp.batch_size, shuffle=hp.shuffle, num_workers=2, sampler=train_sampler, collate_fn=datasets.collate_fn, drop_last=True) optimizer.zero_grad() for d in dataloader: step += 1 if hp.encoder_type == 'Conformer': lr = get_transformer_learning_rate(step) for param_group in optimizer.param_groups: param_group['lr'] = lr print(f'step = {step}') print(f'lr = {lr}') text, mel_input, pos_text, pos_mel, text_lengths, mel_lengths = d text = text.to(DEVICE, non_blocking=True) mel_input = mel_input.to(DEVICE, non_blocking=True) pos_text = pos_text.to(DEVICE, non_blocking=True) pos_mel = pos_mel.to(DEVICE, non_blocking=True) text_lengths = text_lengths.to(DEVICE, non_blocking=True) if hp.frame_stacking > 1 and hp.encoder_type != 'Wave': mel_input, mel_lengths = frame_stacking(mel_input, mel_lengths, hp.frame_stacking) predict_ts = model(mel_input, mel_lengths, text, pos_mel) if hp.decoder_type == 'Attention': loss = label_smoothing_loss(predict_ts, text, text_lengths, hp.T_norm, hp.B_norm) elif hp.decoder_type == 'CTC': predict_ts = F.log_softmax(predict_ts, dim=2).transpose(0, 1) loss = F.ctc_loss(predict_ts, text, mel_lengths, text_lengths, blank=0) # backward loss.backward() # optimizer update if step % hp.accum_grad == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), hp.clip) optimizer.step() optimizer.zero_grad() loss.detach() if torch.isnan(loss): print('loss is nan') sys.exit(1) if hp.debug_mode == 'tensorboard': writer.add_scalar("Loss/train", loss, step) print('loss = {}'.format(loss.item())) else: print('loss = {}'.format(loss.item())) sys.stdout.flush() return step
def ctc_fallback(encoder_outputs, labels, frame_lens, label_lens, blank): assert len(encoder_outputs) == len(labels) == len(frame_lens) == len(label_lens) skipped_indices = [] working_indices = [] for i in range(len(encoder_outputs)): if torch.isinf(F.ctc_loss(encoder_outputs[i:i+1].transpose(0, 1), labels[i:i+1], frame_lens[i:i+1], label_lens[i:i+1], blank=blank)): skipped_indices.append(i) else: working_indices.append(i) return skipped_indices, working_indices
def loss(output, output_len, targets, targets_len): output_trans = output.permute(1, 0, 2) # needed by the CTCLoss loss = F.ctc_loss(output_trans, targets, output_len, targets_len, reduction='none', zero_infinity=True) loss /= output_len loss = loss.mean() return loss
def ctc_loss(y_hat, y, y_hat_lens, target_lens): loss = 0.0 y_hat = y_hat if 'tuple' in str(type(y_hat)) else [y_hat] for y_hat_i in y_hat: #breakpoint() loss_i = F.ctc_loss(y_hat_i.cpu(), y.cpu(), y_hat_lens.cpu(), target_lens.cpu()) loss += loss_i return loss
def cal_ctc_loss(logits_ctc, len_logits_ctc, targets, target_lengths): n_class = logits_ctc.size(-1) ctc_log_probs = F.log_softmax(logits_ctc, dim=-1).transpose(0, 1) ctc_loss = F.ctc_loss(ctc_log_probs, targets, len_logits_ctc, target_lengths, reduction="none", blank=n_class - 1) return ctc_loss.sum()
def forward(self, inputs, inputs_length, targets, targets_length): if self.fir_enc_or_not: t_inputs, t_inputs_length = self.fir_enc(inputs, inputs_length) else: t_inputs, t_inputs_length = inputs, inputs_length enc_state, _ = self.encoder(t_inputs, t_inputs_length) enc_state = enc_state.transpose(0, 1).contiguous() loss = F.ctc_loss(enc_state, targets.int(), t_inputs_length.int(), targets_length.int()) return loss
def ctc_loss(preds, targets, voc_size): # prepare targets target_lengths = (targets != voc_size).long().sum(dim=-1) trimmed_targets = [t[:l] for t, l in zip(targets, target_lengths)] targets = torch.cat(trimmed_targets) x = F.log_softmax(preds, dim=-1) input_lengths = torch.full((x.size(1),), x.size(0), dtype=torch.long) return F.ctc_loss( x, targets, input_lengths, target_lengths, blank=voc_size, zero_infinity=True )
def aug_loss(output, target): S = 8 # 目标序列的长度 S_min = 5 alpha = 0.2 output = output.squeeze(dim=2) output = output.permute(2, 0, 1) # 对应为序列长度,批大小和类别数目 N = output.shape[1] # 这里是Batch size T = output.shape[0] # 输入序列的长度,这里为20,所有样本都是等长 log_probs = output.log_softmax(2).requires_grad_() # 求y~=y*%10 target_ = target % 10 input_lengths = torch.full(size=(N, ), fill_value=T, dtype=torch.long) target_lengths = torch.full(size=(N, ), fill_value=S_min, dtype=torch.long) return F.ctc_loss(log_probs, target, input_lengths, target_lengths, blank=20) + \ alpha * F.ctc_loss(log_probs, target_, input_lengths, target_lengths, blank=20)
def test_ctc_loss(self): # force fp32 because _th_normal_ (used by next line is not supported for fp16) log_probs = torch.randn( 50, 16, 20, device='cuda', dtype=torch.float32).log_softmax(2).detach().requires_grad_() targets = torch.randint(1, 20, (16, 30), device='cuda', dtype=torch.long) input_lengths = torch.full((16, ), 50, dtype=torch.long) target_lengths = torch.randint(10, 30, (16, ), dtype=torch.long) loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths)
def compute_loss( self, model, net_output, sample, reduction='mean', zero_infinity=False, ): log_probs = model.get_normalized_probs(net_output, log_probs=True) targets = torch.cat(sample['target']).cpu() # Expected targets to have CPU Backend target_lengths = sample['target_length'] input_lengths = torch.full((sample['nsentences'],), log_probs.size(0), dtype=torch.int32) loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=self.blank_idx, reduction=reduction, zero_infinity=zero_infinity) return loss
def cal_loss(logits, len_logits, gold, smoothing=0.0): """Calculate cross entropy loss, apply label smoothing if needed. """ n_class = logits.size(-1) target_lengths = gold.ne(0).sum(dim=1).int() ctc_log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1) ctc_loss = F.ctc_loss(ctc_log_probs, gold, len_logits, target_lengths, blank=n_class - 1) return ctc_loss
def ctc_label_smoothing_loss(log_probs, targets, input_lengths, target_lengths, weights): loss = ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='mean') label_smoothing_loss = -((log_probs * weights.to(log_probs.device)).mean()) return { 'loss': loss + label_smoothing_loss, 'ctc_loss': loss, 'label_smooth_loss': label_smoothing_loss }
def compute_ctc_loss(self, x, y, reduction="mean"): """ Args: x: log_probs, (t d) y: labels, (t') Return: loss """ xl = torch.tensor(list(map(len, x))) yl = torch.tensor(list(map(len, y))) x = pad_sequence(x, False) # -> (t b c) y = pad_sequence(y, True) # -> (b s) return F.ctc_loss(x, y, xl, yl, self.blank, reduction, True)
def forward(self, logits, input_lengths, targets, target_lengths): # lengths : (batch_size, ) # log_logits : (T, batch_size, n_class), this kind of shape is required for ctc_loss # log_logits = logits + (logit_mask.unsqueeze(-1) + 1e-45).log() log_logits = logits.log_softmax(-1).transpose(0, 1) loss = F.ctc_loss(log_logits, targets, input_lengths, target_lengths, blank=self.blank_index, reduction='mean', zero_infinity=True) return loss
def forward( self, log_p_attn: torch.Tensor, ilens: torch.Tensor, olens: torch.Tensor, blank_prob: float = np.e**-1, ) -> torch.Tensor: """Calculate forward propagation. Args: log_p_attn (Tensor): Batch of log probability of attention matrix (B, T_feats, T_text). ilens (Tensor): Batch of the lengths of each input (B,). olens (Tensor): Batch of the lengths of each target (B,). blank_prob (float): Blank symbol probability. Returns: Tensor: forwardsum loss value. """ B = log_p_attn.size(0) # add beta-binomial prior bb_prior = self._generate_prior(ilens, olens) bb_prior = bb_prior.to(dtype=log_p_attn.dtype, device=log_p_attn.device) log_p_attn = log_p_attn + bb_prior # a row must be added to the attention matrix to account for # blank token of CTC loss # (B,T_feats,T_text+1) log_p_attn_pd = F.pad(log_p_attn, (1, 0, 0, 0, 0, 0), value=np.log(blank_prob)) loss = 0 for bidx in range(B): # construct target sequnece. # Every text token is mapped to a unique sequnece number. target_seq = torch.arange(1, ilens[bidx] + 1).unsqueeze(0) cur_log_p_attn_pd = log_p_attn_pd[ bidx, : olens[bidx], : ilens[bidx] + 1 ].unsqueeze( 1 ) # (T_feats,1,T_text+1) loss += F.ctc_loss( log_probs=cur_log_p_attn_pd, targets=target_seq, input_lengths=olens[bidx : bidx + 1], target_lengths=ilens[bidx : bidx + 1], zero_infinity=True, ) loss = loss / B return loss
def ctc_label_smoothing_loss(log_probs, targets, lengths, weights): T, N, C = log_probs.shape log_probs_lengths = torch.full(size=(N, ), fill_value=T, dtype=torch.int64) loss = ctc_loss(log_probs.to(torch.float32), targets, log_probs_lengths, lengths, reduction='mean') label_smoothing_loss = -((log_probs * weights.to(log_probs.device)).mean()) return { 'loss': loss + label_smoothing_loss, 'ctc_loss': loss, 'label_smooth_loss': label_smoothing_loss }