def get_WPD_filter( Phi: ComplexTensor, Rf: ComplexTensor, reference_vector: torch.Tensor, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> ComplexTensor: """Return the WPD vector. WPD is the Weighted Power minimization Distortionless response convolutional beamformer. As follows: h = (Rf^-1 @ Phi_{xx}) / tr[(Rf^-1) @ Phi_{xx}] @ u Reference: T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer for Simultaneous Denoising and Dereverberation," in IEEE Signal Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi: 10.1109/LSP.2019.2911179. https://ieeexplore.ieee.org/document/8691481 Args: Phi (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the PSD of zero-padded speech [x^T(t,f) 0 ... 0]^T. Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the power normalized spatio-temporal covariance matrix. reference_vector (torch.Tensor): (B, (btaps+1) * C) is the reference_vector. use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: filter_matrix (ComplexTensor): (B, F, (btaps + 1) * C) """ if diagonal_loading: Rf = tik_reg(Rf, reg=diag_eps, eps=eps) # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) if use_torch_solver: numerator = FC.solve(Phi, Rf)[0] else: numerator = FC.matmul(Rf.inverse2(), Phi) # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) # (B, F, (btaps + 1) * C) return beamform_vector
def test_get_rtf(ch): stft = Stft( n_fft=8, win_length=None, hop_length=2, center=True, window="hann", normalized=False, onesided=True, ) torch.random.manual_seed(0) x = random_speech[..., :ch] n = torch.rand(2, 16, ch, dtype=torch.double) ilens = torch.LongTensor([16, 12]) # (B, T, C, F) -> (B, F, C, T) X = ComplexTensor(*torch.unbind(stft(x, ilens)[0], dim=-1)).transpose(-1, -3) N = ComplexTensor(*torch.unbind(stft(n, ilens)[0], dim=-1)).transpose(-1, -3) # (B, F, C, C) Phi_X = FC.einsum("...ct,...et->...ce", [X, X.conj()]) Phi_N = FC.einsum("...ct,...et->...ce", [N, N.conj()]) # (B, F, C, 1) rtf = get_rtf(Phi_X, Phi_N, reference_vector=0, iterations=20) if is_torch_1_1_plus: rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True).values + 1e-15) else: rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True)[0] + 1e-15) # rtf \approx Phi_N MaxEigVec(Phi_N^-1 @ Phi_X) if is_torch_1_1_plus: # torch.solve is required, which is only available after pytorch 1.1.0+ mat = FC.solve(Phi_X, Phi_N)[0] max_eigenvec = FC.solve(rtf, Phi_N)[0] else: mat = FC.matmul(Phi_N.inverse2(), Phi_X) max_eigenvec = FC.matmul(Phi_N.inverse2(), rtf) factor = FC.matmul(mat, max_eigenvec) assert FC.allclose( FC.matmul(max_eigenvec, factor.transpose(-1, -2)), FC.matmul(factor, max_eigenvec.transpose(-1, -2)), )
def test_solve(): t = ComplexTensor(_get_complex_array(1, 10, 10)) s = ComplexTensor(_get_complex_array(1, 10, 4)) x, _ = F.solve(s, t) y = t @ x numpy.testing.assert_allclose( y.real.numpy()[0], s.real.numpy()[0], atol=1e-13, ) numpy.testing.assert_allclose( y.imag.numpy()[0], s.imag.numpy()[0], atol=1e-13, )
def get_mvdr_vector( psd_s: ComplexTensor, psd_n: ComplexTensor, reference_vector: torch.Tensor, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> ComplexTensor: """Return the MVDR (Minimum Variance Distortionless Response) vector: h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u Reference: On optimal frequency-domain multichannel linear filtering for noise reduction; M. Souden et al., 2010; https://ieeexplore.ieee.org/document/5089420 Args: psd_s (ComplexTensor): speech covariance matrix (..., F, C, C) psd_n (ComplexTensor): observation/noise covariance matrix (..., F, C, C) reference_vector (torch.Tensor): (..., C) use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (ComplexTensor): (..., F, C) """ # noqa: D400 if diagonal_loading: psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps) if use_torch_solver: numerator = FC.solve(psd_s, psd_n)[0] else: numerator = FC.matmul(psd_n.inverse2(), psd_s) # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) return beamform_vector
def solve(b: Union[torch.Tensor, ComplexTensor], a: Union[torch.Tensor, ComplexTensor]): """Solve the linear equation ax = b.""" # NOTE: Do not mix ComplexTensor and torch.complex in the input! # NOTE (wangyou): Until PyTorch 1.9.0, torch.solve does not support # mixed input with complex and real tensors. if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor): if isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor): return FC.solve(b, a, return_LU=False) else: return matmul(inverse(a), b) elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)): if torch.is_complex(a) and torch.is_complex(b): return torch.linalg.solve(a, b) else: return matmul(inverse(a), b) else: if is_torch_1_8_plus: return torch.linalg.solve(a, b) else: return torch.solve(b, a)[0]
def get_mfmvdr_vector(gammax, Phi, use_torch_solver: bool = True, eps: float = EPS): """Compute conventional MFMPDR/MFMVDR filter. Args: gammax (ComplexTensor): (..., L, N) Phi (ComplexTensor): (..., L, N, N) use_torch_solver (bool): Whether to use `solve` instead of `inverse` eps (float) Returns: beamforming_vector (ComplexTensor): (..., L, N) """ # (..., L, N) if use_torch_solver: numerator = FC.solve(gammax.unsqueeze(-1), Phi)[0].squeeze(-1) else: numerator = FC.matmul(Phi.inverse2(), gammax.unsqueeze(-1)).squeeze(-1) denominator = FC.einsum("...d,...d->...", [gammax.conj(), numerator]) return numerator / (denominator.real.unsqueeze(-1) + eps)
def get_rtf( psd_speech: ComplexTensor, psd_noise: ComplexTensor, reference_vector: Union[int, torch.Tensor, None] = None, iterations: int = 3, use_torch_solver: bool = True, ) -> ComplexTensor: """Calculate the relative transfer function (RTF) using the power method. Algorithm: 1) rtf = reference_vector 2) for i in range(iterations): rtf = (psd_noise^-1 @ psd_speech) @ rtf rtf = rtf / ||rtf||_2 # this normalization can be skipped 3) rtf = psd_noise @ rtf 4) rtf = rtf / rtf[..., ref_channel, :] Note: 4) Normalization at the reference channel is not performed here. Args: psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C) reference_vector (torch.Tensor or int): (..., C) or scalar iterations (int): number of iterations in power method use_torch_solver (bool): Whether to use `solve` instead of `inverse` Returns: rtf (ComplexTensor): (..., F, C, 1) """ if use_torch_solver: phi = FC.solve(psd_speech, psd_noise)[0] else: phi = FC.matmul(psd_noise.inverse2(), psd_speech) rtf = (phi[..., reference_vector, None] if isinstance(reference_vector, int) else FC.matmul( phi, reference_vector[..., None, :, None])) for _ in range(iterations - 2): rtf = FC.matmul(phi, rtf) # rtf = rtf / complex_norm(rtf) rtf = FC.matmul(psd_speech, rtf) return rtf
def test_solve(real_vec): if is_torch_1_9_plus: wrappers = [ComplexTensor, torch.complex] modules = [FC, torch] else: wrappers = [ComplexTensor] modules = [FC] for complex_wrapper, complex_module in zip(wrappers, modules): mat = complex_wrapper(torch.from_numpy(mat_np.real), torch.from_numpy(mat_np.imag)) if not real_vec or complex_wrapper is ComplexTensor: vec = complex_wrapper(torch.rand(2, 3, 1), torch.rand(2, 3, 1)) vec2 = vec else: vec = torch.rand(2, 3, 1) vec2 = complex_wrapper(vec, torch.zeros_like(vec)) ret = solve(vec, mat) if isinstance(vec2, ComplexTensor): ret2 = FC.solve(vec2, mat, return_LU=False) else: return torch.linalg.solve(mat, vec2) assert complex_module.allclose(ret, ret2)
def get_WPD_filter_with_rtf( psd_observed_bar: ComplexTensor, psd_speech: ComplexTensor, psd_noise: ComplexTensor, iterations: int = 3, reference_vector: Union[int, torch.Tensor, None] = None, normalize_ref_channel: Optional[int] = None, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-15, ) -> ComplexTensor: """Return the WPD vector calculated with RTF. WPD is the Weighted Power minimization Distortionless response convolutional beamformer. As follows: h = (Rf^-1 @ vbar) / (vbar^H @ R^-1 @ vbar) Reference: T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer for Simultaneous Denoising and Dereverberation," in IEEE Signal Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi: 10.1109/LSP.2019.2911179. https://ieeexplore.ieee.org/document/8691481 Args: psd_observed_bar (ComplexTensor): stacked observation covariance matrix psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C) iterations (int): number of iterations in power method reference_vector (torch.Tensor or int): (..., C) or scalar normalize_ref_channel (int): reference channel for normalizing the RTF use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (ComplexTensor)r: (..., F, C) """ C = psd_noise.size(-1) if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) # (B, F, C, 1) rtf = get_rtf( psd_speech, psd_noise, reference_vector, iterations=iterations, use_torch_solver=use_torch_solver, ) # (B, F, (K+1)*C, 1) rtf = FC.pad(rtf, (0, 0, 0, psd_observed_bar.shape[-1] - C), "constant", 0) # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1) if use_torch_solver: numerator = FC.solve(rtf, psd_observed_bar)[0].squeeze(-1) else: numerator = FC.matmul(psd_observed_bar.inverse2(), rtf).squeeze(-1) denominator = FC.einsum("...d,...d->...", [rtf.squeeze(-1).conj(), numerator]) if normalize_ref_channel is not None: scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj() beamforming_vector = numerator * scale / ( denominator.real.unsqueeze(-1) + eps) else: beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps) return beamforming_vector
def get_mvdr_vector_with_rtf( psd_n: ComplexTensor, psd_speech: ComplexTensor, psd_noise: ComplexTensor, iterations: int = 3, reference_vector: Union[int, torch.Tensor, None] = None, normalize_ref_channel: Optional[int] = None, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> ComplexTensor: """Return the MVDR (Minimum Variance Distortionless Response) vector calculated with RTF: h = (Npsd^-1 @ rtf) / (rtf^H @ Npsd^-1 @ rtf) Reference: On optimal frequency-domain multichannel linear filtering for noise reduction; M. Souden et al., 2010; https://ieeexplore.ieee.org/document/5089420 Args: psd_n (ComplexTensor): observation/noise covariance matrix (..., F, C, C) psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C) iterations (int): number of iterations in power method reference_vector (torch.Tensor or int): (..., C) or scalar normalize_ref_channel (int): reference channel for normalizing the RTF use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (ComplexTensor): (..., F, C) """ # noqa: H405, D205, D400 if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) # (B, F, C, 1) rtf = get_rtf( psd_speech, psd_noise, reference_vector, iterations=iterations, use_torch_solver=use_torch_solver, ) # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1) if use_torch_solver: numerator = FC.solve(rtf, psd_n)[0].squeeze(-1) else: numerator = FC.matmul(psd_n.inverse2(), rtf).squeeze(-1) denominator = FC.einsum("...d,...d->...", [rtf.squeeze(-1).conj(), numerator]) if normalize_ref_channel is not None: scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj() beamforming_vector = numerator * scale / ( denominator.real.unsqueeze(-1) + eps) else: beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps) return beamforming_vector