def forward(self, input: torch.Tensor, ilens: torch.Tensor = None): """ STFT forward function. Args: input: (Batch, Nsamples) or (Batch, Nsample, Channels) ilens: (Batch) Returns: output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2) """ bs = input.size(0) if input.dim() == 3: multi_channel = True # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample) input = input.transpose(1, 2).reshape(-1, input.size(1)) else: multi_channel = False # output: (Batch, Freq, Frames, 2=real_imag) # or (Batch, Channel, Freq, Frames, 2=real_imag) if self.window is not None: window_func = getattr(torch, f"{self.window}_window") window = window_func(self.win_length, dtype=input.dtype, device=input.device) else: window = None output = torch.stft(input, n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, center=self.center, window=window, normalized=self.normalized, onesided=self.onesided) # output: (Batch, Freq, Frames, 2=real_imag) # -> (Batch, Frames, Freq, 2=real_imag) output = output.transpose(1, 2) if multi_channel: # output: (Batch * Channel, Frames, Freq, 2=real_imag) # -> (Batch, Frame, Channel, Freq, 2=real_imag) output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(1, 2) if ilens is not None: if self.center: pad = self.win_length // 2 ilens = ilens + 2 * pad olens = (ilens - self.win_length) // self.hop_length + 1 output.masked_fill_(make_pad_mask(olens, output, 1), 0.0) else: olens = None return output, olens
def _forward(self, xs, ilens, ys=None, olens=None, ds=None, ps=None, es=None, speaker_embeddings=None, is_inference=False, alpha=1.0): x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, speaker_embeddings) d_masks = make_pad_mask(ilens).to(xs.device) if self.stop_gradient_from_pitch_predictor: p_outs = self.pitch_predictor(hs.detach(), d_masks.unsqueeze(-1)) else: p_outs = self.pitch_predictor(hs, d_masks.unsqueeze(-1)) if self.stop_gradient_from_energy_predictor: e_outs = self.energy_predictor(hs.detach(), d_masks.unsqueeze(-1)) else: e_outs = self.energy_predictor(hs, d_masks.unsqueeze(-1)) if is_inference: d_outs = self.duration_predictor.inference(hs, d_masks) p_embs = self.pitch_embed(p_outs.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(e_outs.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs hs = self.length_regulator(hs, d_outs, alpha) else: d_outs = self.duration_predictor(hs, d_masks) p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs hs = self.length_regulator(hs, ds) if olens is not None and not is_inference: if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens h_masks = self._source_mask(olens_in) else: h_masks = None zs, _ = self.decoder(hs, h_masks) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) return before_outs, after_outs, d_outs, p_outs, e_outs