def forward(self, inputs, inputs_length, targets, targets_length, rles,
                rles_length):
        # inputs: N x L x c
        # targets: N x L
        enc_state, _ = self.encoder(inputs, inputs_length)  # [N, L, o]
        forward_out = self.tanh(self.forward_layer(enc_state))  # [N,L,256]
        base_logits = self.base_project_layer(forward_out).log_softmax(
            dim=-1)  # [N,L,6]
        rle_logits = self.rle_project_layer(forward_out).log_softmax(
            dim=-1)  # [N,L,10]

        # print("base logits:", base_logits.transpose(0, 1).shape)
        # print("base target:", targets.shape)
        # print(inputs_length)
        # print(targets_length)
        # print("rle logits:", rle_logits.transpose(0, 1).shape)
        # print("rles target:", rles.shape)
        # print(inputs_length)
        # print(targets_length)
        base_loss = F.ctc_loss(base_logits.transpose(0, 1), targets.int(),
                               inputs_length.int(), targets_length.int())
        rle_loss = F.ctc_loss(rle_logits.transpose(0, 1), rles.int(),
                              inputs_length.int(), rles_length.int())
        # print(targets,base_logits.shape)
        # print(base_loss, rle_loss)

        return base_loss, rle_loss
Beispiel #2
0
def valid(model, optimizer, epoch, dataloader, characters):
    model.eval()
    with tqdm(dataloader) as pbar, torch.no_grad():
        loss_sum = 0
        acc_sum = 0
        for batch_index, (data, target, input_lengths,
                          target_lengths) in enumerate(pbar):
            data, target = data.cuda(), target.cuda()

            output = model(data)
            output_log_softmax = F.log_softmax(output, dim=-1)
            loss = F.ctc_loss(output_log_softmax, target, input_lengths,
                              target_lengths)

            loss = loss.item()
            acc = calc_acc(target, output, characters)

            loss_sum += loss
            acc_sum += acc

            loss_mean = loss_sum / (batch_index + 1)
            acc_mean = acc_sum / (batch_index + 1)

            pbar.set_description(
                f'Test : {epoch} Loss: {loss_mean:.4f} Acc: {acc_mean:.4f} ')
Beispiel #3
0
def ctc_loss(acts,
             labels,
             act_lens,
             label_lens,
             num_symbols=None,
             context_order=1,
             normalize_by_dim=None,
             allow_nonblank_selfloops=True,
             loop_using_symbol_repetitions=False,
             eval_repeats_in_context=False,
             other_data_in_batch=None):
    for condition in [
            context_order == 1, normalize_by_dim is None,
            not eval_repeats_in_context, allow_nonblank_selfloops,
            not loop_using_symbol_repetitions, not other_data_in_batch
    ]:
        assert condition, "Option not supported in this loss"
    # F.ctc_loss doesn't check these conditions and may segfault later on
    assert acts.size(0) == act_lens[0]
    assert labels.max() < acts.size(2)
    return F.ctc_loss(F.log_softmax(acts, dim=-1).contiguous(),
                      labels.cuda(),
                      act_lens,
                      label_lens,
                      reduction='mean',
                      zero_infinity=False)
Beispiel #4
0
def sequence_ctc_loss_with_logits(
    logits: torch.FloatTensor,
    logit_mask: Union[torch.FloatTensor, torch.BoolTensor],
    targets: torch.LongTensor,
    target_mask: Union[torch.FloatTensor, torch.BoolTensor],
    blank_index: torch.LongTensor
) -> torch.FloatTensor:

    # lengths : (batch_size, )
    # calculated by counting number of mask
    logit_lengths = (logit_mask.bool()).long().sum(1)
    target_lengths = (target_mask.bool()).long().sum(1)

    # log_logits : (T, batch_size, n_class), this kind of shape is required for ctc_loss
    #log_logits = logits + (logit_mask.unsqueeze(-1) + 1e-45).log()
    log_logits = logits.log_softmax(-1).transpose(0, 1)
    targets = targets.long()

    loss = F.ctc_loss(log_logits, 
                      targets, 
                      logit_lengths, 
                      target_lengths,
                      blank=blank_index,
                      reduction='mean')
    
    if (logit_lengths < target_lengths).sum() > 0:
        print("The length of predicted alignment is shoter than target length, increase upsample factor.")
        raise Exception

    return loss
