Пример #1
0
    def forward(
        self,
        speech: torch.Tensor,
        speech_original: torch.Tensor,
        speech_lengths: torch.Tensor,
        speech_original_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Calc loss

        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )

        """

        batch_size = speech.shape[0]

        # 1. Encoder
        encoder_out, encoder_out_lens, feats_original, dropout_mask = self.encode(
            speech, speech_original, speech_lengths)

        loss = self._calc_predictive_loss(feats_original, encoder_out,
                                          dropout_mask)

        stats = dict(loss=loss.detach(), )

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                               loss.device)
        return loss, stats, weight
Пример #2
0
 def forward(self, x, x_lengths):
     x = self.layer1(x)
     x = self.layer2(x)
     retval = {
         "loss": x.mean(),
         "stats": {"loss": x.mean()},
         "weight": len(x),
         "optim_idx": torch.randint(0, 2, [1]),
     }
     return force_gatherable(retval, device=x.device)
Пример #3
0
    def forward(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        durations: torch.Tensor,
        durations_lengths: torch.Tensor,
        pitch: torch.Tensor,
        pitch_lengths: torch.Tensor,
        energy: torch.Tensor,
        energy_lengths: torch.Tensor,
        spembs: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        text = text[:, : text_lengths.max()]  # for data-parallel
        speech = speech[:, : speech_lengths.max()]  # for data-parallel
        durations = durations[:, : durations_lengths.max()]  # for data-parallel
        pitch = pitch[:, : pitch_lengths.max()]  # for data-parallel
        energy = energy[:, : energy_lengths.max()]  # for data-parallel

        batch_size = text.size(0)

        # Add eos at the last of sequence
        xs = F.pad(text, [0, 1], "constant", self.padding_idx)
        for i, l in enumerate(text_lengths):
            xs[i, l] = self.eos
        ilens = text_lengths + 1

        ys, ds, ps, es = speech, durations, pitch, energy
        olens = speech_lengths
        
        before_outs, after_outs, d_outs, p_outs, e_outs = self.fastspeech2._forward(
            xs, ilens, ys, olens, ds, ps, es, spembs=spembs, is_inference=False
        )

        ys = speech.transpose(1, 2)
        y_masks = self._source_mask(olens)
        mu = after_outs.transpose(1, 2)
        
        if ys.size(2) % 4 != 0:
            ys = torch.cat([ys, torch.zeros([batch_size, self.odim, 4 - ys.size(2) % 4], dtype=ys.dtype, device=ys.device)], dim=2)
            mu = torch.cat([mu, torch.zeros([mu.size(0), self.odim, 4 - mu.size(2) % 4], dtype=mu.dtype, device=mu.device)], dim=2)
            y_masks = torch.cat([y_masks, torch.zeros([y_masks.size(0), 1, 4 - y_masks.size(2) % 4], dtype=y_masks.dtype, device=y_masks.device)], dim=2)

        noise_estimation, z = self.diffusion(ys, y_masks, mu)
        
        diff_loss = self.criterion(noise_estimation, z, y_masks)
        loss = diff_loss
        stats = dict(
            diff_loss=diff_loss.item(),
            loss=loss.item(),
        )
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
Пример #4
0
    def forward(
        self, text: torch.Tensor, text_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        nll, y_lengths = self.nll(text, text_lengths)
        ntokens = y_lengths.sum()
        loss = nll.sum() / ntokens
        stats = dict(loss=loss.detach())

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, ntokens),
                                               loss.device)
        return loss, stats, weight
Пример #5
0
    def forward(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        src_text: torch.Tensor,
        src_text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
            text: (Batch, Length)
            text_lengths: (Batch,)
            src_text: (Batch, length)
            src_text_lengths: (Batch,)
            kwargs: "utt_id" is among the input.
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
            text.shape[0]
            == text_lengths.shape[0]
            == src_text.shape[0]
            == src_text_lengths.shape[0]
        ), (text.shape, text_lengths.shape, src_text.shape, src_text_lengths.shape)

        batch_size = src_text.shape[0]

        # for data-parallel
        text = text[:, : text_lengths.max()]
        src_text = src_text[:, : src_text_lengths.max()]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(src_text, src_text_lengths)

        # 2a. Attention-decoder branch (MT)
        loss_mt_att, acc_mt_att, bleu_mt_att = self._calc_mt_att_loss(
            encoder_out, encoder_out_lens, text, text_lengths
        )

        # 3. Loss computation
        loss = loss_mt_att

        stats = dict(
            loss=loss.detach(),
            acc=acc_mt_att,
            bleu=bleu_mt_att,
        )

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
Пример #6
0
    def forward_ilm(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
            speech: (Batch, Length, ...) not nessesary it is only used to get device of tensor
            speech_lengths: (Batch, )   not nessesary it is only used to get device of tensor
            text: (Batch, Length)
            text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (text.shape[0] == text_lengths.shape[0]), (text.shape,
                                                          text_lengths.shape)
        batch_size = text.shape[0]

        # for data-parallel
        text = text[:, :text_lengths.max()]

        ys_in_pad, ys_out_pad = add_sos_eos(text, self.sos, self.eos,
                                            self.ignore_id)
        ys_in_lens = text_lengths + 1

        fake_encoder_out = speech.new_zeros(batch_size, 1,
                                            self.encoder._output_size)
        # 1. Forward decoder
        decoder_out, _ = self.decoder.forward_ilm(fake_encoder_out, -1,
                                                  ys_in_pad, ys_in_lens)

        # 2. Compute ilm loss
        loss_ilm = self.criterion_att(decoder_out, ys_out_pad)
        ilm_acc = th_accuracy(
            decoder_out.view(-1, self.vocab_size),
            ys_out_pad,
            ignore_label=self.ignore_id,
        )

        ilm_ppl = torch.exp(loss_ilm)
        stats = dict(ilm_loss=loss_ilm.detach(),
                     ilm_acc=ilm_acc,
                     ilm_ppl=ilm_ppl.detach())

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss_ilm, stats, batch_size),
                                               loss_ilm.device)
        return loss_ilm, stats, weight
Пример #7
0
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Calc loss

        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
            text: (Batch, Length)
            text_lengths: (Batch,)
            kwargs: "utt_id" is among the input.
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
                text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
                                         text.shape, text_lengths.shape)
        batch_size = speech.shape[0]

        # for data-parallel
        text = text[:, :text_lengths.max()]

        # 1. Encoder
        encoder_out = self.encode(speech, speech_lengths, text, text_lengths)

        # 2a. Hubert criterion
        loss, acc_mask, acc_unmask = self._calc_hubert_loss(encoder_out, )

        stats = dict(
            loss=loss.detach(),
            acc_mask=acc_mask,
            acc_unmask=acc_unmask,
            acc=acc_mask,
        )

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                               loss.device)
        return loss, stats, weight
Пример #8
0
def test_force_gatherable_cuda():
    obj = {"a": [torch.tensor([0, 1])]}
    obj2 = force_gatherable(obj, "cuda")
    assert obj2["a"][0].device == torch.device("cuda:0")
Пример #9
0
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
            text: (Batch, Length)
            text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
                text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
                                         text.shape, text_lengths.shape)
        batch_size = speech.shape[0]

        # for data-parallel
        text = text[:, :text_lengths.max()]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        # 2a. Attention-decoder branch
        if self.ctc_weight == 1.0:
            loss_att, acc_att, cer_att, wer_att = None, None, None, None
        else:
            loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
                encoder_out, encoder_out_lens, text, text_lengths)

        # 2b. CTC branch
        if self.ctc_weight == 0.0:
            loss_ctc, cer_ctc = None, None
        else:
            loss_ctc, cer_ctc = self._calc_ctc_loss(encoder_out,
                                                    encoder_out_lens, text,
                                                    text_lengths)

        # 2c. RNN-T branch
        if self.rnnt_decoder is not None:
            _ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text,
                                     text_lengths)

        if self.ctc_weight == 0.0:
            loss = loss_att
        elif self.ctc_weight == 1.0:
            loss = loss_ctc
        else:
            loss = self.ctc_weight * loss_ctc + (1 -
                                                 self.ctc_weight) * loss_att

        stats = dict(
            loss=loss.detach(),
            loss_att=loss_att.detach() if loss_att is not None else None,
            loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
            acc=acc_att,
            cer=cer_att,
            wer=wer_att,
            cer_ctc=cer_ctc,
        )

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                               loss.device)
        return loss, stats, weight
Пример #10
0
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        src_text: Optional[torch.Tensor],
        src_text_lengths: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch,)
            text: (Batch, Length)
            text_lengths: (Batch,)
            src_text: (Batch, length)
            src_text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
                text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
                                         text.shape, text_lengths.shape)

        # additional checks with valid src_text
        if src_text is not None:
            assert src_text_lengths.dim() == 1, src_text_lengths.shape
            assert text.shape[0] == src_text.shape[
                0] == src_text_lengths.shape[0], (
                    text.shape,
                    src_text.shape,
                    src_text_lengths.shape,
                )

        batch_size = speech.shape[0]

        # for data-parallel
        text = text[:, :text_lengths.max()]
        if src_text is not None:
            src_text = src_text[:, :src_text_lengths.max()]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        # 2a. Attention-decoder branch (ST)
        loss_st_att, acc_st_att, bleu_st_att = self._calc_mt_att_loss(
            encoder_out, encoder_out_lens, text, text_lengths, st=True)

        # 2b. CTC branch
        if self.asr_weight > 0:
            assert src_text is not None, "missing source text for asr sub-task of ST"

        if self.asr_weight > 0 and self.mtlalpha > 0:
            loss_asr_ctc, cer_asr_ctc = self._calc_ctc_loss(
                encoder_out, encoder_out_lens, src_text, src_text_lengths)
        else:
            loss_asr_ctc, cer_asr_ctc = 0, None

        # 2c. Attention-decoder branch (extra ASR)
        if self.asr_weight > 0 and self.mtlalpha < 1.0:
            (
                loss_asr_att,
                acc_asr_att,
                cer_asr_att,
                wer_asr_att,
            ) = self._calc_asr_att_loss(encoder_out, encoder_out_lens,
                                        src_text, src_text_lengths)
        else:
            loss_asr_att, acc_asr_att, cer_asr_att, wer_asr_att = 0, None, None, None

        # 2d. Attention-decoder branch (extra MT)
        if self.mt_weight > 0:
            loss_mt_att, acc_mt_att = self._calc_mt_att_loss(encoder_out,
                                                             encoder_out_lens,
                                                             text,
                                                             text_lengths,
                                                             st=False)
        else:
            loss_mt_att, acc_mt_att = 0, None

        # 3. Loss computation
        asr_ctc_weight = self.mtlalpha
        loss_st = loss_st_att
        if asr_ctc_weight == 1.0:
            loss_asr = loss_asr_ctc
        elif asr_ctc_weight == 0.0:
            loss_asr = loss_asr_att
        else:
            loss_asr = (asr_ctc_weight * loss_asr_ctc +
                        (1 - asr_ctc_weight) * loss_asr_att)
        loss_mt = self.mt_weight * loss_mt_att
        loss = ((1 - self.asr_weight - self.mt_weight) * loss_st +
                self.asr_weight * loss_asr + self.mt_weight * loss_mt)

        stats = dict(
            loss=loss.detach(),
            loss_asr=loss_asr.detach()
            if type(loss_asr) is not float else loss_asr,
            loss_mt=loss_mt.detach()
            if type(loss_mt) is not float else loss_mt,
            loss_st=loss_st.detach(),
            acc_asr=acc_asr_att,
            acc_mt=acc_mt_att,
            acc=acc_st_att,
            cer_ctc=cer_asr_ctc,
            cer=cer_asr_att,
            wer=wer_asr_att,
            bleu=bleu_st_att,
        )

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                               loss.device)
        return loss, stats, weight
Пример #11
0
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor = None,
        spk_labels: torch.Tensor = None,
        spk_labels_lengths: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
            speech: (Batch, samples)
            speech_lengths: (Batch,) default None for chunk interator,
                                     because the chunk-iterator does not
                                     have the speech_lengths returned.
                                     see in
                                     espnet2/iterators/chunk_iter_factory.py
            spk_labels: (Batch, )
        """
        assert speech.shape[0] == spk_labels.shape[0], (speech.shape,
                                                        spk_labels.shape)
        batch_size = speech.shape[0]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        if self.attractor is None:
            # 2a. Decoder (baiscally a predction layer after encoder_out)
            pred = self.decoder(encoder_out, encoder_out_lens)
        else:
            # 2b. Encoder Decoder Attractors
            # Shuffle the chronological order of encoder_out, then calculate attractor
            encoder_out_shuffled = encoder_out.clone()
            for i in range(len(encoder_out_lens)):
                encoder_out_shuffled[i, :encoder_out_lens[i], :] = encoder_out[
                    i, torch.randperm(encoder_out_lens[i]), :]
            attractor, att_prob = self.attractor(
                encoder_out_shuffled,
                encoder_out_lens,
                to_device(
                    self,
                    torch.zeros(encoder_out.size(0),
                                spk_labels.size(2) + 1, encoder_out.size(2)),
                ),
            )
            # Remove the final attractor which does not correspond to a speaker
            # Then multiply the attractors and encoder_out
            pred = torch.bmm(encoder_out,
                             attractor[:, :-1, :].permute(0, 2, 1))
        # 3. Aggregate time-domain labels
        spk_labels, spk_labels_lengths = self.label_aggregator(
            spk_labels, spk_labels_lengths)

        # If encoder uses conv* as input_layer (i.e., subsampling),
        # the sequence length of 'pred' might be slighly less than the
        # length of 'spk_labels'. Here we force them to be equal.
        length_diff_tolerance = 2
        length_diff = spk_labels.shape[1] - pred.shape[1]
        if length_diff > 0 and length_diff <= length_diff_tolerance:
            spk_labels = spk_labels[:, 0:pred.shape[1], :]

        if self.attractor is None:
            loss_pit, loss_att = None, None
            loss, perm_idx, perm_list, label_perm = self.pit_loss(
                pred, spk_labels, encoder_out_lens)
        else:
            loss_pit, perm_idx, perm_list, label_perm = self.pit_loss(
                pred, spk_labels, encoder_out_lens)
            loss_att = self.attractor_loss(att_prob, spk_labels)
            loss = loss_pit + self.attractor_weight * loss_att
        (
            correct,
            num_frames,
            speech_scored,
            speech_miss,
            speech_falarm,
            speaker_scored,
            speaker_miss,
            speaker_falarm,
            speaker_error,
        ) = self.calc_diarization_error(pred, label_perm, encoder_out_lens)

        if speech_scored > 0 and num_frames > 0:
            sad_mr, sad_fr, mi, fa, cf, acc, der = (
                speech_miss / speech_scored,
                speech_falarm / speech_scored,
                speaker_miss / speaker_scored,
                speaker_falarm / speaker_scored,
                speaker_error / speaker_scored,
                correct / num_frames,
                (speaker_miss + speaker_falarm + speaker_error) /
                speaker_scored,
            )
        else:
            sad_mr, sad_fr, mi, fa, cf, acc, der = 0, 0, 0, 0, 0, 0, 0

        stats = dict(
            loss=loss.detach(),
            loss_att=loss_att.detach() if loss_att is not None else None,
            loss_pit=loss_pit.detach() if loss_pit is not None else None,
            sad_mr=sad_mr,
            sad_fr=sad_fr,
            mi=mi,
            fa=fa,
            cf=cf,
            acc=acc,
            der=der,
        )

        loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                               loss.device)
        return loss, stats, weight
