def forward( self, speech: torch.Tensor, speech_original: torch.Tensor, speech_lengths: torch.Tensor, speech_original_lengths: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) """ batch_size = speech.shape[0] # 1. Encoder encoder_out, encoder_out_lens, feats_original, dropout_mask = self.encode( speech, speech_original, speech_lengths) loss = self._calc_predictive_loss(feats_original, encoder_out, dropout_mask) stats = dict(loss=loss.detach(), ) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def forward(self, x, x_lengths): x = self.layer1(x) x = self.layer2(x) retval = { "loss": x.mean(), "stats": {"loss": x.mean()}, "weight": len(x), "optim_idx": torch.randint(0, 2, [1]), } return force_gatherable(retval, device=x.device)
def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, durations: torch.Tensor, durations_lengths: torch.Tensor, pitch: torch.Tensor, pitch_lengths: torch.Tensor, energy: torch.Tensor, energy_lengths: torch.Tensor, spembs: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: text = text[:, : text_lengths.max()] # for data-parallel speech = speech[:, : speech_lengths.max()] # for data-parallel durations = durations[:, : durations_lengths.max()] # for data-parallel pitch = pitch[:, : pitch_lengths.max()] # for data-parallel energy = energy[:, : energy_lengths.max()] # for data-parallel batch_size = text.size(0) # Add eos at the last of sequence xs = F.pad(text, [0, 1], "constant", self.padding_idx) for i, l in enumerate(text_lengths): xs[i, l] = self.eos ilens = text_lengths + 1 ys, ds, ps, es = speech, durations, pitch, energy olens = speech_lengths before_outs, after_outs, d_outs, p_outs, e_outs = self.fastspeech2._forward( xs, ilens, ys, olens, ds, ps, es, spembs=spembs, is_inference=False ) ys = speech.transpose(1, 2) y_masks = self._source_mask(olens) mu = after_outs.transpose(1, 2) if ys.size(2) % 4 != 0: ys = torch.cat([ys, torch.zeros([batch_size, self.odim, 4 - ys.size(2) % 4], dtype=ys.dtype, device=ys.device)], dim=2) mu = torch.cat([mu, torch.zeros([mu.size(0), self.odim, 4 - mu.size(2) % 4], dtype=mu.dtype, device=mu.device)], dim=2) y_masks = torch.cat([y_masks, torch.zeros([y_masks.size(0), 1, 4 - y_masks.size(2) % 4], dtype=y_masks.dtype, device=y_masks.device)], dim=2) noise_estimation, z = self.diffusion(ys, y_masks, mu) diff_loss = self.criterion(noise_estimation, z, y_masks) loss = diff_loss stats = dict( diff_loss=diff_loss.item(), loss=loss.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def forward( self, text: torch.Tensor, text_lengths: torch.Tensor ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: nll, y_lengths = self.nll(text, text_lengths) ntokens = y_lengths.sum() loss = nll.sum() / ntokens stats = dict(loss=loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device) return loss, stats, weight
def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, src_text: torch.Tensor, src_text_lengths: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: text: (Batch, Length) text_lengths: (Batch,) src_text: (Batch, length) src_text_lengths: (Batch,) kwargs: "utt_id" is among the input. """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert ( text.shape[0] == text_lengths.shape[0] == src_text.shape[0] == src_text_lengths.shape[0] ), (text.shape, text_lengths.shape, src_text.shape, src_text_lengths.shape) batch_size = src_text.shape[0] # for data-parallel text = text[:, : text_lengths.max()] src_text = src_text[:, : src_text_lengths.max()] # 1. Encoder encoder_out, encoder_out_lens = self.encode(src_text, src_text_lengths) # 2a. Attention-decoder branch (MT) loss_mt_att, acc_mt_att, bleu_mt_att = self._calc_mt_att_loss( encoder_out, encoder_out_lens, text, text_lengths ) # 3. Loss computation loss = loss_mt_att stats = dict( loss=loss.detach(), acc=acc_mt_att, bleu=bleu_mt_att, ) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def forward_ilm( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) not nessesary it is only used to get device of tensor speech_lengths: (Batch, ) not nessesary it is only used to get device of tensor text: (Batch, Length) text_lengths: (Batch,) """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert (text.shape[0] == text_lengths.shape[0]), (text.shape, text_lengths.shape) batch_size = text.shape[0] # for data-parallel text = text[:, :text_lengths.max()] ys_in_pad, ys_out_pad = add_sos_eos(text, self.sos, self.eos, self.ignore_id) ys_in_lens = text_lengths + 1 fake_encoder_out = speech.new_zeros(batch_size, 1, self.encoder._output_size) # 1. Forward decoder decoder_out, _ = self.decoder.forward_ilm(fake_encoder_out, -1, ys_in_pad, ys_in_lens) # 2. Compute ilm loss loss_ilm = self.criterion_att(decoder_out, ys_out_pad) ilm_acc = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) ilm_ppl = torch.exp(loss_ilm) stats = dict(ilm_loss=loss_ilm.detach(), ilm_acc=ilm_acc, ilm_ppl=ilm_ppl.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss_ilm, stats, batch_size), loss_ilm.device) return loss_ilm, stats, weight
def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) kwargs: "utt_id" is among the input. """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) batch_size = speech.shape[0] # for data-parallel text = text[:, :text_lengths.max()] # 1. Encoder encoder_out = self.encode(speech, speech_lengths, text, text_lengths) # 2a. Hubert criterion loss, acc_mask, acc_unmask = self._calc_hubert_loss(encoder_out, ) stats = dict( loss=loss.detach(), acc_mask=acc_mask, acc_unmask=acc_unmask, acc=acc_mask, ) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def test_force_gatherable_cuda(): obj = {"a": [torch.tensor([0, 1])]} obj2 = force_gatherable(obj, "cuda") assert obj2["a"][0].device == torch.device("cuda:0")
def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) batch_size = speech.shape[0] # for data-parallel text = text[:, :text_lengths.max()] # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) # 2a. Attention-decoder branch if self.ctc_weight == 1.0: loss_att, acc_att, cer_att, wer_att = None, None, None, None else: loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( encoder_out, encoder_out_lens, text, text_lengths) # 2b. CTC branch if self.ctc_weight == 0.0: loss_ctc, cer_ctc = None, None else: loss_ctc, cer_ctc = self._calc_ctc_loss(encoder_out, encoder_out_lens, text, text_lengths) # 2c. RNN-T branch if self.rnnt_decoder is not None: _ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text, text_lengths) if self.ctc_weight == 0.0: loss = loss_att elif self.ctc_weight == 1.0: loss = loss_ctc else: loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att stats = dict( loss=loss.detach(), loss_att=loss_att.detach() if loss_att is not None else None, loss_ctc=loss_ctc.detach() if loss_ctc is not None else None, acc=acc_att, cer=cer_att, wer=wer_att, cer_ctc=cer_ctc, ) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, src_text: Optional[torch.Tensor], src_text_lengths: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch,) text: (Batch, Length) text_lengths: (Batch,) src_text: (Batch, length) src_text_lengths: (Batch,) """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) # additional checks with valid src_text if src_text is not None: assert src_text_lengths.dim() == 1, src_text_lengths.shape assert text.shape[0] == src_text.shape[ 0] == src_text_lengths.shape[0], ( text.shape, src_text.shape, src_text_lengths.shape, ) batch_size = speech.shape[0] # for data-parallel text = text[:, :text_lengths.max()] if src_text is not None: src_text = src_text[:, :src_text_lengths.max()] # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) # 2a. Attention-decoder branch (ST) loss_st_att, acc_st_att, bleu_st_att = self._calc_mt_att_loss( encoder_out, encoder_out_lens, text, text_lengths, st=True) # 2b. CTC branch if self.asr_weight > 0: assert src_text is not None, "missing source text for asr sub-task of ST" if self.asr_weight > 0 and self.mtlalpha > 0: loss_asr_ctc, cer_asr_ctc = self._calc_ctc_loss( encoder_out, encoder_out_lens, src_text, src_text_lengths) else: loss_asr_ctc, cer_asr_ctc = 0, None # 2c. Attention-decoder branch (extra ASR) if self.asr_weight > 0 and self.mtlalpha < 1.0: ( loss_asr_att, acc_asr_att, cer_asr_att, wer_asr_att, ) = self._calc_asr_att_loss(encoder_out, encoder_out_lens, src_text, src_text_lengths) else: loss_asr_att, acc_asr_att, cer_asr_att, wer_asr_att = 0, None, None, None # 2d. Attention-decoder branch (extra MT) if self.mt_weight > 0: loss_mt_att, acc_mt_att = self._calc_mt_att_loss(encoder_out, encoder_out_lens, text, text_lengths, st=False) else: loss_mt_att, acc_mt_att = 0, None # 3. Loss computation asr_ctc_weight = self.mtlalpha loss_st = loss_st_att if asr_ctc_weight == 1.0: loss_asr = loss_asr_ctc elif asr_ctc_weight == 0.0: loss_asr = loss_asr_att else: loss_asr = (asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att) loss_mt = self.mt_weight * loss_mt_att loss = ((1 - self.asr_weight - self.mt_weight) * loss_st + self.asr_weight * loss_asr + self.mt_weight * loss_mt) stats = dict( loss=loss.detach(), loss_asr=loss_asr.detach() if type(loss_asr) is not float else loss_asr, loss_mt=loss_mt.detach() if type(loss_mt) is not float else loss_mt, loss_st=loss_st.detach(), acc_asr=acc_asr_att, acc_mt=acc_mt_att, acc=acc_st_att, cer_ctc=cer_asr_ctc, cer=cer_asr_att, wer=wer_asr_att, bleu=bleu_st_att, ) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor = None, spk_labels: torch.Tensor = None, spk_labels_lengths: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, samples) speech_lengths: (Batch,) default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py spk_labels: (Batch, ) """ assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape) batch_size = speech.shape[0] # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) if self.attractor is None: # 2a. Decoder (baiscally a predction layer after encoder_out) pred = self.decoder(encoder_out, encoder_out_lens) else: # 2b. Encoder Decoder Attractors # Shuffle the chronological order of encoder_out, then calculate attractor encoder_out_shuffled = encoder_out.clone() for i in range(len(encoder_out_lens)): encoder_out_shuffled[i, :encoder_out_lens[i], :] = encoder_out[ i, torch.randperm(encoder_out_lens[i]), :] attractor, att_prob = self.attractor( encoder_out_shuffled, encoder_out_lens, to_device( self, torch.zeros(encoder_out.size(0), spk_labels.size(2) + 1, encoder_out.size(2)), ), ) # Remove the final attractor which does not correspond to a speaker # Then multiply the attractors and encoder_out pred = torch.bmm(encoder_out, attractor[:, :-1, :].permute(0, 2, 1)) # 3. Aggregate time-domain labels spk_labels, spk_labels_lengths = self.label_aggregator( spk_labels, spk_labels_lengths) # If encoder uses conv* as input_layer (i.e., subsampling), # the sequence length of 'pred' might be slighly less than the # length of 'spk_labels'. Here we force them to be equal. length_diff_tolerance = 2 length_diff = spk_labels.shape[1] - pred.shape[1] if length_diff > 0 and length_diff <= length_diff_tolerance: spk_labels = spk_labels[:, 0:pred.shape[1], :] if self.attractor is None: loss_pit, loss_att = None, None loss, perm_idx, perm_list, label_perm = self.pit_loss( pred, spk_labels, encoder_out_lens) else: loss_pit, perm_idx, perm_list, label_perm = self.pit_loss( pred, spk_labels, encoder_out_lens) loss_att = self.attractor_loss(att_prob, spk_labels) loss = loss_pit + self.attractor_weight * loss_att ( correct, num_frames, speech_scored, speech_miss, speech_falarm, speaker_scored, speaker_miss, speaker_falarm, speaker_error, ) = self.calc_diarization_error(pred, label_perm, encoder_out_lens) if speech_scored > 0 and num_frames > 0: sad_mr, sad_fr, mi, fa, cf, acc, der = ( speech_miss / speech_scored, speech_falarm / speech_scored, speaker_miss / speaker_scored, speaker_falarm / speaker_scored, speaker_error / speaker_scored, correct / num_frames, (speaker_miss + speaker_falarm + speaker_error) / speaker_scored, ) else: sad_mr, sad_fr, mi, fa, cf, acc, der = 0, 0, 0, 0, 0, 0, 0 stats = dict( loss=loss.detach(), loss_att=loss_att.detach() if loss_att is not None else None, loss_pit=loss_pit.detach() if loss_pit is not None else None, sad_mr=sad_mr, sad_fr=sad_fr, mi=mi, fa=fa, cf=cf, acc=acc, der=der, ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, joint_training: bool = False, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded character ids (B, T_text). text_lengths (LongTensor): Batch of lengths of each input batch (B,). feats (Tensor): Batch of padded target features (B, T_feats, odim). feats_lengths (LongTensor): Batch of the lengths of each target (B,). spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim). sids (Optional[Tensor]): Batch of speaker IDs (B, 1). lids (Optional[Tensor]): Batch of language IDs (B, 1). joint_training (bool): Whether to perform joint training with vocoder. Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value if not joint training else model outputs. """ text = text[:, : text_lengths.max()] # for data-parallel feats = feats[:, : feats_lengths.max()] # for data-parallel batch_size = text.size(0) # Add eos at the last of sequence xs = F.pad(text, [0, 1], "constant", self.padding_idx) for i, l in enumerate(text_lengths): xs[i, l] = self.eos ilens = text_lengths + 1 ys = feats olens = feats_lengths # make labels for stop prediction labels = make_pad_mask(olens - 1).to(ys.device, ys.dtype) labels = F.pad(labels, [0, 1], "constant", 1.0) # calculate tacotron2 outputs after_outs, before_outs, logits, att_ws = self._forward( xs=xs, ilens=ilens, ys=ys, olens=olens, spembs=spembs, sids=sids, lids=lids, ) # modify mod part of groundtruth if self.reduction_factor > 1: assert olens.ge( self.reduction_factor ).all(), "Output length must be greater than or equal to reduction factor." olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) max_out = max(olens) ys = ys[:, :max_out] labels = labels[:, :max_out] labels = torch.scatter( labels, 1, (olens - 1).unsqueeze(1), 1.0 ) # see #3388 # calculate taco2 loss l1_loss, mse_loss, bce_loss = self.taco2_loss( after_outs, before_outs, logits, ys, labels, olens ) if self.loss_type == "L1+L2": loss = l1_loss + mse_loss + bce_loss elif self.loss_type == "L1": loss = l1_loss + bce_loss elif self.loss_type == "L2": loss = mse_loss + bce_loss else: raise ValueError(f"unknown --loss-type {self.loss_type}") stats = dict( l1_loss=l1_loss.item(), mse_loss=mse_loss.item(), bce_loss=bce_loss.item(), ) # calculate attention loss if self.use_guided_attn_loss: # NOTE(kan-bayashi): length of output for auto-regressive # input will be changed when r > 1 if self.reduction_factor > 1: olens_in = olens.new([olen // self.reduction_factor for olen in olens]) else: olens_in = olens attn_loss = self.attn_loss(att_ws, ilens, olens_in) loss = loss + attn_loss stats.update(attn_loss=attn_loss.item()) if not joint_training: stats.update(loss=loss.item()) loss, stats, weight = force_gatherable( (loss, stats, batch_size), loss.device ) return loss, stats, weight else: return loss, stats, after_outs
def forward_loss( self, speech_pre: torch.Tensor, speech_lengths: torch.Tensor, feature_mix: torch.Tensor, feature_pre: torch.Tensor, others: OrderedDict, speech_ref: torch.Tensor, noise_ref: torch.Tensor = None, dereverb_speech_ref: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: loss = 0.0 stats = dict() o = {} for loss_wrapper in self.loss_wrappers: criterion = loss_wrapper.criterion if isinstance(criterion, TimeDomainLoss): if speech_ref[0].dim() == 3: # For multi-channel reference, # only select one channel as the reference speech_ref = [ sr[..., self.ref_channel] for sr in speech_ref ] # for the time domain criterions l, s, o = loss_wrapper(speech_ref, speech_pre, o) elif isinstance(criterion, FrequencyDomainLoss): # for the time-frequency domain criterions if criterion.compute_on_mask: # compute on mask tf_ref = criterion.create_mask_label( feature_mix, [ self.encoder(sr, speech_lengths)[0] for sr in speech_ref ], ) tf_pre = [ others["mask_spk{}".format(spk + 1)] for spk in range(self.num_spk) ] else: # compute on spectrum if speech_ref[0].dim() == 3: # For multi-channel reference, # only select one channel as the reference speech_ref = [ sr[..., self.ref_channel] for sr in speech_ref ] tf_ref = [ self.encoder(sr, speech_lengths)[0] for sr in speech_ref ] tf_pre = feature_pre l, s, o = loss_wrapper(tf_ref, tf_pre, o) loss += l * loss_wrapper.weight stats.update(s) stats["loss"] = loss.detach() # force_gatherable: to-device and to-tensor if scalar for DataParallel batch_size = speech_ref[0].shape[0] loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) # additional checks with valid src_text if "src_text" in kwargs: src_text = kwargs["src_text"] src_text_lengths = kwargs["src_text_lengths"] if src_text is not None: assert src_text_lengths.dim() == 1, src_text_lengths.shape assert (text.shape[0] == src_text.shape[0] == src_text_lengths.shape[0]), ( text.shape, src_text.shape, src_text_lengths.shape, ) else: src_text = None src_text_lengths = None batch_size = speech.shape[0] # clean speech signal speech_ref = None if self.calc_enh_loss: assert "speech_ref1" in kwargs speech_ref = [kwargs["speech_ref1"] ] # [(Batch, samples)] x num_spkr # Calculating enhancement loss utt_id = kwargs.get("utt_id", None) bypass_enh_flag, skip_enhloss_flag = False, False if utt_id is not None: # TODO(xkc): to pass category info and use predefined category list if utt_id[0].endswith("SIMU"): # For simulated single-/multi-speaker data # feed it to Enhancement and calculate loss_enh bypass_enh_flag = False skip_enhloss_flag = False elif utt_id[0].endswith("REAL"): # For single-speaker real data # feed it to Enhancement but without calculating loss_enh bypass_enh_flag = False skip_enhloss_flag = True else: # For clean data # feed it to Enhancement, without calculating loss_enh bypass_enh_flag = True skip_enhloss_flag = True if not self.calc_enh_loss: skip_enhloss_flag = True # Bypass the enhancement module if (self.training and skip_enhloss_flag and not bypass_enh_flag ): # For single-speaker real data: possibility to bypass frontend if random.random() <= self.bypass_enh_prob: bypass_enh_flag = True # 1. Enhancement # model forward loss_enh = None if not bypass_enh_flag: ( speech_pre, feature_mix, feature_pre, others, ) = self.enh_model.forward_enhance(speech, speech_lengths) # loss computation if not skip_enhloss_flag: loss_enh, _, _ = self.enh_model.forward_loss( speech_pre, speech_lengths, feature_mix, feature_pre, others, speech_ref, ) loss_enh = loss_enh[0] else: speech_pre = [speech] # for data-parallel text = text[:, :text_lengths.max()] if src_text is not None: src_text = src_text[:, :src_text_lengths.max()] # 2. ASR or ST if isinstance(self.s2t_model, ESPnetASRModel): # ASR loss_asr, stats, weight = self.s2t_model(speech_pre[0], speech_lengths, text, text_lengths) elif isinstance(self.s2t_model, ESPnetSTModel): # ST loss_asr, stats, weight = self.s2t_model( speech_pre[0], speech_lengths, text, text_lengths, src_text, src_text_lengths, ) else: raise NotImplementedError( f"{type(self.s2t_model)} is not supported yet.") if loss_enh is not None: loss = loss_enh + loss_asr else: loss = loss_asr stats["loss"] = loss.detach() if loss is not None else None stats["loss_enh"] = loss_enh.detach() if loss_enh is not None else None # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, noisy_label_flag: bool=False, replace_label_flag: bool=True, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) """ # if self.stat is not None : # hist = self.stat.confid_hist # if hist.sum() != 0: # hist.requires_grad = False # total_sum = hist.sum() # # simaple mean testing for alpha = 0.27 # z_alpha = 18 # self.th = 1 / self.stat.bins * z_alpha # else: # # logging.warning("Prior histogram has {} value!".format(hist.sum())) # self.th = None assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert ( speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0] ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) batch_size = speech.shape[0] # for data-parallel text = text[:, : text_lengths.max()] # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) # 2a. Attention-decoder branch if self.ctc_weight == 1.0: loss_att, acc_att, cer_att, wer_att = None, None, None, None, None else: if replace_label_flag: decoder_meta_out_prob = self._meta_forward( speech, speech_lengths, text, text_lengths ) else: decoder_meta_out_prob = None loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( encoder_out, encoder_out_lens, text, text_lengths, replace_label_flag, decoder_meta_out_prob ) # 2b. CTC branch if self.ctc_weight == 0.0: loss_ctc, cer_ctc = None, None else: loss_ctc, cer_ctc = self._calc_ctc_loss( encoder_out, encoder_out_lens, text, text_lengths ) # 2c. RNN-T branch if self.rnnt_decoder is not None: _ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text, text_lengths) if self.ctc_weight == 0.0: loss = loss_att elif self.ctc_weight == 1.0: loss = loss_ctc else: loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att stats = dict( loss=loss.detach(), loss_att=loss_att.detach() if loss_att is not None else None, loss_ctc=loss_ctc.detach() if loss_ctc is not None else None, acc=acc_att, cer=cer_att, wer=wer_att, cer_ctc=cer_ctc, # pred_err_att=pred_err_att, ) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, durations: torch.Tensor, durations_lengths: torch.Tensor, spembs: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded character ids (B, Tmax). text_lengths (LongTensor): Batch of lengths of each input (B,). speech (Tensor): Batch of padded target features (B, Lmax, odim). speech_lengths (LongTensor): Batch of the lengths of each target (B,). durations (LongTensor): Batch of padded durations (B, Tmax + 1). durations_lengths (LongTensor): Batch of duration lengths (B, Tmax + 1). spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value. """ text = text[:, :text_lengths.max()] # for data-parallel speech = speech[:, :speech_lengths.max()] # for data-parallel durations = durations[:, :durations_lengths.max()] # for data-parallel batch_size = text.size(0) # Add eos at the last of sequence xs = F.pad(text, [0, 1], "constant", self.padding_idx) for i, l in enumerate(text_lengths): xs[i, l] = self.eos ilens = text_lengths + 1 ys, ds = speech, durations olens = speech_lengths # forward propagation before_outs, after_outs, d_outs = self._forward(xs, ilens, ys, olens, ds, spembs=spembs, is_inference=False) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] # calculate loss if self.postnet is None: after_outs = None l1_loss, duration_loss = self.criterion(after_outs, before_outs, d_outs, ys, ds, ilens, olens) loss = l1_loss + duration_loss stats = dict( l1_loss=l1_loss.item(), duration_loss=duration_loss.item(), loss=loss.item(), ) # report extra information if self.encoder_type == "transformer" and self.use_scaled_pos_enc: stats.update( encoder_alpha=self.encoder.embed[-1].alpha.data.item(), ) if self.decoder_type == "transformer" and self.use_scaled_pos_enc: stats.update( decoder_alpha=self.decoder.embed[-1].alpha.data.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def forward( self, speech_mix: torch.Tensor, speech_mix_lengths: torch.Tensor = None, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech_mix: (Batch, samples) or (Batch, samples, channels) speech_ref: (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels) speech_mix_lengths: (Batch,), default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py """ # clean speech signal of each speaker speech_ref = [ kwargs["speech_ref{}".format(spk + 1)] for spk in range(self.num_spk) ] # (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels) speech_ref = torch.stack(speech_ref, dim=1) if "noise_ref1" in kwargs: # noise signal (optional, required when using # frontend models with beamformering) noise_ref = [ kwargs["noise_ref{}".format(n + 1)] for n in range(self.num_noise_type) ] # (Batch, num_noise_type, samples) or # (Batch, num_noise_type, samples, channels) noise_ref = torch.stack(noise_ref, dim=1) else: noise_ref = None # dereverberated (noisy) signal # (optional, only used for frontend models with WPE) if "dereverb_ref1" in kwargs: # noise signal (optional, required when using # frontend models with beamformering) dereverb_speech_ref = [ kwargs["dereverb_ref{}".format(n + 1)] for n in range(self.num_spk) if "dereverb_ref{}".format(n + 1) in kwargs ] assert len(dereverb_speech_ref) in (1, self.num_spk), len( dereverb_speech_ref ) # (Batch, N, samples) or (Batch, N, samples, channels) dereverb_speech_ref = torch.stack(dereverb_speech_ref, dim=1) else: dereverb_speech_ref = None batch_size = speech_mix.shape[0] speech_lengths = ( speech_mix_lengths if speech_mix_lengths is not None else torch.ones(batch_size).int().fill_(speech_mix.shape[1]) ) assert speech_lengths.dim() == 1, speech_lengths.shape # Check that batch_size is unified assert speech_mix.shape[0] == speech_ref.shape[0] == speech_lengths.shape[0], ( speech_mix.shape, speech_ref.shape, speech_lengths.shape, ) # for data-parallel speech_ref = speech_ref[:, :, : speech_lengths.max()] speech_mix = speech_mix[:, : speech_lengths.max()] loss, speech_pre, others, out_lengths, perm = self._compute_loss( speech_mix, speech_lengths, speech_ref, dereverb_speech_ref=dereverb_speech_ref, noise_ref=noise_ref, ) # add stats for logging if self.loss_type not in ["ci_sdr", "si_snr"]: if self.training: si_snr = None else: speech_pre = [self.decoder(ps, speech_lengths)[0] for ps in speech_pre] speech_ref = torch.unbind(speech_ref, dim=1) if speech_ref[0].dim() == 3: # For si_snr loss, only select one channel as the reference speech_ref = [sr[..., self.ref_channel] for sr in speech_ref] # compute si-snr loss si_snr_loss, perm = self._permutation_loss( speech_ref, speech_pre, self.si_snr_loss, perm=perm ) si_snr = -si_snr_loss.detach() stats = dict( si_snr=si_snr, loss=loss.detach(), ) else: if self.loss_type == "ci_sdr": stats = dict(ci_sdr=-loss.detach(), loss=loss.detach()) elif self.loss_type == "si_snr": stats = dict(si_snr=-loss.detach(), loss=loss.detach()) else: raise ValueError("Unsupported loss type: %s" % self.loss_type) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Forward architecture and compute loss(es). Args: speech: Speech sequences. (B, S) speech_lengths: Speech sequences lengths. (B,) text: Label ID sequences. (B, L) text_lengths: Label ID sequences lengths. (B,) kwargs: Contains "utts_id". Return: loss: Main loss value. stats: Task statistics. weight: Task weights. """ assert text_lengths.dim() == 1, text_lengths.shape assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) batch_size = speech.shape[0] text = text[:, :text_lengths.max()] # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) # 2. Transducer-related I/O preparation decoder_in, target, t_len, u_len = get_transducer_task_io( text, encoder_out_lens, ignore_id=self.ignore_id, ) # 3. Decoder self.decoder.set_device(encoder_out.device) decoder_out = self.decoder(decoder_in) # 4. Joint Network joint_out = self.joint_network(encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)) # 5. Losses loss_trans, cer_trans, wer_trans = self._calc_transducer_loss( encoder_out, joint_out, target, t_len, u_len, ) loss_ctc, loss_lm = 0.0, 0.0 if self.use_auxiliary_ctc: loss_ctc = self._calc_ctc_loss( encoder_out, target, t_len, u_len, ) if self.use_auxiliary_lm_loss: loss_lm = self._calc_lm_loss(decoder_out, target) loss = (self.transducer_weight * loss_trans + self.auxiliary_ctc_weight * loss_ctc + self.auxiliary_lm_loss_weight * loss_lm) stats = dict( loss=loss.detach(), loss_transducer=loss_trans.detach(), aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None, aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None, cer_transducer=cer_trans, wer_transducer=wer_trans, ) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) batch_size = speech.shape[0] # For data-parallel text = text[:, :text_lengths.max()] # Define stats to report loss_mlm, acc_mlm = None, None loss_ctc, cer_ctc = None, None stats = dict() # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) intermediate_outs = None if isinstance(encoder_out, tuple): intermediate_outs = encoder_out[1] encoder_out = encoder_out[0] # 2. CTC branch if self.ctc_weight != 0.0: loss_ctc, cer_ctc = self._calc_ctc_loss(encoder_out, encoder_out_lens, text, text_lengths) # Collect CTC branch stats stats["loss_ctc"] = loss_ctc.detach( ) if loss_ctc is not None else None stats["cer_ctc"] = cer_ctc # 2a. Intermediate CTC (optional) loss_interctc = 0.0 if self.interctc_weight != 0.0 and intermediate_outs is not None: for layer_idx, intermediate_out in intermediate_outs: # we assume intermediate_out has the same length & padding # as those of encoder_out loss_ic, cer_ic = self._calc_ctc_loss(intermediate_out, encoder_out_lens, text, text_lengths) loss_interctc = loss_interctc + loss_ic # Collect Intermedaite CTC stats stats["loss_interctc_layer{}".format(layer_idx)] = ( loss_ic.detach() if loss_ic is not None else None) stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic loss_interctc = loss_interctc / len(intermediate_outs) # calculate whole encoder loss loss_ctc = (1 - self.interctc_weight ) * loss_ctc + self.interctc_weight * loss_interctc # 3. MLM decoder branch if self.ctc_weight != 1.0: loss_mlm, acc_mlm = self._calc_mlm_loss(encoder_out, encoder_out_lens, text, text_lengths) # 4. CTC/MLM loss definition if self.ctc_weight == 0.0: loss = loss_mlm elif self.ctc_weight == 1.0: loss = loss_ctc else: loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_mlm # Collect MLM branch stats stats["loss_mlm"] = loss_mlm.detach() if loss_mlm is not None else None stats["acc_mlm"] = acc_mlm # Collect total loss stats stats["loss"] = loss.detach() # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor = None, text: torch.Tensor = None, text_lengths: torch.Tensor = None, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py text: (Batch, Length) default None just to keep the argument order text_lengths: (Batch,) default None for the same reason as speech_lengths """ if text_lengths is not None: assert text_lengths.dim() == 1, text_lengths.shape if speech_lengths is not None and text_lengths is not None: # Check that batch_size is unified assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) else: assert speech.shape[0] == text.shape[0], (speech.shape, text.shape) # additional checks with valid src_text if "src_text" in kwargs: src_text = kwargs["src_text"] src_text_lengths = kwargs["src_text_lengths"] if src_text is not None: assert src_text_lengths.dim() == 1, src_text_lengths.shape assert (text.shape[0] == src_text.shape[0] == src_text_lengths.shape[0]), ( text.shape, src_text.shape, src_text_lengths.shape, ) else: src_text = None src_text_lengths = None batch_size = speech.shape[0] speech_lengths = (speech_lengths if speech_lengths is not None else torch.ones(batch_size).int() * speech.shape[1]) # number of speakers # Take the number of speakers from text # (= spk_label [Batch, length, num_spk] ) if it is 3-D. # This is to handle flexible number of speakers. # Used only in "enh + diar" task for now. num_spk = text.shape[2] if text.dim() == 3 else self.enh_model.num_spk # clean speech signal of each speaker speech_ref = None if self.calc_enh_loss: assert "speech_ref1" in kwargs speech_ref = [ kwargs["speech_ref{}".format(spk + 1)] for spk in range(num_spk) ] # (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels) speech_ref = torch.stack(speech_ref, dim=1) # for data-parallel speech_ref = speech_ref[..., :speech_lengths.max()] speech_ref = speech_ref.unbind(dim=1) # Calculating enhancement loss utt_id = kwargs.get("utt_id", None) bypass_enh_flag, skip_enhloss_flag = False, False if utt_id is not None and not isinstance(self.s2t_model, ESPnetDiarizationModel): # TODO(xkc): to pass category info and use predefined category list if utt_id[0].endswith("SIMU"): # For simulated single-/multi-speaker data # feed it to Enhancement and calculate loss_enh bypass_enh_flag = False skip_enhloss_flag = False elif utt_id[0].endswith("REAL"): # For single-speaker real data # feed it to Enhancement but without calculating loss_enh bypass_enh_flag = False skip_enhloss_flag = True else: # For clean data # feed it to Enhancement, without calculating loss_enh bypass_enh_flag = True skip_enhloss_flag = True if not self.calc_enh_loss: skip_enhloss_flag = True # Bypass the enhancement module if (self.training and skip_enhloss_flag and not bypass_enh_flag ): # For single-speaker real data: possibility to bypass frontend if random.random() <= self.bypass_enh_prob: bypass_enh_flag = True # 1. Enhancement # model forward loss_enh = None if not bypass_enh_flag: ( speech_pre, feature_mix, feature_pre, others, ) = self.enh_model.forward_enhance(speech, speech_lengths, {"num_spk": num_spk}) # loss computation if not skip_enhloss_flag: loss_enh, _, _ = self.enh_model.forward_loss( speech_pre, speech_lengths, feature_mix, feature_pre, others, speech_ref, ) loss_enh = loss_enh[0] else: speech_pre = [speech] # for data-parallel if text_lengths is not None: text = text[:, :text_lengths.max()] if src_text is not None: src_text = src_text[:, :src_text_lengths.max()] # 2. ASR or ST if isinstance(self.s2t_model, ESPnetASRModel): # ASR loss_asr, stats, weight = self.s2t_model(speech_pre[0], speech_lengths, text, text_lengths) elif isinstance(self.s2t_model, ESPnetSTModel): # ST loss_asr, stats, weight = self.s2t_model( speech_pre[0], speech_lengths, text, text_lengths, src_text, src_text_lengths, ) elif isinstance(self.s2t_model, ESPnetDiarizationModel): # DIAR loss_asr, stats, weight = self.s2t_model( speech=speech.clone(), speech_lengths=speech_lengths, spk_labels=text, spk_labels_lengths=text_lengths, bottleneck_feats=others.get("bottleneck_feats"), bottleneck_feats_lengths=others.get( "bottleneck_feats_lengths"), ) else: raise NotImplementedError( f"{type(self.s2t_model)} is not supported yet.") if loss_enh is not None: loss = loss_enh + loss_asr else: loss = loss_asr stats["loss"] = loss.detach() if loss is not None else None stats["loss_enh"] = loss_enh.detach() if loss_enh is not None else None # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def forward_loss( self, speech_pre: torch.Tensor, speech_lengths: torch.Tensor, feature_mix: torch.Tensor, feature_pre: torch.Tensor, others: OrderedDict, speech_ref: torch.Tensor, noise_ref: torch.Tensor = None, dereverb_speech_ref: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: # for calculating loss on estimated noise signals if getattr(self.separator, "predict_noise", False): assert "noise1" in others, others.keys() if noise_ref is not None and "noise1" in others: for n in range(self.num_noise_type): key = "noise{}".format(n + 1) others[key] = self.decoder(others[key], speech_lengths)[0] # for calculating loss on dereverberated signals if getattr(self.separator, "predict_dereverb", False): assert "dereverb1" in others, others.keys() if dereverb_speech_ref is not None and "dereverb1" in others: for spk in range(self.num_spk): key = "dereverb{}".format(spk + 1) if key in others: others[key] = self.decoder(others[key], speech_lengths)[0] loss = 0.0 stats = {} o = {} for loss_wrapper in self.loss_wrappers: criterion = loss_wrapper.criterion if getattr(criterion, "only_for_test", False) and self.training: continue if getattr(criterion, "is_noise_loss", False): if noise_ref is None: raise ValueError( "No noise reference for training!\n" 'Please specify "--use_noise_ref true" in run.sh') signal_ref = noise_ref signal_pre = [ others["noise{}".format(n + 1)] for n in range(self.num_noise_type) ] elif getattr(criterion, "is_dereverb_loss", False): if dereverb_speech_ref is None: raise ValueError( "No dereverberated reference for training!\n" 'Please specify "--use_dereverb_ref true" in run.sh') signal_ref = dereverb_speech_ref signal_pre = [ others["dereverb{}".format(n + 1)] for n in range(self.num_noise_type) if "dereverb{}".format(n + 1) in others ] if len(signal_pre) == 0: signal_pre = None else: signal_ref = speech_ref signal_pre = speech_pre if isinstance(criterion, TimeDomainLoss): assert signal_pre is not None sref, spre = self._align_ref_pre_channels(signal_ref, signal_pre, ch_dim=2, force_1ch=True) # for the time domain criterions l, s, o = loss_wrapper(sref, spre, {**others, **o}) elif isinstance(criterion, FrequencyDomainLoss): sref, spre = self._align_ref_pre_channels(signal_ref, signal_pre, ch_dim=2, force_1ch=False) # for the time-frequency domain criterions if criterion.compute_on_mask: # compute loss on masks if getattr(criterion, "is_noise_loss", False): tf_ref, tf_pre = self._get_noise_masks( criterion, feature_mix, speech_ref, signal_ref, signal_pre, speech_lengths, others, ) elif getattr(criterion, "is_dereverb_loss", False): tf_ref, tf_pre = self._get_dereverb_masks( criterion, feature_mix, noise_ref, signal_ref, signal_pre, speech_lengths, others, ) else: tf_ref, tf_pre = self._get_speech_masks( criterion, feature_mix, noise_ref, signal_ref, signal_pre, speech_lengths, others, ) else: # compute on spectrum tf_ref = [ self.encoder(sr, speech_lengths)[0] for sr in sref ] tf_pre = [ self.encoder(sp, speech_lengths)[0] for sp in spre ] l, s, o = loss_wrapper(tf_ref, tf_pre, {**others, **o}) else: raise NotImplementedError("Unsupported loss type: %s" % str(criterion)) loss += l * loss_wrapper.weight stats.update(s) if self.training and isinstance(loss, float): raise AttributeError( "At least one criterion must satisfy: only_for_test=False") stats["loss"] = loss.detach() # force_gatherable: to-device and to-tensor if scalar for DataParallel batch_size = speech_ref[0].shape[0] loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def _forward_generator( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, **kwargs, ) -> Dict[str, Any]: """Perform generator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). speech (Tensor): Speech waveform tensor (B, T_wav). speech_lengths (Tensor): Speech length tensor (B,). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: Dict[str, Any]: * loss (Tensor): Loss scalar tensor. * stats (Dict[str, float]): Statistics to be monitored. * weight (Tensor): Weight tensor to summarize losses. * optim_idx (int): Optimizer index (0 for G and 1 for D). """ # setup batch_size = text.size(0) speech = speech.unsqueeze(1) # calculate generator outputs reuse_cache = True if not self.cache_generator_outputs or self._cache is None: reuse_cache = False outs = self.generator( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, sids=sids, spembs=spembs, lids=lids, **kwargs, ) else: outs = self._cache # store cache if self.training and self.cache_generator_outputs and not reuse_cache: self._cache = outs # parse outputs ( speech_hat_, bin_loss, log_p_attn, start_idxs, d_outs, ds, p_outs, ps, e_outs, es, ) = outs speech_ = get_segments( x=speech, start_idxs=start_idxs * self.generator.upsample_factor, segment_size=self.generator.segment_size * self.generator.upsample_factor, ) # calculate discriminator outputs p_hat = self.discriminator(speech_hat_) with torch.no_grad(): # do not store discriminator gradient in generator turn p = self.discriminator(speech_) # calculate losses mel_loss = self.mel_loss(speech_hat_, speech_) adv_loss = self.generator_adv_loss(p_hat) feat_match_loss = self.feat_match_loss(p_hat, p) dur_loss, pitch_loss, energy_loss = self.var_loss( d_outs, ds, p_outs, ps, e_outs, es, text_lengths ) forwardsum_loss = self.forwardsum_loss(log_p_attn, text_lengths, feats_lengths) mel_loss = mel_loss * self.lambda_mel adv_loss = adv_loss * self.lambda_adv feat_match_loss = feat_match_loss * self.lambda_feat_match g_loss = mel_loss + adv_loss + feat_match_loss var_loss = (dur_loss + pitch_loss + energy_loss) * self.lambda_var align_loss = (forwardsum_loss + bin_loss) * self.lambda_align loss = g_loss + var_loss + align_loss stats = dict( generator_loss=loss.item(), generator_g_loss=g_loss.item(), generator_var_loss=var_loss.item(), generator_align_loss=align_loss.item(), generator_g_mel_loss=mel_loss.item(), generator_g_adv_loss=adv_loss.item(), generator_g_feat_match_loss=feat_match_loss.item(), generator_var_dur_loss=dur_loss.item(), generator_var_pitch_loss=pitch_loss.item(), generator_var_energy_loss=energy_loss.item(), generator_align_forwardsum_loss=forwardsum_loss.item(), generator_align_bin_loss=bin_loss.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) # reset cache if reuse_cache or not self.training: self._cache = None return { "loss": loss, "stats": stats, "weight": weight, "optim_idx": 0, # needed for trainer }
def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, joint_training: bool = False, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded character ids (B, Tmax). text_lengths (LongTensor): Batch of lengths of each input batch (B,). feats (Tensor): Batch of padded target features (B, Lmax, odim). feats_lengths (LongTensor): Batch of the lengths of each target (B,). spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim). sids (Optional[Tensor]): Batch of speaker IDs (B, 1). lids (Optional[Tensor]): Batch of language IDs (B, 1). joint_training (bool): Whether to perform joint training with vocoder. Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value if not joint training else model outputs. """ text = text[:, :text_lengths.max()] # for data-parallel feats = feats[:, :feats_lengths.max()] # for data-parallel batch_size = text.size(0) # Add eos at the last of sequence xs = F.pad(text, [0, 1], "constant", self.padding_idx) for i, l in enumerate(text_lengths): xs[i, l] = self.eos ilens = text_lengths + 1 ys = feats olens = feats_lengths # make labels for stop prediction labels = make_pad_mask(olens - 1).to(ys.device, ys.dtype) labels = F.pad(labels, [0, 1], "constant", 1.0) # calculate transformer outputs after_outs, before_outs, logits = self._forward( xs=xs, ilens=ilens, ys=ys, olens=olens, spembs=spembs, sids=sids, lids=lids, ) # modifiy mod part of groundtruth olens_in = olens if self.reduction_factor > 1: assert olens.ge(self.reduction_factor).all( ), "Output length must be greater than or equal to reduction factor." olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] labels = labels[:, :max_olen] labels = torch.scatter(labels, 1, (olens - 1).unsqueeze(1), 1.0) # see #3388 # calculate loss values l1_loss, l2_loss, bce_loss = self.criterion(after_outs, before_outs, logits, ys, labels, olens) if self.loss_type == "L1": loss = l1_loss + bce_loss elif self.loss_type == "L2": loss = l2_loss + bce_loss elif self.loss_type == "L1+L2": loss = l1_loss + l2_loss + bce_loss else: raise ValueError("unknown --loss-type " + self.loss_type) stats = dict( l1_loss=l1_loss.item(), l2_loss=l2_loss.item(), bce_loss=bce_loss.item(), ) # calculate guided attention loss if self.use_guided_attn_loss: # calculate for encoder if "encoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.encoder.encoders)))): att_ws += [ self.encoder.encoders[layer_idx].self_attn. attn[:, :self.num_heads_applied_guided_attn] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_text, T_text) enc_attn_loss = self.attn_criterion(att_ws, ilens, ilens) loss = loss + enc_attn_loss stats.update(enc_attn_loss=enc_attn_loss.item()) # calculate for decoder if "decoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.decoder.decoders)))): att_ws += [ self.decoder.decoders[layer_idx].self_attn. attn[:, :self.num_heads_applied_guided_attn] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_feats, T_feats) dec_attn_loss = self.attn_criterion(att_ws, olens_in, olens_in) loss = loss + dec_attn_loss stats.update(dec_attn_loss=dec_attn_loss.item()) # calculate for encoder-decoder if "encoder-decoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.decoder.decoders)))): att_ws += [ self.decoder.decoders[layer_idx].src_attn. attn[:, :self.num_heads_applied_guided_attn] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_feats, T_text) enc_dec_attn_loss = self.attn_criterion( att_ws, ilens, olens_in) loss = loss + enc_dec_attn_loss stats.update(enc_dec_attn_loss=enc_dec_attn_loss.item()) # report extra information if self.use_scaled_pos_enc: stats.update( encoder_alpha=self.encoder.embed[-1].alpha.data.item(), decoder_alpha=self.decoder.embed[-1].alpha.data.item(), ) if not joint_training: stats.update(loss=loss.item()) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight else: return loss, stats, after_outs
def forward( self, speech_mix: torch.Tensor, speech_mix_lengths: torch.Tensor, speech_ref1: torch.Tensor, speech_ref2: torch.Tensor, text_ref1: torch.Tensor, text_ref2: torch.Tensor, text_ref1_lengths: torch.Tensor, text_ref2_lengths: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Enhancement + Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) """ assert text_ref1_lengths.dim() == text_ref2_lengths.dim() == 1, ( text_ref1_lengths.shape, text_ref2_lengths.shape, ) # Check that batch_size is unified assert (speech_mix.shape[0] == speech_mix_lengths.shape[0] == text_ref1.shape[0] == text_ref1_lengths.shape[0] == text_ref2.shape[0] == text_ref2_lengths.shape[0]), ( speech_mix.shape, speech_mix_lengths.shape, text_ref1.shape, text_ref1_lengths.shape, ) batch_size = speech_mix.shape[0] # for data-parallel text_length_max = max(text_ref1_lengths.max(), text_ref2_lengths.max()) text_ref1 = torch.cat( [ text_ref1, torch.ones(batch_size, text_length_max, dtype=text_ref1.dtype).to(text_ref1.device) * self.idx_blank, ], dim=1, ) text_ref2 = torch.cat( [ text_ref2, torch.ones(batch_size, text_length_max, dtype=text_ref1.dtype).to(text_ref1.device) * self.idx_blank, ], dim=1, ) text_ref1 = text_ref1[:, :text_length_max] text_ref2 = text_ref2[:, :text_length_max] # 0. Enhancement # make sure the speech_pre is the raw waveform with same size. loss_enh, perm, speech_pre = self.forward_enh( speech_mix, speech_mix_lengths, speech_ref1=speech_ref1, speech_ref2=speech_ref2, ) # speech_pre: (bs,num_spk,T) assert speech_pre[:, 0].shape == speech_mix.shape # Pack the separated speakers into the ASR part. speech_pre_all = speech_pre.view( -1, speech_mix.shape[-1]) # (bs*num_spk, T) speech_pre_lengths = torch.stack( [speech_mix_lengths, speech_mix_lengths], dim=1).view(-1) text_ref_all = torch.stack([text_ref1, text_ref2], dim=1).view(batch_size * 2, -1) text_ref_lengths = torch.stack([text_ref1_lengths, text_ref2_lengths], dim=1).view(-1) # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech_pre_all, speech_pre_lengths) # 2a. Attention-decoder branch if self.ctc_weight == 1.0: loss_att, acc_att, cer_att, wer_att = None, None, None, None else: loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( encoder_out, encoder_out_lens, text_ref_all, text_ref_lengths) # 2b. CTC branch if self.ctc_weight == 0.0: loss_ctc, cer_ctc = None, None else: loss_ctc, cer_ctc = self._calc_ctc_loss(encoder_out, encoder_out_lens, text_ref_all, text_ref_lengths) # 2c. RNN-T branch if self.rnnt_decoder is not None: _ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text_ref_all, text_ref_lengths) if self.ctc_weight == 0.0: loss_asr = loss_att elif self.ctc_weight == 1.0: loss_asr = loss_ctc else: loss_asr = self.ctc_weight * loss_ctc + ( 1 - self.ctc_weight) * loss_att if self.enh_weight == 0.0: loss_enh = None loss = loss_asr else: loss = (1 - self.enh_weight) * loss_asr + self.enh_weight * loss_enh stats = dict( loss=loss.detach(), loss_att=loss_att.detach() if loss_att is not None else None, loss_ctc=loss_ctc.detach() if loss_ctc is not None else None, loss_enh=loss_enh.detach() if loss_enh is not None else None, acc=acc_att, cer=cer_att, wer=wer_att, cer_ctc=cer_ctc, ) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def forward( self, speech_mix: torch.Tensor, speech_mix_lengths: torch.Tensor = None, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech_mix: (Batch, samples) or (Batch, samples, channels) speech_ref: (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels) speech_mix_lengths: (Batch,), default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py """ # clean speech signal of each speaker speech_ref = [ kwargs["speech_ref{}".format(spk + 1)] for spk in range(self.num_spk) ] # (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels) speech_ref = torch.stack(speech_ref, dim=1) if "noise_ref1" in kwargs: # noise signal (optional, required when using # frontend models with beamformering) noise_ref = [ kwargs["noise_ref{}".format(n + 1)] for n in range(self.num_noise_type) ] # (Batch, num_noise_type, samples) or # (Batch, num_noise_type, samples, channels) noise_ref = torch.stack(noise_ref, dim=1) else: noise_ref = None # dereverberated noisy signal # (optional, only used for frontend models with WPE) dereverb_speech_ref = kwargs.get("dereverb_ref", None) batch_size = speech_mix.shape[0] speech_lengths = (speech_mix_lengths if speech_mix_lengths is not None else torch.ones(batch_size).int() * speech_mix.shape[1]) assert speech_lengths.dim() == 1, speech_lengths.shape # Check that batch_size is unified assert speech_mix.shape[0] == speech_ref.shape[ 0] == speech_lengths.shape[0], ( speech_mix.shape, speech_ref.shape, speech_lengths.shape, ) batch_size = speech_mix.shape[0] # for data-parallel speech_ref = speech_ref[:, :, :speech_lengths.max()] speech_mix = speech_mix[:, :speech_lengths.max()] if self.loss_type != "si_snr": # prepare reference speech and reference spectrum speech_ref = torch.unbind(speech_ref, dim=1) spectrum_ref = [self.enh_model.stft(sr)[0] for sr in speech_ref] # List[ComplexTensor(Batch, T, F)] or List[ComplexTensor(Batch, T, C, F)] spectrum_ref = [ ComplexTensor(sr[..., 0], sr[..., 1]) for sr in spectrum_ref ] spectrum_mix = self.enh_model.stft(speech_mix)[0] spectrum_mix = ComplexTensor(spectrum_mix[..., 0], spectrum_mix[..., 1]) # predict separated speech and masks spectrum_pre, tf_length, mask_pre = self.enh_model( speech_mix, speech_lengths) # compute TF masking loss if self.loss_type == "magnitude": # compute loss on magnitude spectrum magnitude_pre = [abs(ps) for ps in spectrum_pre] magnitude_ref = [abs(sr) for sr in spectrum_ref] tf_loss, perm = self._permutation_loss(magnitude_ref, magnitude_pre, self.tf_mse_loss) elif self.loss_type == "spectrum": # compute loss on complex spectrum tf_loss, perm = self._permutation_loss(spectrum_ref, spectrum_pre, self.tf_mse_loss) elif self.loss_type.startswith("mask"): if self.loss_type == "mask_mse": loss_func = self.tf_mse_loss else: raise ValueError("Unsupported loss type: %s" % self.loss_type) assert mask_pre is not None mask_pre_ = [ mask_pre["spk{}".format(spk + 1)] for spk in range(self.num_spk) ] # prepare ideal masks mask_ref = self._create_mask_label(spectrum_mix, spectrum_ref, mask_type=self.mask_type) # compute TF masking loss tf_loss, perm = self._permutation_loss(mask_ref, mask_pre_, loss_func) if "dereverb" in mask_pre: if dereverb_speech_ref is None: raise ValueError( "No dereverberated reference for training!\n" 'Please specify "--use_dereverb_ref true" in run.sh' ) dereverb_spectrum_ref = self.enh_model.stft( dereverb_speech_ref)[0] dereverb_spectrum_ref = ComplexTensor( dereverb_spectrum_ref[..., 0], dereverb_spectrum_ref[..., 1]) # ComplexTensor(B, T, F) or ComplexTensor(B, T, C, F) dereverb_mask_ref = self._create_mask_label( spectrum_mix, [dereverb_spectrum_ref], mask_type=self.mask_type)[0] tf_loss = (tf_loss + loss_func( dereverb_mask_ref, mask_pre["dereverb"]).mean()) if "noise1" in mask_pre: if noise_ref is None: raise ValueError( "No noise reference for training!\n" 'Please specify "--use_noise_ref true" in run.sh') noise_ref = torch.unbind(noise_ref, dim=1) noise_spectrum_ref = [ self.enh_model.stft(nr)[0] for nr in noise_ref ] noise_spectrum_ref = [ ComplexTensor(nr[..., 0], nr[..., 1]) for nr in noise_spectrum_ref ] noise_mask_ref = self._create_mask_label( spectrum_mix, noise_spectrum_ref, mask_type=self.mask_type) mask_noise_pre = [ mask_pre["noise{}".format(n + 1)] for n in range(self.num_noise_type) ] tf_noise_loss, perm_n = self._permutation_loss( noise_mask_ref, mask_noise_pre, loss_func) tf_loss = tf_loss + tf_noise_loss else: raise ValueError("Unsupported loss type: %s" % self.loss_type) if self.training: si_snr = None else: speech_pre = [ self.enh_model.stft.inverse(ps, speech_lengths)[0] for ps in spectrum_pre ] if speech_ref[0].dim() == 3: # For si_snr loss, only select one channel as the reference speech_ref = [ sr[..., self.ref_channel] for sr in speech_ref ] # compute si-snr loss si_snr_loss, perm = self._permutation_loss(speech_ref, speech_pre, self.si_snr_loss, perm=perm) si_snr = -si_snr_loss.detach() loss = tf_loss stats = dict( si_snr=si_snr, loss=loss.detach(), ) else: if speech_ref.dim() == 4: # For si_snr loss of multi-channel input, # only select one channel as the reference speech_ref = speech_ref[..., self.ref_channel] speech_pre, speech_lengths, *__ = self.enh_model.forward_rawwav( speech_mix, speech_lengths) # speech_pre: list[(batch, sample)] assert speech_pre[0].dim() == 2, speech_pre[0].dim() speech_ref = torch.unbind(speech_ref, dim=1) # compute si-snr loss si_snr_loss, perm = self._permutation_loss( speech_ref, speech_pre, self.si_snr_loss_zeromean) si_snr = -si_snr_loss loss = si_snr_loss stats = dict(si_snr=si_snr.detach(), loss=loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def _forward_generator( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, ) -> Dict[str, Any]: """Perform generator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). speech (Tensor): Speech waveform tensor (B, T_wav). speech_lengths (Tensor): Speech length tensor (B,). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: Dict[str, Any]: * loss (Tensor): Loss scalar tensor. * stats (Dict[str, float]): Statistics to be monitored. * weight (Tensor): Weight tensor to summarize losses. * optim_idx (int): Optimizer index (0 for G and 1 for D). """ # setup batch_size = text.size(0) feats = feats.transpose(1, 2) speech = speech.unsqueeze(1) # calculate generator outputs reuse_cache = True if not self.cache_generator_outputs or self._cache is None: reuse_cache = False outs = self.generator( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, sids=sids, spembs=spembs, lids=lids, ) else: outs = self._cache # store cache if self.training and self.cache_generator_outputs and not reuse_cache: self._cache = outs # parse outputs speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs _, z_p, m_p, logs_p, _, logs_q = outs_ speech_ = get_segments( x=speech, start_idxs=start_idxs * self.generator.upsample_factor, segment_size=self.generator.segment_size * self.generator.upsample_factor, ) # calculate discriminator outputs p_hat = self.discriminator(speech_hat_) with torch.no_grad(): # do not store discriminator gradient in generator turn p = self.discriminator(speech_) # calculate losses with autocast(enabled=False): mel_loss = self.mel_loss(speech_hat_, speech_) kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask) dur_loss = torch.sum(dur_nll.float()) adv_loss = self.generator_adv_loss(p_hat) feat_match_loss = self.feat_match_loss(p_hat, p) mel_loss = mel_loss * self.lambda_mel kl_loss = kl_loss * self.lambda_kl dur_loss = dur_loss * self.lambda_dur adv_loss = adv_loss * self.lambda_adv feat_match_loss = feat_match_loss * self.lambda_feat_match loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss stats = dict( generator_loss=loss.item(), generator_mel_loss=mel_loss.item(), generator_kl_loss=kl_loss.item(), generator_dur_loss=dur_loss.item(), generator_adv_loss=adv_loss.item(), generator_feat_match_loss=feat_match_loss.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) # reset cache if reuse_cache or not self.training: self._cache = None return { "loss": loss, "stats": stats, "weight": weight, "optim_idx": 0, # needed for trainer }
def _forward_discrminator( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, ) -> Dict[str, Any]: """Perform discriminator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). speech (Tensor): Speech waveform tensor (B, T_wav). speech_lengths (Tensor): Speech length tensor (B,). Returns: Dict[str, Any]: * loss (Tensor): Loss scalar tensor. * stats (Dict[str, float]): Statistics to be monitored. * weight (Tensor): Weight tensor to summarize losses. * optim_idx (int): Optimizer index (0 for G and 1 for D). """ # setup batch_size = text.size(0) speech = speech.unsqueeze(1) # calculate generator outputs reuse_cache = True if not self.cache_generator_outputs or self._cache is None: reuse_cache = False # calculate text2mel outputs text2mel_loss, stats, feats_gen = self.generator["text2mel"]( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, joint_training=True, **kwargs, ) # get random segments feats_gen_, start_idxs = get_random_segments( x=feats_gen.transpose(1, 2), x_lengths=feats_lengths, segment_size=self.segment_size, ) # calculate vocoder outputs speech_hat_ = self.generator["vocoder"](feats_gen_) if self.use_pqmf: speech_hat_ = self.pqmf.synthesis(speech_hat_) else: _, _, speech_hat_, start_idxs = self._cache # store cache if self.cache_generator_outputs and not reuse_cache: self._cache = (text2mel_loss, stats, speech_hat_, start_idxs) # parse outputs speech_ = get_segments( x=speech, start_idxs=start_idxs * self.generator["vocoder"].upsample_factor, segment_size=self.segment_size * self.generator["vocoder"].upsample_factor, ) # calculate discriminator outputs p_hat = self.discriminator(speech_hat_.detach()) p = self.discriminator(speech_) # calculate losses real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) loss = real_loss + fake_loss stats = dict( discriminator_loss=loss.item(), real_loss=real_loss.item(), fake_loss=fake_loss.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) # reset cache if reuse_cache or not self.training: self._cache = None return { "loss": loss, "stats": stats, "weight": weight, "optim_idx": 1, # needed for trainer }
def _forward_discrminator( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, ) -> Dict[str, Any]: """Perform discriminator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). speech (Tensor): Speech waveform tensor (B, T_wav). speech_lengths (Tensor): Speech length tensor (B,). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: Dict[str, Any]: * loss (Tensor): Loss scalar tensor. * stats (Dict[str, float]): Statistics to be monitored. * weight (Tensor): Weight tensor to summarize losses. * optim_idx (int): Optimizer index (0 for G and 1 for D). """ # setup batch_size = text.size(0) feats = feats.transpose(1, 2) speech = speech.unsqueeze(1) # calculate generator outputs reuse_cache = True if not self.cache_generator_outputs or self._cache is None: reuse_cache = False outs = self.generator( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, sids=sids, spembs=spembs, lids=lids, ) else: outs = self._cache # store cache if self.cache_generator_outputs and not reuse_cache: self._cache = outs # parse outputs speech_hat_, _, _, start_idxs, *_ = outs speech_ = get_segments( x=speech, start_idxs=start_idxs * self.generator.upsample_factor, segment_size=self.generator.segment_size * self.generator.upsample_factor, ) # calculate discriminator outputs p_hat = self.discriminator(speech_hat_.detach()) p = self.discriminator(speech_) # calculate losses with autocast(enabled=False): real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) loss = real_loss + fake_loss stats = dict( discriminator_loss=loss.item(), discriminator_real_loss=real_loss.item(), discriminator_fake_loss=fake_loss.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) # reset cache if reuse_cache or not self.training: self._cache = None return { "loss": loss, "stats": stats, "weight": weight, "optim_idx": 1, # needed for trainer }
def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, spembs: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded character ids (B, Tmax). text_lengths (LongTensor): Batch of lengths of each input batch (B,). speech (Tensor): Batch of padded target features (B, Lmax, odim). speech_lengths (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value. """ text = text[:, :text_lengths.max()] # for data-parallel speech = speech[:, :speech_lengths.max()] # for data-parallel batch_size = text.size(0) # Add eos at the last of sequence xs = F.pad(text, [0, 1], "constant", self.padding_idx) for i, l in enumerate(text_lengths): xs[i, l] = self.eos ilens = text_lengths + 1 ys = speech olens = speech_lengths # make labels for stop prediction labels = make_pad_mask(olens - 1).to(ys.device, ys.dtype) labels = F.pad(labels, [0, 1], "constant", 1.0) # calculate tacotron2 outputs after_outs, before_outs, logits, att_ws = self._forward( xs, ilens, ys, olens, spembs) # modify mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_out = max(olens) ys = ys[:, :max_out] labels = labels[:, :max_out] labels[:, -1] = 1.0 # make sure at least one frame has 1 # calculate taco2 loss l1_loss, mse_loss, bce_loss = self.taco2_loss(after_outs, before_outs, logits, ys, labels, olens) if self.loss_type == "L1+L2": loss = l1_loss + mse_loss + bce_loss elif self.loss_type == "L1": loss = l1_loss + bce_loss elif self.loss_type == "L2": loss = mse_loss + bce_loss else: raise ValueError(f"unknown --loss-type {self.loss_type}") stats = dict( l1_loss=l1_loss.item(), mse_loss=mse_loss.item(), bce_loss=bce_loss.item(), ) # calculate attention loss if self.use_guided_attn_loss: # NOTE(kan-bayashi): length of output for auto-regressive # input will be changed when r > 1 if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens attn_loss = self.attn_loss(att_ws, ilens, olens_in) loss = loss + attn_loss stats.update(attn_loss=attn_loss.item()) stats.update(loss=loss.item()) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) kwargs: "utt_id" is among the input. """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert ( speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0] ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) batch_size = speech.shape[0] # for data-parallel text = text[:, : text_lengths.max()] # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) intermediate_outs = None if isinstance(encoder_out, tuple): intermediate_outs = encoder_out[1] encoder_out = encoder_out[0] loss_att, acc_att, cer_att, wer_att = None, None, None, None loss_ctc, cer_ctc = None, None loss_transducer, cer_transducer, wer_transducer = None, None, None stats = dict() # 1. CTC branch if self.ctc_weight != 0.0: loss_ctc, cer_ctc = self._calc_ctc_loss( encoder_out, encoder_out_lens, text, text_lengths ) # Collect CTC branch stats stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None stats["cer_ctc"] = cer_ctc # Intermediate CTC (optional) loss_interctc = 0.0 if self.interctc_weight != 0.0 and intermediate_outs is not None: for layer_idx, intermediate_out in intermediate_outs: # we assume intermediate_out has the same length & padding # as those of encoder_out loss_ic, cer_ic = self._calc_ctc_loss( intermediate_out, encoder_out_lens, text, text_lengths ) loss_interctc = loss_interctc + loss_ic # Collect Intermedaite CTC stats stats["loss_interctc_layer{}".format(layer_idx)] = ( loss_ic.detach() if loss_ic is not None else None ) stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic loss_interctc = loss_interctc / len(intermediate_outs) # calculate whole encoder loss loss_ctc = ( 1 - self.interctc_weight ) * loss_ctc + self.interctc_weight * loss_interctc if self.use_transducer_decoder: # 2a. Transducer decoder branch ( loss_transducer, cer_transducer, wer_transducer, ) = self._calc_transducer_loss( encoder_out, encoder_out_lens, text, ) if loss_ctc is not None: loss = loss_transducer + (self.ctc_weight * loss_ctc) else: loss = loss_transducer # Collect Transducer branch stats stats["loss_transducer"] = ( loss_transducer.detach() if loss_transducer is not None else None ) stats["cer_transducer"] = cer_transducer stats["wer_transducer"] = wer_transducer else: # 2b. Attention decoder branch if self.ctc_weight != 1.0: loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( encoder_out, encoder_out_lens, text, text_lengths ) # 3. CTC-Att loss definition if self.ctc_weight == 0.0: loss = loss_att elif self.ctc_weight == 1.0: loss = loss_ctc else: loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att # Collect Attn branch stats stats["loss_att"] = loss_att.detach() if loss_att is not None else None stats["acc"] = acc_att stats["cer"] = cer_att stats["wer"] = wer_att # Collect total loss stats stats["loss"] = loss.detach() # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight