def forward( self, x: torch.Tensor, ilens: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward function Args: x: (B, L, ...) ilens: (B,) """ if ilens is None: ilens = x.new_full([x.size(0)], x.size(1)) norm_means = self.norm_means norm_vars = self.norm_vars self.mean = self.mean.to(x.device, x.dtype) self.std = self.std.to(x.device, x.dtype) mask = make_pad_mask(ilens, x, 1) # feat: (B, T, D) if norm_means: if x.requires_grad: x = x - self.mean else: x -= self.mean if x.requires_grad: x = x.masked_fill(mask, 0.0) else: x.masked_fill_(mask, 0.0) if norm_vars: x /= self.std return x, ilens
def forward( self, feat: torch.Tensor, ilens: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) mel_feat = torch.matmul(feat, self.melmat) mel_feat = torch.clamp(mel_feat, min=1e-10) if self.log_base is None: logmel_feat = mel_feat.log() elif self.log_base == 2.0: logmel_feat = mel_feat.log2() elif self.log_base == 10.0: logmel_feat = mel_feat.log10() else: logmel_feat = mel_feat.log() / torch.log(self.log_base) # Zero padding if ilens is not None: logmel_feat = logmel_feat.masked_fill( make_pad_mask(ilens, logmel_feat, 1), 0.0 ) else: ilens = feat.new_full( [feat.size(0)], fill_value=feat.size(1), dtype=torch.long ) return logmel_feat, ilens
def label_aggregate( self, input: torch.Tensor, input_lengths: torch.Tensor = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """lage_aggregate function. Args: input: (Batch, Nsamples, Label_dim) input_lengths: (Batch) Returns: output: (Batch, Frames, Label_dim) """ bs = input.size(0) max_length = input.size(1) label_dim = input.size(2) # NOTE(jiatong): # The default behaviour of label aggregation is compatible with # torch.stft about framing and padding. # Step1: center padding if self.center: pad = self.win_length // 2 max_length = max_length + 2 * pad input = torch.nn.functional.pad(input, (0, 0, pad, pad), "constant", 0) input[:, :pad, :] = input[:, pad:(2 * pad), :] input[:, (max_length - pad):max_length, :] = input[:, (max_length - 2 * pad):(max_length - pad), :] nframe = (max_length - self.win_length) // self.hop_length + 1 # Step2: framing output = input.as_strided( (bs, nframe, self.win_length, label_dim), (max_length * label_dim, self.hop_length * label_dim, label_dim, 1), ) # Step3: aggregate label # (bs, nframe, self.win_length, label_dim) => (bs, nframe) _tmp = output.sum(dim=-1, keepdim=False).float() output = _tmp[:, :, self.win_length // 2] # Step4: process lengths if input_lengths is not None: if self.center: pad = self.win_length // 2 input_lengths = input_lengths + 2 * pad olens = (input_lengths - self.win_length) // self.hop_length + 1 output.masked_fill_(make_pad_mask(olens, output, 1), 0.0) else: olens = None return output, olens
def utterance_mvn( x: torch.Tensor, ilens: torch.Tensor = None, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20, ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply utterance mean and variance normalization Args: x: (B, T, D), assumed zero padded ilens: (B,) norm_means: norm_vars: eps: """ if ilens is None: ilens = x.new_full([x.size(0)], x.size(1)) ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)]) # Zero padding if x.requires_grad: x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0) else: x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0) # mean: (B, 1, D) mean = x.sum(dim=1, keepdim=True) / ilens_ if norm_means: x -= mean if norm_vars: var = x.pow(2).sum(dim=1, keepdim=True) / ilens_ std = torch.clamp(var.sqrt(), min=eps) x = x / std.sqrt() return x, ilens else: if norm_vars: y = x - mean y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0) var = y.pow(2).sum(dim=1, keepdim=True) / ilens_ std = torch.clamp(var.sqrt(), min=eps) x /= std return x, ilens
def inverse( self, x: torch.Tensor, ilens: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor]: if ilens is None: ilens = x.new_full([x.size(0)], x.size(1)) norm_means = self.norm_means norm_vars = self.norm_vars self.mean = self.mean.to(x.device, x.dtype) self.std = self.std.to(x.device, x.dtype) mask = make_pad_mask(ilens, x, 1) if x.requires_grad: x = x.masked_fill(mask, 0.0) else: x.masked_fill_(mask, 0.0) if norm_vars: x *= self.std # feat: (B, T, D) if norm_means: x += self.mean x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0) return x, ilens
def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, label: torch.Tensor, label_lengths: torch.Tensor, midi: torch.Tensor, midi_lengths: torch.Tensor, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, joint_training: bool = False, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded character ids (B, Tmax). text_lengths (LongTensor): Batch of lengths of each input batch (B,). feats (Tensor): Batch of padded target features (B, Lmax, odim). feats_lengths (LongTensor): Batch of the lengths of each target (B,). spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim). sids (Optional[Tensor]): Batch of speaker IDs (B, 1). lids (Optional[Tensor]): Batch of language IDs (B, 1). joint_training (bool): Whether to perform joint training with vocoder. Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value if not joint training else model outputs. """ text = text[:, :text_lengths.max()] # for data-parallel feats = feats[:, :feats_lengths.max()] # for data-parallel batch_size = text.size(0) # Add eos at the last of sequence xs = F.pad(text, [0, 1], "constant", self.padding_idx) for i, l in enumerate(text_lengths): xs[i, l] = self.eos ilens = text_lengths + 1 ys = feats olens = feats_lengths # make labels for stop prediction labels = make_pad_mask(olens - 1).to(ys.device, ys.dtype) labels = F.pad(labels, [0, 1], "constant", 1.0) # calculate transformer outputs after_outs, before_outs, logits = self._forward( xs=xs, ilens=ilens, ys=ys, olens=olens, spembs=spembs, sids=sids, lids=lids, ) # modifiy mod part of groundtruth olens_in = olens if self.reduction_factor > 1: assert olens.ge(self.reduction_factor).all( ), "Output length must be greater than or equal to reduction factor." olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] labels = labels[:, :max_olen] labels = torch.scatter(labels, 1, (olens - 1).unsqueeze(1), 1.0) # see #3388 # calculate loss values l1_loss, l2_loss, bce_loss = self.criterion(after_outs, before_outs, logits, ys, labels, olens) if self.loss_type == "L1": loss = l1_loss + bce_loss elif self.loss_type == "L2": loss = l2_loss + bce_loss elif self.loss_type == "L1+L2": loss = l1_loss + l2_loss + bce_loss else: raise ValueError("unknown --loss-type " + self.loss_type) stats = dict( l1_loss=l1_loss.item(), l2_loss=l2_loss.item(), bce_loss=bce_loss.item(), ) # calculate guided attention loss if self.use_guided_attn_loss: # calculate for encoder if "encoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.encoder.encoders)))): att_ws += [ self.encoder.encoders[layer_idx].self_attn. attn[:, :self.num_heads_applied_guided_attn] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_text, T_text) enc_attn_loss = self.attn_criterion(att_ws, ilens, ilens) loss = loss + enc_attn_loss stats.update(enc_attn_loss=enc_attn_loss.item()) # calculate for decoder if "decoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.decoder.decoders)))): att_ws += [ self.decoder.decoders[layer_idx].self_attn. attn[:, :self.num_heads_applied_guided_attn] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_feats, T_feats) dec_attn_loss = self.attn_criterion(att_ws, olens_in, olens_in) loss = loss + dec_attn_loss stats.update(dec_attn_loss=dec_attn_loss.item()) # calculate for encoder-decoder if "encoder-decoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.decoder.decoders)))): att_ws += [ self.decoder.decoders[layer_idx].src_attn. attn[:, :self.num_heads_applied_guided_attn] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_feats, T_text) enc_dec_attn_loss = self.attn_criterion( att_ws, ilens, olens_in) loss = loss + enc_dec_attn_loss stats.update(enc_dec_attn_loss=enc_dec_attn_loss.item()) # report extra information if self.use_scaled_pos_enc: stats.update( encoder_alpha=self.encoder.embed[-1].alpha.data.item(), decoder_alpha=self.decoder.embed[-1].alpha.data.item(), ) if not joint_training: stats.update(loss=loss.item()) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight else: return loss, stats, after_outs
def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, label: torch.Tensor, label_lengths: torch.Tensor, midi: torch.Tensor, midi_lengths: torch.Tensor, tempo: torch.Tensor, tempo_lengths: torch.Tensor, ds: torch.Tensor, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, flag_IsValid=False, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded character ids (B, Tmax). text_lengths (LongTensor): Batch of lengths of each input batch (B,). feats (Tensor): Batch of padded target features (B, Lmax, odim). feats_lengths (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim). ds: durations (LongTensor) Batch of padded durations (B, T_text + 1). // durations_lengths (LongTensor): Batch of duration lengths (B, T_text + 1). spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim). sids (Optional[Tensor]): Batch of speaker IDs (B, 1). lids (Optional[Tensor]): Batch of language IDs (B, 1). Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value. """ text = text[:, :text_lengths.max()] # for data-parallel feats = feats[:, :feats_lengths.max()] # for data-parallel midi = midi[:, :midi_lengths.max()] # for data-parallel label = label[:, :label_lengths.max()] # for data-parallel tempo = tempo[:, :tempo_lengths.max()] # for data-parallel batch_size = text.size(0) label_emb = self.phone_encode_layer(label) midi_emb = self.midi_encode_layer(midi) tempo_emb = self.tempo_encode_layer( tempo ) # FIX ME (Nan): the tempo of singing tacotron is BPM, should change later. ds_tensor = torch.tensor(ds.unsqueeze(-1), dtype=torch.float32).cuda() ds_emb = self.duration_encode_layer(ds_tensor) content_input = torch.cat( [label_emb, midi_emb], dim=-1) # cat this two or cat 4 into content_enc duration_tempo = torch.cat([tempo_emb, ds_emb], dim=-1) att_input = None if self.atype != "GDCA_location": att_input = torch.cat([content_input, duration_tempo], dim=-1) # TODO (Nan): add start & End token # # Add eos at the last of sequence # xs = F.pad(text, [0, 1], "constant", self.padding_idx) # for i, l in enumerate(text_lengths): # xs[i, l] = self.eos # ilens = text_lengths + 1 ys = feats olens = feats_lengths ilens = label_lengths # make labels for stop prediction # TODO: (Nan) change name stop_labels = make_pad_mask(olens - 1).to(ys.device, ys.dtype) stop_labels = F.pad(stop_labels, [0, 1], "constant", 1.0) # calculate tacotron2 outputs after_outs, before_outs, logits, att_ws = self._forward( content_input, att_input, duration_tempo, ilens, ys, olens, spembs) # modify mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_out = max(olens) ys = ys[:, :max_out] stop_labels = stop_labels[:, :max_out] stop_labels[:, -1] = 1.0 # make sure at least one frame has 1 else: ys = feats olens = feats_lengths # calculate taco2 loss l1_loss, mse_loss, bce_loss = self.taco2_loss(after_outs, before_outs, logits, ys, stop_labels, olens) if self.loss_type == "L1+L2": loss = l1_loss + mse_loss + bce_loss elif self.loss_type == "L1": loss = l1_loss + bce_loss elif self.loss_type == "L2": loss = mse_loss + bce_loss else: raise ValueError(f"unknown --loss-type {self.loss_type}") stats = dict( l1_loss=l1_loss.item(), mse_loss=mse_loss.item(), bce_loss=bce_loss.item(), ) # calculate attention loss if self.use_guided_attn_loss: # NOTE(kan-bayashi): length of output for auto-regressive # input will be changed when r > 1 if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens attn_loss = self.attn_loss(att_ws, ilens, olens_in) loss = loss + attn_loss stats.update(attn_loss=attn_loss.item()) stats.update(loss=loss.item()) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) if flag_IsValid == False: # train stage return loss, stats, weight else: # validation stage return loss, stats, weight, after_outs[:, :olens.max()], ys, olens
def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, label: torch.Tensor, label_lengths: torch.Tensor, midi: torch.Tensor, midi_lengths: torch.Tensor, tempo: Optional[torch.Tensor] = None, tempo_lengths: Optional[torch.Tensor] = None, ds: torch.Tensor = None, flag_IsValid=False, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded character ids (B, Tmax). text_lengths (LongTensor): Batch of lengths of each input batch (B,). feats (Tensor): Batch of padded target features (B, Lmax, odim). feats_lengths (LongTensor): Batch of the lengths of each target (B,). label label_lengths midi midi_lengths spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim). sids (Optional[Tensor]): Batch of speaker IDs (B, 1). lids (Optional[Tensor]): Batch of language IDs (B, 1). GS Fix: arguements from forward func. V.S. **batch from muskit_model.py label == durations Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value if not joint training else model outputs. """ text = text[:, :text_lengths.max()] # for data-parallel feats = feats[:, :feats_lengths.max()] # for data-parallel midi = midi[:, :midi_lengths.max()] # for data-parallel label = label[:, :label_lengths.max()] # for data-parallel # tempo = label[:, : tempo_lengths.max()] # for data-parallel batch_size = text.size(0) phone_emb, _ = self.phone_encoder(text) midi_emb = self.midi_encoder_input_layer(midi) label_emb = self.length_regulator(phone_emb, ds) # label_emb = self.enc_postnet( # phone_emb, label, text # ) midi_emb = F.leaky_relu(self.fc_midi(midi_emb)) if self.embed_integration_type == "add": hs = label_emb + midi_emb else: hs = torch.cat((label_emb, midi_emb), dim=-1) # hs = F.leaky_relu(self.projection(hs)) # integrate spk & lang embeddings if self.spks is not None: sid_embs = self.sid_emb(sids.view(-1)) hs = hs + sid_embs.unsqueeze(1) if self.langs is not None: lid_embs = self.lid_emb(lids.view(-1)) hs = hs + lid_embs.unsqueeze(1) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) pos_emb = self.pos(hs) pos_out = F.leaky_relu(self.fc_pos(pos_emb)) hs = hs + pos_out # logging.info(f'Tao - hs:{hs.shape}') # decoder # mel_output, att_weight = self.decoder( zs = self.decoder(hs, pos=(~make_pad_mask(midi_lengths)).to( device=hs.device)) # True mask # mel_output2 = self.double_mel(mel_output) zs = zs[:, self.reduction_factor - 1::self.reduction_factor] before_outs = F.leaky_relu( self.feat_out(zs).view(zs.size(0), -1, self.odim)) # logging.info(f'mel_output:{mel_output}') # zs = self.postnet(zs.transpose(1, 2)) # zs = zs.transpose(1, 2) # (B, T_feats//r, odim * r) -> (B, T_feats//r * r, odim) # postnet -> (B, T_feats//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) # modifiy mod part of groundtruth if self.reduction_factor > 1: assert feats_lengths.ge(self.reduction_factor).all( ), "Output length must be greater than or equal to reduction factor." olens = feats_lengths.new([ olen - olen % self.reduction_factor for olen in feats_lengths ]) max_olen = max(olens) ys = feats[:, :max_olen] else: ys = feats olens = feats_lengths # calculate loss values l1_loss, l2_loss = self.criterion(after_outs[:, :olens.max()], before_outs[:, :olens.max()], ys, olens) if self.loss_type == "L1": loss = l1_loss elif self.loss_type == "L2": loss = l2_loss elif self.loss_type == "L1+L2": loss = l1_loss + l2_loss else: raise ValueError("unknown --loss-type " + self.loss_type) stats = dict( loss=loss.item(), l1_loss=l1_loss.item(), l2_loss=l2_loss.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) if flag_IsValid == False: return loss, stats, weight else: return loss, stats, weight, after_outs[:, :olens.max()], ys, olens
def forward( self, input: torch.Tensor, ilens: torch.Tensor = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """STFT forward function. Args: input: (Batch, Nsamples) or (Batch, Nsample, Channels) ilens: (Batch) Returns: output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2) """ bs = input.size(0) if input.dim() == 3: multi_channel = True # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample) input = input.transpose(1, 2).reshape(-1, input.size(1)) else: multi_channel = False # NOTE(kamo): # The default behaviour of torch.stft is compatible with librosa.stft # about padding and scaling. # Note that it's different from scipy.signal.stft # output: (Batch, Freq, Frames, 2=real_imag) # or (Batch, Channel, Freq, Frames, 2=real_imag) if self.window is not None: window_func = getattr(torch, f"{self.window}_window") window = window_func(self.win_length, dtype=input.dtype, device=input.device) else: window = None output = torch.stft( input, n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, center=self.center, window=window, normalized=self.normalized, onesided=self.onesided, ) # output: (Batch, Freq, Frames, 2=real_imag) # -> (Batch, Frames, Freq, 2=real_imag) output = output.transpose(1, 2) if multi_channel: # output: (Batch * Channel, Frames, Freq, 2=real_imag) # -> (Batch, Frame, Channel, Freq, 2=real_imag) output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(1, 2) if ilens is not None: if self.center: pad = self.win_length // 2 ilens = ilens + 2 * pad olens = (ilens - self.win_length) // self.hop_length + 1 output.masked_fill_(make_pad_mask(olens, output, 1), 0.0) else: olens = None return output, olens
def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, label: torch.Tensor, label_lengths: torch.Tensor, midi: torch.Tensor, midi_lengths: torch.Tensor, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, joint_training: bool = False, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: TBD text (LongTensor): Batch of padded character ids (B, T_text). text_lengths (LongTensor): Batch of lengths of each input batch (B,). feats (Tensor): Batch of padded target features (B, T_feats, odim). feats_lengths (LongTensor): Batch of the lengths of each target (B,). spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim). sids (Optional[Tensor]): Batch of speaker IDs (B, 1). lids (Optional[Tensor]): Batch of language IDs (B, 1). joint_training (bool): Whether to perform joint training with vocoder. Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value if not joint training else model outputs. """ text = text[:, :text_lengths.max()] # for data-parallel feats = feats[:, :feats_lengths.max()] # for data-parallel batch_size = text.size(0) # Add eos at the last of sequence xs = F.pad(text, [0, 1], "constant", self.padding_idx) for i, l in enumerate(text_lengths): xs[i, l] = self.eos ilens = text_lengths + 1 ys = feats olens = feats_lengths # make labels for stop prediction labels = make_pad_mask(olens - 1).to(ys.device, ys.dtype) labels = F.pad(labels, [0, 1], "constant", 1.0) # calculate tacotron2 outputs after_outs, before_outs, logits, att_ws = self._forward( xs=xs, ilens=ilens, ys=ys, olens=olens, spembs=spembs, sids=sids, lids=lids, ) # modify mod part of groundtruth if self.reduction_factor > 1: assert olens.ge(self.reduction_factor).all( ), "Output length must be greater than or equal to reduction factor." olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_out = max(olens) ys = ys[:, :max_out] labels = labels[:, :max_out] labels = torch.scatter(labels, 1, (olens - 1).unsqueeze(1), 1.0) # see #3388 # calculate taco2 loss l1_loss, mse_loss, bce_loss = self.taco2_loss(after_outs, before_outs, logits, ys, labels, olens) if self.loss_type == "L1+L2": loss = l1_loss + mse_loss + bce_loss elif self.loss_type == "L1": loss = l1_loss + bce_loss elif self.loss_type == "L2": loss = mse_loss + bce_loss else: raise ValueError(f"unknown --loss-type {self.loss_type}") stats = dict( l1_loss=l1_loss.item(), mse_loss=mse_loss.item(), bce_loss=bce_loss.item(), ) # calculate attention loss if self.use_guided_attn_loss: # NOTE(kan-bayashi): length of output for auto-regressive # input will be changed when r > 1 if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens attn_loss = self.attn_loss(att_ws, ilens, olens_in) loss = loss + attn_loss stats.update(attn_loss=attn_loss.item()) if not joint_training: stats.update(loss=loss.item()) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight else: return loss, stats, after_outs