Пример #12
0
    def forward(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        feats: torch.Tensor,
        feats_lengths: torch.Tensor,
        spembs: Optional[torch.Tensor] = None,
        sids: Optional[torch.Tensor] = None,
        lids: Optional[torch.Tensor] = None,
        joint_training: bool = False,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Calculate forward propagation.

        Args:
            text (LongTensor): Batch of padded character ids (B, T_text).
            text_lengths (LongTensor): Batch of lengths of each input batch (B,).
            feats (Tensor): Batch of padded target features (B, T_feats, odim).
            feats_lengths (LongTensor): Batch of the lengths of each target (B,).
            spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim).
            sids (Optional[Tensor]): Batch of speaker IDs (B, 1).
            lids (Optional[Tensor]): Batch of language IDs (B, 1).
            joint_training (bool): Whether to perform joint training with vocoder.

        Returns:
            Tensor: Loss scalar value.
            Dict: Statistics to be monitored.
            Tensor: Weight value if not joint training else model outputs.

        """
        text = text[:, : text_lengths.max()]  # for data-parallel
        feats = feats[:, : feats_lengths.max()]  # for data-parallel

        batch_size = text.size(0)

        # Add eos at the last of sequence
        xs = F.pad(text, [0, 1], "constant", self.padding_idx)
        for i, l in enumerate(text_lengths):
            xs[i, l] = self.eos
        ilens = text_lengths + 1

        ys = feats
        olens = feats_lengths

        # make labels for stop prediction
        labels = make_pad_mask(olens - 1).to(ys.device, ys.dtype)
        labels = F.pad(labels, [0, 1], "constant", 1.0)

        # calculate tacotron2 outputs
        after_outs, before_outs, logits, att_ws = self._forward(
            xs=xs,
            ilens=ilens,
            ys=ys,
            olens=olens,
            spembs=spembs,
            sids=sids,
            lids=lids,
        )

        # modify mod part of groundtruth
        if self.reduction_factor > 1:
            assert olens.ge(
                self.reduction_factor
            ).all(), "Output length must be greater than or equal to reduction factor."
            olens = olens.new([olen - olen % self.reduction_factor for olen in olens])
            max_out = max(olens)
            ys = ys[:, :max_out]
            labels = labels[:, :max_out]
            labels = torch.scatter(
                labels, 1, (olens - 1).unsqueeze(1), 1.0
            )  # see #3388

        # calculate taco2 loss
        l1_loss, mse_loss, bce_loss = self.taco2_loss(
            after_outs, before_outs, logits, ys, labels, olens
        )
        if self.loss_type == "L1+L2":
            loss = l1_loss + mse_loss + bce_loss
        elif self.loss_type == "L1":
            loss = l1_loss + bce_loss
        elif self.loss_type == "L2":
            loss = mse_loss + bce_loss
        else:
            raise ValueError(f"unknown --loss-type {self.loss_type}")

        stats = dict(
            l1_loss=l1_loss.item(),
            mse_loss=mse_loss.item(),
            bce_loss=bce_loss.item(),
        )

        # calculate attention loss
        if self.use_guided_attn_loss:
            # NOTE(kan-bayashi): length of output for auto-regressive
            # input will be changed when r > 1
            if self.reduction_factor > 1:
                olens_in = olens.new([olen // self.reduction_factor for olen in olens])
            else:
                olens_in = olens
            attn_loss = self.attn_loss(att_ws, ilens, olens_in)
            loss = loss + attn_loss
            stats.update(attn_loss=attn_loss.item())

        if not joint_training:
            stats.update(loss=loss.item())
            loss, stats, weight = force_gatherable(
                (loss, stats, batch_size), loss.device
            )
            return loss, stats, weight
        else:
            return loss, stats, after_outs
Пример #13
0
    def forward_loss(
        self,
        speech_pre: torch.Tensor,
        speech_lengths: torch.Tensor,
        feature_mix: torch.Tensor,
        feature_pre: torch.Tensor,
        others: OrderedDict,
        speech_ref: torch.Tensor,
        noise_ref: torch.Tensor = None,
        dereverb_speech_ref: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        loss = 0.0
        stats = dict()
        o = {}
        for loss_wrapper in self.loss_wrappers:
            criterion = loss_wrapper.criterion
            if isinstance(criterion, TimeDomainLoss):
                if speech_ref[0].dim() == 3:
                    # For multi-channel reference,
                    # only select one channel as the reference
                    speech_ref = [
                        sr[..., self.ref_channel] for sr in speech_ref
                    ]
                # for the time domain criterions
                l, s, o = loss_wrapper(speech_ref, speech_pre, o)
            elif isinstance(criterion, FrequencyDomainLoss):
                # for the time-frequency domain criterions
                if criterion.compute_on_mask:
                    # compute on mask
                    tf_ref = criterion.create_mask_label(
                        feature_mix,
                        [
                            self.encoder(sr, speech_lengths)[0]
                            for sr in speech_ref
                        ],
                    )
                    tf_pre = [
                        others["mask_spk{}".format(spk + 1)]
                        for spk in range(self.num_spk)
                    ]
                else:
                    # compute on spectrum
                    if speech_ref[0].dim() == 3:
                        # For multi-channel reference,
                        # only select one channel as the reference
                        speech_ref = [
                            sr[..., self.ref_channel] for sr in speech_ref
                        ]
                    tf_ref = [
                        self.encoder(sr, speech_lengths)[0]
                        for sr in speech_ref
                    ]
                    tf_pre = feature_pre

                l, s, o = loss_wrapper(tf_ref, tf_pre, o)
            loss += l * loss_wrapper.weight
            stats.update(s)

        stats["loss"] = loss.detach()

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        batch_size = speech_ref[0].shape[0]
        loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                               loss.device)
        return loss, stats, weight
Пример #14
0
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
            text: (Batch, Length)
            text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
                text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
                                         text.shape, text_lengths.shape)

        # additional checks with valid src_text
        if "src_text" in kwargs:
            src_text = kwargs["src_text"]
            src_text_lengths = kwargs["src_text_lengths"]

            if src_text is not None:
                assert src_text_lengths.dim() == 1, src_text_lengths.shape
                assert (text.shape[0] == src_text.shape[0] ==
                        src_text_lengths.shape[0]), (
                            text.shape,
                            src_text.shape,
                            src_text_lengths.shape,
                        )
        else:
            src_text = None
            src_text_lengths = None

        batch_size = speech.shape[0]

        # clean speech signal
        speech_ref = None
        if self.calc_enh_loss:
            assert "speech_ref1" in kwargs
            speech_ref = [kwargs["speech_ref1"]
                          ]  # [(Batch, samples)] x num_spkr

        # Calculating enhancement loss
        utt_id = kwargs.get("utt_id", None)
        bypass_enh_flag, skip_enhloss_flag = False, False
        if utt_id is not None:
            # TODO(xkc): to pass category info and use predefined category list
            if utt_id[0].endswith("SIMU"):
                # For simulated single-/multi-speaker data
                # feed it to Enhancement and calculate loss_enh
                bypass_enh_flag = False
                skip_enhloss_flag = False
            elif utt_id[0].endswith("REAL"):
                # For single-speaker real data
                # feed it to Enhancement but without calculating loss_enh
                bypass_enh_flag = False
                skip_enhloss_flag = True
            else:
                # For clean data
                # feed it to Enhancement, without calculating loss_enh
                bypass_enh_flag = True
                skip_enhloss_flag = True

        if not self.calc_enh_loss:
            skip_enhloss_flag = True

        # Bypass the enhancement module
        if (self.training and skip_enhloss_flag and not bypass_enh_flag
            ):  # For single-speaker real data: possibility to bypass frontend
            if random.random() <= self.bypass_enh_prob:
                bypass_enh_flag = True

        # 1. Enhancement
        # model forward
        loss_enh = None
        if not bypass_enh_flag:
            (
                speech_pre,
                feature_mix,
                feature_pre,
                others,
            ) = self.enh_model.forward_enhance(speech, speech_lengths)
            # loss computation
            if not skip_enhloss_flag:
                loss_enh, _, _ = self.enh_model.forward_loss(
                    speech_pre,
                    speech_lengths,
                    feature_mix,
                    feature_pre,
                    others,
                    speech_ref,
                )
                loss_enh = loss_enh[0]
        else:
            speech_pre = [speech]

        # for data-parallel
        text = text[:, :text_lengths.max()]
        if src_text is not None:
            src_text = src_text[:, :src_text_lengths.max()]

        # 2. ASR or ST
        if isinstance(self.s2t_model, ESPnetASRModel):  # ASR
            loss_asr, stats, weight = self.s2t_model(speech_pre[0],
                                                     speech_lengths, text,
                                                     text_lengths)
        elif isinstance(self.s2t_model, ESPnetSTModel):  # ST
            loss_asr, stats, weight = self.s2t_model(
                speech_pre[0],
                speech_lengths,
                text,
                text_lengths,
                src_text,
                src_text_lengths,
            )
        else:
            raise NotImplementedError(
                f"{type(self.s2t_model)} is not supported yet.")

        if loss_enh is not None:
            loss = loss_enh + loss_asr
        else:
            loss = loss_asr

        stats["loss"] = loss.detach() if loss is not None else None
        stats["loss_enh"] = loss_enh.detach() if loss_enh is not None else None

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                               loss.device)
        return loss, stats, weight
Пример #15
0
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        noisy_label_flag: bool=False,
        replace_label_flag: bool=True,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
            text: (Batch, Length)
            text_lengths: (Batch,)
        """
        # if self.stat is not None :
        #     hist = self.stat.confid_hist
        #     if hist.sum() != 0:
        #         hist.requires_grad = False
        #         total_sum = hist.sum()
        #         # simaple mean testing for alpha = 0.27
        #         z_alpha = 18
        #         self.th = 1 / self.stat.bins * z_alpha
        #     else:
        #         # logging.warning("Prior histogram has {} value!".format(hist.sum()))
        #         self.th = None

        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
            speech.shape[0]
            == speech_lengths.shape[0]
            == text.shape[0]
            == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
        batch_size = speech.shape[0]

        # for data-parallel
        text = text[:, : text_lengths.max()]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
            
        # 2a. Attention-decoder branch
        if self.ctc_weight == 1.0:
            loss_att, acc_att, cer_att, wer_att = None, None, None, None, None
        else:
            if replace_label_flag:
                decoder_meta_out_prob = self._meta_forward(
                    speech,
                    speech_lengths,
                    text,
                    text_lengths
                )
            else:
                decoder_meta_out_prob = None

            loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
                encoder_out, encoder_out_lens, text, text_lengths, 
                replace_label_flag, 
                decoder_meta_out_prob
            )
            
        # 2b. CTC branch
        if self.ctc_weight == 0.0:
            loss_ctc, cer_ctc = None, None
        else:
            loss_ctc, cer_ctc = self._calc_ctc_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )

        # 2c. RNN-T branch
        if self.rnnt_decoder is not None:
            _ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text, text_lengths)
        
        if self.ctc_weight == 0.0:
            loss = loss_att
        elif self.ctc_weight == 1.0:
            loss = loss_ctc
        else:
            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att 

        stats = dict(
            loss=loss.detach(),
            loss_att=loss_att.detach() if loss_att is not None else None,
            loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
            acc=acc_att,
            cer=cer_att,
            wer=wer_att,
            cer_ctc=cer_ctc,
            # pred_err_att=pred_err_att,
        )

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        
        return loss, stats, weight
