Exemple #1
0
def test_complex_impl_consistency():
    if not is_torch_1_9_plus:
        return
    mat_th = torch.complex(torch.from_numpy(mat_np.real),
                           torch.from_numpy(mat_np.imag))
    mat_ct = ComplexTensor(torch.from_numpy(mat_np.real),
                           torch.from_numpy(mat_np.imag))
    bs = mat_th.shape[0]
    rank = mat_th.shape[-1]
    vec_th = torch.complex(torch.rand(bs, rank),
                           torch.rand(bs, rank)).type_as(mat_th)
    vec_ct = ComplexTensor(vec_th.real, vec_th.imag)

    for result_th, result_ct in (
        (abs(mat_th), abs(mat_ct)),
        (inverse(mat_th), inverse(mat_ct)),
        (matmul(mat_th,
                vec_th.unsqueeze(-1)), matmul(mat_ct, vec_ct.unsqueeze(-1))),
        (solve(vec_th.unsqueeze(-1),
               mat_th), solve(vec_ct.unsqueeze(-1), mat_ct)),
        (
            einsum("bec,bc->be", mat_th, vec_th),
            einsum("bec,bc->be", mat_ct, vec_ct),
        ),
    ):
        np.testing.assert_allclose(result_th.numpy(),
                                   result_ct.numpy(),
                                   atol=1e-6)
Exemple #2
0
def get_rtf(
    psd_speech,
    psd_noise,
    mode="power",
    reference_vector: Union[int, torch.Tensor] = 0,
    iterations: int = 3,
    use_torch_solver: bool = True,
):
    """Calculate the relative transfer function (RTF)

    Algorithm of power method:
        1) rtf = reference_vector
        2) for i in range(iterations):
             rtf = (psd_noise^-1 @ psd_speech) @ rtf
             rtf = rtf / ||rtf||_2  # this normalization can be skipped
        3) rtf = psd_noise @ rtf
        4) rtf = rtf / rtf[..., ref_channel, :]
    Note: 4) Normalization at the reference channel is not performed here.

    Args:
        psd_speech (torch.complex64/ComplexTensor):
            speech covariance matrix (..., F, C, C)
        psd_noise (torch.complex64/ComplexTensor):
            noise covariance matrix (..., F, C, C)
        mode (str): one of ("power", "evd")
            "power": power method
            "evd": eigenvalue decomposition
        reference_vector (torch.Tensor or int): (..., C) or scalar
        iterations (int): number of iterations in power method
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
    Returns:
        rtf (torch.complex64/ComplexTensor): (..., F, C, 1)
    """
    if mode == "power":
        if use_torch_solver:
            phi = solve(psd_speech, psd_noise)
        else:
            phi = matmul(inverse(psd_noise), psd_speech)
        rtf = (
            phi[..., reference_vector, None]
            if isinstance(reference_vector, int)
            else matmul(phi, reference_vector[..., None, :, None])
        )
        for _ in range(iterations - 2):
            rtf = matmul(phi, rtf)
            # rtf = rtf / complex_norm(rtf, dim=-1, keepdim=True)
        rtf = matmul(psd_speech, rtf)
    elif mode == "evd":
        assert (
            is_torch_1_9_plus
            and is_torch_complex_tensor(psd_speech)
            and is_torch_complex_tensor(psd_noise)
        )
        e_vec = generalized_eigenvalue_decomposition(psd_speech, psd_noise)[1]
        rtf = matmul(psd_noise, e_vec[..., -1, None])
    else:
        raise ValueError("Unknown mode: %s" % mode)
    return rtf
