def test_inverse2():
    t = ComplexTensor(_get_complex_array(1, 10, 10))
    x = t @ t.inverse2()
    numpy.testing.assert_allclose(x.real.numpy()[0], numpy.eye(10), atol=1e-11)
    numpy.testing.assert_allclose(x.imag.numpy()[0],
                                  numpy.zeros((10, 10)),
                                  atol=1e-11)
Exemple #2
0
def get_WPD_filter(
    Phi: ComplexTensor,
    Rf: ComplexTensor,
    reference_vector: torch.Tensor,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> ComplexTensor:
    """Return the WPD vector.

        WPD is the Weighted Power minimization Distortionless response
        convolutional beamformer. As follows:

        h = (Rf^-1 @ Phi_{xx}) / tr[(Rf^-1) @ Phi_{xx}] @ u

    Reference:
        T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer
        for Simultaneous Denoising and Dereverberation," in IEEE Signal
        Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi:
        10.1109/LSP.2019.2911179.
        https://ieeexplore.ieee.org/document/8691481

    Args:
        Phi (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C)
            is the PSD of zero-padded speech [x^T(t,f) 0 ... 0]^T.
        Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C)
            is the power normalized spatio-temporal covariance matrix.
        reference_vector (torch.Tensor): (B, (btaps+1) * C)
            is the reference_vector.
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):

    Returns:
        filter_matrix (ComplexTensor): (B, F, (btaps + 1) * C)
    """
    if diagonal_loading:
        Rf = tik_reg(Rf, reg=diag_eps, eps=eps)

    # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
    if use_torch_solver:
        numerator = FC.solve(Phi, Rf)[0]
    else:
        numerator = FC.matmul(Rf.inverse2(), Phi)
    # ws: (..., C, C) / (...,) -> (..., C, C)
    ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
    # (B, F, (btaps + 1) * C)
    return beamform_vector
Exemple #3
0
def get_mvdr_vector(
    psd_s: ComplexTensor,
    psd_n: ComplexTensor,
    reference_vector: torch.Tensor,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> ComplexTensor:
    """Return the MVDR (Minimum Variance Distortionless Response) vector:

        h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u

    Reference:
        On optimal frequency-domain multichannel linear filtering
        for noise reduction; M. Souden et al., 2010;
        https://ieeexplore.ieee.org/document/5089420

    Args:
        psd_s (ComplexTensor): speech covariance matrix (..., F, C, C)
        psd_n (ComplexTensor): observation/noise covariance matrix (..., F, C, C)
        reference_vector (torch.Tensor): (..., C)
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (ComplexTensor): (..., F, C)
    """  # noqa: D400
    if diagonal_loading:
        psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps)

    if use_torch_solver:
        numerator = FC.solve(psd_s, psd_n)[0]
    else:
        numerator = FC.matmul(psd_n.inverse2(), psd_s)
    # ws: (..., C, C) / (...,) -> (..., C, C)
    ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
    return beamform_vector
Exemple #4
0
def get_WPD_filter_v2(
    Phi: ComplexTensor,
    Rf: ComplexTensor,
    reference_vector: torch.Tensor,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> ComplexTensor:
    """Return the WPD vector (v2).

       This implementation is more efficient than `get_WPD_filter` as
        it skips unnecessary computation with zeros.

    Args:
        Phi (ComplexTensor): (B, F, C, C)
            is speech PSD.
        Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C)
            is the power normalized spatio-temporal covariance matrix.
        reference_vector (torch.Tensor): (B, C)
            is the reference_vector.
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):

    Returns:
        filter_matrix (ComplexTensor): (B, F, (btaps+1) * C)
    """
    C = reference_vector.shape[-1]
    if diagonal_loading:
        Rf = tik_reg(Rf, reg=diag_eps, eps=eps)
    inv_Rf = Rf.inverse2()
    # (B, F, (btaps+1) * C, C)
    inv_Rf_pruned = inv_Rf[..., :C]
    # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
    numerator = FC.matmul(inv_Rf_pruned, Phi)
    # ws: (..., (btaps+1) * C, C) / (...,) -> (..., (btaps+1) * C, C)
    ws = numerator / (FC.trace(numerator[..., :C, :])[..., None, None] + eps)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
    # (B, F, (btaps+1) * C)
    return beamform_vector
Exemple #5
0
def get_rtf(
    psd_speech: ComplexTensor,
    psd_noise: ComplexTensor,
    reference_vector: Union[int, torch.Tensor, None] = None,
    iterations: int = 3,
    use_torch_solver: bool = True,
) -> ComplexTensor:
    """Calculate the relative transfer function (RTF) using the power method.

    Algorithm:
        1) rtf = reference_vector
        2) for i in range(iterations):
             rtf = (psd_noise^-1 @ psd_speech) @ rtf
             rtf = rtf / ||rtf||_2  # this normalization can be skipped
        3) rtf = psd_noise @ rtf
        4) rtf = rtf / rtf[..., ref_channel, :]
    Note: 4) Normalization at the reference channel is not performed here.

    Args:
        psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C)
        psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C)
        reference_vector (torch.Tensor or int): (..., C) or scalar
        iterations (int): number of iterations in power method
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
    Returns:
        rtf (ComplexTensor): (..., F, C, 1)
    """
    if use_torch_solver:
        phi = FC.solve(psd_speech, psd_noise)[0]
    else:
        phi = FC.matmul(psd_noise.inverse2(), psd_speech)
    rtf = (phi[..., reference_vector,
               None] if isinstance(reference_vector, int) else FC.matmul(
                   phi, reference_vector[..., None, :, None]))
    for _ in range(iterations - 2):
        rtf = FC.matmul(phi, rtf)
        # rtf = rtf / complex_norm(rtf)
    rtf = FC.matmul(psd_speech, rtf)
    return rtf
Exemple #6
0
def get_WPD_filter_with_rtf(
    psd_observed_bar: ComplexTensor,
    psd_speech: ComplexTensor,
    psd_noise: ComplexTensor,
    iterations: int = 3,
    reference_vector: Union[int, torch.Tensor, None] = None,
    normalize_ref_channel: Optional[int] = None,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-15,
) -> ComplexTensor:
    """Return the WPD vector calculated with RTF.

        WPD is the Weighted Power minimization Distortionless response
        convolutional beamformer. As follows:

        h = (Rf^-1 @ vbar) / (vbar^H @ R^-1 @ vbar)

    Reference:
        T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer
        for Simultaneous Denoising and Dereverberation," in IEEE Signal
        Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi:
        10.1109/LSP.2019.2911179.
        https://ieeexplore.ieee.org/document/8691481

    Args:
        psd_observed_bar (ComplexTensor): stacked observation covariance matrix
        psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C)
        psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C)
        iterations (int): number of iterations in power method
        reference_vector (torch.Tensor or int): (..., C) or scalar
        normalize_ref_channel (int): reference channel for normalizing the RTF
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (ComplexTensor)r: (..., F, C)
    """
    C = psd_noise.size(-1)
    if diagonal_loading:
        psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

    # (B, F, C, 1)
    rtf = get_rtf(
        psd_speech,
        psd_noise,
        reference_vector,
        iterations=iterations,
        use_torch_solver=use_torch_solver,
    )

    # (B, F, (K+1)*C, 1)
    rtf = FC.pad(rtf, (0, 0, 0, psd_observed_bar.shape[-1] - C), "constant", 0)
    # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1)
    if use_torch_solver:
        numerator = FC.solve(rtf, psd_observed_bar)[0].squeeze(-1)
    else:
        numerator = FC.matmul(psd_observed_bar.inverse2(), rtf).squeeze(-1)
    denominator = FC.einsum("...d,...d->...",
                            [rtf.squeeze(-1).conj(), numerator])
    if normalize_ref_channel is not None:
        scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj()
        beamforming_vector = numerator * scale / (
            denominator.real.unsqueeze(-1) + eps)
    else:
        beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps)
    return beamforming_vector
Exemple #7
0
def get_mvdr_vector_with_rtf(
    psd_n: ComplexTensor,
    psd_speech: ComplexTensor,
    psd_noise: ComplexTensor,
    iterations: int = 3,
    reference_vector: Union[int, torch.Tensor, None] = None,
    normalize_ref_channel: Optional[int] = None,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> ComplexTensor:
    """Return the MVDR (Minimum Variance Distortionless Response) vector
        calculated with RTF:

        h = (Npsd^-1 @ rtf) / (rtf^H @ Npsd^-1 @ rtf)

    Reference:
        On optimal frequency-domain multichannel linear filtering
        for noise reduction; M. Souden et al., 2010;
        https://ieeexplore.ieee.org/document/5089420

    Args:
        psd_n (ComplexTensor): observation/noise covariance matrix (..., F, C, C)
        psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C)
        psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C)
        iterations (int): number of iterations in power method
        reference_vector (torch.Tensor or int): (..., C) or scalar
        normalize_ref_channel (int): reference channel for normalizing the RTF
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (ComplexTensor): (..., F, C)
    """  # noqa: H405, D205, D400
    if diagonal_loading:
        psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

    # (B, F, C, 1)
    rtf = get_rtf(
        psd_speech,
        psd_noise,
        reference_vector,
        iterations=iterations,
        use_torch_solver=use_torch_solver,
    )

    # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1)
    if use_torch_solver:
        numerator = FC.solve(rtf, psd_n)[0].squeeze(-1)
    else:
        numerator = FC.matmul(psd_n.inverse2(), rtf).squeeze(-1)
    denominator = FC.einsum("...d,...d->...",
                            [rtf.squeeze(-1).conj(), numerator])
    if normalize_ref_channel is not None:
        scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj()
        beamforming_vector = numerator * scale / (
            denominator.real.unsqueeze(-1) + eps)
    else:
        beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps)
    return beamforming_vector