def forward(self, input: ComplexTensor, ilens: torch.Tensor): """Forward. Args: input (ComplexTensor): spectrum [Batch, T, (C,) F] ilens (torch.Tensor): input lengths [Batch] """ if not isinstance(input, ComplexTensor) and ( is_torch_1_9_plus and not torch.is_complex(input)): raise TypeError("Only support complex tensors for stft decoder") bs = input.size(0) if input.dim() == 4: multi_channel = True # input: (Batch, T, C, F) -> (Batch * C, T, F) input = input.transpose(1, 2).reshape(-1, input.size(1), input.size(3)) else: multi_channel = False wav, wav_lens = self.stft.inverse(input, ilens) if multi_channel: # wav: (Batch * C, Nsamples) -> (Batch, Nsamples, C) wav = wav.reshape(bs, -1, wav.size(1)).transpose(1, 2) return wav, wav_lens
def forward(self, xs: ComplexTensor, input_lengths: torch.LongTensor) \ -> torch.Tensor: assert xs.size(0) == input_lengths.size(0), (xs.size(0), input_lengths.size(0)) # xs: (B, C, T, D) C = xs.size(1) if self.feat_type == 'amplitude': # xs: (B, C, T, F) -> (B, C, T, F) xs = (xs.real ** 2 + xs.imag ** 2) ** 0.5 elif self.feat_type == 'power': # xs: (B, C, T, F) -> (B, C, T, F) xs = xs.real ** 2 + xs.imag ** 2 elif self.feat_type == 'log_power': # xs: (B, C, T, F) -> (B, C, T, F) xs = torch.log(xs.real ** 2 + xs.imag ** 2) elif self.feat_type == 'concat': # xs: (B, C, T, F) -> (B, C, T, 2 * F) xs = torch.cat([xs.real, xs.imag], -1) else: raise NotImplementedError(f'Not implemented: {self.feat_type}') if self.model_type in ('blstm', 'lstm'): # xs: (B, C, T, F) -> xs: (B, C, T, D) xs = self.net(xs, input_lengths) elif self.model_type == 'cnn': if self.channel_independent: # xs: (B, C, T, F) -> xs: (B * C, F, T) xs = xs.view(-1, *xs.size()[2:]).transpose(1, 2) # xs: (B * C, F, T) -> xs: (B * C, D, T) xs = self.net(xs) # xs: (B * C, D, T) -> (B, C, T, D) xs = xs.transpose(1, 2).contiguous().view( -1, C, xs.size(2), xs.size(1)) else: # xs: (B, C, T, F) -> xs: (B, C, T, F) xs = self.net(xs) else: raise NotImplementedError(f'Not implemented: {self.model_type}') # xs: (B, C, T, D) -> out:(B, C, T, F) out = self.linear(xs) # Zero padding out = torch.sigmoid(out) out.masked_fill(make_pad_mask(input_lengths, out, length_dim=2), 0) return out