Exemple #3
0
def get_lcmv_vector_with_rtf(
    psd_n: Union[torch.Tensor, ComplexTensor],
    rtf_mat: Union[torch.Tensor, ComplexTensor],
    reference_vector: Union[int, torch.Tensor, None] = None,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> Union[torch.Tensor, ComplexTensor]:
    """Return the LCMV (Linearly Constrained Minimum Variance) vector
        calculated with RTF:

        h = (Npsd^-1 @ rtf_mat) @ (rtf_mat^H @ Npsd^-1 @ rtf_mat)^-1 @ p

    Reference:
        H. L. Van Trees, “Optimum array processing: Part IV of detection, estimation,
        and modulation theory,” John Wiley & Sons, 2004. (Chapter 6.7)

    Args:
        psd_n (torch.complex64/ComplexTensor):
            observation/noise covariance matrix (..., F, C, C)
        rtf_mat (torch.complex64/ComplexTensor):
            RTF matrix (..., F, C, num_spk)
        reference_vector (torch.Tensor or int): (..., num_spk) or scalar
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (torch.complex64/ComplexTensor): (..., F, C)
    """  # noqa: H405, D205, D400
    if diagonal_loading:
        psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps)

    # numerator: (..., C_1, C_2) x (..., C_2, num_spk) -> (..., C_1, num_spk)
    if use_torch_solver:
        numerator = solve(rtf_mat, psd_n)
    else:
        numerator = matmul(inverse(psd_n), rtf_mat)
    denominator = matmul(rtf_mat.conj().transpose(-1, -2), numerator)
    if isinstance(reference_vector, int):
        ws = inverse(denominator)[..., reference_vector, None]
    else:
        ws = solve(reference_vector, denominator)
    beamforming_vector = matmul(numerator, ws).squeeze(-1)
    return beamforming_vector
Exemple #4
0
def get_WPD_filter(
    Phi: Union[torch.Tensor, ComplexTensor],
    Rf: Union[torch.Tensor, ComplexTensor],
    reference_vector: torch.Tensor,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> Union[torch.Tensor, ComplexTensor]:
    """Return the WPD vector.

        WPD is the Weighted Power minimization Distortionless response
        convolutional beamformer. As follows:

        h = (Rf^-1 @ Phi_{xx}) / tr[(Rf^-1) @ Phi_{xx}] @ u

    Reference:
        T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer
        for Simultaneous Denoising and Dereverberation," in IEEE Signal
        Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi:
        10.1109/LSP.2019.2911179.
        https://ieeexplore.ieee.org/document/8691481

    Args:
        Phi (torch.complex64/ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C)
            is the PSD of zero-padded speech [x^T(t,f) 0 ... 0]^T.
        Rf (torch.complex64/ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C)
            is the power normalized spatio-temporal covariance matrix.
        reference_vector (torch.Tensor): (B, (btaps+1) * C)
            is the reference_vector.
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):

    Returns:
        filter_matrix (torch.complex64/ComplexTensor): (B, F, (btaps + 1) * C)
    """
    if diagonal_loading:
        Rf = tik_reg(Rf, reg=diag_eps, eps=eps)

    # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
    if use_torch_solver:
        numerator = solve(Phi, Rf)
    else:
        numerator = matmul(inverse(Rf), Phi)
    # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not
    # support bacth processing. Use FC.trace() as fallback.
    # ws: (..., C, C) / (...,) -> (..., C, C)
    ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector)
    # (B, F, (btaps + 1) * C)
    return beamform_vector
Exemple #5
0
def get_rtf(
    psd_speech,
    psd_noise,
    reference_vector: Union[int, torch.Tensor, None] = None,
    iterations: int = 3,
    use_torch_solver: bool = True,
):
    """Calculate the relative transfer function (RTF) using the power method.

    Algorithm:
        1) rtf = reference_vector
        2) for i in range(iterations):
             rtf = (psd_noise^-1 @ psd_speech) @ rtf
             rtf = rtf / ||rtf||_2  # this normalization can be skipped
        3) rtf = psd_noise @ rtf
        4) rtf = rtf / rtf[..., ref_channel, :]
    Note: 4) Normalization at the reference channel is not performed here.

    Args:
        psd_speech (torch.complex64/ComplexTensor):
            speech covariance matrix (..., F, C, C)
        psd_noise (torch.complex64/ComplexTensor):
            noise covariance matrix (..., F, C, C)
        reference_vector (torch.Tensor or int): (..., C) or scalar
        iterations (int): number of iterations in power method
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
    Returns:
        rtf (torch.complex64/ComplexTensor): (..., F, C, 1)
    """
    if use_torch_solver:
        phi = solve(psd_speech, psd_noise)
    else:
        phi = matmul(inverse(psd_noise), psd_speech)
    rtf = (phi[..., reference_vector,
               None] if isinstance(reference_vector, int) else matmul(
                   phi, reference_vector[..., None, :, None]))
    for _ in range(iterations - 2):
        rtf = matmul(phi, rtf)
        # rtf = rtf / complex_norm(rtf)
    rtf = matmul(psd_speech, rtf)
    return rtf