Пример #16
0
    def forward(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        durations: torch.Tensor,
        durations_lengths: torch.Tensor,
        spembs: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Calculate forward propagation.

        Args:
            text (LongTensor): Batch of padded character ids (B, Tmax).
            text_lengths (LongTensor): Batch of lengths of each input (B,).
            speech (Tensor): Batch of padded target features (B, Lmax, odim).
            speech_lengths (LongTensor): Batch of the lengths of each target (B,).
            durations (LongTensor): Batch of padded durations (B, Tmax + 1).
            durations_lengths (LongTensor): Batch of duration lengths (B, Tmax + 1).
            spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim).

        Returns:
            Tensor: Loss scalar value.
            Dict: Statistics to be monitored.
            Tensor: Weight value.

        """
        text = text[:, :text_lengths.max()]  # for data-parallel
        speech = speech[:, :speech_lengths.max()]  # for data-parallel
        durations = durations[:, :durations_lengths.max()]  # for data-parallel

        batch_size = text.size(0)

        # Add eos at the last of sequence
        xs = F.pad(text, [0, 1], "constant", self.padding_idx)
        for i, l in enumerate(text_lengths):
            xs[i, l] = self.eos
        ilens = text_lengths + 1

        ys, ds = speech, durations
        olens = speech_lengths

        # forward propagation
        before_outs, after_outs, d_outs = self._forward(xs,
                                                        ilens,
                                                        ys,
                                                        olens,
                                                        ds,
                                                        spembs=spembs,
                                                        is_inference=False)

        # modifiy mod part of groundtruth
        if self.reduction_factor > 1:
            olens = olens.new(
                [olen - olen % self.reduction_factor for olen in olens])
            max_olen = max(olens)
            ys = ys[:, :max_olen]

        # calculate loss
        if self.postnet is None:
            after_outs = None
        l1_loss, duration_loss = self.criterion(after_outs, before_outs,
                                                d_outs, ys, ds, ilens, olens)
        loss = l1_loss + duration_loss

        stats = dict(
            l1_loss=l1_loss.item(),
            duration_loss=duration_loss.item(),
            loss=loss.item(),
        )

        # report extra information
        if self.encoder_type == "transformer" and self.use_scaled_pos_enc:
            stats.update(
                encoder_alpha=self.encoder.embed[-1].alpha.data.item(), )
        if self.decoder_type == "transformer" and self.use_scaled_pos_enc:
            stats.update(
                decoder_alpha=self.decoder.embed[-1].alpha.data.item(), )

        loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                               loss.device)
        return loss, stats, weight
Пример #17
0
    def forward(
        self,
        speech_mix: torch.Tensor,
        speech_mix_lengths: torch.Tensor = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
            speech_mix: (Batch, samples) or (Batch, samples, channels)
            speech_ref: (Batch, num_speaker, samples)
                        or (Batch, num_speaker, samples, channels)
            speech_mix_lengths: (Batch,), default None for chunk interator,
                            because the chunk-iterator does not have the
                            speech_lengths returned. see in
                            espnet2/iterators/chunk_iter_factory.py
        """
        # clean speech signal of each speaker
        speech_ref = [
            kwargs["speech_ref{}".format(spk + 1)] for spk in range(self.num_spk)
        ]
        # (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels)
        speech_ref = torch.stack(speech_ref, dim=1)

        if "noise_ref1" in kwargs:
            # noise signal (optional, required when using
            # frontend models with beamformering)
            noise_ref = [
                kwargs["noise_ref{}".format(n + 1)] for n in range(self.num_noise_type)
            ]
            # (Batch, num_noise_type, samples) or
            # (Batch, num_noise_type, samples, channels)
            noise_ref = torch.stack(noise_ref, dim=1)
        else:
            noise_ref = None

        # dereverberated (noisy) signal
        # (optional, only used for frontend models with WPE)
        if "dereverb_ref1" in kwargs:
            # noise signal (optional, required when using
            # frontend models with beamformering)
            dereverb_speech_ref = [
                kwargs["dereverb_ref{}".format(n + 1)]
                for n in range(self.num_spk)
                if "dereverb_ref{}".format(n + 1) in kwargs
            ]
            assert len(dereverb_speech_ref) in (1, self.num_spk), len(
                dereverb_speech_ref
            )
            # (Batch, N, samples) or (Batch, N, samples, channels)
            dereverb_speech_ref = torch.stack(dereverb_speech_ref, dim=1)
        else:
            dereverb_speech_ref = None

        batch_size = speech_mix.shape[0]
        speech_lengths = (
            speech_mix_lengths
            if speech_mix_lengths is not None
            else torch.ones(batch_size).int().fill_(speech_mix.shape[1])
        )
        assert speech_lengths.dim() == 1, speech_lengths.shape
        # Check that batch_size is unified
        assert speech_mix.shape[0] == speech_ref.shape[0] == speech_lengths.shape[0], (
            speech_mix.shape,
            speech_ref.shape,
            speech_lengths.shape,
        )

        # for data-parallel
        speech_ref = speech_ref[:, :, : speech_lengths.max()]
        speech_mix = speech_mix[:, : speech_lengths.max()]

        loss, speech_pre, others, out_lengths, perm = self._compute_loss(
            speech_mix,
            speech_lengths,
            speech_ref,
            dereverb_speech_ref=dereverb_speech_ref,
            noise_ref=noise_ref,
        )

        # add stats for logging
        if self.loss_type not in ["ci_sdr", "si_snr"]:
            if self.training:
                si_snr = None
            else:
                speech_pre = [self.decoder(ps, speech_lengths)[0] for ps in speech_pre]
                speech_ref = torch.unbind(speech_ref, dim=1)
                if speech_ref[0].dim() == 3:
                    # For si_snr loss, only select one channel as the reference
                    speech_ref = [sr[..., self.ref_channel] for sr in speech_ref]
                # compute si-snr loss
                si_snr_loss, perm = self._permutation_loss(
                    speech_ref, speech_pre, self.si_snr_loss, perm=perm
                )
                si_snr = -si_snr_loss.detach()

            stats = dict(
                si_snr=si_snr,
                loss=loss.detach(),
            )
        else:
            if self.loss_type == "ci_sdr":
                stats = dict(ci_sdr=-loss.detach(), loss=loss.detach())
            elif self.loss_type == "si_snr":
                stats = dict(si_snr=-loss.detach(), loss=loss.detach())
            else:
                raise ValueError("Unsupported loss type: %s" % self.loss_type)

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
Пример #18
0
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Forward architecture and compute loss(es).

        Args:
            speech: Speech sequences. (B, S)
            speech_lengths: Speech sequences lengths. (B,)
            text: Label ID sequences. (B, L)
            text_lengths: Label ID sequences lengths. (B,)
            kwargs: Contains "utts_id".

        Return:
            loss: Main loss value.
            stats: Task statistics.
            weight: Task weights.

        """
        assert text_lengths.dim() == 1, text_lengths.shape
        assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
                text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
                                         text.shape, text_lengths.shape)

        batch_size = speech.shape[0]
        text = text[:, :text_lengths.max()]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        # 2. Transducer-related I/O preparation
        decoder_in, target, t_len, u_len = get_transducer_task_io(
            text,
            encoder_out_lens,
            ignore_id=self.ignore_id,
        )

        # 3. Decoder
        self.decoder.set_device(encoder_out.device)
        decoder_out = self.decoder(decoder_in)

        # 4. Joint Network
        joint_out = self.joint_network(encoder_out.unsqueeze(2),
                                       decoder_out.unsqueeze(1))

        # 5. Losses
        loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
            encoder_out,
            joint_out,
            target,
            t_len,
            u_len,
        )

        loss_ctc, loss_lm = 0.0, 0.0

        if self.use_auxiliary_ctc:
            loss_ctc = self._calc_ctc_loss(
                encoder_out,
                target,
                t_len,
                u_len,
            )

        if self.use_auxiliary_lm_loss:
            loss_lm = self._calc_lm_loss(decoder_out, target)

        loss = (self.transducer_weight * loss_trans +
                self.auxiliary_ctc_weight * loss_ctc +
                self.auxiliary_lm_loss_weight * loss_lm)

        stats = dict(
            loss=loss.detach(),
            loss_transducer=loss_trans.detach(),
            aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
            aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
            cer_transducer=cer_trans,
            wer_transducer=wer_trans,
        )

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                               loss.device)

        return loss, stats, weight
Пример #19
0
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
            text: (Batch, Length)
            text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
                text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
                                         text.shape, text_lengths.shape)
        batch_size = speech.shape[0]

        # For data-parallel
        text = text[:, :text_lengths.max()]

        # Define stats to report
        loss_mlm, acc_mlm = None, None
        loss_ctc, cer_ctc = None, None
        stats = dict()

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]

        # 2. CTC branch
        if self.ctc_weight != 0.0:
            loss_ctc, cer_ctc = self._calc_ctc_loss(encoder_out,
                                                    encoder_out_lens, text,
                                                    text_lengths)

            # Collect CTC branch stats
            stats["loss_ctc"] = loss_ctc.detach(
            ) if loss_ctc is not None else None
            stats["cer_ctc"] = cer_ctc

        # 2a. Intermediate CTC (optional)
        loss_interctc = 0.0
        if self.interctc_weight != 0.0 and intermediate_outs is not None:
            for layer_idx, intermediate_out in intermediate_outs:
                # we assume intermediate_out has the same length & padding
                # as those of encoder_out
                loss_ic, cer_ic = self._calc_ctc_loss(intermediate_out,
                                                      encoder_out_lens, text,
                                                      text_lengths)
                loss_interctc = loss_interctc + loss_ic

                # Collect Intermedaite CTC stats
                stats["loss_interctc_layer{}".format(layer_idx)] = (
                    loss_ic.detach() if loss_ic is not None else None)
                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic

            loss_interctc = loss_interctc / len(intermediate_outs)

            # calculate whole encoder loss
            loss_ctc = (1 - self.interctc_weight
                        ) * loss_ctc + self.interctc_weight * loss_interctc

        # 3. MLM decoder branch
        if self.ctc_weight != 1.0:
            loss_mlm, acc_mlm = self._calc_mlm_loss(encoder_out,
                                                    encoder_out_lens, text,
                                                    text_lengths)

        # 4. CTC/MLM loss definition
        if self.ctc_weight == 0.0:
            loss = loss_mlm
        elif self.ctc_weight == 1.0:
            loss = loss_ctc
        else:
            loss = self.ctc_weight * loss_ctc + (1 -
                                                 self.ctc_weight) * loss_mlm

        # Collect MLM branch stats
        stats["loss_mlm"] = loss_mlm.detach() if loss_mlm is not None else None
        stats["acc_mlm"] = acc_mlm

        # Collect total loss stats
        stats["loss"] = loss.detach()

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                               loss.device)
        return loss, stats, weight
Пример #20
0
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor = None,
        text: torch.Tensor = None,
        text_lengths: torch.Tensor = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, ) default None for chunk interator,
                                      because the chunk-iterator does not
                                      have the speech_lengths returned.
                                      see in
                                      espnet2/iterators/chunk_iter_factory.py
            text: (Batch, Length) default None just to keep the argument order
            text_lengths: (Batch,) default None for the same reason as speech_lengths
        """
        if text_lengths is not None:
            assert text_lengths.dim() == 1, text_lengths.shape
        if speech_lengths is not None and text_lengths is not None:
            # Check that batch_size is unified
            assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0]
                    == text_lengths.shape[0]), (speech.shape,
                                                speech_lengths.shape,
                                                text.shape, text_lengths.shape)
        else:
            assert speech.shape[0] == text.shape[0], (speech.shape, text.shape)

        # additional checks with valid src_text
        if "src_text" in kwargs:
            src_text = kwargs["src_text"]
            src_text_lengths = kwargs["src_text_lengths"]

            if src_text is not None:
                assert src_text_lengths.dim() == 1, src_text_lengths.shape
                assert (text.shape[0] == src_text.shape[0] ==
                        src_text_lengths.shape[0]), (
                            text.shape,
                            src_text.shape,
                            src_text_lengths.shape,
                        )
        else:
            src_text = None
            src_text_lengths = None

        batch_size = speech.shape[0]
        speech_lengths = (speech_lengths if speech_lengths is not None else
                          torch.ones(batch_size).int() * speech.shape[1])

        # number of speakers
        # Take the number of speakers from text
        # (= spk_label [Batch, length, num_spk] ) if it is 3-D.
        # This is to handle flexible number of speakers.
        # Used only in "enh + diar" task for now.
        num_spk = text.shape[2] if text.dim() == 3 else self.enh_model.num_spk

        # clean speech signal of each speaker
        speech_ref = None
        if self.calc_enh_loss:
            assert "speech_ref1" in kwargs
            speech_ref = [
                kwargs["speech_ref{}".format(spk + 1)]
                for spk in range(num_spk)
            ]
            # (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels)
            speech_ref = torch.stack(speech_ref, dim=1)
            # for data-parallel
            speech_ref = speech_ref[..., :speech_lengths.max()]
            speech_ref = speech_ref.unbind(dim=1)

        # Calculating enhancement loss
        utt_id = kwargs.get("utt_id", None)
        bypass_enh_flag, skip_enhloss_flag = False, False
        if utt_id is not None and not isinstance(self.s2t_model,
                                                 ESPnetDiarizationModel):
            # TODO(xkc): to pass category info and use predefined category list
            if utt_id[0].endswith("SIMU"):
                # For simulated single-/multi-speaker data
                # feed it to Enhancement and calculate loss_enh
                bypass_enh_flag = False
                skip_enhloss_flag = False
            elif utt_id[0].endswith("REAL"):
                # For single-speaker real data
                # feed it to Enhancement but without calculating loss_enh
                bypass_enh_flag = False
                skip_enhloss_flag = True
            else:
                # For clean data
                # feed it to Enhancement, without calculating loss_enh
                bypass_enh_flag = True
                skip_enhloss_flag = True

        if not self.calc_enh_loss:
            skip_enhloss_flag = True

        # Bypass the enhancement module
        if (self.training and skip_enhloss_flag and not bypass_enh_flag
            ):  # For single-speaker real data: possibility to bypass frontend
            if random.random() <= self.bypass_enh_prob:
                bypass_enh_flag = True

        # 1. Enhancement
        # model forward
        loss_enh = None
        if not bypass_enh_flag:
            (
                speech_pre,
                feature_mix,
                feature_pre,
                others,
            ) = self.enh_model.forward_enhance(speech, speech_lengths,
                                               {"num_spk": num_spk})
            # loss computation
            if not skip_enhloss_flag:
                loss_enh, _, _ = self.enh_model.forward_loss(
                    speech_pre,
                    speech_lengths,
                    feature_mix,
                    feature_pre,
                    others,
                    speech_ref,
                )
                loss_enh = loss_enh[0]
        else:
            speech_pre = [speech]

        # for data-parallel
        if text_lengths is not None:
            text = text[:, :text_lengths.max()]
        if src_text is not None:
            src_text = src_text[:, :src_text_lengths.max()]

        # 2. ASR or ST
        if isinstance(self.s2t_model, ESPnetASRModel):  # ASR
            loss_asr, stats, weight = self.s2t_model(speech_pre[0],
                                                     speech_lengths, text,
                                                     text_lengths)
        elif isinstance(self.s2t_model, ESPnetSTModel):  # ST
            loss_asr, stats, weight = self.s2t_model(
                speech_pre[0],
                speech_lengths,
                text,
                text_lengths,
                src_text,
                src_text_lengths,
            )
        elif isinstance(self.s2t_model, ESPnetDiarizationModel):  # DIAR
            loss_asr, stats, weight = self.s2t_model(
                speech=speech.clone(),
                speech_lengths=speech_lengths,
                spk_labels=text,
                spk_labels_lengths=text_lengths,
                bottleneck_feats=others.get("bottleneck_feats"),
                bottleneck_feats_lengths=others.get(
                    "bottleneck_feats_lengths"),
            )
        else:
            raise NotImplementedError(
                f"{type(self.s2t_model)} is not supported yet.")

        if loss_enh is not None:
            loss = loss_enh + loss_asr
        else:
            loss = loss_asr

        stats["loss"] = loss.detach() if loss is not None else None
        stats["loss_enh"] = loss_enh.detach() if loss_enh is not None else None

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                               loss.device)
        return loss, stats, weight
