Exemple #1
0
def get_mvdr_vector(
    psd_s: ComplexTensor,
    psd_n: ComplexTensor,
    reference_vector: torch.Tensor,
    eps: float = 1e-15,
) -> ComplexTensor:
    """Return the MVDR(Minimum Variance Distortionless Response) vector:

        h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u

    Reference:
        On optimal frequency-domain multichannel linear filtering
        for noise reduction; M. Souden et al., 2010;
        https://ieeexplore.ieee.org/document/5089420

    Args:
        psd_s (ComplexTensor): (..., F, C, C)
        psd_n (ComplexTensor): (..., F, C, C)
        reference_vector (torch.Tensor): (..., C)
        eps (float):
    Returns:
        beamform_vector (ComplexTensor)r: (..., F, C)
    """
    # Add eps
    C = psd_n.size(-1)
    eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device)
    shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C]
    eye = eye.view(*shape)
    try:
        #psd_n_i = (psd_n + eps * eye).inverse()
        psd_n_i = psd_n.inverse()
    except:
        try:
            psd_n = psd_n / 10e+4
            psd_s = psd_s / 10e+4
            psd_n += 1e-4 * eye
            psd_n_i = psd_n.inverse()
        except:
            try:
                psd_n = psd_n / 10e+10
                psd_s = psd_s / 10e+10
                psd_n += 1e-4 * eye
                psd_n_i = psd_n.inverse()
            except:
                raise Exception('psd not invertable.')

    # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
    numerator = FC.einsum('...ec,...cd->...ed', [psd_n_i, psd_s])
    # ws: (..., C, C) / (...,) -> (..., C, C)
    ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
    return beamform_vector
Exemple #2
0
def get_filter_matrix_conj(correlation_matrix: ComplexTensor,
                           correlation_vector: ComplexTensor) -> ComplexTensor:
    """Calculate (conjugate) filter matrix based on correlations for one freq.

    Args:
        correlation_matrix : Correlation matrix (F, taps * C, taps * C)
        correlation_vector : Correlation vector (F, taps, C, C)

    Returns:
        filter_matrix_conj (ComplexTensor): (F, taps, C, C)
    """
    F, taps, C, _ = correlation_vector.size()

    # (F, taps, C1, C2) -> (F, C1, taps, C2) -> (F, C1, taps * C2)
    correlation_vector = \
        correlation_vector.permute(0, 2, 1, 3)\
        .contiguous().view(F, C, taps * C)

    inv_correlation_matrix = correlation_matrix.inverse()
    # (F, C, taps, C) x (F, taps * C, taps * C) -> (F, C, taps * C)
    stacked_filter_conj = FC.matmul(correlation_vector,
                                    inv_correlation_matrix.transpose(-1, -2))

    # (F, C1, taps * C2) -> (F, C1, taps, C2) -> (F, taps, C2, C1)
    filter_matrix_conj = \
        stacked_filter_conj.view(F, C, taps, C).permute(0, 2, 3, 1)
    return filter_matrix_conj
def test_inverse():
    t = ComplexTensor(_get_complex_array(1, 10, 10))
    x = t @ t.inverse()
    numpy.testing.assert_allclose(x.real.numpy()[0], numpy.eye(10), atol=1e-11)
    numpy.testing.assert_allclose(x.imag.numpy()[0],
                                  numpy.zeros((10, 10)),
                                  atol=1e-11)
Exemple #4
0
def get_filter_matrix_conj(correlation_matrix: ComplexTensor,
                           correlation_vector: ComplexTensor,
                           eps: float = 1e-10) -> ComplexTensor:
    """Calculate (conjugate) filter matrix based on correlations for one freq.

    Args:
        correlation_matrix : Correlation matrix (F, taps * C, taps * C)
        correlation_vector : Correlation vector (F, taps, C, C)
        eps:

    Returns:
        filter_matrix_conj (ComplexTensor): (F, taps, C, C)
    """
    F, taps, C, _ = correlation_vector.size()

    # (F, taps, C1, C2) -> (F, C1, taps, C2) -> (F, C1, taps * C2)
    correlation_vector = \
        correlation_vector.permute(0, 2, 1, 3)\
        .contiguous().view(F, C, taps * C)

    eye = torch.eye(correlation_matrix.size(-1),
                    dtype=correlation_matrix.dtype,
                    device=correlation_matrix.device)
    shape = tuple(1 for _ in range(correlation_matrix.dim() - 2)) + \
        correlation_matrix.shape[-2:]
    eye = eye.view(*shape)
    correlation_matrix += eps * eye

    inv_correlation_matrix = correlation_matrix.inverse()
    # (F, C, taps, C) x (F, taps * C, taps * C) -> (F, C, taps * C)
    stacked_filter_conj = FC.matmul(correlation_vector,
                                    inv_correlation_matrix.transpose(-1, -2))

    # (F, C1, taps * C2) -> (F, C1, taps, C2) -> (F, taps, C2, C1)
    filter_matrix_conj = \
        stacked_filter_conj.view(F, C, taps, C).permute(0, 2, 3, 1)
    return filter_matrix_conj