Exemple #6
0
def get_mvdr_vector(
    psd_s,
    psd_n,
    reference_vector: torch.Tensor,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
):
    """Return the MVDR (Minimum Variance Distortionless Response) vector:

        h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u

    Reference:
        On optimal frequency-domain multichannel linear filtering
        for noise reduction; M. Souden et al., 2010;
        https://ieeexplore.ieee.org/document/5089420

    Args:
        psd_s (torch.complex64/ComplexTensor):
            speech covariance matrix (..., F, C, C)
        psd_n (torch.complex64/ComplexTensor):
            observation/noise covariance matrix (..., F, C, C)
        reference_vector (torch.Tensor): (..., C)
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (torch.complex64/ComplexTensor): (..., F, C)
    """  # noqa: D400
    if diagonal_loading:
        psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps)

    if use_torch_solver:
        numerator = solve(psd_s, psd_n)
    else:
        numerator = matmul(inverse(psd_n), psd_s)
    # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not
    # support bacth processing. Use FC.trace() as fallback.
    # ws: (..., C, C) / (...,) -> (..., C, C)
    ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector)
    return beamform_vector
Exemple #7
0
def get_WPD_filter_v2(
    Phi: Union[torch.Tensor, ComplexTensor],
    Rf: Union[torch.Tensor, ComplexTensor],
    reference_vector: torch.Tensor,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> Union[torch.Tensor, ComplexTensor]:
    """Return the WPD vector (v2).

       This implementation is more efficient than `get_WPD_filter` as
        it skips unnecessary computation with zeros.

    Args:
        Phi (torch.complex64/ComplexTensor): (B, F, C, C)
            is speech PSD.
        Rf (torch.complex64/ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C)
            is the power normalized spatio-temporal covariance matrix.
        reference_vector (torch.Tensor): (B, C)
            is the reference_vector.
        diagonal_loading (bool):
            Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):

    Returns:
        filter_matrix (torch.complex64/ComplexTensor): (B, F, (btaps+1) * C)
    """
    C = reference_vector.shape[-1]
    if diagonal_loading:
        Rf = tik_reg(Rf, reg=diag_eps, eps=eps)
    inv_Rf = inverse(Rf)
    # (B, F, (btaps+1) * C, C)
    inv_Rf_pruned = inv_Rf[..., :C]
    # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
    numerator = matmul(inv_Rf_pruned, Phi)
    # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not
    # support bacth processing. Use FC.trace() as fallback.
    # ws: (..., (btaps+1) * C, C) / (...,) -> (..., (btaps+1) * C, C)
    ws = numerator / (FC.trace(numerator[..., :C, :])[..., None, None] + eps)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector)
    # (B, F, (btaps+1) * C)
    return beamform_vector
Exemple #8
0
def test_matmul(real_vec):
    if is_torch_1_9_plus:
        wrappers = [ComplexTensor, torch.complex]
        modules = [FC, torch]
    else:
        wrappers = [ComplexTensor]
        modules = [FC]

    for complex_wrapper, complex_module in zip(wrappers, modules):
        mat = complex_wrapper(torch.rand(2, 3, 3), torch.rand(2, 3, 3))
        if real_vec:
            vec = torch.rand(2, 3, 1)
            vec2 = complex_wrapper(vec, torch.zeros_like(vec))
        else:
            vec = complex_wrapper(torch.rand(2, 3, 1), torch.rand(2, 3, 1))
            vec2 = vec
        ret = matmul(mat, vec)
        ret2 = complex_module.matmul(mat, vec2)
        assert complex_module.allclose(ret, ret2)