Beispiel #5
0
    def compute_loss(
        self,
        model_output: torch.Tensor,
        target: List[str],
    ) -> torch.Tensor:
        """Compute CTC loss for the model.

        Args:
            gt: the encoded tensor with gt labels
            model_output: predicted logits of the model
            seq_len: lengths of each gt word inside the batch

        Returns:
            The loss of the model on the batch
        """
        gt, seq_len = self.build_target(target)
        batch_len = model_output.shape[0]
        input_length = model_output.shape[1] * torch.ones(size=(batch_len, ),
                                                          dtype=torch.int32)
        # N x T x C -> T x N x C
        logits = model_output.permute(1, 0, 2)
        probs = F.log_softmax(logits, dim=-1)
        ctc_loss = F.ctc_loss(
            probs,
            torch.from_numpy(gt),
            input_length,
            torch.tensor(seq_len, dtype=torch.int),
            len(self.vocab),
            zero_infinity=True,
        )

        return ctc_loss
Beispiel #6
0
 def forward(self):
     a = torch.randn(3, 2)
     b = torch.rand(3, 2)
     c = torch.rand(3)
     log_probs = torch.randn(50, 16, 20).log_softmax(2).detach()
     targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
     input_lengths = torch.full((16, ), 50, dtype=torch.long)
     target_lengths = torch.randint(10, 30, (16, ), dtype=torch.long)
     return len(
         F.binary_cross_entropy(torch.sigmoid(a), b),
         F.binary_cross_entropy_with_logits(torch.sigmoid(a), b),
         F.poisson_nll_loss(a, b),
         F.cosine_embedding_loss(a, b, c),
         F.cross_entropy(a, b),
         F.ctc_loss(log_probs, targets, input_lengths, target_lengths),
         # F.gaussian_nll_loss(a, b, torch.ones(5, 1)), # ENTER is not supported in mobile module
         F.hinge_embedding_loss(a, b),
         F.kl_div(a, b),
         F.l1_loss(a, b),
         F.mse_loss(a, b),
         F.margin_ranking_loss(c, c, c),
         F.multilabel_margin_loss(self.x, self.y),
         F.multilabel_soft_margin_loss(self.x, self.y),
         F.multi_margin_loss(self.x, torch.tensor([3])),
         F.nll_loss(a, torch.tensor([1, 0, 1])),
         F.huber_loss(a, b),
         F.smooth_l1_loss(a, b),
         F.soft_margin_loss(a, b),
         F.triplet_margin_loss(a, b, -b),
         # F.triplet_margin_with_distance_loss(a, b, -b), # can't take variable number of arguments
     )
Beispiel #7
0
 def ctc_label_smoothing_loss(self, log_probs, targets, lengths, weights=None):
     T, N, C = log_probs.shape
     weights = weights or torch.cat([torch.tensor([0.4]), (0.1 / (C - 1)) * torch.ones(C - 1)])
     log_probs_lengths = torch.full(size=(N, ), fill_value=T, dtype=torch.int64)
     loss = ctc_loss(log_probs.to(torch.float32), targets, log_probs_lengths, lengths, reduction='mean')
     label_smoothing_loss = -((log_probs * weights.to(log_probs.device)).mean())
     return {'total_loss': loss + label_smoothing_loss, 'loss': loss, 'label_smooth_loss': label_smoothing_loss}
