def __init__( self, num_spk: int = 1, normalize_input: bool = False, mask_type: str = "IPM^2", # STFT options n_fft: int = 512, win_length: int = None, hop_length: int = 128, center: bool = True, window: Optional[str] = "hann", pad_mode: str = "reflect", normalized: bool = False, onesided: bool = True, # Dereverberation options use_wpe: bool = False, wnet_type: str = "blstmp", wlayers: int = 3, wunits: int = 300, wprojs: int = 320, wdropout_rate: float = 0.0, taps: int = 5, delay: int = 3, use_dnn_mask_for_wpe: bool = True, # Beamformer options use_beamformer: bool = True, bnet_type: str = "blstmp", blayers: int = 3, bunits: int = 300, bprojs: int = 320, badim: int = 320, ref_channel: int = -1, use_noise_mask: bool = True, beamformer_type="mvdr", bdropout_rate=0.0, ): super(BeamformerNet, self).__init__() self.mask_type = mask_type self.num_spk = num_spk self.num_bin = n_fft // 2 + 1 self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, center=center, window=window, pad_mode=pad_mode, normalized=normalized, onesided=onesided, ) self.normalize_input = normalize_input self.use_beamformer = use_beamformer self.use_wpe = use_wpe if self.use_wpe: if use_dnn_mask_for_wpe: # Use DNN for power estimation iterations = 1 else: # Performing as conventional WPE, without DNN Estimator iterations = 2 self.wpe = DNN_WPE( wtype=wnet_type, widim=self.num_bin, wunits=wunits, wprojs=wprojs, wlayers=wlayers, taps=taps, delay=delay, dropout_rate=wdropout_rate, iterations=iterations, use_dnn_mask=use_dnn_mask_for_wpe, ) else: self.wpe = None self.ref_channel = ref_channel if self.use_beamformer: self.beamformer = DNN_Beamformer( btype=bnet_type, bidim=self.num_bin, bunits=bunits, bprojs=bprojs, blayers=blayers, num_spk=num_spk, use_noise_mask=use_noise_mask, dropout_rate=bdropout_rate, badim=badim, ref_channel=ref_channel, beamformer_type=beamformer_type, btaps=taps, bdelay=delay, ) else: self.beamformer = None
class NeuralBeamformer(AbsSeparator): def __init__( self, input_dim: int, num_spk: int = 1, loss_type: str = "mask_mse", # Dereverberation options use_wpe: bool = False, wnet_type: str = "blstmp", wlayers: int = 3, wunits: int = 300, wprojs: int = 320, wdropout_rate: float = 0.0, taps: int = 5, delay: int = 3, use_dnn_mask_for_wpe: bool = True, wnonlinear: str = "crelu", multi_source_wpe: bool = True, wnormalization: bool = False, # Beamformer options use_beamformer: bool = True, bnet_type: str = "blstmp", blayers: int = 3, bunits: int = 300, bprojs: int = 320, badim: int = 320, ref_channel: int = -1, use_noise_mask: bool = True, bnonlinear: str = "sigmoid", beamformer_type: str = "mvdr_souden", rtf_iterations: int = 2, bdropout_rate: float = 0.0, shared_power: bool = True, # For numerical stability diagonal_loading: bool = True, diag_eps_wpe: float = 1e-7, diag_eps_bf: float = 1e-7, mask_flooring: bool = False, flooring_thres_wpe: float = 1e-6, flooring_thres_bf: float = 1e-6, use_torch_solver: bool = True, ): super().__init__() self._num_spk = num_spk self.loss_type = loss_type if loss_type not in ("mask_mse", "spectrum", "spectrum_log", "magnitude"): raise ValueError("Unsupported loss type: %s" % loss_type) self.use_beamformer = use_beamformer self.use_wpe = use_wpe if self.use_wpe: if use_dnn_mask_for_wpe: # Use DNN for power estimation iterations = 1 else: # Performing as conventional WPE, without DNN Estimator iterations = 2 self.wpe = DNN_WPE( wtype=wnet_type, widim=input_dim, wlayers=wlayers, wunits=wunits, wprojs=wprojs, dropout_rate=wdropout_rate, taps=taps, delay=delay, use_dnn_mask=use_dnn_mask_for_wpe, nmask=1 if multi_source_wpe else num_spk, nonlinear=wnonlinear, iterations=iterations, normalization=wnormalization, diagonal_loading=diagonal_loading, diag_eps=diag_eps_wpe, mask_flooring=mask_flooring, flooring_thres=flooring_thres_wpe, use_torch_solver=use_torch_solver, ) else: self.wpe = None self.ref_channel = ref_channel if self.use_beamformer: self.beamformer = DNN_Beamformer( bidim=input_dim, btype=bnet_type, blayers=blayers, bunits=bunits, bprojs=bprojs, num_spk=num_spk, use_noise_mask=use_noise_mask, nonlinear=bnonlinear, dropout_rate=bdropout_rate, badim=badim, ref_channel=ref_channel, beamformer_type=beamformer_type, rtf_iterations=rtf_iterations, btaps=taps, bdelay=delay, diagonal_loading=diagonal_loading, diag_eps=diag_eps_bf, mask_flooring=mask_flooring, flooring_thres=flooring_thres_bf, use_torch_solver=use_torch_solver, ) else: self.beamformer = None # share speech powers between WPE and beamforming (wMPDR/WPD) self.shared_power = shared_power and use_wpe def forward( self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: """Forward. Args: input (torch.complex64/ComplexTensor): mixed speech [Batch, Frames, Channel, Freq] ilens (torch.Tensor): input lengths [Batch] Returns: enhanced speech (single-channel): List[torch.complex64/ComplexTensor] output lengths other predcited data: OrderedDict[ 'dereverb1': ComplexTensor(Batch, Frames, Channel, Freq), 'mask_dereverb1': torch.Tensor(Batch, Frames, Channel, Freq), 'mask_noise1': torch.Tensor(Batch, Frames, Channel, Freq), 'mask_spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # Shape of input spectrum must be (B, T, F) or (B, T, C, F) assert input.dim() in (3, 4), input.dim() enhanced = input others = OrderedDict() if ( self.training and self.loss_type is not None and self.loss_type.startswith("mask") ): # Only estimating masks during training for saving memory if self.use_wpe: if input.dim() == 3: mask_w, ilens = self.wpe.predict_mask(input.unsqueeze(-2), ilens) mask_w = mask_w.squeeze(-2) elif input.dim() == 4: mask_w, ilens = self.wpe.predict_mask(input, ilens) if mask_w is not None: if isinstance(enhanced, list): # single-source WPE for spk in range(self.num_spk): others["mask_dereverb{}".format(spk + 1)] = mask_w[spk] else: # multi-source WPE others["mask_dereverb1"] = mask_w if self.use_beamformer and input.dim() == 4: others_b, ilens = self.beamformer.predict_mask(input, ilens) for spk in range(self.num_spk): others["mask_spk{}".format(spk + 1)] = others_b[spk] if len(others_b) > self.num_spk: others["mask_noise1"] = others_b[self.num_spk] return None, ilens, others else: powers = None # Performing both mask estimation and enhancement if input.dim() == 3: # single-channel input (B, T, F) if self.use_wpe: enhanced, ilens, mask_w, powers = self.wpe( input.unsqueeze(-2), ilens ) if isinstance(enhanced, list): # single-source WPE enhanced = [enh.squeeze(-2) for enh in enhanced] if mask_w is not None: for spk in range(self.num_spk): key = "dereverb{}".format(spk + 1) others[key] = enhanced[spk] others["mask_" + key] = mask_w[spk].squeeze(-2) else: # multi-source WPE enhanced = enhanced.squeeze(-2) if mask_w is not None: others["dereverb1"] = enhanced others["mask_dereverb1"] = mask_w.squeeze(-2) else: # multi-channel input (B, T, C, F) # 1. WPE if self.use_wpe: enhanced, ilens, mask_w, powers = self.wpe(input, ilens) if mask_w is not None: if isinstance(enhanced, list): # single-source WPE for spk in range(self.num_spk): key = "dereverb{}".format(spk + 1) others[key] = enhanced[spk] others["mask_" + key] = mask_w[spk] else: # multi-source WPE others["dereverb1"] = enhanced others["mask_dereverb1"] = mask_w.squeeze(-2) # 2. Beamformer if self.use_beamformer: if ( not self.beamformer.beamformer_type.startswith("wmpdr") or not self.beamformer.beamformer_type.startswith("wpd") or not self.shared_power or (self.wpe.nmask == 1 and self.num_spk > 1) ): powers = None # enhanced: (B, T, C, F) -> (B, T, F) if isinstance(enhanced, list): # outputs of single-source WPE raise NotImplementedError( "Single-source WPE is not supported with beamformer " "in multi-speaker cases." ) else: # output of multi-source WPE enhanced, ilens, others_b = self.beamformer( enhanced, ilens, powers=powers ) for spk in range(self.num_spk): others["mask_spk{}".format(spk + 1)] = others_b[spk] if len(others_b) > self.num_spk: others["mask_noise1"] = others_b[self.num_spk] if not isinstance(enhanced, list): enhanced = [enhanced] return enhanced, ilens, others @property def num_spk(self): return self._num_spk
def __init__( self, input_dim: int, num_spk: int = 1, loss_type: str = "mask_mse", # Dereverberation options use_wpe: bool = False, wnet_type: str = "blstmp", wlayers: int = 3, wunits: int = 300, wprojs: int = 320, wdropout_rate: float = 0.0, taps: int = 5, delay: int = 3, use_dnn_mask_for_wpe: bool = True, wnonlinear: str = "crelu", multi_source_wpe: bool = True, wnormalization: bool = False, # Beamformer options use_beamformer: bool = True, bnet_type: str = "blstmp", blayers: int = 3, bunits: int = 300, bprojs: int = 320, badim: int = 320, ref_channel: int = -1, use_noise_mask: bool = True, bnonlinear: str = "sigmoid", beamformer_type: str = "mvdr_souden", rtf_iterations: int = 2, bdropout_rate: float = 0.0, shared_power: bool = True, # For numerical stability diagonal_loading: bool = True, diag_eps_wpe: float = 1e-7, diag_eps_bf: float = 1e-7, mask_flooring: bool = False, flooring_thres_wpe: float = 1e-6, flooring_thres_bf: float = 1e-6, use_torch_solver: bool = True, ): super().__init__() self._num_spk = num_spk self.loss_type = loss_type if loss_type not in ("mask_mse", "spectrum", "spectrum_log", "magnitude"): raise ValueError("Unsupported loss type: %s" % loss_type) self.use_beamformer = use_beamformer self.use_wpe = use_wpe if self.use_wpe: if use_dnn_mask_for_wpe: # Use DNN for power estimation iterations = 1 else: # Performing as conventional WPE, without DNN Estimator iterations = 2 self.wpe = DNN_WPE( wtype=wnet_type, widim=input_dim, wlayers=wlayers, wunits=wunits, wprojs=wprojs, dropout_rate=wdropout_rate, taps=taps, delay=delay, use_dnn_mask=use_dnn_mask_for_wpe, nmask=1 if multi_source_wpe else num_spk, nonlinear=wnonlinear, iterations=iterations, normalization=wnormalization, diagonal_loading=diagonal_loading, diag_eps=diag_eps_wpe, mask_flooring=mask_flooring, flooring_thres=flooring_thres_wpe, use_torch_solver=use_torch_solver, ) else: self.wpe = None self.ref_channel = ref_channel if self.use_beamformer: self.beamformer = DNN_Beamformer( bidim=input_dim, btype=bnet_type, blayers=blayers, bunits=bunits, bprojs=bprojs, num_spk=num_spk, use_noise_mask=use_noise_mask, nonlinear=bnonlinear, dropout_rate=bdropout_rate, badim=badim, ref_channel=ref_channel, beamformer_type=beamformer_type, rtf_iterations=rtf_iterations, btaps=taps, bdelay=delay, diagonal_loading=diagonal_loading, diag_eps=diag_eps_bf, mask_flooring=mask_flooring, flooring_thres=flooring_thres_bf, use_torch_solver=use_torch_solver, ) else: self.beamformer = None # share speech powers between WPE and beamforming (wMPDR/WPD) self.shared_power = shared_power and use_wpe
class BeamformerNet(AbsEnhancement): """TF Masking based beamformer""" def __init__( self, num_spk: int = 1, normalize_input: bool = False, mask_type: str = "IPM^2", loss_type: str = "mask_mse", # STFT options n_fft: int = 512, win_length: int = None, hop_length: int = 128, center: bool = True, window: Optional[str] = "hann", normalized: bool = False, onesided: bool = True, # Dereverberation options use_wpe: bool = False, wnet_type: str = "blstmp", wlayers: int = 3, wunits: int = 300, wprojs: int = 320, wdropout_rate: float = 0.0, taps: int = 5, delay: int = 3, use_dnn_mask_for_wpe: bool = True, wnonlinear: str = "crelu", # Beamformer options use_beamformer: bool = True, bnet_type: str = "blstmp", blayers: int = 3, bunits: int = 300, bprojs: int = 320, badim: int = 320, ref_channel: int = -1, use_noise_mask: bool = True, bnonlinear: str = "sigmoid", beamformer_type="mvdr", bdropout_rate=0.0, ): super(BeamformerNet, self).__init__() self.mask_type = mask_type self.loss_type = loss_type if loss_type not in ("mask_mse", "spectrum"): raise ValueError("Unsupported loss type: %s" % loss_type) self.num_spk = num_spk self.num_bin = n_fft // 2 + 1 self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, center=center, window=window, normalized=normalized, onesided=onesided, ) self.normalize_input = normalize_input self.use_beamformer = use_beamformer self.use_wpe = use_wpe if self.use_wpe: if use_dnn_mask_for_wpe: # Use DNN for power estimation iterations = 1 else: # Performing as conventional WPE, without DNN Estimator iterations = 2 self.wpe = DNN_WPE( wtype=wnet_type, widim=self.num_bin, wunits=wunits, wprojs=wprojs, wlayers=wlayers, taps=taps, delay=delay, dropout_rate=wdropout_rate, iterations=iterations, use_dnn_mask=use_dnn_mask_for_wpe, nonlinear=wnonlinear, ) else: self.wpe = None self.ref_channel = ref_channel if self.use_beamformer: self.beamformer = DNN_Beamformer( btype=bnet_type, bidim=self.num_bin, bunits=bunits, bprojs=bprojs, blayers=blayers, num_spk=num_spk, use_noise_mask=use_noise_mask, nonlinear=bnonlinear, dropout_rate=bdropout_rate, badim=badim, ref_channel=ref_channel, beamformer_type=beamformer_type, btaps=taps, bdelay=delay, ) else: self.beamformer = None def forward(self, input: torch.Tensor, ilens: torch.Tensor): """Forward. Args: input (torch.Tensor): mixed speech [Batch, Nsample, Channel] ilens (torch.Tensor): input lengths [Batch] Returns: enhanced speech (single-channel): torch.Tensor or List[torch.Tensor] output lengths predcited masks: OrderedDict[ 'dereverb': torch.Tensor(Batch, Frames, Channel, Freq), 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), 'noise1': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # wave -> stft -> magnitude specturm input_spectrum, flens = self.stft(input, ilens) # (Batch, Frames, Freq) or (Batch, Frames, Channels, Freq) input_spectrum = ComplexTensor(input_spectrum[..., 0], input_spectrum[..., 1]) if self.normalize_input: input_spectrum = input_spectrum / abs(input_spectrum).max() # Shape of input spectrum must be (B, T, F) or (B, T, C, F) assert input_spectrum.dim() in (3, 4), input_spectrum.dim() enhanced = input_spectrum masks = OrderedDict() if self.training and self.loss_type.startswith("mask"): # Only estimating masks for training if self.use_wpe: if input_spectrum.dim() == 3: mask_w, flens = self.wpe.predict_mask( input_spectrum.unsqueeze(-2), flens) mask_w = mask_w.squeeze(-2) elif input_spectrum.dim() == 4: if self.use_beamformer: enhanced, flens, mask_w = self.wpe( input_spectrum, flens) else: mask_w, flens = self.wpe.predict_mask( input_spectrum, flens) if mask_w is not None: masks["dereverb"] = mask_w if self.use_beamformer and input_spectrum.dim() == 4: masks_b, flens = self.beamformer.predict_mask(enhanced, flens) for spk in range(self.num_spk): masks["spk{}".format(spk + 1)] = masks_b[spk] if len(masks_b) > self.num_spk: masks["noise1"] = masks_b[self.num_spk] return None, flens, masks else: # Performing both mask estimation and enhancement if input_spectrum.dim() == 3: # single-channel input (B, T, F) if self.use_wpe: enhanced, flens, mask_w = self.wpe( input_spectrum.unsqueeze(-2), flens) enhanced = enhanced.squeeze(-2) if mask_w is not None: masks["dereverb"] = mask_w.squeeze(-2) else: # multi-channel input (B, T, C, F) # 1. WPE if self.use_wpe: enhanced, flens, mask_w = self.wpe(input_spectrum, flens) if mask_w is not None: masks["dereverb"] = mask_w # 2. Beamformer if self.use_beamformer: # enhanced: (B, T, C, F) -> (B, T, F) enhanced, flens, masks_b = self.beamformer(enhanced, flens) for spk in range(self.num_spk): masks["spk{}".format(spk + 1)] = masks_b[spk] if len(masks_b) > self.num_spk: masks["noise1"] = masks_b[self.num_spk] # Convert ComplexTensor to torch.Tensor # (B, T, F) -> (B, T, F, 2) if isinstance(enhanced, list): # multi-speaker output enhanced = [ torch.stack([enh.real, enh.imag], dim=-1) for enh in enhanced ] else: # single-speaker output enhanced = torch.stack([enhanced.real, enhanced.imag], dim=-1).float() return enhanced, flens, masks def forward_rawwav(self, input: torch.Tensor, ilens: torch.Tensor): """Output with wavformes. Args: input (torch.Tensor): mixed speech [Batch, Nsample, Channel] ilens (torch.Tensor): input lengths [Batch] Returns: predcited speech wavs (single-channel): torch.Tensor(Batch, Nsamples), or List[torch.Tensor(Batch, Nsamples)] output lengths predcited masks: OrderedDict[ 'dereverb': torch.Tensor(Batch, Frames, Channel, Freq), 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), 'noise1': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # predict spectrum for each speaker predicted_spectrums, flens, masks = self.forward(input, ilens) if predicted_spectrums is None: predicted_wavs = None elif isinstance(predicted_spectrums, list): # multi-speaker input predicted_wavs = [ self.stft.inverse(ps, ilens)[0] for ps in predicted_spectrums ] else: # single-speaker input predicted_wavs = self.stft.inverse(predicted_spectrums, ilens)[0] return predicted_wavs, ilens, masks