Exemple #9
0
def get_filter_matrix_conj(
    correlation_matrix: Union[torch.Tensor, ComplexTensor],
    correlation_vector: Union[torch.Tensor, ComplexTensor],
    eps: float = 1e-10,
) -> Union[torch.Tensor, ComplexTensor]:
    """Calculate (conjugate) filter matrix based on correlations for one freq.

    Args:
        correlation_matrix : Correlation matrix (F, taps * C, taps * C)
        correlation_vector : Correlation vector (F, taps, C, C)
        eps:

    Returns:
        filter_matrix_conj (torch.complex/ComplexTensor): (F, taps, C, C)
    """
    F, taps, C, _ = correlation_vector.size()

    # (F, taps, C1, C2) -> (F, C1, taps, C2) -> (F, C1, taps * C2)
    correlation_vector = (
        correlation_vector.permute(0, 2, 1, 3).contiguous().view(F, C, taps * C)
    )

    eye = torch.eye(
        correlation_matrix.size(-1),
        dtype=correlation_matrix.dtype,
        device=correlation_matrix.device,
    )
    shape = (
        tuple(1 for _ in range(correlation_matrix.dim() - 2))
        + correlation_matrix.shape[-2:]
    )
    eye = eye.view(*shape)
    correlation_matrix += eps * eye

    inv_correlation_matrix = correlation_matrix.inverse()
    # (F, C, taps, C) x (F, taps * C, taps * C) -> (F, C, taps * C)
    stacked_filter_conj = matmul(
        correlation_vector, inv_correlation_matrix.transpose(-1, -2)
    )

    # (F, C1, taps * C2) -> (F, C1, taps, C2) -> (F, taps, C2, C1)
    filter_matrix_conj = stacked_filter_conj.view(F, C, taps, C).permute(0, 2, 3, 1)
    return filter_matrix_conj
Exemple #10
0
def get_mfmvdr_vector(gammax,
                      Phi,
                      use_torch_solver: bool = True,
                      eps: float = EPS):
    """Compute conventional MFMPDR/MFMVDR filter.

    Args:
        gammax (torch.complex64/ComplexTensor): (..., L, N)
        Phi (torch.complex64/ComplexTensor): (..., L, N, N)
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        eps (float)
    Returns:
        beamforming_vector (torch.complex64/ComplexTensor): (..., L, N)
    """
    # (..., L, N)
    if use_torch_solver:
        numerator = solve(gammax.unsqueeze(-1), Phi).squeeze(-1)
    else:
        numerator = matmul(inverse(Phi), gammax.unsqueeze(-1)).squeeze(-1)
    denominator = einsum("...d,...d->...", gammax.conj(), numerator)
    return numerator / (denominator.real.unsqueeze(-1) + eps)
Exemple #11
0
def get_mwf_vector(
    psd_s,
    psd_n,
    reference_vector: Union[torch.Tensor, int],
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
):
    """Return the MWF (Minimum Multi-channel Wiener Filter) vector:

        h = (Npsd^-1 @ Spsd) @ u

    Args:
        psd_s (torch.complex64/ComplexTensor):
            speech covariance matrix (..., F, C, C)
        psd_n (torch.complex64/ComplexTensor):
            power-normalized observation covariance matrix (..., F, C, C)
        reference_vector (torch.Tensor or int): (..., C) or scalar
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (torch.complex64/ComplexTensor): (..., F, C)
    """  # noqa: D400
    if diagonal_loading:
        psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps)

    if use_torch_solver:
        ws = solve(psd_s, psd_n)
    else:
        ws = matmul(inverse(psd_n), psd_s)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    if isinstance(reference_vector, int):
        beamform_vector = ws[..., reference_vector]
    else:
        beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector)
    return beamform_vector
