def get_mvdr_vector( psd_s: ComplexTensor, psd_n: ComplexTensor, reference_vector: torch.Tensor, eps: float = 1e-15, ) -> ComplexTensor: """Return the MVDR(Minimum Variance Distortionless Response) vector: h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u Reference: On optimal frequency-domain multichannel linear filtering for noise reduction; M. Souden et al., 2010; https://ieeexplore.ieee.org/document/5089420 Args: psd_s (ComplexTensor): (..., F, C, C) psd_n (ComplexTensor): (..., F, C, C) reference_vector (torch.Tensor): (..., C) eps (float): Returns: beamform_vector (ComplexTensor)r: (..., F, C) """ # Add eps C = psd_n.size(-1) eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device) shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C] eye = eye.view(*shape) try: #psd_n_i = (psd_n + eps * eye).inverse() psd_n_i = psd_n.inverse() except: try: psd_n = psd_n / 10e+4 psd_s = psd_s / 10e+4 psd_n += 1e-4 * eye psd_n_i = psd_n.inverse() except: try: psd_n = psd_n / 10e+10 psd_s = psd_s / 10e+10 psd_n += 1e-4 * eye psd_n_i = psd_n.inverse() except: raise Exception('psd not invertable.') # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) numerator = FC.einsum('...ec,...cd->...ed', [psd_n_i, psd_s]) # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) return beamform_vector
def get_filter_matrix_conj(correlation_matrix: ComplexTensor, correlation_vector: ComplexTensor) -> ComplexTensor: """Calculate (conjugate) filter matrix based on correlations for one freq. Args: correlation_matrix : Correlation matrix (F, taps * C, taps * C) correlation_vector : Correlation vector (F, taps, C, C) Returns: filter_matrix_conj (ComplexTensor): (F, taps, C, C) """ F, taps, C, _ = correlation_vector.size() # (F, taps, C1, C2) -> (F, C1, taps, C2) -> (F, C1, taps * C2) correlation_vector = \ correlation_vector.permute(0, 2, 1, 3)\ .contiguous().view(F, C, taps * C) inv_correlation_matrix = correlation_matrix.inverse() # (F, C, taps, C) x (F, taps * C, taps * C) -> (F, C, taps * C) stacked_filter_conj = FC.matmul(correlation_vector, inv_correlation_matrix.transpose(-1, -2)) # (F, C1, taps * C2) -> (F, C1, taps, C2) -> (F, taps, C2, C1) filter_matrix_conj = \ stacked_filter_conj.view(F, C, taps, C).permute(0, 2, 3, 1) return filter_matrix_conj
def test_inverse(): t = ComplexTensor(_get_complex_array(1, 10, 10)) x = t @ t.inverse() numpy.testing.assert_allclose(x.real.numpy()[0], numpy.eye(10), atol=1e-11) numpy.testing.assert_allclose(x.imag.numpy()[0], numpy.zeros((10, 10)), atol=1e-11)
def get_filter_matrix_conj(correlation_matrix: ComplexTensor, correlation_vector: ComplexTensor, eps: float = 1e-10) -> ComplexTensor: """Calculate (conjugate) filter matrix based on correlations for one freq. Args: correlation_matrix : Correlation matrix (F, taps * C, taps * C) correlation_vector : Correlation vector (F, taps, C, C) eps: Returns: filter_matrix_conj (ComplexTensor): (F, taps, C, C) """ F, taps, C, _ = correlation_vector.size() # (F, taps, C1, C2) -> (F, C1, taps, C2) -> (F, C1, taps * C2) correlation_vector = \ correlation_vector.permute(0, 2, 1, 3)\ .contiguous().view(F, C, taps * C) eye = torch.eye(correlation_matrix.size(-1), dtype=correlation_matrix.dtype, device=correlation_matrix.device) shape = tuple(1 for _ in range(correlation_matrix.dim() - 2)) + \ correlation_matrix.shape[-2:] eye = eye.view(*shape) correlation_matrix += eps * eye inv_correlation_matrix = correlation_matrix.inverse() # (F, C, taps, C) x (F, taps * C, taps * C) -> (F, C, taps * C) stacked_filter_conj = FC.matmul(correlation_vector, inv_correlation_matrix.transpose(-1, -2)) # (F, C1, taps * C2) -> (F, C1, taps, C2) -> (F, taps, C2, C1) filter_matrix_conj = \ stacked_filter_conj.view(F, C, taps, C).permute(0, 2, 3, 1) return filter_matrix_conj