Ejemplo n.º 1
0
 def forward(self, inputs, targets, inputs_len, targets_len):
     '''
     Args:
         inputs(acoustic feature): [N,T,D]
         targets(phoneme sequence): [N,T]
         inputs_len: [N]
         targets_len: [N]
     Return:
         outputs(predicted logits): [N,T,E]
     '''
     enc_state, inputs_len = self.encoder(inputs, inputs_len)
     dec_state, _ = self.decoder(F.pad(targets, pad=[1,0,0,0], value=self.blank_idx))
     
     dec_state = dec_state.unsqueeze(1)
     enc_state = enc_state.unsqueeze(2)
     t = enc_state.size(1)
     u = dec_state.size(2)
     dec_state = dec_state.repeat([1,t,1,1])
     enc_state = enc_state.repeat([1,1,u,1])
     concat_state = torch.cat([enc_state, dec_state], dim=-1)
     
     logits = self.out(self.tanh(self.joint(concat_state)))
     logits = F.log_softmax(logits, dim=-1)
     loss = rnnt_loss(logits, targets.int(), inputs_len.int(), targets_len.int(), blank=self.blank_idx)
     return loss.mean()
Ejemplo n.º 2
0
    def forward_transducer(self, eouts, elens, ys):
        """Compute Transducer loss.

        Args:
            eouts (FloatTensor): `[B, T, enc_n_units]`
            elens (IntTensor): `[B]`
            ys (list): length `B`, each of which contains a list of size `[L]`
        Returns:
            loss (FloatTensor): `[1]`

        """
        # Append <sos> and <eos>
        _ys = [
            np2tensor(np.fromiter(y, dtype=np.int64), eouts.device) for y in ys
        ]
        ylens = np2tensor(np.fromiter([y.size(0) for y in _ys],
                                      dtype=np.int32))
        eos = eouts.new_zeros((1, ), dtype=torch.int64).fill_(self.eos)
        ys_in = pad_list([torch.cat([eos, y], dim=0) for y in _ys],
                         self.pad)  # `[B, L+1]`
        ys_out = pad_list(_ys, self.blank)  # `[B, L]`

        # Update prediction network
        ys_emb = self.dropout_emb(self.embed(ys_in))
        dout, _ = self.recurrency(ys_emb, None)

        # Compute output distribution
        logits = self.joint(eouts, dout)  # `[B, T, L+1, vocab]`

        # Compute Transducer loss
        log_probs = torch.log_softmax(logits, dim=-1)
        assert log_probs.size(2) == ys_out.size(1) + 1
        if self.device_id >= 0:
            ys_out = ys_out.to(eouts.device)
            elens = elens.to(eouts.device)
            ylens = ylens.to(eouts.device)
            import warp_rnnt
            loss = warp_rnnt.rnnt_loss(log_probs,
                                       ys_out.int(),
                                       elens,
                                       ylens,
                                       average_frames=False,
                                       reduction='mean',
                                       gather=False)
        else:
            import warprnnt_pytorch
            self.warprnnt_loss = warprnnt_pytorch.RNNTLoss()
            loss = self.warprnnt_loss(log_probs, ys_out.int(), elens, ylens)
            # NOTE: Transducer loss has already been normalized by bs
            # NOTE: index 0 is reserved for blank in warprnnt_pytorch

        return loss
Ejemplo n.º 3
0
 def cal_transducer_loss(self,
                         model_output,
                         target,
                         frame_length,
                         target_length,
                         type='rnnt'):
     log_prob = t.nn.functional.log_softmax(model_output, -1)
     rnn_t_loss = rnnt_loss(log_probs=log_prob,
                            labels=target.int(),
                            frames_lengths=frame_length.int(),
                            labels_lengths=target_length.int(),
                            reduction='mean')
     return rnn_t_loss
Ejemplo n.º 4
0
    for xs, ys, xn, yn in progress:

        optimizer.zero_grad()

        xs = xs.cuda(non_blocking=True)
        ys = ys.cuda(non_blocking=True)
        xn = xn.cuda(non_blocking=True)
        yn = yn.cuda(non_blocking=True)

        zs, xs, xn = model(xs, ys, xn, yn)

        ys = ys.t().contiguous()

        loss = rnnt_loss(zs,
                         ys,
                         xn,
                         yn,
                         average_frames=False,
                         reduction="mean")
        loss.backward()

        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 100)

        optimizer.step()

        err.update(loss.item())
        grd.update(grad_norm)

        progress.set_description('epoch %d %s %s' % (epoch + 1, err, grd))

    model.eval()
Ejemplo n.º 5
0
            rewards = relu(SymAcc).reshape(K, -1).cuda()

            rewards_mean = rewards.mean().item()

            rewards -= rewards.mean(dim=0)

            elu(rewards, alpha=gamma, inplace=True)

            hs_k = hs_k.reshape(K, len(xs), -1)
            hn_k = hn_k.reshape(K, len(xs))

        model.train()

        zs, xs, xn = model(xs, ys.t(), xn, yn)

        loss1 = rnnt_loss(zs, ys, xn, yn).mean()

        loss2 = -(zs.exp() * zs).sum(dim=-1).mean()

        for k in range(K):

            ys = hs_k[k]
            yn = hn_k[k]

            ys = ys[:, :yn.max()].contiguous()

            zs = model.forward_language(ys.t(), yn)

            zs = model.forward_joint(xs, zs)

            nll = rnnt_loss(zs, ys, xn, yn)