Exemple #12
0
def get_WPD_filter_with_rtf(
    psd_observed_bar: Union[torch.Tensor, ComplexTensor],
    psd_speech: Union[torch.Tensor, ComplexTensor],
    psd_noise: Union[torch.Tensor, ComplexTensor],
    iterations: int = 3,
    reference_vector: Union[int, torch.Tensor, None] = None,
    normalize_ref_channel: Optional[int] = None,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-15,
) -> Union[torch.Tensor, ComplexTensor]:
    """Return the WPD vector calculated with RTF.

        WPD is the Weighted Power minimization Distortionless response
        convolutional beamformer. As follows:

        h = (Rf^-1 @ vbar) / (vbar^H @ R^-1 @ vbar)

    Reference:
        T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer
        for Simultaneous Denoising and Dereverberation," in IEEE Signal
        Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi:
        10.1109/LSP.2019.2911179.
        https://ieeexplore.ieee.org/document/8691481

    Args:
        psd_observed_bar (torch.complex64/ComplexTensor):
            stacked observation covariance matrix
        psd_speech (torch.complex64/ComplexTensor):
            speech covariance matrix (..., F, C, C)
        psd_noise (torch.complex64/ComplexTensor):
            noise covariance matrix (..., F, C, C)
        iterations (int): number of iterations in power method
        reference_vector (torch.Tensor or int): (..., C) or scalar
        normalize_ref_channel (int):
            reference channel for normalizing the RTF
        use_torch_solver (bool):
            Whether to use `solve` instead of `inverse`
        diagonal_loading (bool):
            Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (torch.complex64/ComplexTensor)r: (..., F, C)
    """
    if isinstance(psd_speech, ComplexTensor):
        pad_func = FC.pad
    elif is_torch_complex_tensor(psd_speech):
        pad_func = torch.nn.functional.pad
    else:
        raise ValueError(
            "Please update your PyTorch version to 1.9+ for complex support.")

    C = psd_noise.size(-1)
    if diagonal_loading:
        psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

    # (B, F, C, 1)
    rtf = get_rtf(
        psd_speech,
        psd_noise,
        reference_vector,
        iterations=iterations,
        use_torch_solver=use_torch_solver,
    )

    # (B, F, (K+1)*C, 1)
    rtf = pad_func(rtf, (0, 0, 0, psd_observed_bar.shape[-1] - C), "constant",
                   0)
    # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1)
    if use_torch_solver:
        numerator = solve(rtf, psd_observed_bar).squeeze(-1)
    else:
        numerator = matmul(inverse(psd_observed_bar), rtf).squeeze(-1)
    denominator = einsum("...d,...d->...", rtf.squeeze(-1).conj(), numerator)
    if normalize_ref_channel is not None:
        scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj()
        beamforming_vector = numerator * scale / (
            denominator.real.unsqueeze(-1) + eps)
    else:
        beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps)
    return beamforming_vector
Exemple #13
0
def get_mvdr_vector_with_rtf(
    psd_n: Union[torch.Tensor, ComplexTensor],
    psd_speech: Union[torch.Tensor, ComplexTensor],
    psd_noise: Union[torch.Tensor, ComplexTensor],
    iterations: int = 3,
    reference_vector: Union[int, torch.Tensor, None] = None,
    normalize_ref_channel: Optional[int] = None,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> Union[torch.Tensor, ComplexTensor]:
    """Return the MVDR (Minimum Variance Distortionless Response) vector
        calculated with RTF:

        h = (Npsd^-1 @ rtf) / (rtf^H @ Npsd^-1 @ rtf)

    Reference:
        On optimal frequency-domain multichannel linear filtering
        for noise reduction; M. Souden et al., 2010;
        https://ieeexplore.ieee.org/document/5089420

    Args:
        psd_n (torch.complex64/ComplexTensor):
            observation/noise covariance matrix (..., F, C, C)
        psd_speech (torch.complex64/ComplexTensor):
            speech covariance matrix (..., F, C, C)
        psd_noise (torch.complex64/ComplexTensor):
            noise covariance matrix (..., F, C, C)
        iterations (int): number of iterations in power method
        reference_vector (torch.Tensor or int): (..., C) or scalar
        normalize_ref_channel (int): reference channel for normalizing the RTF
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (torch.complex64/ComplexTensor): (..., F, C)
    """  # noqa: H405, D205, D400
    if diagonal_loading:
        psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

    # (B, F, C, 1)
    rtf = get_rtf(
        psd_speech,
        psd_noise,
        reference_vector,
        iterations=iterations,
        use_torch_solver=use_torch_solver,
    )

    # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1)
    if use_torch_solver:
        numerator = solve(rtf, psd_n).squeeze(-1)
    else:
        numerator = matmul(inverse(psd_n), rtf).squeeze(-1)
    denominator = einsum("...d,...d->...", rtf.squeeze(-1).conj(), numerator)
    if normalize_ref_channel is not None:
        scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj()
        beamforming_vector = numerator * scale / (
            denominator.real.unsqueeze(-1) + eps)
    else:
        beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps)
    return beamforming_vector