Beispiel #8
0
    def forward(self, x, nx, y=None, ny=None):
        "(B, H, W) batches of mfccs, (B, W) batches of graphemes."

        with amp.autocast(enabled=self.cfg.mixed_precision):
            B, H, W = x.shape

            # first convolution, collapses H
            x = F.pad(x, (self.cfg.padding(), self.cfg.padding(), 0, 0))
            x = F.relu(self.conv(x.unsqueeze(1)))  # add empty channel dim
            x = torch.squeeze(x, 2).permute(0, 2, 1)  # B, W, C

            # dense, gru
            x = F.relu(self.dense_a(x))
            x = F.relu(self.dense_b(x))
            x = F.relu(self.gru(x)[0])

            # sum over last dimension, fwd and bwd
            x = torch.split(x, self.cfg.n_hidden, dim=-1)
            x = torch.sum(torch.stack(x, dim=-1), dim=-1)

            # head
            x = F.relu(self.dense_end(x))
            x = F.log_softmax(x, dim=2)

            # loss
            loss = None
            if y is not None and ny is not None:
                nx = self.cfg.frame_lengths(nx)
                xctc = x.permute(1, 0, 2)  # W, B, C
                loss = F.ctc_loss(xctc, y, nx, ny)

            return x, loss
Beispiel #9
0
def train(model, optimizer, epoch, dataloader, characters):
    model.train()
    loss_mean = 0
    acc_mean = 0
    with tqdm(dataloader) as pbar:
        for batch_index, (data, target, input_lengths,
                          target_lengths) in enumerate(pbar):
            data, target = data.cuda(), target.cuda()

            optimizer.zero_grad()
            output = model(data)

            output_log_softmax = F.log_softmax(output, dim=-1)
            loss = F.ctc_loss(output_log_softmax, target, input_lengths,
                              target_lengths)

            loss.backward()
            optimizer.step()

            loss = loss.item()
            acc = calc_acc(target, output, characters)

            if batch_index == 0:
                loss_mean = loss
                acc_mean = acc

            loss_mean = 0.1 * loss + 0.9 * loss_mean
            acc_mean = 0.1 * acc + 0.9 * acc_mean

            pbar.set_description(
                f'Epoch: {epoch} Loss: {loss_mean:.4f} Acc: {acc_mean:.4f} ')
Beispiel #10
0
 def loss_score(self, batch, y_pred):
     assert len(y_pred.shape) == 3
     x, y_target = batch
     y_true, x_t_batch, y_t_batch = y_target
     y_pred = F.log_softmax(y_pred, dim=2)
     err = F.ctc_loss(y_pred, y_true, x_t_batch, y_t_batch)
     return err
Beispiel #11
0
    def _alignment_cost(log_probs, allowed_skips_beg, allowed_skips_end,
                        force_forbid_blank):
        # log_probs is BS x WIN_LEN x NUM_PREDS
        bs, win_len, num_preds = log_probs.size()
        assert win_len >= num_preds
        padded_log_probs = F.pad(log_probs,
                                 (0, 0, allowed_skips_beg, allowed_skips_end),
                                 "constant", 0)
        padded_win_len = win_len + allowed_skips_beg + allowed_skips_end
        fake_ctc_labels = torch.arange(1, num_preds + 1,
                                       dtype=torch.int).expand(bs, num_preds)

        # append impossible BLANK probabilities
        ctc_log_probs = padded_log_probs.permute(1, 0, 2).contiguous()
        if force_forbid_blank:
            ctc_log_probs = torch.cat((torch.empty(
                padded_win_len, bs, 1,
                device=log_probs.device).fill_(-1000), ctc_log_probs), 2)
        # Now ctc_log_probs is win_size x BS x (num_preds + 1)
        assert ctc_log_probs.is_contiguous()

        # normalize the log-probs over num_preds
        # This is required, because ctc returns a bad gradient when given
        # unnormalized log probs
        log_sum_exps = torch.logsumexp(ctc_log_probs, 2, keepdim=True)
        ctc_log_probs = ctc_log_probs - log_sum_exps
        losses = F.ctc_loss(ctc_log_probs,
                            fake_ctc_labels,
                            torch.empty(bs,
                                        dtype=torch.int).fill_(padded_win_len),
                            torch.empty(bs, dtype=torch.int).fill_(num_preds),
                            reduction='none')
        losses = losses - log_sum_exps.squeeze(2).sum(0)
        return losses
