def test_inverse(): layer = Stft() x = torch.randn(2, 400, requires_grad=True) y, _ = layer(x) x_lengths = torch.IntTensor([400, 300]) raw, _ = layer.inverse(y, x_lengths) raw, _ = layer.inverse(y)
class STFTDecoder(AbsDecoder): """STFT decoder for speech enhancement and separation""" def __init__( self, n_fft: int = 512, win_length: int = None, hop_length: int = 128, window="hann", center: bool = True, normalized: bool = False, onesided: bool = True, ): super().__init__() self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, ) def forward(self, input: ComplexTensor, ilens: torch.Tensor): """Forward. Args: input (ComplexTensor): spectrum [Batch, T, (C,) F] ilens (torch.Tensor): input lengths [Batch] """ if not isinstance(input, ComplexTensor) and ( is_torch_1_9_plus and not torch.is_complex(input)): raise TypeError("Only support complex tensors for stft decoder") bs = input.size(0) if input.dim() == 4: multi_channel = True # input: (Batch, T, C, F) -> (Batch * C, T, F) input = input.transpose(1, 2).reshape(-1, input.size(1), input.size(3)) else: multi_channel = False wav, wav_lens = self.stft.inverse(input, ilens) if multi_channel: # wav: (Batch * C, Nsamples) -> (Batch, Nsamples, C) wav = wav.reshape(bs, -1, wav.size(1)).transpose(1, 2) return wav, wav_lens
class STFTDecoder(AbsDecoder): """STFT decoder for speech enhancement and separation""" def __init__( self, n_fft: int = 512, win_length: int = None, hop_length: int = 128, window="hann", center: bool = True, normalized: bool = False, onesided: bool = True, ): super().__init__() self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, ) def forward(self, input: ComplexTensor, ilens: torch.Tensor): """Forward. Args: input (ComplexTensor): spectrum [Batch, T, F] ilens (torch.Tensor): input lengths [Batch] """ if not isinstance(input, ComplexTensor) and ( is_torch_1_9_plus and not torch.is_complex(input) ): raise TypeError("Only support complex tensors for stft decoder") wav, wav_lens = self.stft.inverse(input, ilens) return wav, wav_lens
class BeamformerNet(AbsEnhancement): """TF Masking based beamformer """ def __init__( self, num_spk: int = 1, normalize_input: bool = False, mask_type: str = "IPM^2", # STFT options n_fft: int = 512, win_length: int = None, hop_length: int = 128, center: bool = True, window: Optional[str] = "hann", pad_mode: str = "reflect", normalized: bool = False, onesided: bool = True, # Dereverberation options use_wpe: bool = False, wnet_type: str = "blstmp", wlayers: int = 3, wunits: int = 300, wprojs: int = 320, wdropout_rate: float = 0.0, taps: int = 5, delay: int = 3, use_dnn_mask_for_wpe: bool = True, # Beamformer options use_beamformer: bool = True, bnet_type: str = "blstmp", blayers: int = 3, bunits: int = 300, bprojs: int = 320, badim: int = 320, ref_channel: int = -1, use_noise_mask: bool = True, beamformer_type="mvdr", bdropout_rate=0.0, ): super(BeamformerNet, self).__init__() self.mask_type = mask_type self.num_spk = num_spk self.num_bin = n_fft // 2 + 1 self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, center=center, window=window, pad_mode=pad_mode, normalized=normalized, onesided=onesided, ) self.normalize_input = normalize_input self.use_beamformer = use_beamformer self.use_wpe = use_wpe if self.use_wpe: if use_dnn_mask_for_wpe: # Use DNN for power estimation iterations = 1 else: # Performing as conventional WPE, without DNN Estimator iterations = 2 self.wpe = DNN_WPE( wtype=wnet_type, widim=self.num_bin, wunits=wunits, wprojs=wprojs, wlayers=wlayers, taps=taps, delay=delay, dropout_rate=wdropout_rate, iterations=iterations, use_dnn_mask=use_dnn_mask_for_wpe, ) else: self.wpe = None self.ref_channel = ref_channel if self.use_beamformer: self.beamformer = DNN_Beamformer( btype=bnet_type, bidim=self.num_bin, bunits=bunits, bprojs=bprojs, blayers=blayers, num_spk=num_spk, use_noise_mask=use_noise_mask, dropout_rate=bdropout_rate, badim=badim, ref_channel=ref_channel, beamformer_type=beamformer_type, btaps=taps, bdelay=delay, ) else: self.beamformer = None def forward(self, input: torch.Tensor, ilens: torch.Tensor): """Forward. Args: input (torch.Tensor): mixed speech [Batch, Nsample, Channel] ilens (torch.Tensor): input lengths [Batch] Returns: enhanced speech (single-channel): torch.Tensor or List[torch.Tensor] output lengths predcited masks: OrderedDict[ 'dereverb': torch.Tensor(Batch, Frames, Channel, Freq), 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), 'noise1': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # wave -> stft -> magnitude specturm input_spectrum, flens = self.stft(input, ilens) # (Batch, Frames, Freq) or (Batch, Frames, Channels, Freq) input_spectrum = ComplexTensor(input_spectrum[..., 0], input_spectrum[..., 1]) if self.normalize_input: input_spectrum = input_spectrum / abs(input_spectrum).max() enhanced = input_spectrum masks = OrderedDict() if input_spectrum.dim() == 3: # single-channel input if self.use_wpe: # (B, T, F) enhanced, flens, mask_w = self.wpe(input_spectrum.unsqueeze(-2), flens) enhanced = enhanced.squeeze(-2) if mask_w is not None: masks["dereverb"] = mask_w.squeeze(-2) elif input_spectrum.dim() == 4: # multi-channel input # 1. WPE if self.use_wpe: # (B, T, C, F) enhanced, flens, mask_w = self.wpe(input_spectrum, flens) if mask_w is not None: masks["dereverb"] = mask_w # 2. Beamformer if self.use_beamformer: # enhanced: (B, T, C, F) -> (B, T, F) enhanced, flens, masks_b = self.beamformer(enhanced, flens) for spk in range(self.num_spk): masks["spk{}".format(spk + 1)] = masks_b[spk] if len(masks_b) > self.num_spk: masks["noise1"] = masks_b[self.num_spk] else: raise ValueError( "Invalid spectrum dimension: {}".format(input_spectrum.shape) ) # Convert ComplexTensor to torch.Tensor # (B, T, F) -> (B, T, F, 2) if isinstance(enhanced, list): # multi-speaker output enhanced = [torch.stack([enh.real, enh.imag], dim=-1) for enh in enhanced] else: # single-speaker output enhanced = torch.stack([enhanced.real, enhanced.imag], dim=-1).float() return enhanced, flens, masks def forward_rawwav(self, input: torch.Tensor, ilens: torch.Tensor): """Output with wavformes. Args: input (torch.Tensor): mixed speech [Batch, Nsample, Channel] ilens (torch.Tensor): input lengths [Batch] Returns: predcited speech wavs (single-channel): torch.Tensor(Batch, Nsamples), or List[torch.Tensor(Batch, Nsamples)] output lengths predcited masks: OrderedDict[ 'dereverb': torch.Tensor(Batch, Frames, Channel, Freq), 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), 'noise1': torch.Tensor(Batch, Frames, Channel, Freq), ] """ enhanced, flens, masks = self.forward(input, ilens) if isinstance(enhanced, list): # multi-speaker input predicted_wavs = [self.stft.inverse(ps, ilens)[0] for ps in enhanced] else: # single-speaker input predicted_wavs = self.stft.inverse(enhanced, ilens)[0] return predicted_wavs, ilens, masks
class TFMaskingNet(AbsEnhancement): """TF Masking Speech Separation Net.""" def __init__( self, n_fft: int = 512, win_length: int = None, hop_length: int = 128, rnn_type: str = "blstm", layer: int = 3, unit: int = 512, dropout: float = 0.0, num_spk: int = 2, nonlinear: str = "sigmoid", utt_mvn: bool = False, mask_type: str = "IRM", loss_type: str = "mask_mse", ): super(TFMaskingNet, self).__init__() self.num_spk = num_spk self.num_bin = n_fft // 2 + 1 self.mask_type = mask_type self.loss_type = loss_type if loss_type not in ("mask_mse", "magnitude", "spectrum"): raise ValueError("Unsupported loss type: %s" % loss_type) self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, ) if utt_mvn: self.utt_mvn = UtteranceMVN(norm_means=True, norm_vars=True) else: self.utt_mvn = None self.rnn = RNN( idim=self.num_bin, elayers=layer, cdim=unit, hdim=unit, dropout=dropout, typ=rnn_type, ) self.linear = torch.nn.ModuleList( [torch.nn.Linear(unit, self.num_bin) for _ in range(self.num_spk)]) if nonlinear not in ("sigmoid", "relu", "tanh"): raise ValueError("Not supporting nonlinear={}".format(nonlinear)) self.nonlinear = { "sigmoid": torch.nn.Sigmoid(), "relu": torch.nn.ReLU(), "tanh": torch.nn.Tanh(), }[nonlinear] def forward(self, input: torch.Tensor, ilens: torch.Tensor): """Forward. Args: input (torch.Tensor): mixed speech [Batch, sample] ilens (torch.Tensor): input lengths [Batch] Returns: separated (list[ComplexTensor]): [(B, T, F), ...] ilens (torch.Tensor): (B,) predcited masks: OrderedDict[ 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # wave -> stft -> magnitude specturm input_spectrum, flens = self.stft(input, ilens) input_spectrum = ComplexTensor(input_spectrum[..., 0], input_spectrum[..., 1]) input_magnitude = abs(input_spectrum) input_phase = input_spectrum / (input_magnitude + 10e-12) # apply utt mvn if self.utt_mvn: input_magnitude_mvn, fle = self.utt_mvn(input_magnitude, flens) else: input_magnitude_mvn = input_magnitude # predict masks for each speaker x, flens, _ = self.rnn(input_magnitude_mvn, flens) masks = [] for linear in self.linear: y = linear(x) y = self.nonlinear(y) masks.append(y) if self.training and self.loss_type.startswith("mask"): predicted_spectrums = None else: # apply mask predict_magnitude = [input_magnitude * m for m in masks] predicted_spectrums = [ input_phase * pm for pm in predict_magnitude ] masks = OrderedDict( zip(["spk{}".format(i + 1) for i in range(len(masks))], masks)) return predicted_spectrums, flens, masks def forward_rawwav( self, input: torch.Tensor, ilens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Output with waveforms. Args: input (torch.Tensor): mixed speech [Batch, sample] ilens (torch.Tensor): input lengths [Batch] Returns: predcited speech [Batch, num_speaker, sample] output lengths predcited masks: OrderedDict[ 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # predict spectrum for each speaker predicted_spectrums, flens, masks = self.forward(input, ilens) if predicted_spectrums is None: predicted_wavs = None elif isinstance(predicted_spectrums, list): # multi-speaker input predicted_wavs = [ self.stft.inverse(ps, ilens)[0] for ps in predicted_spectrums ] else: # single-speaker input predicted_wavs = self.stft.inverse(predicted_spectrums, ilens)[0] return predicted_wavs, ilens, masks
class TFMaskingTransformer(AbsEnhancement): """TF Masking Speech Separation Net.""" def __init__( self, n_fft: int = 256, win_length: int = None, hop_length: int = 128, dnn_type: str = "transformer", #layer: int = 3, #unit: int = 512, dropout: float = 0.0, num_spk: int = 2, nonlinear: str = "sigmoid", utt_mvn: bool = False, mask_type: str = "IRM", loss_type: str = "mask_mse", d_model: int = 256, nhead: int = 4, linear_units: int = 2048, num_layers: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, input_layer: Optional[str] = "linear", pos_enc_class=PositionalEncoding, normalize_before: bool = True, concat_after: bool = False, positionwise_layer_type: str = "linear", positionwise_conv_kernel_size: int = 1, padding_idx: int = -1, ): super(TFMaskingTransformer, self).__init__() self.num_spk = num_spk self.num_bin = n_fft // 2 + 1 self.mask_type = mask_type self.loss_type = loss_type if loss_type not in ("mask_mse", "magnitude", "spectrum"): raise ValueError("Unsupported loss type: %s" % loss_type) self.stft = Stft(n_fft=n_fft, win_length=win_length, hop_length=hop_length,) if utt_mvn: self.utt_mvn = UtteranceMVN(norm_means=True, norm_vars=True) else: self.utt_mvn = None #self.rnn = RNN( # idim=self.num_bin, # elayers=layer, # cdim=unit, # hdim=unit, # dropout=dropout, # typ=rnn_type, #) self.encoder = TransformerEncoder( input_size=self.num_bin, output_size=d_model, attention_heads=nhead, linear_units=linear_units, num_blocks=num_layers, positional_dropout_rate=positional_dropout_rate, attention_dropout_rate=attention_dropout_rate, input_layer=input_layer, normalize_before=normalize_before, concat_after=concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, padding_idx=padding_idx, ) self.linear = torch.nn.ModuleList( [torch.nn.Linear(d_model, self.num_bin) for _ in range(self.num_spk)] ) if nonlinear not in ("sigmoid", "relu", "tanh"): raise ValueError("Not supporting nonlinear={}".format(nonlinear)) self.nonlinear = { "sigmoid": torch.nn.Sigmoid(), "relu": torch.nn.ReLU(), "tanh": torch.nn.Tanh(), }[nonlinear] def forward(self, input: torch.Tensor, ilens: torch.Tensor): """Forward. Args: input (torch.Tensor): mixed speech [Batch, sample] ilens (torch.Tensor): input lengths [Batch] Returns: separated (list[ComplexTensor]): [(B, T, F), ...] ilens (torch.Tensor): (B,) predcited masks: OrderedDict[ 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # wave -> stft -> magnitude specturm input_spectrum, flens = self.stft(input, ilens) input_spectrum = ComplexTensor(input_spectrum[..., 0], input_spectrum[..., 1]) input_magnitude = abs(input_spectrum) input_phase = input_spectrum / (input_magnitude + 10e-12) # apply utt mvn if self.utt_mvn: input_magnitude_mvn, fle = self.utt_mvn(input_magnitude, flens) else: input_magnitude_mvn = input_magnitude # predict masks for each speaker #x, flens, _ = self.rnn(input_magnitude_mvn, flens) x, olens, _ = self.encoder(input_magnitude_mvn, flens) masks = [] for linear in self.linear: y = linear(x) y = self.nonlinear(y) masks.append(y) if self.training and self.loss_type.startswith("mask"): predicted_spectrums = None else: # apply mask predict_magnitude = [input_magnitude * m for m in masks] predicted_spectrums = [input_phase * pm for pm in predict_magnitude] masks = OrderedDict( zip(["spk{}".format(i + 1) for i in range(len(masks))], masks) ) return predicted_spectrums, flens, masks def forward_rawwav( self, input: torch.Tensor, ilens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Output with waveforms. Args: input (torch.Tensor): mixed speech [Batch, sample] ilens (torch.Tensor): input lengths [Batch] Returns: predcited speech [Batch, num_speaker, sample] output lengths predcited masks: OrderedDict[ 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # predict spectrum for each speaker predicted_spectrums, flens, masks = self.forward(input, ilens) if predicted_spectrums is None: predicted_wavs = None elif isinstance(predicted_spectrums, list): # multi-speaker input predicted_wavs = [ self.stft.inverse(ps, ilens)[0] for ps in predicted_spectrums ] else: # single-speaker input predicted_wavs = self.stft.inverse(predicted_spectrums, ilens)[0] return predicted_wavs, ilens, masks