Exemple #14
0
def get_gev_vector(
    psd_noise: Union[torch.Tensor, ComplexTensor],
    psd_speech: Union[torch.Tensor, ComplexTensor],
    mode="power",
    reference_vector: Union[int, torch.Tensor] = 0,
    iterations: int = 3,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> Union[torch.Tensor, ComplexTensor]:
    """Return the generalized eigenvalue (GEV) beamformer vector:

        psd_speech @ h = lambda * psd_noise @ h

    Reference:
        Blind acoustic beamforming based on generalized eigenvalue decomposition;
        E. Warsitz and R. Haeb-Umbach, 2007.

    Args:
        psd_noise (torch.complex64/ComplexTensor):
            noise covariance matrix (..., F, C, C)
        psd_speech (torch.complex64/ComplexTensor):
            speech covariance matrix (..., F, C, C)
        mode (str): one of ("power", "evd")
            "power": power method
            "evd": eigenvalue decomposition (only for torch builtin complex tensors)
        reference_vector (torch.Tensor or int): (..., C) or scalar
        iterations (int): number of iterations in power method
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (torch.complex64/ComplexTensor): (..., F, C)
    """  # noqa: H405, D205, D400
    if diagonal_loading:
        psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

    if mode == "power":
        if use_torch_solver:
            phi = solve(psd_speech, psd_noise)
        else:
            phi = matmul(inverse(psd_noise), psd_speech)
        e_vec = (phi[..., reference_vector,
                     None] if isinstance(reference_vector, int) else matmul(
                         phi, reference_vector[..., None, :, None]))
        for _ in range(iterations - 1):
            e_vec = matmul(phi, e_vec)
            # e_vec = e_vec / complex_norm(e_vec, dim=-1, keepdim=True)
        e_vec = e_vec.squeeze(-1)
    elif mode == "evd":
        assert (is_torch_1_9_plus and is_torch_complex_tensor(psd_speech)
                and is_torch_complex_tensor(psd_noise))
        # e_vec = generalized_eigenvalue_decomposition(psd_speech, psd_noise)[1][...,-1]
        e_vec = psd_noise.new_zeros(psd_noise.shape[:-1])
        for f in range(psd_noise.shape[-3]):
            try:
                e_vec[..., f, :] = generalized_eigenvalue_decomposition(
                    psd_speech[..., f, :, :], psd_noise[..., f, :, :])[1][...,
                                                                          -1]
            except RuntimeError:
                # port from github.com/fgnt/nn-gev/blob/master/fgnt/beamforming.py#L106
                print(
                    "GEV beamformer: LinAlg error for frequency {}".format(f),
                    flush=True,
                )
                C = psd_noise.size(-1)
                e_vec[...,
                      f, :] = (psd_noise.new_ones(e_vec[..., f, :].shape) /
                               FC.trace(psd_noise[..., f, :, :]) * C)
    else:
        raise ValueError("Unknown mode: %s" % mode)

    beamforming_vector = e_vec / complex_norm(e_vec, dim=-1, keepdim=True)
    beamforming_vector = gev_phase_correction(beamforming_vector)
    return beamforming_vector
Exemple #15
0
def get_rank1_mwf_vector(
    psd_speech,
    psd_noise,
    reference_vector: Union[torch.Tensor, int],
    denoising_weight: float = 1.0,
    approx_low_rank_psd_speech: bool = False,
    iterations: int = 3,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
):
    """Return the R1-MWF (Rank-1 Multi-channel Wiener Filter) vector

        h = (Npsd^-1 @ Spsd) / (mu + Tr(Npsd^-1 @ Spsd)) @ u

    Reference:
        [1] Rank-1 constrained multichannel Wiener filter for speech recognition in
        noisy environments; Z. Wang et al, 2018
        https://hal.inria.fr/hal-01634449/document
        [2] Low-rank approximation based multichannel Wiener filter algorithms for
        noise reduction with application in cochlear implants; R. Serizel, 2014
        https://ieeexplore.ieee.org/document/6730918

    Args:
        psd_speech (torch.complex64/ComplexTensor):
            speech covariance matrix (..., F, C, C)
        psd_noise (torch.complex64/ComplexTensor):
            noise covariance matrix (..., F, C, C)
        reference_vector (torch.Tensor or int): (..., C) or scalar
        denoising_weight (float): a trade-off parameter between noise reduction and
            speech distortion.
            A larger value leads to more noise reduction at the expense of more speech
            distortion.
            When `denoising_weight = 0`, it corresponds to MVDR beamformer.
        approx_low_rank_psd_speech (bool): whether to replace original input psd_speech
            with its low-rank approximation as in [1]
        iterations (int): number of iterations in power method, only used when
            `approx_low_rank_psd_speech = True`
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (torch.complex64/ComplexTensor): (..., F, C)
    """  # noqa: H405, D205, D400
    if approx_low_rank_psd_speech:
        if diagonal_loading:
            psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

        # (B, F, C, 1)
        recon_vec = get_rtf(
            psd_speech,
            psd_noise,
            mode="power",
            iterations=iterations,
            reference_vector=reference_vector,
            use_torch_solver=use_torch_solver,
        )
        # Eq. (25) in Ref[1]
        psd_speech_r1 = matmul(recon_vec, recon_vec.conj().transpose(-1, -2))
        sigma_speech = FC.trace(psd_speech) / (FC.trace(psd_speech_r1) + eps)
        psd_speech_r1 = psd_speech_r1 * sigma_speech[..., None, None]
        # c.f. Eq. (62) in Ref[2]
        psd_speech = psd_speech_r1
    elif diagonal_loading:
        psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

    if use_torch_solver:
        numerator = solve(psd_speech, psd_noise)
    else:
        numerator = matmul(inverse(psd_noise), psd_speech)

    # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not
    # support bacth processing. Use FC.trace() as fallback.
    # ws: (..., C, C) / (...,) -> (..., C, C)
    ws = numerator / (denoising_weight + FC.trace(numerator)[..., None, None] +
                      eps)

    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    if isinstance(reference_vector, int):
        beamform_vector = ws[..., reference_vector]
    else:
        beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector)
    return beamform_vector