Beispiel #12
0
    def test_ctc_loss(self):
        torch.manual_seed(0)

        N = 16  # Batch size
        T = 50  # Input sequence length
        C = 20  # Number of classes (including blank)
        S = 30  # Target sequence length of longest target in batch (padding length)
        S_min = 10  # Minimum target length (only for testing)

        logits = torch.randn(N, T, C)
        targets = torch.randint(1, C, (N, S), dtype=torch.long)
        input_lengths = torch.full((N, ), T, dtype=torch.long)
        target_lengths = torch.randint(S_min, S, (N, ), dtype=torch.long)

        config = CTCLoss.Config()
        config.blank = 0  # Needs to be set to 0 for CuDNN support.
        ctc_loss_fn = CTCLoss(config=config)

        ctc_loss_val = ctc_loss_fn(
            logits,
            targets,
            input_lengths,
            target_lengths,
        )

        # PyTorch CTC loss
        log_probs = logits.permute(1, 0, 2).log_softmax(
            2
        )  # permute to conform to CTC loss input tensor (T,N,C) in PyTorch.
        lib_ctc_loss_val = F.ctc_loss(log_probs, targets, input_lengths,
                                      target_lengths)

        self.assertAlmostEqual(ctc_loss_val.item(), lib_ctc_loss_val.item())
Beispiel #13
0
def train(epoch):

    net.train()


    closs = []
    for iter_idx, (img, transcr) in enumerate(train_loader):

        optimizer.zero_grad()

        img = Variable(img.to(device))

        # geometrical and morphological deformations
        '''
        rids = torch.BoolTensor(torch.bernoulli(.1 * torch.ones(img.size(0))).bool())
        if sum(rids) > 1:
            u = 4 # 2 ** np.random.randint(2, 5)
            img[rids] = local_deform(img[rids], rm=None, dscale=u, a=np.random.uniform(2.0, 4.0))

        rids = torch.BoolTensor(torch.bernoulli(.1 * torch.ones(img.size(0))).bool())
        if sum(rids) > 1:
            u = 4 # 2 ** np.random.randint(2, 5)
            img[rids] = morphological(img[rids], rm=None, dscale=u)
        '''

        output = net(img.detach()) #.permute(2, 0, 1)

        act_lens = torch.IntTensor(img.size(0)*[output.size(0)]) 
        labels = torch.IntTensor([cdict[c] for c in ''.join(transcr)])
        label_lens = torch.IntTensor([len(t) for t in transcr])

        ls_output = F.log_softmax(output, dim=2).cpu()
        loss_val = F.ctc_loss(ls_output, labels, act_lens, label_lens, zero_infinity=True, reduction='sum') / img.size(0)

        closs += [loss_val.data]

        loss_val.backward()

        optimizer.step()


        # mean runing errors??
        if iter_idx % args.display == args.display-1:
            logger.info('Epoch %d, Iteration %d: %f', epoch, iter_idx+1, sum(closs)/len(closs))
            closs = []

            net.eval()

            tst_img, tst_transcr = test_set.__getitem__(np.random.randint(test_set.__len__()))
            print('orig:: ' + tst_transcr)
            with torch.no_grad():
                tst_o = net(Variable(tst_img.to(device)).unsqueeze(0)) #.permute(2, 0, 1)
                tdec = tst_o.argmax(2).permute(1, 0).cpu().numpy().squeeze()
                tt = [v for j, v in enumerate(tdec) if j == 0 or v != tdec[j - 1]]
                print('gdec:: ' + ''.join([icdict[t] for t in tt]).replace('_', ''))

            net.train()
Beispiel #14
0
def cal_ctc(pred, gold, pre_len, gold_len):
    pred_log_prob = F.log_softmax(pred, -1)
    loss = ctc_loss(pred_log_prob.transpose(0, 1),
                    gold,
                    pre_len,
                    gold_len,
                    blank=1,
                    zero_infinity=True)
    return loss
