Beispiel #1
0
def test_complex_impl_consistency():
    if not is_torch_1_9_plus:
        return
    mat_th = torch.complex(torch.from_numpy(mat_np.real),
                           torch.from_numpy(mat_np.imag))
    mat_ct = ComplexTensor(torch.from_numpy(mat_np.real),
                           torch.from_numpy(mat_np.imag))
    bs = mat_th.shape[0]
    rank = mat_th.shape[-1]
    vec_th = torch.complex(torch.rand(bs, rank),
                           torch.rand(bs, rank)).type_as(mat_th)
    vec_ct = ComplexTensor(vec_th.real, vec_th.imag)

    for result_th, result_ct in (
        (abs(mat_th), abs(mat_ct)),
        (inverse(mat_th), inverse(mat_ct)),
        (matmul(mat_th,
                vec_th.unsqueeze(-1)), matmul(mat_ct, vec_ct.unsqueeze(-1))),
        (solve(vec_th.unsqueeze(-1),
               mat_th), solve(vec_ct.unsqueeze(-1), mat_ct)),
        (
            einsum("bec,bc->be", mat_th, vec_th),
            einsum("bec,bc->be", mat_ct, vec_ct),
        ),
    ):
        np.testing.assert_allclose(result_th.numpy(),
                                   result_ct.numpy(),
                                   atol=1e-6)
Beispiel #2
0
def test_get_rtf(ch, mode):
    if not is_torch_1_9_plus and mode == "evd":
        # torch 1.9.0+ is required for "evd" mode
        return
    if mode == "evd":
        complex_wrapper = torch.complex
        complex_module = torch
    else:
        complex_wrapper = ComplexTensor
        complex_module = FC
    stft = Stft(
        n_fft=8,
        win_length=None,
        hop_length=2,
        center=True,
        window="hann",
        normalized=False,
        onesided=True,
    )
    torch.random.manual_seed(0)
    x = random_speech[..., :ch]
    ilens = torch.LongTensor([16, 12])
    # (B, T, C, F) -> (B, F, C, T)
    X = complex_wrapper(*torch.unbind(stft(x, ilens)[0], dim=-1)).transpose(
        -1, -3)
    # (B, F, C, C)
    Phi_X = complex_module.einsum("...ct,...et->...ce", [X, X.conj()])

    is_singular = True
    while is_singular:
        N = complex_wrapper(torch.randn_like(X.real), torch.randn_like(X.imag))
        Phi_N = complex_module.einsum("...ct,...et->...ce", [N, N.conj()])
        is_singular = not np.all(np.linalg.matrix_rank(Phi_N.numpy()) == ch)

    # (B, F, C, 1)
    rtf = get_rtf(Phi_X, Phi_N, mode=mode, reference_vector=0, iterations=20)
    if is_torch_1_1_plus:
        rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True).values + 1e-15)
    else:
        rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True)[0] + 1e-15)
    # rtf \approx Phi_N MaxEigVec(Phi_N^-1 @ Phi_X)
    if is_torch_1_1_plus:
        # torch.solve is required, which is only available after pytorch 1.1.0+
        mat = solve(Phi_X, Phi_N)[0]
        max_eigenvec = solve(rtf, Phi_N)[0]
    else:
        mat = complex_module.matmul(Phi_N.inverse2(), Phi_X)
        max_eigenvec = complex_module.matmul(Phi_N.inverse2(), rtf)
    factor = complex_module.matmul(mat, max_eigenvec)
    assert complex_module.allclose(
        complex_module.matmul(max_eigenvec, factor.transpose(-1, -2)),
        complex_module.matmul(factor, max_eigenvec.transpose(-1, -2)),
    )
Beispiel #3
0
def get_lcmv_vector_with_rtf(
    psd_n: Union[torch.Tensor, ComplexTensor],
    rtf_mat: Union[torch.Tensor, ComplexTensor],
    reference_vector: Union[int, torch.Tensor, None] = None,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> Union[torch.Tensor, ComplexTensor]:
    """Return the LCMV (Linearly Constrained Minimum Variance) vector
        calculated with RTF:

        h = (Npsd^-1 @ rtf_mat) @ (rtf_mat^H @ Npsd^-1 @ rtf_mat)^-1 @ p

    Reference:
        H. L. Van Trees, “Optimum array processing: Part IV of detection, estimation,
        and modulation theory,” John Wiley & Sons, 2004. (Chapter 6.7)

    Args:
        psd_n (torch.complex64/ComplexTensor):
            observation/noise covariance matrix (..., F, C, C)
        rtf_mat (torch.complex64/ComplexTensor):
            RTF matrix (..., F, C, num_spk)
        reference_vector (torch.Tensor or int): (..., num_spk) or scalar
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (torch.complex64/ComplexTensor): (..., F, C)
    """  # noqa: H405, D205, D400
    if diagonal_loading:
        psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps)

    # numerator: (..., C_1, C_2) x (..., C_2, num_spk) -> (..., C_1, num_spk)
    if use_torch_solver:
        numerator = solve(rtf_mat, psd_n)
    else:
        numerator = matmul(inverse(psd_n), rtf_mat)
    denominator = matmul(rtf_mat.conj().transpose(-1, -2), numerator)
    if isinstance(reference_vector, int):
        ws = inverse(denominator)[..., reference_vector, None]
    else:
        ws = solve(reference_vector, denominator)
    beamforming_vector = matmul(numerator, ws).squeeze(-1)
    return beamforming_vector
Beispiel #4
0
def get_rtf(
    psd_speech,
    psd_noise,
    mode="power",
    reference_vector: Union[int, torch.Tensor] = 0,
    iterations: int = 3,
    use_torch_solver: bool = True,
):
    """Calculate the relative transfer function (RTF)

    Algorithm of power method:
        1) rtf = reference_vector
        2) for i in range(iterations):
             rtf = (psd_noise^-1 @ psd_speech) @ rtf
             rtf = rtf / ||rtf||_2  # this normalization can be skipped
        3) rtf = psd_noise @ rtf
        4) rtf = rtf / rtf[..., ref_channel, :]
    Note: 4) Normalization at the reference channel is not performed here.

    Args:
        psd_speech (torch.complex64/ComplexTensor):
            speech covariance matrix (..., F, C, C)
        psd_noise (torch.complex64/ComplexTensor):
            noise covariance matrix (..., F, C, C)
        mode (str): one of ("power", "evd")
            "power": power method
            "evd": eigenvalue decomposition
        reference_vector (torch.Tensor or int): (..., C) or scalar
        iterations (int): number of iterations in power method
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
    Returns:
        rtf (torch.complex64/ComplexTensor): (..., F, C, 1)
    """
    if mode == "power":
        if use_torch_solver:
            phi = solve(psd_speech, psd_noise)
        else:
            phi = matmul(inverse(psd_noise), psd_speech)
        rtf = (
            phi[..., reference_vector, None]
            if isinstance(reference_vector, int)
            else matmul(phi, reference_vector[..., None, :, None])
        )
        for _ in range(iterations - 2):
            rtf = matmul(phi, rtf)
            # rtf = rtf / complex_norm(rtf, dim=-1, keepdim=True)
        rtf = matmul(psd_speech, rtf)
    elif mode == "evd":
        assert (
            is_torch_1_9_plus
            and is_torch_complex_tensor(psd_speech)
            and is_torch_complex_tensor(psd_noise)
        )
        e_vec = generalized_eigenvalue_decomposition(psd_speech, psd_noise)[1]
        rtf = matmul(psd_noise, e_vec[..., -1, None])
    else:
        raise ValueError("Unknown mode: %s" % mode)
    return rtf
Beispiel #5
0
def get_WPD_filter(
    Phi: Union[torch.Tensor, ComplexTensor],
    Rf: Union[torch.Tensor, ComplexTensor],
    reference_vector: torch.Tensor,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> Union[torch.Tensor, ComplexTensor]:
    """Return the WPD vector.

        WPD is the Weighted Power minimization Distortionless response
        convolutional beamformer. As follows:

        h = (Rf^-1 @ Phi_{xx}) / tr[(Rf^-1) @ Phi_{xx}] @ u

    Reference:
        T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer
        for Simultaneous Denoising and Dereverberation," in IEEE Signal
        Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi:
        10.1109/LSP.2019.2911179.
        https://ieeexplore.ieee.org/document/8691481

    Args:
        Phi (torch.complex64/ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C)
            is the PSD of zero-padded speech [x^T(t,f) 0 ... 0]^T.
        Rf (torch.complex64/ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C)
            is the power normalized spatio-temporal covariance matrix.
        reference_vector (torch.Tensor): (B, (btaps+1) * C)
            is the reference_vector.
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):

    Returns:
        filter_matrix (torch.complex64/ComplexTensor): (B, F, (btaps + 1) * C)
    """
    if diagonal_loading:
        Rf = tik_reg(Rf, reg=diag_eps, eps=eps)

    # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
    if use_torch_solver:
        numerator = solve(Phi, Rf)
    else:
        numerator = matmul(inverse(Rf), Phi)
    # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not
    # support bacth processing. Use FC.trace() as fallback.
    # ws: (..., C, C) / (...,) -> (..., C, C)
    ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector)
    # (B, F, (btaps + 1) * C)
    return beamform_vector
Beispiel #6
0
def get_mvdr_vector(
    psd_s,
    psd_n,
    reference_vector: torch.Tensor,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
):
    """Return the MVDR (Minimum Variance Distortionless Response) vector:

        h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u

    Reference:
        On optimal frequency-domain multichannel linear filtering
        for noise reduction; M. Souden et al., 2010;
        https://ieeexplore.ieee.org/document/5089420

    Args:
        psd_s (torch.complex64/ComplexTensor):
            speech covariance matrix (..., F, C, C)
        psd_n (torch.complex64/ComplexTensor):
            observation/noise covariance matrix (..., F, C, C)
        reference_vector (torch.Tensor): (..., C)
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (torch.complex64/ComplexTensor): (..., F, C)
    """  # noqa: D400
    if diagonal_loading:
        psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps)

    if use_torch_solver:
        numerator = solve(psd_s, psd_n)
    else:
        numerator = matmul(inverse(psd_n), psd_s)
    # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not
    # support bacth processing. Use FC.trace() as fallback.
    # ws: (..., C, C) / (...,) -> (..., C, C)
    ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector)
    return beamform_vector
Beispiel #7
0
def test_solve(real_vec):
    if is_torch_1_9_plus:
        wrappers = [ComplexTensor, torch.complex]
        modules = [FC, torch]
    else:
        wrappers = [ComplexTensor]
        modules = [FC]

    for complex_wrapper, complex_module in zip(wrappers, modules):
        mat = complex_wrapper(torch.from_numpy(mat_np.real),
                              torch.from_numpy(mat_np.imag))
        if not real_vec or complex_wrapper is ComplexTensor:
            vec = complex_wrapper(torch.rand(2, 3, 1), torch.rand(2, 3, 1))
            vec2 = vec
        else:
            vec = torch.rand(2, 3, 1)
            vec2 = complex_wrapper(vec, torch.zeros_like(vec))
        ret = solve(vec, mat)
        ret2 = complex_module.solve(vec2, mat)[0]
        assert complex_module.allclose(ret, ret2)
Beispiel #8
0
def get_rtf(
    psd_speech,
    psd_noise,
    reference_vector: Union[int, torch.Tensor, None] = None,
    iterations: int = 3,
    use_torch_solver: bool = True,
):
    """Calculate the relative transfer function (RTF) using the power method.

    Algorithm:
        1) rtf = reference_vector
        2) for i in range(iterations):
             rtf = (psd_noise^-1 @ psd_speech) @ rtf
             rtf = rtf / ||rtf||_2  # this normalization can be skipped
        3) rtf = psd_noise @ rtf
        4) rtf = rtf / rtf[..., ref_channel, :]
    Note: 4) Normalization at the reference channel is not performed here.

    Args:
        psd_speech (torch.complex64/ComplexTensor):
            speech covariance matrix (..., F, C, C)
        psd_noise (torch.complex64/ComplexTensor):
            noise covariance matrix (..., F, C, C)
        reference_vector (torch.Tensor or int): (..., C) or scalar
        iterations (int): number of iterations in power method
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
    Returns:
        rtf (torch.complex64/ComplexTensor): (..., F, C, 1)
    """
    if use_torch_solver:
        phi = solve(psd_speech, psd_noise)
    else:
        phi = matmul(inverse(psd_noise), psd_speech)
    rtf = (phi[..., reference_vector,
               None] if isinstance(reference_vector, int) else matmul(
                   phi, reference_vector[..., None, :, None]))
    for _ in range(iterations - 2):
        rtf = matmul(phi, rtf)
        # rtf = rtf / complex_norm(rtf)
    rtf = matmul(psd_speech, rtf)
    return rtf
Beispiel #9
0
def get_mfmvdr_vector(gammax,
                      Phi,
                      use_torch_solver: bool = True,
                      eps: float = EPS):
    """Compute conventional MFMPDR/MFMVDR filter.

    Args:
        gammax (torch.complex64/ComplexTensor): (..., L, N)
        Phi (torch.complex64/ComplexTensor): (..., L, N, N)
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        eps (float)
    Returns:
        beamforming_vector (torch.complex64/ComplexTensor): (..., L, N)
    """
    # (..., L, N)
    if use_torch_solver:
        numerator = solve(gammax.unsqueeze(-1), Phi).squeeze(-1)
    else:
        numerator = matmul(inverse(Phi), gammax.unsqueeze(-1)).squeeze(-1)
    denominator = einsum("...d,...d->...", gammax.conj(), numerator)
    return numerator / (denominator.real.unsqueeze(-1) + eps)
Beispiel #10
0
def get_mwf_vector(
    psd_s,
    psd_n,
    reference_vector: Union[torch.Tensor, int],
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
):
    """Return the MWF (Minimum Multi-channel Wiener Filter) vector:

        h = (Npsd^-1 @ Spsd) @ u

    Args:
        psd_s (torch.complex64/ComplexTensor):
            speech covariance matrix (..., F, C, C)
        psd_n (torch.complex64/ComplexTensor):
            power-normalized observation covariance matrix (..., F, C, C)
        reference_vector (torch.Tensor or int): (..., C) or scalar
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (torch.complex64/ComplexTensor): (..., F, C)
    """  # noqa: D400
    if diagonal_loading:
        psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps)

    if use_torch_solver:
        ws = solve(psd_s, psd_n)
    else:
        ws = matmul(inverse(psd_n), psd_s)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    if isinstance(reference_vector, int):
        beamform_vector = ws[..., reference_vector]
    else:
        beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector)
    return beamform_vector
Beispiel #11
0
def get_WPD_filter_with_rtf(
    psd_observed_bar: Union[torch.Tensor, ComplexTensor],
    psd_speech: Union[torch.Tensor, ComplexTensor],
    psd_noise: Union[torch.Tensor, ComplexTensor],
    iterations: int = 3,
    reference_vector: Union[int, torch.Tensor, None] = None,
    normalize_ref_channel: Optional[int] = None,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-15,
) -> Union[torch.Tensor, ComplexTensor]:
    """Return the WPD vector calculated with RTF.

        WPD is the Weighted Power minimization Distortionless response
        convolutional beamformer. As follows:

        h = (Rf^-1 @ vbar) / (vbar^H @ R^-1 @ vbar)

    Reference:
        T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer
        for Simultaneous Denoising and Dereverberation," in IEEE Signal
        Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi:
        10.1109/LSP.2019.2911179.
        https://ieeexplore.ieee.org/document/8691481

    Args:
        psd_observed_bar (torch.complex64/ComplexTensor):
            stacked observation covariance matrix
        psd_speech (torch.complex64/ComplexTensor):
            speech covariance matrix (..., F, C, C)
        psd_noise (torch.complex64/ComplexTensor):
            noise covariance matrix (..., F, C, C)
        iterations (int): number of iterations in power method
        reference_vector (torch.Tensor or int): (..., C) or scalar
        normalize_ref_channel (int):
            reference channel for normalizing the RTF
        use_torch_solver (bool):
            Whether to use `solve` instead of `inverse`
        diagonal_loading (bool):
            Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (torch.complex64/ComplexTensor)r: (..., F, C)
    """
    if isinstance(psd_speech, ComplexTensor):
        pad_func = FC.pad
    elif is_torch_complex_tensor(psd_speech):
        pad_func = torch.nn.functional.pad
    else:
        raise ValueError(
            "Please update your PyTorch version to 1.9+ for complex support.")

    C = psd_noise.size(-1)
    if diagonal_loading:
        psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

    # (B, F, C, 1)
    rtf = get_rtf(
        psd_speech,
        psd_noise,
        reference_vector,
        iterations=iterations,
        use_torch_solver=use_torch_solver,
    )

    # (B, F, (K+1)*C, 1)
    rtf = pad_func(rtf, (0, 0, 0, psd_observed_bar.shape[-1] - C), "constant",
                   0)
    # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1)
    if use_torch_solver:
        numerator = solve(rtf, psd_observed_bar).squeeze(-1)
    else:
        numerator = matmul(inverse(psd_observed_bar), rtf).squeeze(-1)
    denominator = einsum("...d,...d->...", rtf.squeeze(-1).conj(), numerator)
    if normalize_ref_channel is not None:
        scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj()
        beamforming_vector = numerator * scale / (
            denominator.real.unsqueeze(-1) + eps)
    else:
        beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps)
    return beamforming_vector
Beispiel #12
0
def get_mvdr_vector_with_rtf(
    psd_n: Union[torch.Tensor, ComplexTensor],
    psd_speech: Union[torch.Tensor, ComplexTensor],
    psd_noise: Union[torch.Tensor, ComplexTensor],
    iterations: int = 3,
    reference_vector: Union[int, torch.Tensor, None] = None,
    normalize_ref_channel: Optional[int] = None,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> Union[torch.Tensor, ComplexTensor]:
    """Return the MVDR (Minimum Variance Distortionless Response) vector
        calculated with RTF:

        h = (Npsd^-1 @ rtf) / (rtf^H @ Npsd^-1 @ rtf)

    Reference:
        On optimal frequency-domain multichannel linear filtering
        for noise reduction; M. Souden et al., 2010;
        https://ieeexplore.ieee.org/document/5089420

    Args:
        psd_n (torch.complex64/ComplexTensor):
            observation/noise covariance matrix (..., F, C, C)
        psd_speech (torch.complex64/ComplexTensor):
            speech covariance matrix (..., F, C, C)
        psd_noise (torch.complex64/ComplexTensor):
            noise covariance matrix (..., F, C, C)
        iterations (int): number of iterations in power method
        reference_vector (torch.Tensor or int): (..., C) or scalar
        normalize_ref_channel (int): reference channel for normalizing the RTF
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (torch.complex64/ComplexTensor): (..., F, C)
    """  # noqa: H405, D205, D400
    if diagonal_loading:
        psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

    # (B, F, C, 1)
    rtf = get_rtf(
        psd_speech,
        psd_noise,
        reference_vector,
        iterations=iterations,
        use_torch_solver=use_torch_solver,
    )

    # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1)
    if use_torch_solver:
        numerator = solve(rtf, psd_n).squeeze(-1)
    else:
        numerator = matmul(inverse(psd_n), rtf).squeeze(-1)
    denominator = einsum("...d,...d->...", rtf.squeeze(-1).conj(), numerator)
    if normalize_ref_channel is not None:
        scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj()
        beamforming_vector = numerator * scale / (
            denominator.real.unsqueeze(-1) + eps)
    else:
        beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps)
    return beamforming_vector