Exemple #16
0
def get_sdw_mwf_vector(
    psd_speech,
    psd_noise,
    reference_vector: Union[torch.Tensor, int],
    denoising_weight: float = 1.0,
    approx_low_rank_psd_speech: bool = False,
    iterations: int = 3,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
):
    """Return the SDW-MWF (Speech Distortion Weighted Multi-channel Wiener Filter) vector

        h = (Spsd + mu * Npsd)^-1 @ Spsd @ u

    Reference:
        [1] Spatially pre-processed speech distortion weighted multi-channel Wiener
        filtering for noise reduction; A. Spriet et al, 2004
        https://dl.acm.org/doi/abs/10.1016/j.sigpro.2004.07.028
        [2] Rank-1 constrained multichannel Wiener filter for speech recognition in
        noisy environments; Z. Wang et al, 2018
        https://hal.inria.fr/hal-01634449/document
        [3] Low-rank approximation based multichannel Wiener filter algorithms for
        noise reduction with application in cochlear implants; R. Serizel, 2014
        https://ieeexplore.ieee.org/document/6730918

    Args:
        psd_speech (torch.complex64/ComplexTensor):
            speech covariance matrix (..., F, C, C)
        psd_noise (torch.complex64/ComplexTensor):
            noise covariance matrix (..., F, C, C)
        reference_vector (torch.Tensor or int): (..., C) or scalar
        denoising_weight (float): a trade-off parameter between noise reduction and
            speech distortion.
            A larger value leads to more noise reduction at the expense of more speech
            distortion.
            The plain MWF is obtained with `denoising_weight = 1` (by default).
        approx_low_rank_psd_speech (bool): whether to replace original input psd_speech
            with its low-rank approximation as in [2]
        iterations (int): number of iterations in power method, only used when
            `approx_low_rank_psd_speech = True`
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (torch.complex64/ComplexTensor): (..., F, C)
    """  # noqa: H405, D205, D400
    if approx_low_rank_psd_speech:
        if diagonal_loading:
            psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

        # (B, F, C, 1)
        recon_vec = get_rtf(
            psd_speech,
            psd_noise,
            mode="power",
            iterations=iterations,
            reference_vector=reference_vector,
            use_torch_solver=use_torch_solver,
        )
        # Eq. (25) in Ref[2]
        psd_speech_r1 = matmul(recon_vec, recon_vec.conj().transpose(-1, -2))
        sigma_speech = FC.trace(psd_speech) / (FC.trace(psd_speech_r1) + eps)
        psd_speech_r1 = psd_speech_r1 * sigma_speech[..., None, None]
        # c.f. Eq. (62) in Ref[3]
        psd_speech = psd_speech_r1

    psd_n = psd_speech + denoising_weight * psd_noise
    if diagonal_loading:
        psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps)

    if use_torch_solver:
        ws = solve(psd_speech, psd_n)
    else:
        ws = matmul(inverse(psd_n), psd_speech)

    if isinstance(reference_vector, int):
        beamform_vector = ws[..., reference_vector]
    else:
        beamform_vector = einsum("...fec,...c->...fe", ws, reference_vector)
    return beamform_vector