Пример #21
0
    def forward_loss(
        self,
        speech_pre: torch.Tensor,
        speech_lengths: torch.Tensor,
        feature_mix: torch.Tensor,
        feature_pre: torch.Tensor,
        others: OrderedDict,
        speech_ref: torch.Tensor,
        noise_ref: torch.Tensor = None,
        dereverb_speech_ref: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        # for calculating loss on estimated noise signals
        if getattr(self.separator, "predict_noise", False):
            assert "noise1" in others, others.keys()
        if noise_ref is not None and "noise1" in others:
            for n in range(self.num_noise_type):
                key = "noise{}".format(n + 1)
                others[key] = self.decoder(others[key], speech_lengths)[0]
        # for calculating loss on dereverberated signals
        if getattr(self.separator, "predict_dereverb", False):
            assert "dereverb1" in others, others.keys()
        if dereverb_speech_ref is not None and "dereverb1" in others:
            for spk in range(self.num_spk):
                key = "dereverb{}".format(spk + 1)
                if key in others:
                    others[key] = self.decoder(others[key], speech_lengths)[0]

        loss = 0.0
        stats = {}
        o = {}
        for loss_wrapper in self.loss_wrappers:
            criterion = loss_wrapper.criterion
            if getattr(criterion, "only_for_test", False) and self.training:
                continue
            if getattr(criterion, "is_noise_loss", False):
                if noise_ref is None:
                    raise ValueError(
                        "No noise reference for training!\n"
                        'Please specify "--use_noise_ref true" in run.sh')
                signal_ref = noise_ref
                signal_pre = [
                    others["noise{}".format(n + 1)]
                    for n in range(self.num_noise_type)
                ]
            elif getattr(criterion, "is_dereverb_loss", False):
                if dereverb_speech_ref is None:
                    raise ValueError(
                        "No dereverberated reference for training!\n"
                        'Please specify "--use_dereverb_ref true" in run.sh')
                signal_ref = dereverb_speech_ref
                signal_pre = [
                    others["dereverb{}".format(n + 1)]
                    for n in range(self.num_noise_type)
                    if "dereverb{}".format(n + 1) in others
                ]
                if len(signal_pre) == 0:
                    signal_pre = None
            else:
                signal_ref = speech_ref
                signal_pre = speech_pre

            if isinstance(criterion, TimeDomainLoss):
                assert signal_pre is not None
                sref, spre = self._align_ref_pre_channels(signal_ref,
                                                          signal_pre,
                                                          ch_dim=2,
                                                          force_1ch=True)
                # for the time domain criterions
                l, s, o = loss_wrapper(sref, spre, {**others, **o})
            elif isinstance(criterion, FrequencyDomainLoss):
                sref, spre = self._align_ref_pre_channels(signal_ref,
                                                          signal_pre,
                                                          ch_dim=2,
                                                          force_1ch=False)
                # for the time-frequency domain criterions
                if criterion.compute_on_mask:
                    # compute loss on masks
                    if getattr(criterion, "is_noise_loss", False):
                        tf_ref, tf_pre = self._get_noise_masks(
                            criterion,
                            feature_mix,
                            speech_ref,
                            signal_ref,
                            signal_pre,
                            speech_lengths,
                            others,
                        )
                    elif getattr(criterion, "is_dereverb_loss", False):
                        tf_ref, tf_pre = self._get_dereverb_masks(
                            criterion,
                            feature_mix,
                            noise_ref,
                            signal_ref,
                            signal_pre,
                            speech_lengths,
                            others,
                        )
                    else:
                        tf_ref, tf_pre = self._get_speech_masks(
                            criterion,
                            feature_mix,
                            noise_ref,
                            signal_ref,
                            signal_pre,
                            speech_lengths,
                            others,
                        )
                else:
                    # compute on spectrum
                    tf_ref = [
                        self.encoder(sr, speech_lengths)[0] for sr in sref
                    ]
                    tf_pre = [
                        self.encoder(sp, speech_lengths)[0] for sp in spre
                    ]

                l, s, o = loss_wrapper(tf_ref, tf_pre, {**others, **o})
            else:
                raise NotImplementedError("Unsupported loss type: %s" %
                                          str(criterion))

            loss += l * loss_wrapper.weight
            stats.update(s)

        if self.training and isinstance(loss, float):
            raise AttributeError(
                "At least one criterion must satisfy: only_for_test=False")
        stats["loss"] = loss.detach()

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        batch_size = speech_ref[0].shape[0]
        loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                               loss.device)
        return loss, stats, weight
