def get_WPD_filter( Phi: ComplexTensor, Rf: ComplexTensor, reference_vector: torch.Tensor, eps: float = 1e-15, ) -> ComplexTensor: """Return the WPD vector. WPD is the Weighted Power minimization Distortionless response convolutional beamformer. As follows: h = (Rf^-1 @ Phi_{xx}) / tr[(Rf^-1) @ Phi_{xx}] @ u Reference: T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer for Simultaneous Denoising and Dereverberation," in IEEE Signal Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi: 10.1109/LSP.2019.2911179. https://ieeexplore.ieee.org/document/8691481 Args: Phi (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the PSD of zero-padded speech [x^T(t,f) 0 ... 0]^T. Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the power normalized spatio-temporal covariance matrix. reference_vector (torch.Tensor): (B, (btaps+1) * C) is the reference_vector. eps (float): Returns: filter_matrix (ComplexTensor): (B, F, (btaps + 1) * C) """ try: inv_Rf = inv(Rf) except Exception: try: reg_coeff_tensor = ( ComplexTensor(torch.rand_like(Rf.real), torch.rand_like(Rf.real)) * 1e-4 ) Rf = Rf / 10e4 Phi = Phi / 10e4 Rf += reg_coeff_tensor inv_Rf = inv(Rf) except Exception: reg_coeff_tensor = ( ComplexTensor(torch.rand_like(Rf.real), torch.rand_like(Rf.real)) * 1e-1 ) Rf = Rf / 10e10 Phi = Phi / 10e10 Rf += reg_coeff_tensor inv_Rf = inv(Rf) # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) numerator = FC.einsum("...ec,...cd->...ed", [inv_Rf, Phi]) # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) # (B, F, (btaps + 1) * C) return beamform_vector
def get_mvdr_vector(psd_s: ComplexTensor, psd_n: ComplexTensor, reference_vector: torch.Tensor, eps: float = 1e-15) -> ComplexTensor: """Return the MVDR(Minimum Variance Distortionless Response) vector: h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u Reference: On optimal frequency-domain multichannel linear filtering for noise reduction; M. Souden et al., 2010; https://ieeexplore.ieee.org/document/5089420 Args: psd_s (ComplexTensor): (..., F, C, C) psd_n (ComplexTensor): (..., F, C, C) reference_vector (torch.Tensor): (..., C) eps (float): Returns: beamform_vector (ComplexTensor)r: (..., F, C) """ # Add eps C = psd_n.size(-1) eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device) shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C] eye = eye.view(*shape) psd_n += eps * eye # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) numerator = FC.einsum('...ec,...cd->...ed', [psd_n.inverse(), psd_s]) # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = FC.einsum('...fec,...c->...fe', [ws, reference_vector]) return beamform_vector
def get_WPD_filter( Phi: Union[torch.Tensor, ComplexTensor], Rf: Union[torch.Tensor, ComplexTensor], reference_vector: torch.Tensor, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> Union[torch.Tensor, ComplexTensor]: """Return the WPD vector. WPD is the Weighted Power minimization Distortionless response convolutional beamformer. As follows: h = (Rf^-1 @ Phi_{xx}) / tr[(Rf^-1) @ Phi_{xx}] @ u Reference: T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer for Simultaneous Denoising and Dereverberation," in IEEE Signal Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi: 10.1109/LSP.2019.2911179. https://ieeexplore.ieee.org/document/8691481 Args: Phi (torch.complex64/ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the PSD of zero-padded speech [x^T(t,f) 0 ... 0]^T. Rf (torch.complex64/ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the power normalized spatio-temporal covariance matrix. reference_vector (torch.Tensor): (B, (btaps+1) * C) is the reference_vector. use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: filter_matrix (torch.complex64/ComplexTensor): (B, F, (btaps + 1) * C) """ if diagonal_loading: Rf = tik_reg(Rf, reg=diag_eps, eps=eps) # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) if use_torch_solver: numerator = solve(Phi, Rf) else: numerator = matmul(inverse(Rf), Phi) # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not # support bacth processing. Use FC.trace() as fallback. # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector) # (B, F, (btaps + 1) * C) return beamform_vector
def get_mvdr_vector( psd_s, psd_n, reference_vector: torch.Tensor, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ): """Return the MVDR (Minimum Variance Distortionless Response) vector: h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u Reference: On optimal frequency-domain multichannel linear filtering for noise reduction; M. Souden et al., 2010; https://ieeexplore.ieee.org/document/5089420 Args: psd_s (torch.complex64/ComplexTensor): speech covariance matrix (..., F, C, C) psd_n (torch.complex64/ComplexTensor): observation/noise covariance matrix (..., F, C, C) reference_vector (torch.Tensor): (..., C) use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (torch.complex64/ComplexTensor): (..., F, C) """ # noqa: D400 if diagonal_loading: psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps) if use_torch_solver: numerator = solve(psd_s, psd_n) else: numerator = matmul(inverse(psd_n), psd_s) # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not # support bacth processing. Use FC.trace() as fallback. # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector) return beamform_vector
def get_WPD_filter_v2( Phi: Union[torch.Tensor, ComplexTensor], Rf: Union[torch.Tensor, ComplexTensor], reference_vector: torch.Tensor, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> Union[torch.Tensor, ComplexTensor]: """Return the WPD vector (v2). This implementation is more efficient than `get_WPD_filter` as it skips unnecessary computation with zeros. Args: Phi (torch.complex64/ComplexTensor): (B, F, C, C) is speech PSD. Rf (torch.complex64/ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the power normalized spatio-temporal covariance matrix. reference_vector (torch.Tensor): (B, C) is the reference_vector. diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: filter_matrix (torch.complex64/ComplexTensor): (B, F, (btaps+1) * C) """ C = reference_vector.shape[-1] if diagonal_loading: Rf = tik_reg(Rf, reg=diag_eps, eps=eps) inv_Rf = inverse(Rf) # (B, F, (btaps+1) * C, C) inv_Rf_pruned = inv_Rf[..., :C] # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) numerator = matmul(inv_Rf_pruned, Phi) # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not # support bacth processing. Use FC.trace() as fallback. # ws: (..., (btaps+1) * C, C) / (...,) -> (..., (btaps+1) * C, C) ws = numerator / (FC.trace(numerator[..., :C, :])[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector) # (B, F, (btaps+1) * C) return beamform_vector
def get_WPD_filter_v2( Phi: ComplexTensor, Rf: ComplexTensor, reference_vector: torch.Tensor, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> ComplexTensor: """Return the WPD vector (v2). This implementation is more efficient than `get_WPD_filter` as it skips unnecessary computation with zeros. Args: Phi (ComplexTensor): (B, F, C, C) is speech PSD. Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the power normalized spatio-temporal covariance matrix. reference_vector (torch.Tensor): (B, C) is the reference_vector. diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: filter_matrix (ComplexTensor): (B, F, (btaps+1) * C) """ C = reference_vector.shape[-1] if diagonal_loading: Rf = tik_reg(Rf, reg=diag_eps, eps=eps) inv_Rf = Rf.inverse2() # (B, F, (btaps+1) * C, C) inv_Rf_pruned = inv_Rf[..., :C] # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) numerator = FC.matmul(inv_Rf_pruned, Phi) # ws: (..., (btaps+1) * C, C) / (...,) -> (..., (btaps+1) * C, C) ws = numerator / (FC.trace(numerator[..., :C, :])[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) # (B, F, (btaps+1) * C) return beamform_vector
def tik_reg(mat, reg: float = 1e-8, eps: float = 1e-8): """Perform Tikhonov regularization (only modifying real part). Args: mat (torch.complex64/ComplexTensor): input matrix (..., C, C) reg (float): regularization factor eps (float) Returns: ret (torch.complex64/ComplexTensor): regularized matrix (..., C, C) """ # Add eps C = mat.size(-1) eye = torch.eye(C, dtype=mat.dtype, device=mat.device) shape = [1 for _ in range(mat.dim() - 2)] + [C, C] eye = eye.view(*shape).repeat(*mat.shape[:-2], 1, 1) with torch.no_grad(): epsilon = FC.trace(mat).real[..., None, None] * reg # in case that correlation_matrix is all-zero epsilon = epsilon + eps mat = mat + epsilon * eye return mat
def test_trace(): t = ComplexTensor(_get_complex_array(10, 10)) x = numpy.trace(t.numpy()) y = F.trace(t).numpy() numpy.testing.assert_allclose(x, y)
def trace(a: Union[torch.Tensor, ComplexTensor]): # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not # support bacth processing. Use FC.trace() as fallback. return FC.trace(a)
def get_gev_vector( psd_noise: Union[torch.Tensor, ComplexTensor], psd_speech: Union[torch.Tensor, ComplexTensor], mode="power", reference_vector: Union[int, torch.Tensor] = 0, iterations: int = 3, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> Union[torch.Tensor, ComplexTensor]: """Return the generalized eigenvalue (GEV) beamformer vector: psd_speech @ h = lambda * psd_noise @ h Reference: Blind acoustic beamforming based on generalized eigenvalue decomposition; E. Warsitz and R. Haeb-Umbach, 2007. Args: psd_noise (torch.complex64/ComplexTensor): noise covariance matrix (..., F, C, C) psd_speech (torch.complex64/ComplexTensor): speech covariance matrix (..., F, C, C) mode (str): one of ("power", "evd") "power": power method "evd": eigenvalue decomposition (only for torch builtin complex tensors) reference_vector (torch.Tensor or int): (..., C) or scalar iterations (int): number of iterations in power method use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (torch.complex64/ComplexTensor): (..., F, C) """ # noqa: H405, D205, D400 if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) if mode == "power": if use_torch_solver: phi = solve(psd_speech, psd_noise) else: phi = matmul(inverse(psd_noise), psd_speech) e_vec = (phi[..., reference_vector, None] if isinstance(reference_vector, int) else matmul( phi, reference_vector[..., None, :, None])) for _ in range(iterations - 1): e_vec = matmul(phi, e_vec) # e_vec = e_vec / complex_norm(e_vec, dim=-1, keepdim=True) e_vec = e_vec.squeeze(-1) elif mode == "evd": assert (is_torch_1_9_plus and is_torch_complex_tensor(psd_speech) and is_torch_complex_tensor(psd_noise)) # e_vec = generalized_eigenvalue_decomposition(psd_speech, psd_noise)[1][...,-1] e_vec = psd_noise.new_zeros(psd_noise.shape[:-1]) for f in range(psd_noise.shape[-3]): try: e_vec[..., f, :] = generalized_eigenvalue_decomposition( psd_speech[..., f, :, :], psd_noise[..., f, :, :])[1][..., -1] except RuntimeError: # port from github.com/fgnt/nn-gev/blob/master/fgnt/beamforming.py#L106 print( "GEV beamformer: LinAlg error for frequency {}".format(f), flush=True, ) C = psd_noise.size(-1) e_vec[..., f, :] = (psd_noise.new_ones(e_vec[..., f, :].shape) / FC.trace(psd_noise[..., f, :, :]) * C) else: raise ValueError("Unknown mode: %s" % mode) beamforming_vector = e_vec / complex_norm(e_vec, dim=-1, keepdim=True) beamforming_vector = gev_phase_correction(beamforming_vector) return beamforming_vector
def get_rank1_mwf_vector( psd_speech, psd_noise, reference_vector: Union[torch.Tensor, int], denoising_weight: float = 1.0, approx_low_rank_psd_speech: bool = False, iterations: int = 3, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ): """Return the R1-MWF (Rank-1 Multi-channel Wiener Filter) vector h = (Npsd^-1 @ Spsd) / (mu + Tr(Npsd^-1 @ Spsd)) @ u Reference: [1] Rank-1 constrained multichannel Wiener filter for speech recognition in noisy environments; Z. Wang et al, 2018 https://hal.inria.fr/hal-01634449/document [2] Low-rank approximation based multichannel Wiener filter algorithms for noise reduction with application in cochlear implants; R. Serizel, 2014 https://ieeexplore.ieee.org/document/6730918 Args: psd_speech (torch.complex64/ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (torch.complex64/ComplexTensor): noise covariance matrix (..., F, C, C) reference_vector (torch.Tensor or int): (..., C) or scalar denoising_weight (float): a trade-off parameter between noise reduction and speech distortion. A larger value leads to more noise reduction at the expense of more speech distortion. When `denoising_weight = 0`, it corresponds to MVDR beamformer. approx_low_rank_psd_speech (bool): whether to replace original input psd_speech with its low-rank approximation as in [1] iterations (int): number of iterations in power method, only used when `approx_low_rank_psd_speech = True` use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (torch.complex64/ComplexTensor): (..., F, C) """ # noqa: H405, D205, D400 if approx_low_rank_psd_speech: if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) # (B, F, C, 1) recon_vec = get_rtf( psd_speech, psd_noise, mode="power", iterations=iterations, reference_vector=reference_vector, use_torch_solver=use_torch_solver, ) # Eq. (25) in Ref[1] psd_speech_r1 = matmul(recon_vec, recon_vec.conj().transpose(-1, -2)) sigma_speech = FC.trace(psd_speech) / (FC.trace(psd_speech_r1) + eps) psd_speech_r1 = psd_speech_r1 * sigma_speech[..., None, None] # c.f. Eq. (62) in Ref[2] psd_speech = psd_speech_r1 elif diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) if use_torch_solver: numerator = solve(psd_speech, psd_noise) else: numerator = matmul(inverse(psd_noise), psd_speech) # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not # support bacth processing. Use FC.trace() as fallback. # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (denoising_weight + FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) if isinstance(reference_vector, int): beamform_vector = ws[..., reference_vector] else: beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector) return beamform_vector
def get_sdw_mwf_vector( psd_speech, psd_noise, reference_vector: Union[torch.Tensor, int], denoising_weight: float = 1.0, approx_low_rank_psd_speech: bool = False, iterations: int = 3, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ): """Return the SDW-MWF (Speech Distortion Weighted Multi-channel Wiener Filter) vector h = (Spsd + mu * Npsd)^-1 @ Spsd @ u Reference: [1] Spatially pre-processed speech distortion weighted multi-channel Wiener filtering for noise reduction; A. Spriet et al, 2004 https://dl.acm.org/doi/abs/10.1016/j.sigpro.2004.07.028 [2] Rank-1 constrained multichannel Wiener filter for speech recognition in noisy environments; Z. Wang et al, 2018 https://hal.inria.fr/hal-01634449/document [3] Low-rank approximation based multichannel Wiener filter algorithms for noise reduction with application in cochlear implants; R. Serizel, 2014 https://ieeexplore.ieee.org/document/6730918 Args: psd_speech (torch.complex64/ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (torch.complex64/ComplexTensor): noise covariance matrix (..., F, C, C) reference_vector (torch.Tensor or int): (..., C) or scalar denoising_weight (float): a trade-off parameter between noise reduction and speech distortion. A larger value leads to more noise reduction at the expense of more speech distortion. The plain MWF is obtained with `denoising_weight = 1` (by default). approx_low_rank_psd_speech (bool): whether to replace original input psd_speech with its low-rank approximation as in [2] iterations (int): number of iterations in power method, only used when `approx_low_rank_psd_speech = True` use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (torch.complex64/ComplexTensor): (..., F, C) """ # noqa: H405, D205, D400 if approx_low_rank_psd_speech: if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) # (B, F, C, 1) recon_vec = get_rtf( psd_speech, psd_noise, mode="power", iterations=iterations, reference_vector=reference_vector, use_torch_solver=use_torch_solver, ) # Eq. (25) in Ref[2] psd_speech_r1 = matmul(recon_vec, recon_vec.conj().transpose(-1, -2)) sigma_speech = FC.trace(psd_speech) / (FC.trace(psd_speech_r1) + eps) psd_speech_r1 = psd_speech_r1 * sigma_speech[..., None, None] # c.f. Eq. (62) in Ref[3] psd_speech = psd_speech_r1 psd_n = psd_speech + denoising_weight * psd_noise if diagonal_loading: psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps) if use_torch_solver: ws = solve(psd_speech, psd_n) else: ws = matmul(inverse(psd_n), psd_speech) if isinstance(reference_vector, int): beamform_vector = ws[..., reference_vector] else: beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector) return beamform_vector