Ejemplo n.º 6
0
    def forward(
        self,
        eouts,
        elens,
        eouts_inter=None,
        ys=None,
        ylens=None,
        ys_in=None,
        ys_out=None,
        soft_labels=None,
        ps=None,
        plens=None,
    ):
        loss = 0
        loss_dict = {}

        # Prediction network
        douts, _ = self.recurrency(ys_in, dstate=None)

        # Joint network
        logits = self.joint(eouts, douts)  # (B, T, L + 1, vocab)
        log_probs = torch.log_softmax(logits, dim=-1)
        assert log_probs.size(2) == ys.size(1) + 1

        # NOTE: rnnt_loss only accepts ys, elens, ylens with torch.int
        loss_rnnt = warp_rnnt.rnnt_loss(
            log_probs,
            ys.int(),
            elens.int(),
            ylens.int(),
            average_frames=False,
            reduction="mean",
            blank=self.blank_id,
            gather=False,
        )
        loss += loss_rnnt  # main loss
        loss_dict["loss_rnnt"] = loss_rnnt

        if self.mtl_ctc_weight > 0:
            # NOTE: KD is not applied to auxiliary CTC
            loss_ctc, _, _ = self.ctc(eouts=eouts,
                                      elens=elens,
                                      ys=ys,
                                      ylens=ylens,
                                      soft_labels=None)
            loss += self.mtl_ctc_weight * loss_ctc  # auxiliary loss
            loss_dict["loss_ctc"] = loss_ctc

        if self.kd_weight > 0 and soft_labels is not None:
            if self.kd_type == "word":
                loss_kd = self.transducer_kd_loss(logits, soft_labels, elens,
                                                  ylens)
            elif self.kd_type == "align":
                aligns = self.forced_aligner(log_probs, elens, ys, ylens)
                loss_kd = self.transducer_kd_loss(logits, ys, soft_labels,
                                                  aligns, elens, ylens)

            loss_dict["loss_kd"] = loss_kd

            if self.reduce_main_loss_kd:
                loss = (1 - self.kd_weight) * loss + self.kd_weight * loss_kd
            else:
                loss += self.kd_weight * loss_kd

        loss_dict["loss_total"] = loss

        return loss, loss_dict, logits
Ejemplo n.º 7
0
    def forward_rnnt(self, eouts, elens, ys):
        """Compute XE loss for the attention-based sequence-to-sequence model.

        Args:
            eouts (FloatTensor): `[B, T, dec_n_units]`
            elens (IntTensor): `[B]`
            ys (list): A list of length `[B]`, which contains a list of size `[L]`
        Returns:
            loss (FloatTensor): `[1]`

        """
        # Append <sos> and <eos>
        eos = eouts.new_zeros(1).fill_(self.eos).long()
        if self.end_pointing:
            _ys = [
                np2tensor(np.fromiter(y + [self.eos], dtype=np.int64),
                          self.device_id) for y in ys
            ]
        else:
            _ys = [
                np2tensor(np.fromiter(y, dtype=np.int64), self.device_id)
                for y in ys
            ]
        ylens = np2tensor(np.fromiter([y.size(0) for y in _ys],
                                      dtype=np.int32))
        ys_in = pad_list([torch.cat([eos, y], dim=0) for y in _ys], self.pad)
        ys_out = pad_list(_ys, self.blank)

        # Update prediction network
        ys_emb = self.dropout_emb(self.embed(ys_in))
        dout, _ = self.recurrency(ys_emb, None)

        # Compute output distribution
        logits = self.joint(eouts, dout)

        # Compute Transducer loss
        log_probs = torch.log_softmax(logits, dim=-1)
        if self.device_id >= 0:
            ys_out = ys_out.cuda(self.device_id)
            elens = elens.cuda(self.device_id)
            ylens = ylens.cuda(self.device_id)

        assert log_probs.size(2) == ys_out.size(1) + 1
        # loss = self.warprnnt_loss(log_probs, ys_out.int(), elens, ylens)
        # NOTE: Transducer loss has already been normalized by bs
        # NOTE: index 0 is reserved for blank in warprnnt_pytorch
        import warp_rnnt
        loss = warp_rnnt.rnnt_loss(log_probs,
                                   ys_out.int(),
                                   elens,
                                   ylens,
                                   average_frames=False,
                                   reduction='mean',
                                   gather=False)

        # Label smoothing for Transducer
        # if self.lsm_prob > 0:
        #     loss = loss * (1 - self.lsm_prob) + kldiv_lsm_ctc(logits,
        #                                                       ylens=elens,
        #                                                       size_average=True) * self.lsm_prob
        # TODO(hirofumi): this leads to out of memory

        return loss