Beispiel #15
0
def train_loop(model, optimizer, writer, step, args, hp):
    dataset_train = datasets.get_dataset(hp.train_script, hp, use_spec_aug=hp.use_spec_aug)
    if hp.encoder_type == 'Conformer':
        train_sampler = datasets.LengthsBatchSampler(dataset_train, hp.batch_size * 1500, hp.lengths_file)
        dataloader = DataLoader(dataset_train, batch_sampler=train_sampler, num_workers=2, collate_fn=datasets.collate_fn)
    else:
        train_sampler = DistributedSampler(dataset_train) if args.n_gpus > 1 else None
        dataloader = DataLoader(dataset_train, batch_size=hp.batch_size, shuffle=hp.shuffle, num_workers=2, sampler=train_sampler, collate_fn=datasets.collate_fn, drop_last=True)
    optimizer.zero_grad()
    for d in dataloader:
        step += 1
        if hp.encoder_type == 'Conformer':
            lr = get_transformer_learning_rate(step)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            print(f'step = {step}')
            print(f'lr = {lr}')
        text, mel_input, pos_text, pos_mel, text_lengths, mel_lengths = d

        text = text.to(DEVICE, non_blocking=True)
        mel_input = mel_input.to(DEVICE, non_blocking=True)
        pos_text = pos_text.to(DEVICE, non_blocking=True)
        pos_mel = pos_mel.to(DEVICE, non_blocking=True)
        text_lengths = text_lengths.to(DEVICE, non_blocking=True)

        if hp.frame_stacking > 1 and hp.encoder_type != 'Wave':
            mel_input, mel_lengths = frame_stacking(mel_input, mel_lengths, hp.frame_stacking)

        predict_ts = model(mel_input, mel_lengths, text, pos_mel)

        if hp.decoder_type == 'Attention':
            loss = label_smoothing_loss(predict_ts, text, text_lengths, hp.T_norm, hp.B_norm)
        elif hp.decoder_type == 'CTC':
            predict_ts = F.log_softmax(predict_ts, dim=2).transpose(0, 1)
            loss = F.ctc_loss(predict_ts, text, mel_lengths, text_lengths, blank=0)

        # backward
        loss.backward()
        # optimizer update
        if step % hp.accum_grad == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), hp.clip)
            optimizer.step()
            optimizer.zero_grad()
            loss.detach()

        if torch.isnan(loss):
            print('loss is nan')
            sys.exit(1)
        if hp.debug_mode == 'tensorboard':
            writer.add_scalar("Loss/train", loss, step)
            print('loss = {}'.format(loss.item()))
        else:
            print('loss = {}'.format(loss.item()))

        sys.stdout.flush()
    return step
Beispiel #16
0
def ctc_fallback(encoder_outputs, labels, frame_lens, label_lens, blank):
    assert len(encoder_outputs) == len(labels) == len(frame_lens) == len(label_lens)
    skipped_indices = []
    working_indices = []
    for i in range(len(encoder_outputs)):
        if torch.isinf(F.ctc_loss(encoder_outputs[i:i+1].transpose(0, 1), labels[i:i+1], frame_lens[i:i+1], label_lens[i:i+1], blank=blank)):
            skipped_indices.append(i)
        else:
            working_indices.append(i)
    return skipped_indices, working_indices
Beispiel #17
0
 def loss(output, output_len, targets, targets_len):
     output_trans = output.permute(1, 0, 2)  # needed by the CTCLoss
     loss = F.ctc_loss(output_trans,
                       targets,
                       output_len,
                       targets_len,
                       reduction='none',
                       zero_infinity=True)
     loss /= output_len
     loss = loss.mean()
     return loss
Beispiel #18
0
def ctc_loss(y_hat, y, y_hat_lens, target_lens):
    loss = 0.0
    y_hat = y_hat if 'tuple' in str(type(y_hat)) else [y_hat]

    for y_hat_i in y_hat:
        #breakpoint()
        loss_i = F.ctc_loss(y_hat_i.cpu(), y.cpu(), y_hat_lens.cpu(),
                            target_lens.cpu())
        loss += loss_i

    return loss
