def test_complex_impl_consistency(): if not is_torch_1_9_plus: return mat_th = torch.complex(torch.from_numpy(mat_np.real), torch.from_numpy(mat_np.imag)) mat_ct = ComplexTensor(torch.from_numpy(mat_np.real), torch.from_numpy(mat_np.imag)) bs = mat_th.shape[0] rank = mat_th.shape[-1] vec_th = torch.complex(torch.rand(bs, rank), torch.rand(bs, rank)).type_as(mat_th) vec_ct = ComplexTensor(vec_th.real, vec_th.imag) for result_th, result_ct in ( (abs(mat_th), abs(mat_ct)), (inverse(mat_th), inverse(mat_ct)), (matmul(mat_th, vec_th.unsqueeze(-1)), matmul(mat_ct, vec_ct.unsqueeze(-1))), (solve(vec_th.unsqueeze(-1), mat_th), solve(vec_ct.unsqueeze(-1), mat_ct)), ( einsum("bec,bc->be", mat_th, vec_th), einsum("bec,bc->be", mat_ct, vec_ct), ), ): np.testing.assert_allclose(result_th.numpy(), result_ct.numpy(), atol=1e-6)
def test_get_rtf(ch, mode): if not is_torch_1_9_plus and mode == "evd": # torch 1.9.0+ is required for "evd" mode return if mode == "evd": complex_wrapper = torch.complex complex_module = torch else: complex_wrapper = ComplexTensor complex_module = FC stft = Stft( n_fft=8, win_length=None, hop_length=2, center=True, window="hann", normalized=False, onesided=True, ) torch.random.manual_seed(0) x = random_speech[..., :ch] ilens = torch.LongTensor([16, 12]) # (B, T, C, F) -> (B, F, C, T) X = complex_wrapper(*torch.unbind(stft(x, ilens)[0], dim=-1)).transpose( -1, -3) # (B, F, C, C) Phi_X = complex_module.einsum("...ct,...et->...ce", [X, X.conj()]) is_singular = True while is_singular: N = complex_wrapper(torch.randn_like(X.real), torch.randn_like(X.imag)) Phi_N = complex_module.einsum("...ct,...et->...ce", [N, N.conj()]) is_singular = not np.all(np.linalg.matrix_rank(Phi_N.numpy()) == ch) # (B, F, C, 1) rtf = get_rtf(Phi_X, Phi_N, mode=mode, reference_vector=0, iterations=20) if is_torch_1_1_plus: rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True).values + 1e-15) else: rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True)[0] + 1e-15) # rtf \approx Phi_N MaxEigVec(Phi_N^-1 @ Phi_X) if is_torch_1_1_plus: # torch.solve is required, which is only available after pytorch 1.1.0+ mat = solve(Phi_X, Phi_N)[0] max_eigenvec = solve(rtf, Phi_N)[0] else: mat = complex_module.matmul(Phi_N.inverse2(), Phi_X) max_eigenvec = complex_module.matmul(Phi_N.inverse2(), rtf) factor = complex_module.matmul(mat, max_eigenvec) assert complex_module.allclose( complex_module.matmul(max_eigenvec, factor.transpose(-1, -2)), complex_module.matmul(factor, max_eigenvec.transpose(-1, -2)), )
def get_lcmv_vector_with_rtf( psd_n: Union[torch.Tensor, ComplexTensor], rtf_mat: Union[torch.Tensor, ComplexTensor], reference_vector: Union[int, torch.Tensor, None] = None, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> Union[torch.Tensor, ComplexTensor]: """Return the LCMV (Linearly Constrained Minimum Variance) vector calculated with RTF: h = (Npsd^-1 @ rtf_mat) @ (rtf_mat^H @ Npsd^-1 @ rtf_mat)^-1 @ p Reference: H. L. Van Trees, “Optimum array processing: Part IV of detection, estimation, and modulation theory,” John Wiley & Sons, 2004. (Chapter 6.7) Args: psd_n (torch.complex64/ComplexTensor): observation/noise covariance matrix (..., F, C, C) rtf_mat (torch.complex64/ComplexTensor): RTF matrix (..., F, C, num_spk) reference_vector (torch.Tensor or int): (..., num_spk) or scalar use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (torch.complex64/ComplexTensor): (..., F, C) """ # noqa: H405, D205, D400 if diagonal_loading: psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps) # numerator: (..., C_1, C_2) x (..., C_2, num_spk) -> (..., C_1, num_spk) if use_torch_solver: numerator = solve(rtf_mat, psd_n) else: numerator = matmul(inverse(psd_n), rtf_mat) denominator = matmul(rtf_mat.conj().transpose(-1, -2), numerator) if isinstance(reference_vector, int): ws = inverse(denominator)[..., reference_vector, None] else: ws = solve(reference_vector, denominator) beamforming_vector = matmul(numerator, ws).squeeze(-1) return beamforming_vector
def get_rtf( psd_speech, psd_noise, mode="power", reference_vector: Union[int, torch.Tensor] = 0, iterations: int = 3, use_torch_solver: bool = True, ): """Calculate the relative transfer function (RTF) Algorithm of power method: 1) rtf = reference_vector 2) for i in range(iterations): rtf = (psd_noise^-1 @ psd_speech) @ rtf rtf = rtf / ||rtf||_2 # this normalization can be skipped 3) rtf = psd_noise @ rtf 4) rtf = rtf / rtf[..., ref_channel, :] Note: 4) Normalization at the reference channel is not performed here. Args: psd_speech (torch.complex64/ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (torch.complex64/ComplexTensor): noise covariance matrix (..., F, C, C) mode (str): one of ("power", "evd") "power": power method "evd": eigenvalue decomposition reference_vector (torch.Tensor or int): (..., C) or scalar iterations (int): number of iterations in power method use_torch_solver (bool): Whether to use `solve` instead of `inverse` Returns: rtf (torch.complex64/ComplexTensor): (..., F, C, 1) """ if mode == "power": if use_torch_solver: phi = solve(psd_speech, psd_noise) else: phi = matmul(inverse(psd_noise), psd_speech) rtf = ( phi[..., reference_vector, None] if isinstance(reference_vector, int) else matmul(phi, reference_vector[..., None, :, None]) ) for _ in range(iterations - 2): rtf = matmul(phi, rtf) # rtf = rtf / complex_norm(rtf, dim=-1, keepdim=True) rtf = matmul(psd_speech, rtf) elif mode == "evd": assert ( is_torch_1_9_plus and is_torch_complex_tensor(psd_speech) and is_torch_complex_tensor(psd_noise) ) e_vec = generalized_eigenvalue_decomposition(psd_speech, psd_noise)[1] rtf = matmul(psd_noise, e_vec[..., -1, None]) else: raise ValueError("Unknown mode: %s" % mode) return rtf
def get_WPD_filter( Phi: Union[torch.Tensor, ComplexTensor], Rf: Union[torch.Tensor, ComplexTensor], reference_vector: torch.Tensor, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> Union[torch.Tensor, ComplexTensor]: """Return the WPD vector. WPD is the Weighted Power minimization Distortionless response convolutional beamformer. As follows: h = (Rf^-1 @ Phi_{xx}) / tr[(Rf^-1) @ Phi_{xx}] @ u Reference: T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer for Simultaneous Denoising and Dereverberation," in IEEE Signal Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi: 10.1109/LSP.2019.2911179. https://ieeexplore.ieee.org/document/8691481 Args: Phi (torch.complex64/ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the PSD of zero-padded speech [x^T(t,f) 0 ... 0]^T. Rf (torch.complex64/ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the power normalized spatio-temporal covariance matrix. reference_vector (torch.Tensor): (B, (btaps+1) * C) is the reference_vector. use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: filter_matrix (torch.complex64/ComplexTensor): (B, F, (btaps + 1) * C) """ if diagonal_loading: Rf = tik_reg(Rf, reg=diag_eps, eps=eps) # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) if use_torch_solver: numerator = solve(Phi, Rf) else: numerator = matmul(inverse(Rf), Phi) # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not # support bacth processing. Use FC.trace() as fallback. # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector) # (B, F, (btaps + 1) * C) return beamform_vector
def get_mvdr_vector( psd_s, psd_n, reference_vector: torch.Tensor, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ): """Return the MVDR (Minimum Variance Distortionless Response) vector: h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u Reference: On optimal frequency-domain multichannel linear filtering for noise reduction; M. Souden et al., 2010; https://ieeexplore.ieee.org/document/5089420 Args: psd_s (torch.complex64/ComplexTensor): speech covariance matrix (..., F, C, C) psd_n (torch.complex64/ComplexTensor): observation/noise covariance matrix (..., F, C, C) reference_vector (torch.Tensor): (..., C) use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (torch.complex64/ComplexTensor): (..., F, C) """ # noqa: D400 if diagonal_loading: psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps) if use_torch_solver: numerator = solve(psd_s, psd_n) else: numerator = matmul(inverse(psd_n), psd_s) # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not # support bacth processing. Use FC.trace() as fallback. # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector) return beamform_vector
def test_solve(real_vec): if is_torch_1_9_plus: wrappers = [ComplexTensor, torch.complex] modules = [FC, torch] else: wrappers = [ComplexTensor] modules = [FC] for complex_wrapper, complex_module in zip(wrappers, modules): mat = complex_wrapper(torch.from_numpy(mat_np.real), torch.from_numpy(mat_np.imag)) if not real_vec or complex_wrapper is ComplexTensor: vec = complex_wrapper(torch.rand(2, 3, 1), torch.rand(2, 3, 1)) vec2 = vec else: vec = torch.rand(2, 3, 1) vec2 = complex_wrapper(vec, torch.zeros_like(vec)) ret = solve(vec, mat) ret2 = complex_module.solve(vec2, mat)[0] assert complex_module.allclose(ret, ret2)
def get_rtf( psd_speech, psd_noise, reference_vector: Union[int, torch.Tensor, None] = None, iterations: int = 3, use_torch_solver: bool = True, ): """Calculate the relative transfer function (RTF) using the power method. Algorithm: 1) rtf = reference_vector 2) for i in range(iterations): rtf = (psd_noise^-1 @ psd_speech) @ rtf rtf = rtf / ||rtf||_2 # this normalization can be skipped 3) rtf = psd_noise @ rtf 4) rtf = rtf / rtf[..., ref_channel, :] Note: 4) Normalization at the reference channel is not performed here. Args: psd_speech (torch.complex64/ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (torch.complex64/ComplexTensor): noise covariance matrix (..., F, C, C) reference_vector (torch.Tensor or int): (..., C) or scalar iterations (int): number of iterations in power method use_torch_solver (bool): Whether to use `solve` instead of `inverse` Returns: rtf (torch.complex64/ComplexTensor): (..., F, C, 1) """ if use_torch_solver: phi = solve(psd_speech, psd_noise) else: phi = matmul(inverse(psd_noise), psd_speech) rtf = (phi[..., reference_vector, None] if isinstance(reference_vector, int) else matmul( phi, reference_vector[..., None, :, None])) for _ in range(iterations - 2): rtf = matmul(phi, rtf) # rtf = rtf / complex_norm(rtf) rtf = matmul(psd_speech, rtf) return rtf
def get_mfmvdr_vector(gammax, Phi, use_torch_solver: bool = True, eps: float = EPS): """Compute conventional MFMPDR/MFMVDR filter. Args: gammax (torch.complex64/ComplexTensor): (..., L, N) Phi (torch.complex64/ComplexTensor): (..., L, N, N) use_torch_solver (bool): Whether to use `solve` instead of `inverse` eps (float) Returns: beamforming_vector (torch.complex64/ComplexTensor): (..., L, N) """ # (..., L, N) if use_torch_solver: numerator = solve(gammax.unsqueeze(-1), Phi).squeeze(-1) else: numerator = matmul(inverse(Phi), gammax.unsqueeze(-1)).squeeze(-1) denominator = einsum("...d,...d->...", gammax.conj(), numerator) return numerator / (denominator.real.unsqueeze(-1) + eps)
def get_mwf_vector( psd_s, psd_n, reference_vector: Union[torch.Tensor, int], use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ): """Return the MWF (Minimum Multi-channel Wiener Filter) vector: h = (Npsd^-1 @ Spsd) @ u Args: psd_s (torch.complex64/ComplexTensor): speech covariance matrix (..., F, C, C) psd_n (torch.complex64/ComplexTensor): power-normalized observation covariance matrix (..., F, C, C) reference_vector (torch.Tensor or int): (..., C) or scalar use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (torch.complex64/ComplexTensor): (..., F, C) """ # noqa: D400 if diagonal_loading: psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps) if use_torch_solver: ws = solve(psd_s, psd_n) else: ws = matmul(inverse(psd_n), psd_s) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) if isinstance(reference_vector, int): beamform_vector = ws[..., reference_vector] else: beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector) return beamform_vector
def get_WPD_filter_with_rtf( psd_observed_bar: Union[torch.Tensor, ComplexTensor], psd_speech: Union[torch.Tensor, ComplexTensor], psd_noise: Union[torch.Tensor, ComplexTensor], iterations: int = 3, reference_vector: Union[int, torch.Tensor, None] = None, normalize_ref_channel: Optional[int] = None, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-15, ) -> Union[torch.Tensor, ComplexTensor]: """Return the WPD vector calculated with RTF. WPD is the Weighted Power minimization Distortionless response convolutional beamformer. As follows: h = (Rf^-1 @ vbar) / (vbar^H @ R^-1 @ vbar) Reference: T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer for Simultaneous Denoising and Dereverberation," in IEEE Signal Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi: 10.1109/LSP.2019.2911179. https://ieeexplore.ieee.org/document/8691481 Args: psd_observed_bar (torch.complex64/ComplexTensor): stacked observation covariance matrix psd_speech (torch.complex64/ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (torch.complex64/ComplexTensor): noise covariance matrix (..., F, C, C) iterations (int): number of iterations in power method reference_vector (torch.Tensor or int): (..., C) or scalar normalize_ref_channel (int): reference channel for normalizing the RTF use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (torch.complex64/ComplexTensor)r: (..., F, C) """ if isinstance(psd_speech, ComplexTensor): pad_func = FC.pad elif is_torch_complex_tensor(psd_speech): pad_func = torch.nn.functional.pad else: raise ValueError( "Please update your PyTorch version to 1.9+ for complex support.") C = psd_noise.size(-1) if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) # (B, F, C, 1) rtf = get_rtf( psd_speech, psd_noise, reference_vector, iterations=iterations, use_torch_solver=use_torch_solver, ) # (B, F, (K+1)*C, 1) rtf = pad_func(rtf, (0, 0, 0, psd_observed_bar.shape[-1] - C), "constant", 0) # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1) if use_torch_solver: numerator = solve(rtf, psd_observed_bar).squeeze(-1) else: numerator = matmul(inverse(psd_observed_bar), rtf).squeeze(-1) denominator = einsum("...d,...d->...", rtf.squeeze(-1).conj(), numerator) if normalize_ref_channel is not None: scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj() beamforming_vector = numerator * scale / ( denominator.real.unsqueeze(-1) + eps) else: beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps) return beamforming_vector
def get_mvdr_vector_with_rtf( psd_n: Union[torch.Tensor, ComplexTensor], psd_speech: Union[torch.Tensor, ComplexTensor], psd_noise: Union[torch.Tensor, ComplexTensor], iterations: int = 3, reference_vector: Union[int, torch.Tensor, None] = None, normalize_ref_channel: Optional[int] = None, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> Union[torch.Tensor, ComplexTensor]: """Return the MVDR (Minimum Variance Distortionless Response) vector calculated with RTF: h = (Npsd^-1 @ rtf) / (rtf^H @ Npsd^-1 @ rtf) Reference: On optimal frequency-domain multichannel linear filtering for noise reduction; M. Souden et al., 2010; https://ieeexplore.ieee.org/document/5089420 Args: psd_n (torch.complex64/ComplexTensor): observation/noise covariance matrix (..., F, C, C) psd_speech (torch.complex64/ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (torch.complex64/ComplexTensor): noise covariance matrix (..., F, C, C) iterations (int): number of iterations in power method reference_vector (torch.Tensor or int): (..., C) or scalar normalize_ref_channel (int): reference channel for normalizing the RTF use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (torch.complex64/ComplexTensor): (..., F, C) """ # noqa: H405, D205, D400 if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) # (B, F, C, 1) rtf = get_rtf( psd_speech, psd_noise, reference_vector, iterations=iterations, use_torch_solver=use_torch_solver, ) # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1) if use_torch_solver: numerator = solve(rtf, psd_n).squeeze(-1) else: numerator = matmul(inverse(psd_n), rtf).squeeze(-1) denominator = einsum("...d,...d->...", rtf.squeeze(-1).conj(), numerator) if normalize_ref_channel is not None: scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj() beamforming_vector = numerator * scale / ( denominator.real.unsqueeze(-1) + eps) else: beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps) return beamforming_vector
def get_gev_vector( psd_noise: Union[torch.Tensor, ComplexTensor], psd_speech: Union[torch.Tensor, ComplexTensor], mode="power", reference_vector: Union[int, torch.Tensor] = 0, iterations: int = 3, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> Union[torch.Tensor, ComplexTensor]: """Return the generalized eigenvalue (GEV) beamformer vector: psd_speech @ h = lambda * psd_noise @ h Reference: Blind acoustic beamforming based on generalized eigenvalue decomposition; E. Warsitz and R. Haeb-Umbach, 2007. Args: psd_noise (torch.complex64/ComplexTensor): noise covariance matrix (..., F, C, C) psd_speech (torch.complex64/ComplexTensor): speech covariance matrix (..., F, C, C) mode (str): one of ("power", "evd") "power": power method "evd": eigenvalue decomposition (only for torch builtin complex tensors) reference_vector (torch.Tensor or int): (..., C) or scalar iterations (int): number of iterations in power method use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (torch.complex64/ComplexTensor): (..., F, C) """ # noqa: H405, D205, D400 if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) if mode == "power": if use_torch_solver: phi = solve(psd_speech, psd_noise) else: phi = matmul(inverse(psd_noise), psd_speech) e_vec = (phi[..., reference_vector, None] if isinstance(reference_vector, int) else matmul( phi, reference_vector[..., None, :, None])) for _ in range(iterations - 1): e_vec = matmul(phi, e_vec) # e_vec = e_vec / complex_norm(e_vec, dim=-1, keepdim=True) e_vec = e_vec.squeeze(-1) elif mode == "evd": assert (is_torch_1_9_plus and is_torch_complex_tensor(psd_speech) and is_torch_complex_tensor(psd_noise)) # e_vec = generalized_eigenvalue_decomposition(psd_speech, psd_noise)[1][...,-1] e_vec = psd_noise.new_zeros(psd_noise.shape[:-1]) for f in range(psd_noise.shape[-3]): try: e_vec[..., f, :] = generalized_eigenvalue_decomposition( psd_speech[..., f, :, :], psd_noise[..., f, :, :])[1][..., -1] except RuntimeError: # port from github.com/fgnt/nn-gev/blob/master/fgnt/beamforming.py#L106 print( "GEV beamformer: LinAlg error for frequency {}".format(f), flush=True, ) C = psd_noise.size(-1) e_vec[..., f, :] = (psd_noise.new_ones(e_vec[..., f, :].shape) / FC.trace(psd_noise[..., f, :, :]) * C) else: raise ValueError("Unknown mode: %s" % mode) beamforming_vector = e_vec / complex_norm(e_vec, dim=-1, keepdim=True) beamforming_vector = gev_phase_correction(beamforming_vector) return beamforming_vector
def get_rank1_mwf_vector( psd_speech, psd_noise, reference_vector: Union[torch.Tensor, int], denoising_weight: float = 1.0, approx_low_rank_psd_speech: bool = False, iterations: int = 3, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ): """Return the R1-MWF (Rank-1 Multi-channel Wiener Filter) vector h = (Npsd^-1 @ Spsd) / (mu + Tr(Npsd^-1 @ Spsd)) @ u Reference: [1] Rank-1 constrained multichannel Wiener filter for speech recognition in noisy environments; Z. Wang et al, 2018 https://hal.inria.fr/hal-01634449/document [2] Low-rank approximation based multichannel Wiener filter algorithms for noise reduction with application in cochlear implants; R. Serizel, 2014 https://ieeexplore.ieee.org/document/6730918 Args: psd_speech (torch.complex64/ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (torch.complex64/ComplexTensor): noise covariance matrix (..., F, C, C) reference_vector (torch.Tensor or int): (..., C) or scalar denoising_weight (float): a trade-off parameter between noise reduction and speech distortion. A larger value leads to more noise reduction at the expense of more speech distortion. When `denoising_weight = 0`, it corresponds to MVDR beamformer. approx_low_rank_psd_speech (bool): whether to replace original input psd_speech with its low-rank approximation as in [1] iterations (int): number of iterations in power method, only used when `approx_low_rank_psd_speech = True` use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (torch.complex64/ComplexTensor): (..., F, C) """ # noqa: H405, D205, D400 if approx_low_rank_psd_speech: if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) # (B, F, C, 1) recon_vec = get_rtf( psd_speech, psd_noise, mode="power", iterations=iterations, reference_vector=reference_vector, use_torch_solver=use_torch_solver, ) # Eq. (25) in Ref[1] psd_speech_r1 = matmul(recon_vec, recon_vec.conj().transpose(-1, -2)) sigma_speech = FC.trace(psd_speech) / (FC.trace(psd_speech_r1) + eps) psd_speech_r1 = psd_speech_r1 * sigma_speech[..., None, None] # c.f. Eq. (62) in Ref[2] psd_speech = psd_speech_r1 elif diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) if use_torch_solver: numerator = solve(psd_speech, psd_noise) else: numerator = matmul(inverse(psd_noise), psd_speech) # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not # support bacth processing. Use FC.trace() as fallback. # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (denoising_weight + FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) if isinstance(reference_vector, int): beamform_vector = ws[..., reference_vector] else: beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector) return beamform_vector
def get_sdw_mwf_vector( psd_speech, psd_noise, reference_vector: Union[torch.Tensor, int], denoising_weight: float = 1.0, approx_low_rank_psd_speech: bool = False, iterations: int = 3, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ): """Return the SDW-MWF (Speech Distortion Weighted Multi-channel Wiener Filter) vector h = (Spsd + mu * Npsd)^-1 @ Spsd @ u Reference: [1] Spatially pre-processed speech distortion weighted multi-channel Wiener filtering for noise reduction; A. Spriet et al, 2004 https://dl.acm.org/doi/abs/10.1016/j.sigpro.2004.07.028 [2] Rank-1 constrained multichannel Wiener filter for speech recognition in noisy environments; Z. Wang et al, 2018 https://hal.inria.fr/hal-01634449/document [3] Low-rank approximation based multichannel Wiener filter algorithms for noise reduction with application in cochlear implants; R. Serizel, 2014 https://ieeexplore.ieee.org/document/6730918 Args: psd_speech (torch.complex64/ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (torch.complex64/ComplexTensor): noise covariance matrix (..., F, C, C) reference_vector (torch.Tensor or int): (..., C) or scalar denoising_weight (float): a trade-off parameter between noise reduction and speech distortion. A larger value leads to more noise reduction at the expense of more speech distortion. The plain MWF is obtained with `denoising_weight = 1` (by default). approx_low_rank_psd_speech (bool): whether to replace original input psd_speech with its low-rank approximation as in [2] iterations (int): number of iterations in power method, only used when `approx_low_rank_psd_speech = True` use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (torch.complex64/ComplexTensor): (..., F, C) """ # noqa: H405, D205, D400 if approx_low_rank_psd_speech: if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) # (B, F, C, 1) recon_vec = get_rtf( psd_speech, psd_noise, mode="power", iterations=iterations, reference_vector=reference_vector, use_torch_solver=use_torch_solver, ) # Eq. (25) in Ref[2] psd_speech_r1 = matmul(recon_vec, recon_vec.conj().transpose(-1, -2)) sigma_speech = FC.trace(psd_speech) / (FC.trace(psd_speech_r1) + eps) psd_speech_r1 = psd_speech_r1 * sigma_speech[..., None, None] # c.f. Eq. (62) in Ref[3] psd_speech = psd_speech_r1 psd_n = psd_speech + denoising_weight * psd_noise if diagonal_loading: psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps) if use_torch_solver: ws = solve(psd_speech, psd_n) else: ws = matmul(inverse(psd_n), psd_speech) if isinstance(reference_vector, int): beamform_vector = ws[..., reference_vector] else: beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector) return beamform_vector