Beispiel #13
0
def get_gev_vector(
    psd_noise: Union[torch.Tensor, ComplexTensor],
    psd_speech: Union[torch.Tensor, ComplexTensor],
    mode="power",
    reference_vector: Union[int, torch.Tensor] = 0,
    iterations: int = 3,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> Union[torch.Tensor, ComplexTensor]:
    """Return the generalized eigenvalue (GEV) beamformer vector:

        psd_speech @ h = lambda * psd_noise @ h

    Reference:
        Blind acoustic beamforming based on generalized eigenvalue decomposition;
        E. Warsitz and R. Haeb-Umbach, 2007.

    Args:
        psd_noise (torch.complex64/ComplexTensor):
            noise covariance matrix (..., F, C, C)
        psd_speech (torch.complex64/ComplexTensor):
            speech covariance matrix (..., F, C, C)
        mode (str): one of ("power", "evd")
            "power": power method
            "evd": eigenvalue decomposition (only for torch builtin complex tensors)
        reference_vector (torch.Tensor or int): (..., C) or scalar
        iterations (int): number of iterations in power method
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (torch.complex64/ComplexTensor): (..., F, C)
    """  # noqa: H405, D205, D400
    if diagonal_loading:
        psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

    if mode == "power":
        if use_torch_solver:
            phi = solve(psd_speech, psd_noise)
        else:
            phi = matmul(inverse(psd_noise), psd_speech)
        e_vec = (phi[..., reference_vector,
                     None] if isinstance(reference_vector, int) else matmul(
                         phi, reference_vector[..., None, :, None]))
        for _ in range(iterations - 1):
            e_vec = matmul(phi, e_vec)
            # e_vec = e_vec / complex_norm(e_vec, dim=-1, keepdim=True)
        e_vec = e_vec.squeeze(-1)
    elif mode == "evd":
        assert (is_torch_1_9_plus and is_torch_complex_tensor(psd_speech)
                and is_torch_complex_tensor(psd_noise))
        # e_vec = generalized_eigenvalue_decomposition(psd_speech, psd_noise)[1][...,-1]
        e_vec = psd_noise.new_zeros(psd_noise.shape[:-1])
        for f in range(psd_noise.shape[-3]):
            try:
                e_vec[..., f, :] = generalized_eigenvalue_decomposition(
                    psd_speech[..., f, :, :], psd_noise[..., f, :, :])[1][...,
                                                                          -1]
            except RuntimeError:
                # port from github.com/fgnt/nn-gev/blob/master/fgnt/beamforming.py#L106
                print(
                    "GEV beamformer: LinAlg error for frequency {}".format(f),
                    flush=True,
                )
                C = psd_noise.size(-1)
                e_vec[...,
                      f, :] = (psd_noise.new_ones(e_vec[..., f, :].shape) /
                               FC.trace(psd_noise[..., f, :, :]) * C)
    else:
        raise ValueError("Unknown mode: %s" % mode)

    beamforming_vector = e_vec / complex_norm(e_vec, dim=-1, keepdim=True)
    beamforming_vector = gev_phase_correction(beamforming_vector)
    return beamforming_vector
Beispiel #14
0
def get_rank1_mwf_vector(
    psd_speech,
    psd_noise,
    reference_vector: Union[torch.Tensor, int],
    denoising_weight: float = 1.0,
    approx_low_rank_psd_speech: bool = False,
    iterations: int = 3,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
):
    """Return the R1-MWF (Rank-1 Multi-channel Wiener Filter) vector

        h = (Npsd^-1 @ Spsd) / (mu + Tr(Npsd^-1 @ Spsd)) @ u

    Reference:
        [1] Rank-1 constrained multichannel Wiener filter for speech recognition in
        noisy environments; Z. Wang et al, 2018
        https://hal.inria.fr/hal-01634449/document
        [2] Low-rank approximation based multichannel Wiener filter algorithms for
        noise reduction with application in cochlear implants; R. Serizel, 2014
        https://ieeexplore.ieee.org/document/6730918

    Args:
        psd_speech (torch.complex64/ComplexTensor):
            speech covariance matrix (..., F, C, C)
        psd_noise (torch.complex64/ComplexTensor):
            noise covariance matrix (..., F, C, C)
        reference_vector (torch.Tensor or int): (..., C) or scalar
        denoising_weight (float): a trade-off parameter between noise reduction and
            speech distortion.
            A larger value leads to more noise reduction at the expense of more speech
            distortion.
            When `denoising_weight = 0`, it corresponds to MVDR beamformer.
        approx_low_rank_psd_speech (bool): whether to replace original input psd_speech
            with its low-rank approximation as in [1]
        iterations (int): number of iterations in power method, only used when
            `approx_low_rank_psd_speech = True`
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (torch.complex64/ComplexTensor): (..., F, C)
    """  # noqa: H405, D205, D400
    if approx_low_rank_psd_speech:
        if diagonal_loading:
            psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

        # (B, F, C, 1)
        recon_vec = get_rtf(
            psd_speech,
            psd_noise,
            mode="power",
            iterations=iterations,
            reference_vector=reference_vector,
            use_torch_solver=use_torch_solver,
        )
        # Eq. (25) in Ref[1]
        psd_speech_r1 = matmul(recon_vec, recon_vec.conj().transpose(-1, -2))
        sigma_speech = FC.trace(psd_speech) / (FC.trace(psd_speech_r1) + eps)
        psd_speech_r1 = psd_speech_r1 * sigma_speech[..., None, None]
        # c.f. Eq. (62) in Ref[2]
        psd_speech = psd_speech_r1
    elif diagonal_loading:
        psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

    if use_torch_solver:
        numerator = solve(psd_speech, psd_noise)
    else:
        numerator = matmul(inverse(psd_noise), psd_speech)

    # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not
    # support bacth processing. Use FC.trace() as fallback.
    # ws: (..., C, C) / (...,) -> (..., C, C)
    ws = numerator / (denoising_weight + FC.trace(numerator)[..., None, None] +
                      eps)

    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    if isinstance(reference_vector, int):
        beamform_vector = ws[..., reference_vector]
    else:
        beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector)
    return beamform_vector
