def test_inverse2(): t = ComplexTensor(_get_complex_array(1, 10, 10)) x = t @ t.inverse2() numpy.testing.assert_allclose(x.real.numpy()[0], numpy.eye(10), atol=1e-11) numpy.testing.assert_allclose(x.imag.numpy()[0], numpy.zeros((10, 10)), atol=1e-11)
def get_WPD_filter( Phi: ComplexTensor, Rf: ComplexTensor, reference_vector: torch.Tensor, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> ComplexTensor: """Return the WPD vector. WPD is the Weighted Power minimization Distortionless response convolutional beamformer. As follows: h = (Rf^-1 @ Phi_{xx}) / tr[(Rf^-1) @ Phi_{xx}] @ u Reference: T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer for Simultaneous Denoising and Dereverberation," in IEEE Signal Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi: 10.1109/LSP.2019.2911179. https://ieeexplore.ieee.org/document/8691481 Args: Phi (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the PSD of zero-padded speech [x^T(t,f) 0 ... 0]^T. Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the power normalized spatio-temporal covariance matrix. reference_vector (torch.Tensor): (B, (btaps+1) * C) is the reference_vector. use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: filter_matrix (ComplexTensor): (B, F, (btaps + 1) * C) """ if diagonal_loading: Rf = tik_reg(Rf, reg=diag_eps, eps=eps) # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) if use_torch_solver: numerator = FC.solve(Phi, Rf)[0] else: numerator = FC.matmul(Rf.inverse2(), Phi) # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) # (B, F, (btaps + 1) * C) return beamform_vector
def get_mvdr_vector( psd_s: ComplexTensor, psd_n: ComplexTensor, reference_vector: torch.Tensor, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> ComplexTensor: """Return the MVDR (Minimum Variance Distortionless Response) vector: h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u Reference: On optimal frequency-domain multichannel linear filtering for noise reduction; M. Souden et al., 2010; https://ieeexplore.ieee.org/document/5089420 Args: psd_s (ComplexTensor): speech covariance matrix (..., F, C, C) psd_n (ComplexTensor): observation/noise covariance matrix (..., F, C, C) reference_vector (torch.Tensor): (..., C) use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (ComplexTensor): (..., F, C) """ # noqa: D400 if diagonal_loading: psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps) if use_torch_solver: numerator = FC.solve(psd_s, psd_n)[0] else: numerator = FC.matmul(psd_n.inverse2(), psd_s) # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) return beamform_vector
def get_WPD_filter_v2( Phi: ComplexTensor, Rf: ComplexTensor, reference_vector: torch.Tensor, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> ComplexTensor: """Return the WPD vector (v2). This implementation is more efficient than `get_WPD_filter` as it skips unnecessary computation with zeros. Args: Phi (ComplexTensor): (B, F, C, C) is speech PSD. Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the power normalized spatio-temporal covariance matrix. reference_vector (torch.Tensor): (B, C) is the reference_vector. diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: filter_matrix (ComplexTensor): (B, F, (btaps+1) * C) """ C = reference_vector.shape[-1] if diagonal_loading: Rf = tik_reg(Rf, reg=diag_eps, eps=eps) inv_Rf = Rf.inverse2() # (B, F, (btaps+1) * C, C) inv_Rf_pruned = inv_Rf[..., :C] # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) numerator = FC.matmul(inv_Rf_pruned, Phi) # ws: (..., (btaps+1) * C, C) / (...,) -> (..., (btaps+1) * C, C) ws = numerator / (FC.trace(numerator[..., :C, :])[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) # (B, F, (btaps+1) * C) return beamform_vector
def get_rtf( psd_speech: ComplexTensor, psd_noise: ComplexTensor, reference_vector: Union[int, torch.Tensor, None] = None, iterations: int = 3, use_torch_solver: bool = True, ) -> ComplexTensor: """Calculate the relative transfer function (RTF) using the power method. Algorithm: 1) rtf = reference_vector 2) for i in range(iterations): rtf = (psd_noise^-1 @ psd_speech) @ rtf rtf = rtf / ||rtf||_2 # this normalization can be skipped 3) rtf = psd_noise @ rtf 4) rtf = rtf / rtf[..., ref_channel, :] Note: 4) Normalization at the reference channel is not performed here. Args: psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C) reference_vector (torch.Tensor or int): (..., C) or scalar iterations (int): number of iterations in power method use_torch_solver (bool): Whether to use `solve` instead of `inverse` Returns: rtf (ComplexTensor): (..., F, C, 1) """ if use_torch_solver: phi = FC.solve(psd_speech, psd_noise)[0] else: phi = FC.matmul(psd_noise.inverse2(), psd_speech) rtf = (phi[..., reference_vector, None] if isinstance(reference_vector, int) else FC.matmul( phi, reference_vector[..., None, :, None])) for _ in range(iterations - 2): rtf = FC.matmul(phi, rtf) # rtf = rtf / complex_norm(rtf) rtf = FC.matmul(psd_speech, rtf) return rtf
def get_WPD_filter_with_rtf( psd_observed_bar: ComplexTensor, psd_speech: ComplexTensor, psd_noise: ComplexTensor, iterations: int = 3, reference_vector: Union[int, torch.Tensor, None] = None, normalize_ref_channel: Optional[int] = None, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-15, ) -> ComplexTensor: """Return the WPD vector calculated with RTF. WPD is the Weighted Power minimization Distortionless response convolutional beamformer. As follows: h = (Rf^-1 @ vbar) / (vbar^H @ R^-1 @ vbar) Reference: T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer for Simultaneous Denoising and Dereverberation," in IEEE Signal Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi: 10.1109/LSP.2019.2911179. https://ieeexplore.ieee.org/document/8691481 Args: psd_observed_bar (ComplexTensor): stacked observation covariance matrix psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C) iterations (int): number of iterations in power method reference_vector (torch.Tensor or int): (..., C) or scalar normalize_ref_channel (int): reference channel for normalizing the RTF use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (ComplexTensor)r: (..., F, C) """ C = psd_noise.size(-1) if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) # (B, F, C, 1) rtf = get_rtf( psd_speech, psd_noise, reference_vector, iterations=iterations, use_torch_solver=use_torch_solver, ) # (B, F, (K+1)*C, 1) rtf = FC.pad(rtf, (0, 0, 0, psd_observed_bar.shape[-1] - C), "constant", 0) # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1) if use_torch_solver: numerator = FC.solve(rtf, psd_observed_bar)[0].squeeze(-1) else: numerator = FC.matmul(psd_observed_bar.inverse2(), rtf).squeeze(-1) denominator = FC.einsum("...d,...d->...", [rtf.squeeze(-1).conj(), numerator]) if normalize_ref_channel is not None: scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj() beamforming_vector = numerator * scale / ( denominator.real.unsqueeze(-1) + eps) else: beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps) return beamforming_vector
def get_mvdr_vector_with_rtf( psd_n: ComplexTensor, psd_speech: ComplexTensor, psd_noise: ComplexTensor, iterations: int = 3, reference_vector: Union[int, torch.Tensor, None] = None, normalize_ref_channel: Optional[int] = None, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> ComplexTensor: """Return the MVDR (Minimum Variance Distortionless Response) vector calculated with RTF: h = (Npsd^-1 @ rtf) / (rtf^H @ Npsd^-1 @ rtf) Reference: On optimal frequency-domain multichannel linear filtering for noise reduction; M. Souden et al., 2010; https://ieeexplore.ieee.org/document/5089420 Args: psd_n (ComplexTensor): observation/noise covariance matrix (..., F, C, C) psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C) iterations (int): number of iterations in power method reference_vector (torch.Tensor or int): (..., C) or scalar normalize_ref_channel (int): reference channel for normalizing the RTF use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (ComplexTensor): (..., F, C) """ # noqa: H405, D205, D400 if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) # (B, F, C, 1) rtf = get_rtf( psd_speech, psd_noise, reference_vector, iterations=iterations, use_torch_solver=use_torch_solver, ) # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1) if use_torch_solver: numerator = FC.solve(rtf, psd_n)[0].squeeze(-1) else: numerator = FC.matmul(psd_n.inverse2(), rtf).squeeze(-1) denominator = FC.einsum("...d,...d->...", [rtf.squeeze(-1).conj(), numerator]) if normalize_ref_channel is not None: scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj() beamforming_vector = numerator * scale / ( denominator.real.unsqueeze(-1) + eps) else: beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps) return beamforming_vector