def test_inverse(): layer = Stft() x = torch.randn(2, 400, requires_grad=True) y, _ = layer(x) x_lengths = torch.IntTensor([400, 300]) raw, _ = layer.inverse(y, x_lengths) raw, _ = layer.inverse(y)
def __init__( self, n_fft: int = 512, win_length: int = None, hop_length: int = 128, rnn_type: str = "blstm", layer: int = 3, unit: int = 512, dropout: float = 0.0, num_spk: int = 2, nonlinear: str = "sigmoid", utt_mvn: bool = False, mask_type: str = "IRM", loss_type: str = "mask_mse", ): super(TFMaskingNet, self).__init__() self.num_spk = num_spk self.num_bin = n_fft // 2 + 1 self.mask_type = mask_type self.loss_type = loss_type if loss_type not in ("mask_mse", "magnitude", "spectrum"): raise ValueError("Unsupported loss type: %s" % loss_type) self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, ) if utt_mvn: self.utt_mvn = UtteranceMVN(norm_means=True, norm_vars=True) else: self.utt_mvn = None self.rnn = RNN( idim=self.num_bin, elayers=layer, cdim=unit, hdim=unit, dropout=dropout, typ=rnn_type, ) self.linear = torch.nn.ModuleList( [torch.nn.Linear(unit, self.num_bin) for _ in range(self.num_spk)]) if nonlinear not in ("sigmoid", "relu", "tanh"): raise ValueError("Not supporting nonlinear={}".format(nonlinear)) self.nonlinear = { "sigmoid": torch.nn.Sigmoid(), "relu": torch.nn.ReLU(), "tanh": torch.nn.Tanh(), }[nonlinear]
def test_forward(): layer = Stft(win_length=4, hop_length=2, n_fft=4) x = torch.randn(2, 30) y, _ = layer(x) assert y.shape == (2, 16, 3, 2) y, ylen = layer(x, torch.tensor([30, 15], dtype=torch.long)) assert (ylen == torch.tensor((16, 8), dtype=torch.long)).all()
def __init__( self, window_sz=[512], hop_sz=None, eps=1e-8, time_domain_weight=0.5, name=None, only_for_test=False, ): _name = "TD_L1_loss" if name is None else name super(MultiResL1SpecLoss, self).__init__(_name, only_for_test=only_for_test) assert all([x % 2 == 0 for x in window_sz]) self.window_sz = window_sz if hop_sz is None: self.hop_sz = [x // 2 for x in window_sz] else: self.hop_sz = hop_sz self.time_domain_weight = time_domain_weight self.eps = eps self.stft_encoders = torch.nn.ModuleList([]) for w, h in zip(self.window_sz, self.hop_sz): stft_enc = Stft( n_fft=w, win_length=w, hop_length=h, window=None, center=True, normalized=False, onesided=True, ) self.stft_encoders.append(stft_enc)
def __init__( self, fs: Union[int, str] = 22050, n_fft: int = 1024, win_length: int = None, hop_length: int = 256, window: Optional[str] = "hann", center: bool = True, normalized: bool = False, onesided: bool = True, use_token_averaged_energy: bool = True, ): assert check_argument_types() super().__init__() if isinstance(fs, str): fs = humanfriendly.parse_size(fs) self.fs = fs self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length self.window = window self.use_token_averaged_energy = use_token_averaged_energy self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, )
def test_gevd(ch): stft = Stft( n_fft=8, win_length=None, hop_length=2, center=True, window="hann", normalized=False, onesided=True, ) torch.random.manual_seed(0) x = random_speech[..., :ch] ilens = torch.LongTensor([16, 12]) # (B, T, C, F) -> (B, F, C, T) X = torch.complex(*torch.unbind(stft(x, ilens)[0], dim=-1)).transpose( -1, -3) # (B, F, C, C) Phi_X = torch.einsum("...ct,...et->...ce", [X, X.conj()]) is_singular = True while is_singular: N = torch.randn_like(X) Phi_N = torch.einsum("...ct,...et->...ce", [N, N.conj()]) is_singular = not torch.linalg.matrix_rank(Phi_N).eq(ch).all() # Phi_N = torch.eye(ch, dtype=Phi_X.dtype).view(1, 1, ch, ch).expand_as(Phi_X) # e_val: (B, F, C) # e_vec: (B, F, C, C) e_val, e_vec = generalized_eigenvalue_decomposition(Phi_X, Phi_N) e_val = e_val.to(dtype=e_vec.dtype) assert torch.allclose( torch.matmul(Phi_X, e_vec), torch.matmul(torch.matmul(Phi_N, e_vec), e_val.diag_embed()), )
def __init__( self, n_fft: int = 1024, win_length: int = None, hop_length: int = 256, window: Optional[str] = "hann", center: bool = True, pad_mode: str = "reflect", normalized: bool = False, onesided: bool = True, ): assert check_argument_types() super().__init__() self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length self.window = window self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, pad_mode=pad_mode, normalized=normalized, onesided=onesided, ) self.n_fft = n_fft
def __init__( self, fs: Union[int, str] = 16000, n_fft: int = 512, win_length: int = None, hop_length: int = 128, window: Optional[str] = "hann", center: bool = True, normalized: bool = False, onesided: bool = True, n_mels: int = 80, fmin: int = None, fmax: int = None, htk: bool = False, frontend_conf: Optional[dict] = get_default_kwargs(Frontend), apply_stft: bool = True, ): assert check_argument_types() super().__init__() if isinstance(fs, str): fs = humanfriendly.parse_size(fs) # Deepcopy (In general, dict shouldn't be used as default arg) frontend_conf = copy.deepcopy(frontend_conf) self.hop_length = hop_length if apply_stft: self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, center=center, window=window, normalized=normalized, onesided=onesided, ) else: self.stft = None self.apply_stft = apply_stft if frontend_conf is not None: self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf) else: self.frontend = None self.logmel = LogMel( fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, ) self.n_mels = n_mels self.frontend_type = "default"
def __init__( self, n_fft: int = 512, win_length: int = None, hop_length: int = 128, window="hann", center: bool = True, normalized: bool = False, onesided: bool = True, ): super().__init__() self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, )
def test_get_rtf(ch, mode): if not is_torch_1_9_plus and mode == "evd": # torch 1.9.0+ is required for "evd" mode return if mode == "evd": complex_wrapper = torch.complex complex_module = torch else: complex_wrapper = ComplexTensor complex_module = FC stft = Stft( n_fft=8, win_length=None, hop_length=2, center=True, window="hann", normalized=False, onesided=True, ) torch.random.manual_seed(0) x = random_speech[..., :ch] ilens = torch.LongTensor([16, 12]) # (B, T, C, F) -> (B, F, C, T) X = complex_wrapper(*torch.unbind(stft(x, ilens)[0], dim=-1)).transpose( -1, -3) # (B, F, C, C) Phi_X = complex_module.einsum("...ct,...et->...ce", [X, X.conj()]) is_singular = True while is_singular: N = complex_wrapper(torch.randn_like(X.real), torch.randn_like(X.imag)) Phi_N = complex_module.einsum("...ct,...et->...ce", [N, N.conj()]) is_singular = not np.all(np.linalg.matrix_rank(Phi_N.numpy()) == ch) # (B, F, C, 1) rtf = get_rtf(Phi_X, Phi_N, mode=mode, reference_vector=0, iterations=20) if is_torch_1_1_plus: rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True).values + 1e-15) else: rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True)[0] + 1e-15) # rtf \approx Phi_N MaxEigVec(Phi_N^-1 @ Phi_X) if is_torch_1_1_plus: # torch.solve is required, which is only available after pytorch 1.1.0+ mat = solve(Phi_X, Phi_N)[0] max_eigenvec = solve(rtf, Phi_N)[0] else: mat = complex_module.matmul(Phi_N.inverse2(), Phi_X) max_eigenvec = complex_module.matmul(Phi_N.inverse2(), rtf) factor = complex_module.matmul(mat, max_eigenvec) assert complex_module.allclose( complex_module.matmul(max_eigenvec, factor.transpose(-1, -2)), complex_module.matmul(factor, max_eigenvec.transpose(-1, -2)), )
class STFTDecoder(AbsDecoder): """STFT decoder for speech enhancement and separation""" def __init__( self, n_fft: int = 512, win_length: int = None, hop_length: int = 128, window="hann", center: bool = True, normalized: bool = False, onesided: bool = True, ): super().__init__() self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, ) def forward(self, input: ComplexTensor, ilens: torch.Tensor): """Forward. Args: input (ComplexTensor): spectrum [Batch, T, (C,) F] ilens (torch.Tensor): input lengths [Batch] """ if not isinstance(input, ComplexTensor) and ( is_torch_1_9_plus and not torch.is_complex(input)): raise TypeError("Only support complex tensors for stft decoder") bs = input.size(0) if input.dim() == 4: multi_channel = True # input: (Batch, T, C, F) -> (Batch * C, T, F) input = input.transpose(1, 2).reshape(-1, input.size(1), input.size(3)) else: multi_channel = False wav, wav_lens = self.stft.inverse(input, ilens) if multi_channel: # wav: (Batch * C, Nsamples) -> (Batch, Nsamples, C) wav = wav.reshape(bs, -1, wav.size(1)).transpose(1, 2) return wav, wav_lens
def __init__( self, fs: Union[int, str] = 16000, n_fft: int = 1024, win_length: int = None, hop_length: int = 256, window: Optional[str] = "hann", center: bool = True, pad_mode: str = "reflect", normalized: bool = False, onesided: bool = True, n_mels: int = 80, fmin: Optional[int] = 80, fmax: Optional[int] = 7600, htk: bool = False, ): assert check_argument_types() super().__init__() if isinstance(fs, str): fs = humanfriendly.parse_size(fs) self.fs = fs self.n_mels = n_mels self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length self.window = window self.fmin = fmin self.fmax = fmax self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, pad_mode=pad_mode, normalized=normalized, onesided=onesided, ) self.logmel = LogMel( fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, log_base=10.0, )
def test_get_rtf(ch): stft = Stft( n_fft=8, win_length=None, hop_length=2, center=True, window="hann", normalized=False, onesided=True, ) torch.random.manual_seed(0) x = random_speech[..., :ch] n = torch.rand(2, 16, ch, dtype=torch.double) ilens = torch.LongTensor([16, 12]) # (B, T, C, F) -> (B, F, C, T) X = ComplexTensor(*torch.unbind(stft(x, ilens)[0], dim=-1)).transpose(-1, -3) N = ComplexTensor(*torch.unbind(stft(n, ilens)[0], dim=-1)).transpose(-1, -3) # (B, F, C, C) Phi_X = FC.einsum("...ct,...et->...ce", [X, X.conj()]) Phi_N = FC.einsum("...ct,...et->...ce", [N, N.conj()]) # (B, F, C, 1) rtf = get_rtf(Phi_X, Phi_N, reference_vector=0, iterations=20) if is_torch_1_1_plus: rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True).values + 1e-15) else: rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True)[0] + 1e-15) # rtf \approx Phi_N MaxEigVec(Phi_N^-1 @ Phi_X) if is_torch_1_1_plus: # torch.solve is required, which is only available after pytorch 1.1.0+ mat = FC.solve(Phi_X, Phi_N)[0] max_eigenvec = FC.solve(rtf, Phi_N)[0] else: mat = FC.matmul(Phi_N.inverse2(), Phi_X) max_eigenvec = FC.matmul(Phi_N.inverse2(), rtf) factor = FC.matmul(mat, max_eigenvec) assert FC.allclose( FC.matmul(max_eigenvec, factor.transpose(-1, -2)), FC.matmul(factor, max_eigenvec.transpose(-1, -2)), )
class STFTDecoder(AbsDecoder): """STFT decoder for speech enhancement and separation""" def __init__( self, n_fft: int = 512, win_length: int = None, hop_length: int = 128, window="hann", center: bool = True, normalized: bool = False, onesided: bool = True, ): super().__init__() self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, ) def forward(self, input: ComplexTensor, ilens: torch.Tensor): """Forward. Args: input (ComplexTensor): spectrum [Batch, T, F] ilens (torch.Tensor): input lengths [Batch] """ if not isinstance(input, ComplexTensor) and ( is_torch_1_9_plus and not torch.is_complex(input) ): raise TypeError("Only support complex tensors for stft decoder") wav, wav_lens = self.stft.inverse(input, ilens) return wav, wav_lens
def __init__( self, n_fft: int = 512, win_length: int = None, hop_length: int = 128, window="hann", center: bool = True, normalized: bool = False, onesided: bool = True, use_builtin_complex: bool = True, ): super().__init__() self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, ) self._output_dim = n_fft // 2 + 1 if onesided else n_fft self.use_builtin_complex = use_builtin_complex
def test_repr(): print(Stft())
class TFMaskingTransformer(AbsEnhancement): """TF Masking Speech Separation Net.""" def __init__( self, n_fft: int = 256, win_length: int = None, hop_length: int = 128, dnn_type: str = "transformer", #layer: int = 3, #unit: int = 512, dropout: float = 0.0, num_spk: int = 2, nonlinear: str = "sigmoid", utt_mvn: bool = False, mask_type: str = "IRM", loss_type: str = "mask_mse", d_model: int = 256, nhead: int = 4, linear_units: int = 2048, num_layers: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, input_layer: Optional[str] = "linear", pos_enc_class=PositionalEncoding, normalize_before: bool = True, concat_after: bool = False, positionwise_layer_type: str = "linear", positionwise_conv_kernel_size: int = 1, padding_idx: int = -1, ): super(TFMaskingTransformer, self).__init__() self.num_spk = num_spk self.num_bin = n_fft // 2 + 1 self.mask_type = mask_type self.loss_type = loss_type if loss_type not in ("mask_mse", "magnitude", "spectrum"): raise ValueError("Unsupported loss type: %s" % loss_type) self.stft = Stft(n_fft=n_fft, win_length=win_length, hop_length=hop_length,) if utt_mvn: self.utt_mvn = UtteranceMVN(norm_means=True, norm_vars=True) else: self.utt_mvn = None #self.rnn = RNN( # idim=self.num_bin, # elayers=layer, # cdim=unit, # hdim=unit, # dropout=dropout, # typ=rnn_type, #) self.encoder = TransformerEncoder( input_size=self.num_bin, output_size=d_model, attention_heads=nhead, linear_units=linear_units, num_blocks=num_layers, positional_dropout_rate=positional_dropout_rate, attention_dropout_rate=attention_dropout_rate, input_layer=input_layer, normalize_before=normalize_before, concat_after=concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, padding_idx=padding_idx, ) self.linear = torch.nn.ModuleList( [torch.nn.Linear(d_model, self.num_bin) for _ in range(self.num_spk)] ) if nonlinear not in ("sigmoid", "relu", "tanh"): raise ValueError("Not supporting nonlinear={}".format(nonlinear)) self.nonlinear = { "sigmoid": torch.nn.Sigmoid(), "relu": torch.nn.ReLU(), "tanh": torch.nn.Tanh(), }[nonlinear] def forward(self, input: torch.Tensor, ilens: torch.Tensor): """Forward. Args: input (torch.Tensor): mixed speech [Batch, sample] ilens (torch.Tensor): input lengths [Batch] Returns: separated (list[ComplexTensor]): [(B, T, F), ...] ilens (torch.Tensor): (B,) predcited masks: OrderedDict[ 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # wave -> stft -> magnitude specturm input_spectrum, flens = self.stft(input, ilens) input_spectrum = ComplexTensor(input_spectrum[..., 0], input_spectrum[..., 1]) input_magnitude = abs(input_spectrum) input_phase = input_spectrum / (input_magnitude + 10e-12) # apply utt mvn if self.utt_mvn: input_magnitude_mvn, fle = self.utt_mvn(input_magnitude, flens) else: input_magnitude_mvn = input_magnitude # predict masks for each speaker #x, flens, _ = self.rnn(input_magnitude_mvn, flens) x, olens, _ = self.encoder(input_magnitude_mvn, flens) masks = [] for linear in self.linear: y = linear(x) y = self.nonlinear(y) masks.append(y) if self.training and self.loss_type.startswith("mask"): predicted_spectrums = None else: # apply mask predict_magnitude = [input_magnitude * m for m in masks] predicted_spectrums = [input_phase * pm for pm in predict_magnitude] masks = OrderedDict( zip(["spk{}".format(i + 1) for i in range(len(masks))], masks) ) return predicted_spectrums, flens, masks def forward_rawwav( self, input: torch.Tensor, ilens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Output with waveforms. Args: input (torch.Tensor): mixed speech [Batch, sample] ilens (torch.Tensor): input lengths [Batch] Returns: predcited speech [Batch, num_speaker, sample] output lengths predcited masks: OrderedDict[ 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # predict spectrum for each speaker predicted_spectrums, flens, masks = self.forward(input, ilens) if predicted_spectrums is None: predicted_wavs = None elif isinstance(predicted_spectrums, list): # multi-speaker input predicted_wavs = [ self.stft.inverse(ps, ilens)[0] for ps in predicted_spectrums ] else: # single-speaker input predicted_wavs = self.stft.inverse(predicted_spectrums, ilens)[0] return predicted_wavs, ilens, masks
def test_backward_not_leaf_in(): layer = Stft() x = torch.randn(2, 400, requires_grad=True) x = x + 2 y, _ = layer(x) y.sum().backward()
def __init__( self, n_fft: int = 256, win_length: int = None, hop_length: int = 128, dnn_type: str = "transformer", #layer: int = 3, #unit: int = 512, dropout: float = 0.0, num_spk: int = 2, nonlinear: str = "sigmoid", utt_mvn: bool = False, mask_type: str = "IRM", loss_type: str = "mask_mse", d_model: int = 256, nhead: int = 4, linear_units: int = 2048, num_layers: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, input_layer: Optional[str] = "linear", pos_enc_class=PositionalEncoding, normalize_before: bool = True, concat_after: bool = False, positionwise_layer_type: str = "linear", positionwise_conv_kernel_size: int = 1, padding_idx: int = -1, ): super(TFMaskingTransformer, self).__init__() self.num_spk = num_spk self.num_bin = n_fft // 2 + 1 self.mask_type = mask_type self.loss_type = loss_type if loss_type not in ("mask_mse", "magnitude", "spectrum"): raise ValueError("Unsupported loss type: %s" % loss_type) self.stft = Stft(n_fft=n_fft, win_length=win_length, hop_length=hop_length,) if utt_mvn: self.utt_mvn = UtteranceMVN(norm_means=True, norm_vars=True) else: self.utt_mvn = None #self.rnn = RNN( # idim=self.num_bin, # elayers=layer, # cdim=unit, # hdim=unit, # dropout=dropout, # typ=rnn_type, #) self.encoder = TransformerEncoder( input_size=self.num_bin, output_size=d_model, attention_heads=nhead, linear_units=linear_units, num_blocks=num_layers, positional_dropout_rate=positional_dropout_rate, attention_dropout_rate=attention_dropout_rate, input_layer=input_layer, normalize_before=normalize_before, concat_after=concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, padding_idx=padding_idx, ) self.linear = torch.nn.ModuleList( [torch.nn.Linear(d_model, self.num_bin) for _ in range(self.num_spk)] ) if nonlinear not in ("sigmoid", "relu", "tanh"): raise ValueError("Not supporting nonlinear={}".format(nonlinear)) self.nonlinear = { "sigmoid": torch.nn.Sigmoid(), "relu": torch.nn.ReLU(), "tanh": torch.nn.Tanh(), }[nonlinear]
def __init__( self, num_spk: int = 1, normalize_input: bool = False, mask_type: str = "IPM^2", # STFT options n_fft: int = 512, win_length: int = None, hop_length: int = 128, center: bool = True, window: Optional[str] = "hann", pad_mode: str = "reflect", normalized: bool = False, onesided: bool = True, # Dereverberation options use_wpe: bool = False, wnet_type: str = "blstmp", wlayers: int = 3, wunits: int = 300, wprojs: int = 320, wdropout_rate: float = 0.0, taps: int = 5, delay: int = 3, use_dnn_mask_for_wpe: bool = True, # Beamformer options use_beamformer: bool = True, bnet_type: str = "blstmp", blayers: int = 3, bunits: int = 300, bprojs: int = 320, badim: int = 320, ref_channel: int = -1, use_noise_mask: bool = True, beamformer_type="mvdr", bdropout_rate=0.0, ): super(BeamformerNet, self).__init__() self.mask_type = mask_type self.num_spk = num_spk self.num_bin = n_fft // 2 + 1 self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, center=center, window=window, pad_mode=pad_mode, normalized=normalized, onesided=onesided, ) self.normalize_input = normalize_input self.use_beamformer = use_beamformer self.use_wpe = use_wpe if self.use_wpe: if use_dnn_mask_for_wpe: # Use DNN for power estimation iterations = 1 else: # Performing as conventional WPE, without DNN Estimator iterations = 2 self.wpe = DNN_WPE( wtype=wnet_type, widim=self.num_bin, wunits=wunits, wprojs=wprojs, wlayers=wlayers, taps=taps, delay=delay, dropout_rate=wdropout_rate, iterations=iterations, use_dnn_mask=use_dnn_mask_for_wpe, ) else: self.wpe = None self.ref_channel = ref_channel if self.use_beamformer: self.beamformer = DNN_Beamformer( btype=bnet_type, bidim=self.num_bin, bunits=bunits, bprojs=bprojs, blayers=blayers, num_spk=num_spk, use_noise_mask=use_noise_mask, dropout_rate=bdropout_rate, badim=badim, ref_channel=ref_channel, beamformer_type=beamformer_type, btaps=taps, bdelay=delay, ) else: self.beamformer = None
class BeamformerNet(AbsEnhancement): """TF Masking based beamformer """ def __init__( self, num_spk: int = 1, normalize_input: bool = False, mask_type: str = "IPM^2", # STFT options n_fft: int = 512, win_length: int = None, hop_length: int = 128, center: bool = True, window: Optional[str] = "hann", pad_mode: str = "reflect", normalized: bool = False, onesided: bool = True, # Dereverberation options use_wpe: bool = False, wnet_type: str = "blstmp", wlayers: int = 3, wunits: int = 300, wprojs: int = 320, wdropout_rate: float = 0.0, taps: int = 5, delay: int = 3, use_dnn_mask_for_wpe: bool = True, # Beamformer options use_beamformer: bool = True, bnet_type: str = "blstmp", blayers: int = 3, bunits: int = 300, bprojs: int = 320, badim: int = 320, ref_channel: int = -1, use_noise_mask: bool = True, beamformer_type="mvdr", bdropout_rate=0.0, ): super(BeamformerNet, self).__init__() self.mask_type = mask_type self.num_spk = num_spk self.num_bin = n_fft // 2 + 1 self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, center=center, window=window, pad_mode=pad_mode, normalized=normalized, onesided=onesided, ) self.normalize_input = normalize_input self.use_beamformer = use_beamformer self.use_wpe = use_wpe if self.use_wpe: if use_dnn_mask_for_wpe: # Use DNN for power estimation iterations = 1 else: # Performing as conventional WPE, without DNN Estimator iterations = 2 self.wpe = DNN_WPE( wtype=wnet_type, widim=self.num_bin, wunits=wunits, wprojs=wprojs, wlayers=wlayers, taps=taps, delay=delay, dropout_rate=wdropout_rate, iterations=iterations, use_dnn_mask=use_dnn_mask_for_wpe, ) else: self.wpe = None self.ref_channel = ref_channel if self.use_beamformer: self.beamformer = DNN_Beamformer( btype=bnet_type, bidim=self.num_bin, bunits=bunits, bprojs=bprojs, blayers=blayers, num_spk=num_spk, use_noise_mask=use_noise_mask, dropout_rate=bdropout_rate, badim=badim, ref_channel=ref_channel, beamformer_type=beamformer_type, btaps=taps, bdelay=delay, ) else: self.beamformer = None def forward(self, input: torch.Tensor, ilens: torch.Tensor): """Forward. Args: input (torch.Tensor): mixed speech [Batch, Nsample, Channel] ilens (torch.Tensor): input lengths [Batch] Returns: enhanced speech (single-channel): torch.Tensor or List[torch.Tensor] output lengths predcited masks: OrderedDict[ 'dereverb': torch.Tensor(Batch, Frames, Channel, Freq), 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), 'noise1': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # wave -> stft -> magnitude specturm input_spectrum, flens = self.stft(input, ilens) # (Batch, Frames, Freq) or (Batch, Frames, Channels, Freq) input_spectrum = ComplexTensor(input_spectrum[..., 0], input_spectrum[..., 1]) if self.normalize_input: input_spectrum = input_spectrum / abs(input_spectrum).max() enhanced = input_spectrum masks = OrderedDict() if input_spectrum.dim() == 3: # single-channel input if self.use_wpe: # (B, T, F) enhanced, flens, mask_w = self.wpe(input_spectrum.unsqueeze(-2), flens) enhanced = enhanced.squeeze(-2) if mask_w is not None: masks["dereverb"] = mask_w.squeeze(-2) elif input_spectrum.dim() == 4: # multi-channel input # 1. WPE if self.use_wpe: # (B, T, C, F) enhanced, flens, mask_w = self.wpe(input_spectrum, flens) if mask_w is not None: masks["dereverb"] = mask_w # 2. Beamformer if self.use_beamformer: # enhanced: (B, T, C, F) -> (B, T, F) enhanced, flens, masks_b = self.beamformer(enhanced, flens) for spk in range(self.num_spk): masks["spk{}".format(spk + 1)] = masks_b[spk] if len(masks_b) > self.num_spk: masks["noise1"] = masks_b[self.num_spk] else: raise ValueError( "Invalid spectrum dimension: {}".format(input_spectrum.shape) ) # Convert ComplexTensor to torch.Tensor # (B, T, F) -> (B, T, F, 2) if isinstance(enhanced, list): # multi-speaker output enhanced = [torch.stack([enh.real, enh.imag], dim=-1) for enh in enhanced] else: # single-speaker output enhanced = torch.stack([enhanced.real, enhanced.imag], dim=-1).float() return enhanced, flens, masks def forward_rawwav(self, input: torch.Tensor, ilens: torch.Tensor): """Output with wavformes. Args: input (torch.Tensor): mixed speech [Batch, Nsample, Channel] ilens (torch.Tensor): input lengths [Batch] Returns: predcited speech wavs (single-channel): torch.Tensor(Batch, Nsamples), or List[torch.Tensor(Batch, Nsamples)] output lengths predcited masks: OrderedDict[ 'dereverb': torch.Tensor(Batch, Frames, Channel, Freq), 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), 'noise1': torch.Tensor(Batch, Frames, Channel, Freq), ] """ enhanced, flens, masks = self.forward(input, ilens) if isinstance(enhanced, list): # multi-speaker input predicted_wavs = [self.stft.inverse(ps, ilens)[0] for ps in enhanced] else: # single-speaker input predicted_wavs = self.stft.inverse(enhanced, ilens)[0] return predicted_wavs, ilens, masks
class TFMaskingNet(AbsEnhancement): """TF Masking Speech Separation Net.""" def __init__( self, n_fft: int = 512, win_length: int = None, hop_length: int = 128, rnn_type: str = "blstm", layer: int = 3, unit: int = 512, dropout: float = 0.0, num_spk: int = 2, nonlinear: str = "sigmoid", utt_mvn: bool = False, mask_type: str = "IRM", loss_type: str = "mask_mse", ): super(TFMaskingNet, self).__init__() self.num_spk = num_spk self.num_bin = n_fft // 2 + 1 self.mask_type = mask_type self.loss_type = loss_type if loss_type not in ("mask_mse", "magnitude", "spectrum"): raise ValueError("Unsupported loss type: %s" % loss_type) self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, ) if utt_mvn: self.utt_mvn = UtteranceMVN(norm_means=True, norm_vars=True) else: self.utt_mvn = None self.rnn = RNN( idim=self.num_bin, elayers=layer, cdim=unit, hdim=unit, dropout=dropout, typ=rnn_type, ) self.linear = torch.nn.ModuleList( [torch.nn.Linear(unit, self.num_bin) for _ in range(self.num_spk)]) if nonlinear not in ("sigmoid", "relu", "tanh"): raise ValueError("Not supporting nonlinear={}".format(nonlinear)) self.nonlinear = { "sigmoid": torch.nn.Sigmoid(), "relu": torch.nn.ReLU(), "tanh": torch.nn.Tanh(), }[nonlinear] def forward(self, input: torch.Tensor, ilens: torch.Tensor): """Forward. Args: input (torch.Tensor): mixed speech [Batch, sample] ilens (torch.Tensor): input lengths [Batch] Returns: separated (list[ComplexTensor]): [(B, T, F), ...] ilens (torch.Tensor): (B,) predcited masks: OrderedDict[ 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # wave -> stft -> magnitude specturm input_spectrum, flens = self.stft(input, ilens) input_spectrum = ComplexTensor(input_spectrum[..., 0], input_spectrum[..., 1]) input_magnitude = abs(input_spectrum) input_phase = input_spectrum / (input_magnitude + 10e-12) # apply utt mvn if self.utt_mvn: input_magnitude_mvn, fle = self.utt_mvn(input_magnitude, flens) else: input_magnitude_mvn = input_magnitude # predict masks for each speaker x, flens, _ = self.rnn(input_magnitude_mvn, flens) masks = [] for linear in self.linear: y = linear(x) y = self.nonlinear(y) masks.append(y) if self.training and self.loss_type.startswith("mask"): predicted_spectrums = None else: # apply mask predict_magnitude = [input_magnitude * m for m in masks] predicted_spectrums = [ input_phase * pm for pm in predict_magnitude ] masks = OrderedDict( zip(["spk{}".format(i + 1) for i in range(len(masks))], masks)) return predicted_spectrums, flens, masks def forward_rawwav( self, input: torch.Tensor, ilens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Output with waveforms. Args: input (torch.Tensor): mixed speech [Batch, sample] ilens (torch.Tensor): input lengths [Batch] Returns: predcited speech [Batch, num_speaker, sample] output lengths predcited masks: OrderedDict[ 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # predict spectrum for each speaker predicted_spectrums, flens, masks = self.forward(input, ilens) if predicted_spectrums is None: predicted_wavs = None elif isinstance(predicted_spectrums, list): # multi-speaker input predicted_wavs = [ self.stft.inverse(ps, ilens)[0] for ps in predicted_spectrums ] else: # single-speaker input predicted_wavs = self.stft.inverse(predicted_spectrums, ilens)[0] return predicted_wavs, ilens, masks