Пример #22
0
    def _forward_generator(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        feats: torch.Tensor,
        feats_lengths: torch.Tensor,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        sids: Optional[torch.Tensor] = None,
        spembs: Optional[torch.Tensor] = None,
        lids: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Dict[str, Any]:
        """Perform generator forward.

        Args:
            text (Tensor): Text index tensor (B, T_text).
            text_lengths (Tensor): Text length tensor (B,).
            feats (Tensor): Feature tensor (B, T_feats, aux_channels).
            feats_lengths (Tensor): Feature length tensor (B,).
            speech (Tensor): Speech waveform tensor (B, T_wav).
            speech_lengths (Tensor): Speech length tensor (B,).
            sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
            spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
            lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).

        Returns:
            Dict[str, Any]:
                * loss (Tensor): Loss scalar tensor.
                * stats (Dict[str, float]): Statistics to be monitored.
                * weight (Tensor): Weight tensor to summarize losses.
                * optim_idx (int): Optimizer index (0 for G and 1 for D).

        """
        # setup
        batch_size = text.size(0)
        speech = speech.unsqueeze(1)

        # calculate generator outputs
        reuse_cache = True
        if not self.cache_generator_outputs or self._cache is None:
            reuse_cache = False
            outs = self.generator(
                text=text,
                text_lengths=text_lengths,
                feats=feats,
                feats_lengths=feats_lengths,
                sids=sids,
                spembs=spembs,
                lids=lids,
                **kwargs,
            )
        else:
            outs = self._cache

        # store cache
        if self.training and self.cache_generator_outputs and not reuse_cache:
            self._cache = outs

        # parse outputs
        (
            speech_hat_,
            bin_loss,
            log_p_attn,
            start_idxs,
            d_outs,
            ds,
            p_outs,
            ps,
            e_outs,
            es,
        ) = outs
        speech_ = get_segments(
            x=speech,
            start_idxs=start_idxs * self.generator.upsample_factor,
            segment_size=self.generator.segment_size * self.generator.upsample_factor,
        )

        # calculate discriminator outputs
        p_hat = self.discriminator(speech_hat_)
        with torch.no_grad():
            # do not store discriminator gradient in generator turn
            p = self.discriminator(speech_)

        # calculate losses
        mel_loss = self.mel_loss(speech_hat_, speech_)
        adv_loss = self.generator_adv_loss(p_hat)
        feat_match_loss = self.feat_match_loss(p_hat, p)
        dur_loss, pitch_loss, energy_loss = self.var_loss(
            d_outs, ds, p_outs, ps, e_outs, es, text_lengths
        )
        forwardsum_loss = self.forwardsum_loss(log_p_attn, text_lengths, feats_lengths)

        mel_loss = mel_loss * self.lambda_mel
        adv_loss = adv_loss * self.lambda_adv
        feat_match_loss = feat_match_loss * self.lambda_feat_match
        g_loss = mel_loss + adv_loss + feat_match_loss
        var_loss = (dur_loss + pitch_loss + energy_loss) * self.lambda_var
        align_loss = (forwardsum_loss + bin_loss) * self.lambda_align

        loss = g_loss + var_loss + align_loss

        stats = dict(
            generator_loss=loss.item(),
            generator_g_loss=g_loss.item(),
            generator_var_loss=var_loss.item(),
            generator_align_loss=align_loss.item(),
            generator_g_mel_loss=mel_loss.item(),
            generator_g_adv_loss=adv_loss.item(),
            generator_g_feat_match_loss=feat_match_loss.item(),
            generator_var_dur_loss=dur_loss.item(),
            generator_var_pitch_loss=pitch_loss.item(),
            generator_var_energy_loss=energy_loss.item(),
            generator_align_forwardsum_loss=forwardsum_loss.item(),
            generator_align_bin_loss=bin_loss.item(),
        )

        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)

        # reset cache
        if reuse_cache or not self.training:
            self._cache = None

        return {
            "loss": loss,
            "stats": stats,
            "weight": weight,
            "optim_idx": 0,  # needed for trainer
        }