Beispiel #19
0
def cal_ctc_loss(logits_ctc, len_logits_ctc, targets, target_lengths):
    n_class = logits_ctc.size(-1)
    ctc_log_probs = F.log_softmax(logits_ctc, dim=-1).transpose(0, 1)
    ctc_loss = F.ctc_loss(ctc_log_probs,
                          targets,
                          len_logits_ctc,
                          target_lengths,
                          reduction="none",
                          blank=n_class - 1)

    return ctc_loss.sum()
Beispiel #20
0
    def forward(self, inputs, inputs_length, targets, targets_length):
        if self.fir_enc_or_not:
            t_inputs, t_inputs_length = self.fir_enc(inputs, inputs_length)
        else:
            t_inputs, t_inputs_length = inputs, inputs_length

        enc_state, _ = self.encoder(t_inputs, t_inputs_length)
        enc_state = enc_state.transpose(0, 1).contiguous()
        loss = F.ctc_loss(enc_state, targets.int(), t_inputs_length.int(),
                          targets_length.int())
        return loss
Beispiel #21
0
def ctc_loss(preds, targets, voc_size):
    # prepare targets
    target_lengths = (targets != voc_size).long().sum(dim=-1)
    trimmed_targets = [t[:l] for t, l in zip(targets, target_lengths)]
    targets = torch.cat(trimmed_targets)

    x = F.log_softmax(preds, dim=-1)
    input_lengths = torch.full((x.size(1),), x.size(0), dtype=torch.long)
    return F.ctc_loss(
        x, targets, input_lengths, target_lengths,
        blank=voc_size, zero_infinity=True
    )
Beispiel #22
0
def aug_loss(output, target):

    S = 8  # 目标序列的长度
    S_min = 5
    alpha = 0.2

    output = output.squeeze(dim=2)
    output = output.permute(2, 0, 1)  # 对应为序列长度,批大小和类别数目

    N = output.shape[1]  # 这里是Batch size
    T = output.shape[0]  # 输入序列的长度,这里为20,所有样本都是等长

    log_probs = output.log_softmax(2).requires_grad_()

    # 求y~=y*%10
    target_ = target % 10

    input_lengths = torch.full(size=(N, ), fill_value=T, dtype=torch.long)
    target_lengths = torch.full(size=(N, ), fill_value=S_min, dtype=torch.long)
    return F.ctc_loss(log_probs,  target, input_lengths, target_lengths, blank=20) + \
    alpha * F.ctc_loss(log_probs,  target_, input_lengths, target_lengths, blank=20)
Beispiel #23
0
 def test_ctc_loss(self):
     # force fp32 because _th_normal_ (used by next line is not supported for fp16)
     log_probs = torch.randn(
         50, 16, 20, device='cuda',
         dtype=torch.float32).log_softmax(2).detach().requires_grad_()
     targets = torch.randint(1,
                             20, (16, 30),
                             device='cuda',
                             dtype=torch.long)
     input_lengths = torch.full((16, ), 50, dtype=torch.long)
     target_lengths = torch.randint(10, 30, (16, ), dtype=torch.long)
     loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths)
Beispiel #24
0
 def compute_loss(
     self, model, net_output, sample,
     reduction='mean', zero_infinity=False,
 ):
     log_probs = model.get_normalized_probs(net_output, log_probs=True)
     targets = torch.cat(sample['target']).cpu()  # Expected targets to have CPU Backend
     target_lengths = sample['target_length']
     input_lengths = torch.full((sample['nsentences'],), log_probs.size(0), dtype=torch.int32)
     loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths,
                       blank=self.blank_idx, reduction=reduction,
                       zero_infinity=zero_infinity)
     return loss
Beispiel #25
0
def cal_loss(logits, len_logits, gold, smoothing=0.0):
    """Calculate cross entropy loss, apply label smoothing if needed.
    """
    n_class = logits.size(-1)
    target_lengths = gold.ne(0).sum(dim=1).int()
    ctc_log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1)
    ctc_loss = F.ctc_loss(ctc_log_probs,
                          gold,
                          len_logits,
                          target_lengths,
                          blank=n_class - 1)

    return ctc_loss