Beispiel #15
0
def get_sdw_mwf_vector(
    psd_speech,
    psd_noise,
    reference_vector: Union[torch.Tensor, int],
    denoising_weight: float = 1.0,
    approx_low_rank_psd_speech: bool = False,
    iterations: int = 3,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
):
    """Return the SDW-MWF (Speech Distortion Weighted Multi-channel Wiener Filter) vector

        h = (Spsd + mu * Npsd)^-1 @ Spsd @ u

    Reference:
        [1] Spatially pre-processed speech distortion weighted multi-channel Wiener
        filtering for noise reduction; A. Spriet et al, 2004
        https://dl.acm.org/doi/abs/10.1016/j.sigpro.2004.07.028
        [2] Rank-1 constrained multichannel Wiener filter for speech recognition in
        noisy environments; Z. Wang et al, 2018
        https://hal.inria.fr/hal-01634449/document
        [3] Low-rank approximation based multichannel Wiener filter algorithms for
        noise reduction with application in cochlear implants; R. Serizel, 2014
        https://ieeexplore.ieee.org/document/6730918

    Args:
        psd_speech (torch.complex64/ComplexTensor):
            speech covariance matrix (..., F, C, C)
        psd_noise (torch.complex64/ComplexTensor):
            noise covariance matrix (..., F, C, C)
        reference_vector (torch.Tensor or int): (..., C) or scalar
        denoising_weight (float): a trade-off parameter between noise reduction and
            speech distortion.
            A larger value leads to more noise reduction at the expense of more speech
            distortion.
            The plain MWF is obtained with `denoising_weight = 1` (by default).
        approx_low_rank_psd_speech (bool): whether to replace original input psd_speech
            with its low-rank approximation as in [2]
        iterations (int): number of iterations in power method, only used when
            `approx_low_rank_psd_speech = True`
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (torch.complex64/ComplexTensor): (..., F, C)
    """  # noqa: H405, D205, D400
    if approx_low_rank_psd_speech:
        if diagonal_loading:
            psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

        # (B, F, C, 1)
        recon_vec = get_rtf(
            psd_speech,
            psd_noise,
            mode="power",
            iterations=iterations,
            reference_vector=reference_vector,
            use_torch_solver=use_torch_solver,
        )
        # Eq. (25) in Ref[2]
        psd_speech_r1 = matmul(recon_vec, recon_vec.conj().transpose(-1, -2))
        sigma_speech = FC.trace(psd_speech) / (FC.trace(psd_speech_r1) + eps)
        psd_speech_r1 = psd_speech_r1 * sigma_speech[..., None, None]
        # c.f. Eq. (62) in Ref[3]
        psd_speech = psd_speech_r1

    psd_n = psd_speech + denoising_weight * psd_noise
    if diagonal_loading:
        psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps)

    if use_torch_solver:
        ws = solve(psd_speech, psd_n)
    else:
        ws = matmul(inverse(psd_n), psd_speech)

    if isinstance(reference_vector, int):
        beamform_vector = ws[..., reference_vector]
    else:
        beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector)
    return beamform_vector