Пример #23
0
    def forward(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        feats: torch.Tensor,
        feats_lengths: torch.Tensor,
        spembs: Optional[torch.Tensor] = None,
        sids: Optional[torch.Tensor] = None,
        lids: Optional[torch.Tensor] = None,
        joint_training: bool = False,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Calculate forward propagation.

        Args:
            text (LongTensor): Batch of padded character ids (B, Tmax).
            text_lengths (LongTensor): Batch of lengths of each input batch (B,).
            feats (Tensor): Batch of padded target features (B, Lmax, odim).
            feats_lengths (LongTensor): Batch of the lengths of each target (B,).
            spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim).
            sids (Optional[Tensor]): Batch of speaker IDs (B, 1).
            lids (Optional[Tensor]): Batch of language IDs (B, 1).
            joint_training (bool): Whether to perform joint training with vocoder.

        Returns:
            Tensor: Loss scalar value.
            Dict: Statistics to be monitored.
            Tensor: Weight value if not joint training else model outputs.

        """
        text = text[:, :text_lengths.max()]  # for data-parallel
        feats = feats[:, :feats_lengths.max()]  # for data-parallel
        batch_size = text.size(0)

        # Add eos at the last of sequence
        xs = F.pad(text, [0, 1], "constant", self.padding_idx)
        for i, l in enumerate(text_lengths):
            xs[i, l] = self.eos
        ilens = text_lengths + 1

        ys = feats
        olens = feats_lengths

        # make labels for stop prediction
        labels = make_pad_mask(olens - 1).to(ys.device, ys.dtype)
        labels = F.pad(labels, [0, 1], "constant", 1.0)

        # calculate transformer outputs
        after_outs, before_outs, logits = self._forward(
            xs=xs,
            ilens=ilens,
            ys=ys,
            olens=olens,
            spembs=spembs,
            sids=sids,
            lids=lids,
        )

        # modifiy mod part of groundtruth
        olens_in = olens
        if self.reduction_factor > 1:
            assert olens.ge(self.reduction_factor).all(
            ), "Output length must be greater than or equal to reduction factor."
            olens_in = olens.new(
                [olen // self.reduction_factor for olen in olens])
            olens = olens.new(
                [olen - olen % self.reduction_factor for olen in olens])
            max_olen = max(olens)
            ys = ys[:, :max_olen]
            labels = labels[:, :max_olen]
            labels = torch.scatter(labels, 1, (olens - 1).unsqueeze(1),
                                   1.0)  # see #3388

        # calculate loss values
        l1_loss, l2_loss, bce_loss = self.criterion(after_outs, before_outs,
                                                    logits, ys, labels, olens)
        if self.loss_type == "L1":
            loss = l1_loss + bce_loss
        elif self.loss_type == "L2":
            loss = l2_loss + bce_loss
        elif self.loss_type == "L1+L2":
            loss = l1_loss + l2_loss + bce_loss
        else:
            raise ValueError("unknown --loss-type " + self.loss_type)

        stats = dict(
            l1_loss=l1_loss.item(),
            l2_loss=l2_loss.item(),
            bce_loss=bce_loss.item(),
        )

        # calculate guided attention loss
        if self.use_guided_attn_loss:
            # calculate for encoder
            if "encoder" in self.modules_applied_guided_attn:
                att_ws = []
                for idx, layer_idx in enumerate(
                        reversed(range(len(self.encoder.encoders)))):
                    att_ws += [
                        self.encoder.encoders[layer_idx].self_attn.
                        attn[:, :self.num_heads_applied_guided_attn]
                    ]
                    if idx + 1 == self.num_layers_applied_guided_attn:
                        break
                att_ws = torch.cat(att_ws, dim=1)  # (B, H*L, T_text, T_text)
                enc_attn_loss = self.attn_criterion(att_ws, ilens, ilens)
                loss = loss + enc_attn_loss
                stats.update(enc_attn_loss=enc_attn_loss.item())
            # calculate for decoder
            if "decoder" in self.modules_applied_guided_attn:
                att_ws = []
                for idx, layer_idx in enumerate(
                        reversed(range(len(self.decoder.decoders)))):
                    att_ws += [
                        self.decoder.decoders[layer_idx].self_attn.
                        attn[:, :self.num_heads_applied_guided_attn]
                    ]
                    if idx + 1 == self.num_layers_applied_guided_attn:
                        break
                att_ws = torch.cat(att_ws, dim=1)  # (B, H*L, T_feats, T_feats)
                dec_attn_loss = self.attn_criterion(att_ws, olens_in, olens_in)
                loss = loss + dec_attn_loss
                stats.update(dec_attn_loss=dec_attn_loss.item())
            # calculate for encoder-decoder
            if "encoder-decoder" in self.modules_applied_guided_attn:
                att_ws = []
                for idx, layer_idx in enumerate(
                        reversed(range(len(self.decoder.decoders)))):
                    att_ws += [
                        self.decoder.decoders[layer_idx].src_attn.
                        attn[:, :self.num_heads_applied_guided_attn]
                    ]
                    if idx + 1 == self.num_layers_applied_guided_attn:
                        break
                att_ws = torch.cat(att_ws, dim=1)  # (B, H*L, T_feats, T_text)
                enc_dec_attn_loss = self.attn_criterion(
                    att_ws, ilens, olens_in)
                loss = loss + enc_dec_attn_loss
                stats.update(enc_dec_attn_loss=enc_dec_attn_loss.item())

        # report extra information
        if self.use_scaled_pos_enc:
            stats.update(
                encoder_alpha=self.encoder.embed[-1].alpha.data.item(),
                decoder_alpha=self.decoder.embed[-1].alpha.data.item(),
            )

        if not joint_training:
            stats.update(loss=loss.item())
            loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                                   loss.device)
            return loss, stats, weight
        else:
            return loss, stats, after_outs
Пример #24
0
    def forward(
        self,
        speech_mix: torch.Tensor,
        speech_mix_lengths: torch.Tensor,
        speech_ref1: torch.Tensor,
        speech_ref2: torch.Tensor,
        text_ref1: torch.Tensor,
        text_ref2: torch.Tensor,
        text_ref1_lengths: torch.Tensor,
        text_ref2_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Enhancement + Frontend + Encoder + Decoder + Calc loss

        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
            text: (Batch, Length)
            text_lengths: (Batch,)
        """
        assert text_ref1_lengths.dim() == text_ref2_lengths.dim() == 1, (
            text_ref1_lengths.shape,
            text_ref2_lengths.shape,
        )
        # Check that batch_size is unified
        assert (speech_mix.shape[0] == speech_mix_lengths.shape[0] ==
                text_ref1.shape[0] == text_ref1_lengths.shape[0] ==
                text_ref2.shape[0] == text_ref2_lengths.shape[0]), (
                    speech_mix.shape,
                    speech_mix_lengths.shape,
                    text_ref1.shape,
                    text_ref1_lengths.shape,
                )
        batch_size = speech_mix.shape[0]

        # for data-parallel
        text_length_max = max(text_ref1_lengths.max(), text_ref2_lengths.max())
        text_ref1 = torch.cat(
            [
                text_ref1,
                torch.ones(batch_size, text_length_max,
                           dtype=text_ref1.dtype).to(text_ref1.device) *
                self.idx_blank,
            ],
            dim=1,
        )
        text_ref2 = torch.cat(
            [
                text_ref2,
                torch.ones(batch_size, text_length_max,
                           dtype=text_ref1.dtype).to(text_ref1.device) *
                self.idx_blank,
            ],
            dim=1,
        )
        text_ref1 = text_ref1[:, :text_length_max]
        text_ref2 = text_ref2[:, :text_length_max]

        # 0. Enhancement
        # make sure the speech_pre is the raw waveform with same size.
        loss_enh, perm, speech_pre = self.forward_enh(
            speech_mix,
            speech_mix_lengths,
            speech_ref1=speech_ref1,
            speech_ref2=speech_ref2,
        )
        # speech_pre: (bs,num_spk,T)
        assert speech_pre[:, 0].shape == speech_mix.shape

        # Pack the separated speakers into the ASR part.
        speech_pre_all = speech_pre.view(
            -1, speech_mix.shape[-1])  # (bs*num_spk, T)
        speech_pre_lengths = torch.stack(
            [speech_mix_lengths, speech_mix_lengths], dim=1).view(-1)
        text_ref_all = torch.stack([text_ref1, text_ref2],
                                   dim=1).view(batch_size * 2, -1)
        text_ref_lengths = torch.stack([text_ref1_lengths, text_ref2_lengths],
                                       dim=1).view(-1)

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech_pre_all,
                                                    speech_pre_lengths)

        # 2a. Attention-decoder branch
        if self.ctc_weight == 1.0:
            loss_att, acc_att, cer_att, wer_att = None, None, None, None
        else:
            loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
                encoder_out, encoder_out_lens, text_ref_all, text_ref_lengths)

        # 2b. CTC branch
        if self.ctc_weight == 0.0:
            loss_ctc, cer_ctc = None, None
        else:
            loss_ctc, cer_ctc = self._calc_ctc_loss(encoder_out,
                                                    encoder_out_lens,
                                                    text_ref_all,
                                                    text_ref_lengths)

        # 2c. RNN-T branch
        if self.rnnt_decoder is not None:
            _ = self._calc_rnnt_loss(encoder_out, encoder_out_lens,
                                     text_ref_all, text_ref_lengths)

        if self.ctc_weight == 0.0:
            loss_asr = loss_att
        elif self.ctc_weight == 1.0:
            loss_asr = loss_ctc
        else:
            loss_asr = self.ctc_weight * loss_ctc + (
                1 - self.ctc_weight) * loss_att

        if self.enh_weight == 0.0:
            loss_enh = None
            loss = loss_asr
        else:
            loss = (1 -
                    self.enh_weight) * loss_asr + self.enh_weight * loss_enh

        stats = dict(
            loss=loss.detach(),
            loss_att=loss_att.detach() if loss_att is not None else None,
            loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
            loss_enh=loss_enh.detach() if loss_enh is not None else None,
            acc=acc_att,
            cer=cer_att,
            wer=wer_att,
            cer_ctc=cer_ctc,
        )

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                               loss.device)
        return loss, stats, weight