Beispiel #26
0
def ctc_label_smoothing_loss(log_probs, targets, input_lengths, target_lengths,
                             weights):
    loss = ctc_loss(log_probs,
                    targets,
                    input_lengths,
                    target_lengths,
                    reduction='mean')
    label_smoothing_loss = -((log_probs * weights.to(log_probs.device)).mean())
    return {
        'loss': loss + label_smoothing_loss,
        'ctc_loss': loss,
        'label_smooth_loss': label_smoothing_loss
    }
Beispiel #27
0
 def compute_ctc_loss(self, x, y, reduction="mean"):
     """
     Args:
         x: log_probs, (t d)
         y: labels, (t')
     Return:
         loss
     """
     xl = torch.tensor(list(map(len, x)))
     yl = torch.tensor(list(map(len, y)))
     x = pad_sequence(x, False)  # -> (t b c)
     y = pad_sequence(y, True)  # -> (b s)
     return F.ctc_loss(x, y, xl, yl, self.blank, reduction, True)
Beispiel #28
0
 def forward(self, logits, input_lengths, targets, target_lengths):
     # lengths : (batch_size, )
     # log_logits : (T, batch_size, n_class), this kind of shape is required for ctc_loss
     # log_logits = logits + (logit_mask.unsqueeze(-1) + 1e-45).log()
     log_logits = logits.log_softmax(-1).transpose(0, 1)
     loss = F.ctc_loss(log_logits,
                       targets,
                       input_lengths,
                       target_lengths,
                       blank=self.blank_index,
                       reduction='mean',
                       zero_infinity=True)
     return loss
Beispiel #29
0
    def forward(
        self,
        log_p_attn: torch.Tensor,
        ilens: torch.Tensor,
        olens: torch.Tensor,
        blank_prob: float = np.e**-1,
    ) -> torch.Tensor:
        """Calculate forward propagation.

        Args:
            log_p_attn (Tensor): Batch of log probability of attention matrix
                (B, T_feats, T_text).
            ilens (Tensor): Batch of the lengths of each input (B,).
            olens (Tensor): Batch of the lengths of each target (B,).
            blank_prob (float): Blank symbol probability.

        Returns:
            Tensor: forwardsum loss value.

        """
        B = log_p_attn.size(0)

        # add beta-binomial prior
        bb_prior = self._generate_prior(ilens, olens)
        bb_prior = bb_prior.to(dtype=log_p_attn.dtype, device=log_p_attn.device)
        log_p_attn = log_p_attn + bb_prior

        # a row must be added to the attention matrix to account for
        #    blank token of CTC loss
        # (B,T_feats,T_text+1)
        log_p_attn_pd = F.pad(log_p_attn, (1, 0, 0, 0, 0, 0), value=np.log(blank_prob))

        loss = 0
        for bidx in range(B):
            # construct target sequnece.
            # Every text token is mapped to a unique sequnece number.
            target_seq = torch.arange(1, ilens[bidx] + 1).unsqueeze(0)
            cur_log_p_attn_pd = log_p_attn_pd[
                bidx, : olens[bidx], : ilens[bidx] + 1
            ].unsqueeze(
                1
            )  # (T_feats,1,T_text+1)
            loss += F.ctc_loss(
                log_probs=cur_log_p_attn_pd,
                targets=target_seq,
                input_lengths=olens[bidx : bidx + 1],
                target_lengths=ilens[bidx : bidx + 1],
                zero_infinity=True,
            )
        loss = loss / B
        return loss
Beispiel #30
0
def ctc_label_smoothing_loss(log_probs, targets, lengths, weights):
    T, N, C = log_probs.shape
    log_probs_lengths = torch.full(size=(N, ), fill_value=T, dtype=torch.int64)
    loss = ctc_loss(log_probs.to(torch.float32),
                    targets,
                    log_probs_lengths,
                    lengths,
                    reduction='mean')
    label_smoothing_loss = -((log_probs * weights.to(log_probs.device)).mean())
    return {
        'loss': loss + label_smoothing_loss,
        'ctc_loss': loss,
        'label_smooth_loss': label_smoothing_loss
    }