Пример #25
0
    def forward(
        self,
        speech_mix: torch.Tensor,
        speech_mix_lengths: torch.Tensor = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
            speech_mix: (Batch, samples) or (Batch, samples, channels)
            speech_ref: (Batch, num_speaker, samples)
                        or (Batch, num_speaker, samples, channels)
            speech_mix_lengths: (Batch,), default None for chunk interator,
                            because the chunk-iterator does not have the
                            speech_lengths returned. see in
                            espnet2/iterators/chunk_iter_factory.py
        """
        # clean speech signal of each speaker
        speech_ref = [
            kwargs["speech_ref{}".format(spk + 1)]
            for spk in range(self.num_spk)
        ]
        # (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels)
        speech_ref = torch.stack(speech_ref, dim=1)

        if "noise_ref1" in kwargs:
            # noise signal (optional, required when using
            # frontend models with beamformering)
            noise_ref = [
                kwargs["noise_ref{}".format(n + 1)]
                for n in range(self.num_noise_type)
            ]
            # (Batch, num_noise_type, samples) or
            # (Batch, num_noise_type, samples, channels)
            noise_ref = torch.stack(noise_ref, dim=1)
        else:
            noise_ref = None

        # dereverberated noisy signal
        # (optional, only used for frontend models with WPE)
        dereverb_speech_ref = kwargs.get("dereverb_ref", None)

        batch_size = speech_mix.shape[0]
        speech_lengths = (speech_mix_lengths if speech_mix_lengths is not None
                          else torch.ones(batch_size).int() *
                          speech_mix.shape[1])
        assert speech_lengths.dim() == 1, speech_lengths.shape
        # Check that batch_size is unified
        assert speech_mix.shape[0] == speech_ref.shape[
            0] == speech_lengths.shape[0], (
                speech_mix.shape,
                speech_ref.shape,
                speech_lengths.shape,
            )
        batch_size = speech_mix.shape[0]

        # for data-parallel
        speech_ref = speech_ref[:, :, :speech_lengths.max()]
        speech_mix = speech_mix[:, :speech_lengths.max()]

        if self.loss_type != "si_snr":
            # prepare reference speech and reference spectrum
            speech_ref = torch.unbind(speech_ref, dim=1)
            spectrum_ref = [self.enh_model.stft(sr)[0] for sr in speech_ref]

            # List[ComplexTensor(Batch, T, F)] or List[ComplexTensor(Batch, T, C, F)]
            spectrum_ref = [
                ComplexTensor(sr[..., 0], sr[..., 1]) for sr in spectrum_ref
            ]
            spectrum_mix = self.enh_model.stft(speech_mix)[0]
            spectrum_mix = ComplexTensor(spectrum_mix[..., 0],
                                         spectrum_mix[..., 1])

            # predict separated speech and masks
            spectrum_pre, tf_length, mask_pre = self.enh_model(
                speech_mix, speech_lengths)

            # compute TF masking loss
            if self.loss_type == "magnitude":
                # compute loss on magnitude spectrum
                magnitude_pre = [abs(ps) for ps in spectrum_pre]
                magnitude_ref = [abs(sr) for sr in spectrum_ref]
                tf_loss, perm = self._permutation_loss(magnitude_ref,
                                                       magnitude_pre,
                                                       self.tf_mse_loss)
            elif self.loss_type == "spectrum":
                # compute loss on complex spectrum
                tf_loss, perm = self._permutation_loss(spectrum_ref,
                                                       spectrum_pre,
                                                       self.tf_mse_loss)
            elif self.loss_type.startswith("mask"):
                if self.loss_type == "mask_mse":
                    loss_func = self.tf_mse_loss
                else:
                    raise ValueError("Unsupported loss type: %s" %
                                     self.loss_type)

                assert mask_pre is not None
                mask_pre_ = [
                    mask_pre["spk{}".format(spk + 1)]
                    for spk in range(self.num_spk)
                ]

                # prepare ideal masks
                mask_ref = self._create_mask_label(spectrum_mix,
                                                   spectrum_ref,
                                                   mask_type=self.mask_type)

                # compute TF masking loss
                tf_loss, perm = self._permutation_loss(mask_ref, mask_pre_,
                                                       loss_func)

                if "dereverb" in mask_pre:
                    if dereverb_speech_ref is None:
                        raise ValueError(
                            "No dereverberated reference for training!\n"
                            'Please specify "--use_dereverb_ref true" in run.sh'
                        )

                    dereverb_spectrum_ref = self.enh_model.stft(
                        dereverb_speech_ref)[0]
                    dereverb_spectrum_ref = ComplexTensor(
                        dereverb_spectrum_ref[..., 0],
                        dereverb_spectrum_ref[..., 1])
                    # ComplexTensor(B, T, F) or ComplexTensor(B, T, C, F)
                    dereverb_mask_ref = self._create_mask_label(
                        spectrum_mix, [dereverb_spectrum_ref],
                        mask_type=self.mask_type)[0]

                    tf_loss = (tf_loss + loss_func(
                        dereverb_mask_ref, mask_pre["dereverb"]).mean())

                if "noise1" in mask_pre:
                    if noise_ref is None:
                        raise ValueError(
                            "No noise reference for training!\n"
                            'Please specify "--use_noise_ref true" in run.sh')

                    noise_ref = torch.unbind(noise_ref, dim=1)
                    noise_spectrum_ref = [
                        self.enh_model.stft(nr)[0] for nr in noise_ref
                    ]
                    noise_spectrum_ref = [
                        ComplexTensor(nr[..., 0], nr[..., 1])
                        for nr in noise_spectrum_ref
                    ]
                    noise_mask_ref = self._create_mask_label(
                        spectrum_mix,
                        noise_spectrum_ref,
                        mask_type=self.mask_type)

                    mask_noise_pre = [
                        mask_pre["noise{}".format(n + 1)]
                        for n in range(self.num_noise_type)
                    ]
                    tf_noise_loss, perm_n = self._permutation_loss(
                        noise_mask_ref, mask_noise_pre, loss_func)
                    tf_loss = tf_loss + tf_noise_loss
            else:
                raise ValueError("Unsupported loss type: %s" % self.loss_type)

            if self.training:
                si_snr = None
            else:
                speech_pre = [
                    self.enh_model.stft.inverse(ps, speech_lengths)[0]
                    for ps in spectrum_pre
                ]
                if speech_ref[0].dim() == 3:
                    # For si_snr loss, only select one channel as the reference
                    speech_ref = [
                        sr[..., self.ref_channel] for sr in speech_ref
                    ]
                # compute si-snr loss
                si_snr_loss, perm = self._permutation_loss(speech_ref,
                                                           speech_pre,
                                                           self.si_snr_loss,
                                                           perm=perm)
                si_snr = -si_snr_loss.detach()

            loss = tf_loss

            stats = dict(
                si_snr=si_snr,
                loss=loss.detach(),
            )
        else:
            if speech_ref.dim() == 4:
                # For si_snr loss of multi-channel input,
                # only select one channel as the reference
                speech_ref = speech_ref[..., self.ref_channel]

            speech_pre, speech_lengths, *__ = self.enh_model.forward_rawwav(
                speech_mix, speech_lengths)
            # speech_pre: list[(batch, sample)]
            assert speech_pre[0].dim() == 2, speech_pre[0].dim()
            speech_ref = torch.unbind(speech_ref, dim=1)

            # compute si-snr loss
            si_snr_loss, perm = self._permutation_loss(
                speech_ref, speech_pre, self.si_snr_loss_zeromean)
            si_snr = -si_snr_loss
            loss = si_snr_loss
            stats = dict(si_snr=si_snr.detach(), loss=loss.detach())

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                               loss.device)
        return loss, stats, weight
Пример #26
0
    def _forward_generator(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        feats: torch.Tensor,
        feats_lengths: torch.Tensor,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        sids: Optional[torch.Tensor] = None,
        spembs: Optional[torch.Tensor] = None,
        lids: Optional[torch.Tensor] = None,
    ) -> Dict[str, Any]:
        """Perform generator forward.

        Args:
            text (Tensor): Text index tensor (B, T_text).
            text_lengths (Tensor): Text length tensor (B,).
            feats (Tensor): Feature tensor (B, T_feats, aux_channels).
            feats_lengths (Tensor): Feature length tensor (B,).
            speech (Tensor): Speech waveform tensor (B, T_wav).
            speech_lengths (Tensor): Speech length tensor (B,).
            sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
            spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
            lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).

        Returns:
            Dict[str, Any]:
                * loss (Tensor): Loss scalar tensor.
                * stats (Dict[str, float]): Statistics to be monitored.
                * weight (Tensor): Weight tensor to summarize losses.
                * optim_idx (int): Optimizer index (0 for G and 1 for D).

        """
        # setup
        batch_size = text.size(0)
        feats = feats.transpose(1, 2)
        speech = speech.unsqueeze(1)

        # calculate generator outputs
        reuse_cache = True
        if not self.cache_generator_outputs or self._cache is None:
            reuse_cache = False
            outs = self.generator(
                text=text,
                text_lengths=text_lengths,
                feats=feats,
                feats_lengths=feats_lengths,
                sids=sids,
                spembs=spembs,
                lids=lids,
            )
        else:
            outs = self._cache

        # store cache
        if self.training and self.cache_generator_outputs and not reuse_cache:
            self._cache = outs

        # parse outputs
        speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
        _, z_p, m_p, logs_p, _, logs_q = outs_
        speech_ = get_segments(
            x=speech,
            start_idxs=start_idxs * self.generator.upsample_factor,
            segment_size=self.generator.segment_size * self.generator.upsample_factor,
        )

        # calculate discriminator outputs
        p_hat = self.discriminator(speech_hat_)
        with torch.no_grad():
            # do not store discriminator gradient in generator turn
            p = self.discriminator(speech_)

        # calculate losses
        with autocast(enabled=False):
            mel_loss = self.mel_loss(speech_hat_, speech_)
            kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
            dur_loss = torch.sum(dur_nll.float())
            adv_loss = self.generator_adv_loss(p_hat)
            feat_match_loss = self.feat_match_loss(p_hat, p)

            mel_loss = mel_loss * self.lambda_mel
            kl_loss = kl_loss * self.lambda_kl
            dur_loss = dur_loss * self.lambda_dur
            adv_loss = adv_loss * self.lambda_adv
            feat_match_loss = feat_match_loss * self.lambda_feat_match
            loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss

        stats = dict(
            generator_loss=loss.item(),
            generator_mel_loss=mel_loss.item(),
            generator_kl_loss=kl_loss.item(),
            generator_dur_loss=dur_loss.item(),
            generator_adv_loss=adv_loss.item(),
            generator_feat_match_loss=feat_match_loss.item(),
        )

        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)

        # reset cache
        if reuse_cache or not self.training:
            self._cache = None

        return {
            "loss": loss,
            "stats": stats,
            "weight": weight,
            "optim_idx": 0,  # needed for trainer
        }
Пример #27
0
    def _forward_discrminator(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        feats: torch.Tensor,
        feats_lengths: torch.Tensor,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Dict[str, Any]:
        """Perform discriminator forward.

        Args:
            text (Tensor): Text index tensor (B, T_text).
            text_lengths (Tensor): Text length tensor (B,).
            feats (Tensor): Feature tensor (B, T_feats, aux_channels).
            feats_lengths (Tensor): Feature length tensor (B,).
            speech (Tensor): Speech waveform tensor (B, T_wav).
            speech_lengths (Tensor): Speech length tensor (B,).

        Returns:
            Dict[str, Any]:
                * loss (Tensor): Loss scalar tensor.
                * stats (Dict[str, float]): Statistics to be monitored.
                * weight (Tensor): Weight tensor to summarize losses.
                * optim_idx (int): Optimizer index (0 for G and 1 for D).

        """
        # setup
        batch_size = text.size(0)
        speech = speech.unsqueeze(1)

        # calculate generator outputs
        reuse_cache = True
        if not self.cache_generator_outputs or self._cache is None:
            reuse_cache = False
            # calculate text2mel outputs
            text2mel_loss, stats, feats_gen = self.generator["text2mel"](
                text=text,
                text_lengths=text_lengths,
                feats=feats,
                feats_lengths=feats_lengths,
                joint_training=True,
                **kwargs,
            )
            # get random segments
            feats_gen_, start_idxs = get_random_segments(
                x=feats_gen.transpose(1, 2),
                x_lengths=feats_lengths,
                segment_size=self.segment_size,
            )
            # calculate vocoder outputs
            speech_hat_ = self.generator["vocoder"](feats_gen_)
            if self.use_pqmf:
                speech_hat_ = self.pqmf.synthesis(speech_hat_)
        else:
            _, _, speech_hat_, start_idxs = self._cache

        # store cache
        if self.cache_generator_outputs and not reuse_cache:
            self._cache = (text2mel_loss, stats, speech_hat_, start_idxs)

        # parse outputs
        speech_ = get_segments(
            x=speech,
            start_idxs=start_idxs * self.generator["vocoder"].upsample_factor,
            segment_size=self.segment_size *
            self.generator["vocoder"].upsample_factor,
        )

        # calculate discriminator outputs
        p_hat = self.discriminator(speech_hat_.detach())
        p = self.discriminator(speech_)

        # calculate losses
        real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
        loss = real_loss + fake_loss

        stats = dict(
            discriminator_loss=loss.item(),
            real_loss=real_loss.item(),
            fake_loss=fake_loss.item(),
        )
        loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                               loss.device)

        # reset cache
        if reuse_cache or not self.training:
            self._cache = None

        return {
            "loss": loss,
            "stats": stats,
            "weight": weight,
            "optim_idx": 1,  # needed for trainer
        }
Пример #28
0
    def _forward_discrminator(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        feats: torch.Tensor,
        feats_lengths: torch.Tensor,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        sids: Optional[torch.Tensor] = None,
        spembs: Optional[torch.Tensor] = None,
        lids: Optional[torch.Tensor] = None,
    ) -> Dict[str, Any]:
        """Perform discriminator forward.

        Args:
            text (Tensor): Text index tensor (B, T_text).
            text_lengths (Tensor): Text length tensor (B,).
            feats (Tensor): Feature tensor (B, T_feats, aux_channels).
            feats_lengths (Tensor): Feature length tensor (B,).
            speech (Tensor): Speech waveform tensor (B, T_wav).
            speech_lengths (Tensor): Speech length tensor (B,).
            sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
            spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
            lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).

        Returns:
            Dict[str, Any]:
                * loss (Tensor): Loss scalar tensor.
                * stats (Dict[str, float]): Statistics to be monitored.
                * weight (Tensor): Weight tensor to summarize losses.
                * optim_idx (int): Optimizer index (0 for G and 1 for D).

        """
        # setup
        batch_size = text.size(0)
        feats = feats.transpose(1, 2)
        speech = speech.unsqueeze(1)

        # calculate generator outputs
        reuse_cache = True
        if not self.cache_generator_outputs or self._cache is None:
            reuse_cache = False
            outs = self.generator(
                text=text,
                text_lengths=text_lengths,
                feats=feats,
                feats_lengths=feats_lengths,
                sids=sids,
                spembs=spembs,
                lids=lids,
            )
        else:
            outs = self._cache

        # store cache
        if self.cache_generator_outputs and not reuse_cache:
            self._cache = outs

        # parse outputs
        speech_hat_, _, _, start_idxs, *_ = outs
        speech_ = get_segments(
            x=speech,
            start_idxs=start_idxs * self.generator.upsample_factor,
            segment_size=self.generator.segment_size * self.generator.upsample_factor,
        )

        # calculate discriminator outputs
        p_hat = self.discriminator(speech_hat_.detach())
        p = self.discriminator(speech_)

        # calculate losses
        with autocast(enabled=False):
            real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
            loss = real_loss + fake_loss

        stats = dict(
            discriminator_loss=loss.item(),
            discriminator_real_loss=real_loss.item(),
            discriminator_fake_loss=fake_loss.item(),
        )
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)

        # reset cache
        if reuse_cache or not self.training:
            self._cache = None

        return {
            "loss": loss,
            "stats": stats,
            "weight": weight,
            "optim_idx": 1,  # needed for trainer
        }
Пример #29
0
    def forward(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        spembs: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Calculate forward propagation.

        Args:
            text (LongTensor): Batch of padded character ids (B, Tmax).
            text_lengths (LongTensor): Batch of lengths of each input batch (B,).
            speech (Tensor): Batch of padded target features (B, Lmax, odim).
            speech_lengths (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim).

        Returns:
            Tensor: Loss scalar value.
            Dict: Statistics to be monitored.
            Tensor: Weight value.

        """
        text = text[:, :text_lengths.max()]  # for data-parallel
        speech = speech[:, :speech_lengths.max()]  # for data-parallel

        batch_size = text.size(0)

        # Add eos at the last of sequence
        xs = F.pad(text, [0, 1], "constant", self.padding_idx)
        for i, l in enumerate(text_lengths):
            xs[i, l] = self.eos
        ilens = text_lengths + 1

        ys = speech
        olens = speech_lengths

        # make labels for stop prediction
        labels = make_pad_mask(olens - 1).to(ys.device, ys.dtype)
        labels = F.pad(labels, [0, 1], "constant", 1.0)

        # calculate tacotron2 outputs
        after_outs, before_outs, logits, att_ws = self._forward(
            xs, ilens, ys, olens, spembs)

        # modify mod part of groundtruth
        if self.reduction_factor > 1:
            olens = olens.new(
                [olen - olen % self.reduction_factor for olen in olens])
            max_out = max(olens)
            ys = ys[:, :max_out]
            labels = labels[:, :max_out]
            labels[:, -1] = 1.0  # make sure at least one frame has 1

        # calculate taco2 loss
        l1_loss, mse_loss, bce_loss = self.taco2_loss(after_outs, before_outs,
                                                      logits, ys, labels,
                                                      olens)
        if self.loss_type == "L1+L2":
            loss = l1_loss + mse_loss + bce_loss
        elif self.loss_type == "L1":
            loss = l1_loss + bce_loss
        elif self.loss_type == "L2":
            loss = mse_loss + bce_loss
        else:
            raise ValueError(f"unknown --loss-type {self.loss_type}")

        stats = dict(
            l1_loss=l1_loss.item(),
            mse_loss=mse_loss.item(),
            bce_loss=bce_loss.item(),
        )

        # calculate attention loss
        if self.use_guided_attn_loss:
            # NOTE(kan-bayashi): length of output for auto-regressive
            # input will be changed when r > 1
            if self.reduction_factor > 1:
                olens_in = olens.new(
                    [olen // self.reduction_factor for olen in olens])
            else:
                olens_in = olens
            attn_loss = self.attn_loss(att_ws, ilens, olens_in)
            loss = loss + attn_loss
            stats.update(attn_loss=attn_loss.item())

        stats.update(loss=loss.item())

        loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                               loss.device)
        return loss, stats, weight
Пример #30
0
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
            text: (Batch, Length)
            text_lengths: (Batch,)
            kwargs: "utt_id" is among the input.
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
            speech.shape[0]
            == speech_lengths.shape[0]
            == text.shape[0]
            == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
        batch_size = speech.shape[0]

        # for data-parallel
        text = text[:, : text_lengths.max()]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]

        loss_att, acc_att, cer_att, wer_att = None, None, None, None
        loss_ctc, cer_ctc = None, None
        loss_transducer, cer_transducer, wer_transducer = None, None, None
        stats = dict()

        # 1. CTC branch
        if self.ctc_weight != 0.0:
            loss_ctc, cer_ctc = self._calc_ctc_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )

            # Collect CTC branch stats
            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
            stats["cer_ctc"] = cer_ctc

        # Intermediate CTC (optional)
        loss_interctc = 0.0
        if self.interctc_weight != 0.0 and intermediate_outs is not None:
            for layer_idx, intermediate_out in intermediate_outs:
                # we assume intermediate_out has the same length & padding
                # as those of encoder_out
                loss_ic, cer_ic = self._calc_ctc_loss(
                    intermediate_out, encoder_out_lens, text, text_lengths
                )
                loss_interctc = loss_interctc + loss_ic

                # Collect Intermedaite CTC stats
                stats["loss_interctc_layer{}".format(layer_idx)] = (
                    loss_ic.detach() if loss_ic is not None else None
                )
                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic

            loss_interctc = loss_interctc / len(intermediate_outs)

            # calculate whole encoder loss
            loss_ctc = (
                1 - self.interctc_weight
            ) * loss_ctc + self.interctc_weight * loss_interctc

        if self.use_transducer_decoder:
            # 2a. Transducer decoder branch
            (
                loss_transducer,
                cer_transducer,
                wer_transducer,
            ) = self._calc_transducer_loss(
                encoder_out,
                encoder_out_lens,
                text,
            )

            if loss_ctc is not None:
                loss = loss_transducer + (self.ctc_weight * loss_ctc)
            else:
                loss = loss_transducer

            # Collect Transducer branch stats
            stats["loss_transducer"] = (
                loss_transducer.detach() if loss_transducer is not None else None
            )
            stats["cer_transducer"] = cer_transducer
            stats["wer_transducer"] = wer_transducer

        else:
            # 2b. Attention decoder branch
            if self.ctc_weight != 1.0:
                loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
                    encoder_out, encoder_out_lens, text, text_lengths
                )

            # 3. CTC-Att loss definition
            if self.ctc_weight == 0.0:
                loss = loss_att
            elif self.ctc_weight == 1.0:
                loss = loss_ctc
            else:
                loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att

            # Collect Attn branch stats
            stats["loss_att"] = loss_att.detach() if loss_att is not None else None
            stats["acc"] = acc_att
            stats["cer"] = cer_att
            stats["wer"] = wer_att

        # Collect total loss stats
        stats["loss"] = loss.detach()

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight