コード例 #1
0
ファイル: beamformer.py プロジェクト: zqs01/espnet
def get_mvdr_vector(psd_s: ComplexTensor,
                    psd_n: ComplexTensor,
                    reference_vector: torch.Tensor,
                    eps: float = 1e-15) -> ComplexTensor:
    """Return the MVDR(Minimum Variance Distortionless Response) vector:

        h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u

    Reference:
        On optimal frequency-domain multichannel linear filtering
        for noise reduction; M. Souden et al., 2010;
        https://ieeexplore.ieee.org/document/5089420

    Args:
        psd_s (ComplexTensor): (..., F, C, C)
        psd_n (ComplexTensor): (..., F, C, C)
        reference_vector (torch.Tensor): (..., C)
        eps (float):
    Returns:
        beamform_vector (ComplexTensor)r: (..., F, C)
    """
    # Add eps
    C = psd_n.size(-1)
    eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device)
    shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C]
    eye = eye.view(*shape)
    psd_n += eps * eye

    # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
    numerator = FC.einsum('...ec,...cd->...ed', [psd_n.inverse(), psd_s])
    # ws: (..., C, C) / (...,) -> (..., C, C)
    ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    beamform_vector = FC.einsum('...fec,...c->...fe', [ws, reference_vector])
    return beamform_vector
コード例 #2
0
def get_WPD_filter(
    Phi: ComplexTensor,
    Rf: ComplexTensor,
    reference_vector: torch.Tensor,
    eps: float = 1e-15,
) -> ComplexTensor:
    """Return the WPD vector.

        WPD is the Weighted Power minimization Distortionless response
        convolutional beamformer. As follows:

        h = (Rf^-1 @ Phi_{xx}) / tr[(Rf^-1) @ Phi_{xx}] @ u

    Reference:
        T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer
        for Simultaneous Denoising and Dereverberation," in IEEE Signal
        Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi:
        10.1109/LSP.2019.2911179.
        https://ieeexplore.ieee.org/document/8691481

    Args:
        Phi (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C)
            is the PSD of zero-padded speech [x^T(t,f) 0 ... 0]^T.
        Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C)
            is the power normalized spatio-temporal covariance matrix.
        reference_vector (torch.Tensor): (B, (btaps+1) * C)
            is the reference_vector.
        eps (float):

    Returns:
        filter_matrix (ComplexTensor): (B, F, (btaps + 1) * C)
    """
    try:
        inv_Rf = inv(Rf)
    except Exception:
        try:
            reg_coeff_tensor = (
                ComplexTensor(torch.rand_like(Rf.real), torch.rand_like(Rf.real)) * 1e-4
            )
            Rf = Rf / 10e4
            Phi = Phi / 10e4
            Rf += reg_coeff_tensor
            inv_Rf = inv(Rf)
        except Exception:
            reg_coeff_tensor = (
                ComplexTensor(torch.rand_like(Rf.real), torch.rand_like(Rf.real)) * 1e-1
            )
            Rf = Rf / 10e10
            Phi = Phi / 10e10
            Rf += reg_coeff_tensor
            inv_Rf = inv(Rf)

    # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
    numerator = FC.einsum("...ec,...cd->...ed", [inv_Rf, Phi])
    # ws: (..., C, C) / (...,) -> (..., C, C)
    ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
    # (B, F, (btaps + 1) * C)
    return beamform_vector
コード例 #3
0
def get_covariances(
    Y: ComplexTensor,
    inverse_power: torch.Tensor,
    bdelay: int,
    btaps: int,
    get_vector: bool = False,
) -> ComplexTensor:
    """Calculates the power normalized spatio-temporal

     covariance matrix of the framed signal.

    Args:
        Y : Complext STFT signal with shape (B, F, C, T)
        inverse_power : Weighting factor with shape (B, F, T)

    Returns:
        Correlation matrix of shape (B, F, (btaps+1) * C, (btaps+1) * C)
        Correlation vector of shape (B, F, btaps + 1, C, C)
    """
    assert inverse_power.dim() == 3, inverse_power.dim()
    assert inverse_power.size(0) == Y.size(0), (inverse_power.size(0), Y.size(0))

    Bs, Fdim, C, T = Y.shape

    # (B, F, C, T - bdelay - btaps + 1, btaps + 1)
    Psi = signal_framing(Y, btaps + 1, 1, bdelay, do_padding=False)[
        ..., : T - bdelay - btaps + 1, :
    ]
    # Reverse along btaps-axis:
    # [tau, tau-bdelay, tau-bdelay-1, ..., tau-bdelay-frame_length+1]
    Psi = FC.reverse(Psi, dim=-1)
    Psi_norm = Psi * inverse_power[..., None, bdelay + btaps - 1 :, None]

    # let T' = T - bdelay - btaps + 1
    # (B, F, C, T', btaps + 1) x (B, F, C, T', btaps + 1)
    #  -> (B, F, btaps + 1, C, btaps + 1, C)
    covariance_matrix = FC.einsum("bfdtk,bfetl->bfkdle", (Psi, Psi_norm.conj()))

    # (B, F, btaps + 1, C, btaps + 1, C)
    #   -> (B, F, (btaps + 1) * C, (btaps + 1) * C)
    covariance_matrix = covariance_matrix.view(
        Bs, Fdim, (btaps + 1) * C, (btaps + 1) * C
    )

    if get_vector:
        # (B, F, C, T', btaps + 1) x (B, F, C, T')
        #    --> (B, F, btaps +1, C, C)
        covariance_vector = FC.einsum(
            "bfdtk,bfet->bfked", (Psi_norm, Y[..., bdelay + btaps - 1 :].conj())
        )
        return covariance_matrix, covariance_vector
    else:
        return covariance_matrix
コード例 #4
0
def filter_minimum_gain_like(G_min,
                             w,
                             y,
                             alpha=None,
                             k: float = 10.0,
                             eps: float = EPS):
    """Approximate a minimum gain operation.

    speech_estimate = alpha w^H y + (1 - alpha) G_min Y,
    where alpha = 1 / (1 + exp(-2 k x)), x = w^H y - G_min Y

    Args:
        G_min (float): minimum gain
        w (ComplexTensor): filter coefficients (..., L, N)
        y (ComplexTensor): buffered and stacked input (..., L, N)
        alpha: mixing factor
        k (float): scaling in tanh-like function
        esp (float)
    Returns:
        output (ComplexTensor): minimum gain-filtered output
        alpha (float): optional
    """
    # (..., L)
    filtered_input = FC.einsum("...d,...d->...", [w.conj(), y])
    # (..., L)
    Y = y[..., -1]
    return minimum_gain_like(G_min, Y, filtered_input, alpha, k, eps)
コード例 #5
0
def einsum(equation, *operands):
    # NOTE: Do not mix ComplexTensor and torch.complex in the input!
    # NOTE (wangyou): Until PyTorch 1.9.0, torch.einsum does not support
    # mixed input with complex and real tensors.
    if len(operands) == 1:
        if isinstance(operands[0], (tuple, list)):
            operands = operands[0]
        complex_module = FC if isinstance(operands[0],
                                          ComplexTensor) else torch
        return complex_module.einsum(equation, *operands)
    elif len(operands) != 2:
        op0 = operands[0]
        same_type = all(op.dtype == op0.dtype for op in operands[1:])
        if same_type:
            _einsum = FC.einsum if isinstance(op0,
                                              ComplexTensor) else torch.einsum
            return _einsum(equation, *operands)
        else:
            raise ValueError("0 or More than 2 operands are not supported.")
    a, b = operands
    if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
        return FC.einsum(equation, a, b)
    elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
        if not torch.is_complex(a):
            o_real = torch.einsum(equation, a, b.real)
            o_imag = torch.einsum(equation, a, b.imag)
            return torch.complex(o_real, o_imag)
        elif not torch.is_complex(b):
            o_real = torch.einsum(equation, a.real, b)
            o_imag = torch.einsum(equation, a.imag, b)
            return torch.complex(o_real, o_imag)
        else:
            return torch.einsum(equation, a, b)
    else:
        return torch.einsum(equation, a, b)
コード例 #6
0
ファイル: beamformer.py プロジェクト: zqs01/espnet
def get_power_spectral_density_matrix(xs: ComplexTensor,
                                      mask: torch.Tensor,
                                      normalization=True,
                                      eps: float = 1e-15) -> ComplexTensor:
    """Return cross-channel power spectral density (PSD) matrix

    Args:
        xs (ComplexTensor): (..., F, C, T)
        mask (torch.Tensor): (..., F, C, T)
        normalization (bool):
        eps (float):
    Returns
        psd (ComplexTensor): (..., F, C, C)

    """
    # outer product: (..., C_1, T) x (..., C_2, T) -> (..., T, C, C_2)
    psd_Y = FC.einsum('...ct,...et->...tce', [xs, xs.conj()])

    # Averaging mask along C: (..., C, T) -> (..., T)
    mask = mask.mean(dim=-2)

    # Normalized mask along T: (..., T)
    if normalization:
        # If assuming the tensor is padded with zero, the summation along
        # the time axis is same regardless of the padding length.
        mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)

    # psd: (..., T, C, C)
    psd = psd_Y * mask[..., None, None]
    # (..., T, C, C) -> (..., C, C)
    psd = psd.sum(dim=-3)

    return psd
コード例 #7
0
def perform_WPD_filtering(filter_matrix: ComplexTensor, Y: ComplexTensor,
                          bdelay: int, btaps: int) -> ComplexTensor:
    """Perform WPD filtering.

    Args:
        filter_matrix: Filter matrix (B, F, (btaps + 1) * C)
        Y : Complex STFT signal with shape (B, F, C, T)

    Returns:
        enhanced (ComplexTensor): (B, F, T)
    """
    # (B, F, C, T) --> (B, F, C, T, btaps + 1)
    Ytilde = signal_framing(Y,
                            btaps + 1,
                            1,
                            bdelay,
                            do_padding=True,
                            pad_value=0)
    Ytilde = FC.reverse(Ytilde, dim=-1)

    Bs, Fdim, C, T = Y.shape
    # --> (B, F, T, btaps + 1, C) --> (B, F, T, (btaps + 1) * C)
    Ytilde = Ytilde.permute(0, 1, 3, 4, 2).contiguous().view(Bs, Fdim, T, -1)
    # (B, F, T, 1)
    enhanced = FC.einsum("...tc,...c->...t", [Ytilde, filter_matrix.conj()])
    return enhanced
コード例 #8
0
def get_WPD_filter(
    Phi: ComplexTensor,
    Rf: ComplexTensor,
    reference_vector: torch.Tensor,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> ComplexTensor:
    """Return the WPD vector.

        WPD is the Weighted Power minimization Distortionless response
        convolutional beamformer. As follows:

        h = (Rf^-1 @ Phi_{xx}) / tr[(Rf^-1) @ Phi_{xx}] @ u

    Reference:
        T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer
        for Simultaneous Denoising and Dereverberation," in IEEE Signal
        Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi:
        10.1109/LSP.2019.2911179.
        https://ieeexplore.ieee.org/document/8691481

    Args:
        Phi (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C)
            is the PSD of zero-padded speech [x^T(t,f) 0 ... 0]^T.
        Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C)
            is the power normalized spatio-temporal covariance matrix.
        reference_vector (torch.Tensor): (B, (btaps+1) * C)
            is the reference_vector.
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):

    Returns:
        filter_matrix (ComplexTensor): (B, F, (btaps + 1) * C)
    """
    if diagonal_loading:
        Rf = tik_reg(Rf, reg=diag_eps, eps=eps)

    # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
    if use_torch_solver:
        numerator = FC.solve(Phi, Rf)[0]
    else:
        numerator = FC.matmul(Rf.inverse2(), Phi)
    # ws: (..., C, C) / (...,) -> (..., C, C)
    ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
    # (B, F, (btaps + 1) * C)
    return beamform_vector
コード例 #9
0
ファイル: test_enh_layers.py プロジェクト: jumon/espnet-1
def test_get_rtf(ch):
    stft = Stft(
        n_fft=8,
        win_length=None,
        hop_length=2,
        center=True,
        window="hann",
        normalized=False,
        onesided=True,
    )
    torch.random.manual_seed(0)
    x = random_speech[..., :ch]
    n = torch.rand(2, 16, ch, dtype=torch.double)
    ilens = torch.LongTensor([16, 12])
    # (B, T, C, F) -> (B, F, C, T)
    X = ComplexTensor(*torch.unbind(stft(x, ilens)[0], dim=-1)).transpose(-1, -3)
    N = ComplexTensor(*torch.unbind(stft(n, ilens)[0], dim=-1)).transpose(-1, -3)
    # (B, F, C, C)
    Phi_X = FC.einsum("...ct,...et->...ce", [X, X.conj()])
    Phi_N = FC.einsum("...ct,...et->...ce", [N, N.conj()])
    # (B, F, C, 1)
    rtf = get_rtf(Phi_X, Phi_N, reference_vector=0, iterations=20)
    if is_torch_1_1_plus:
        rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True).values + 1e-15)
    else:
        rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True)[0] + 1e-15)
    # rtf \approx Phi_N MaxEigVec(Phi_N^-1 @ Phi_X)
    if is_torch_1_1_plus:
        # torch.solve is required, which is only available after pytorch 1.1.0+
        mat = FC.solve(Phi_X, Phi_N)[0]
        max_eigenvec = FC.solve(rtf, Phi_N)[0]
    else:
        mat = FC.matmul(Phi_N.inverse2(), Phi_X)
        max_eigenvec = FC.matmul(Phi_N.inverse2(), rtf)
    factor = FC.matmul(mat, max_eigenvec)
    assert FC.allclose(
        FC.matmul(max_eigenvec, factor.transpose(-1, -2)),
        FC.matmul(factor, max_eigenvec.transpose(-1, -2)),
    )
コード例 #10
0
def get_correlations(Y: ComplexTensor, inverse_power: torch.Tensor, taps,
                     delay) -> Tuple[ComplexTensor, ComplexTensor]:
    """Calculates weighted correlations of a window of length taps

    Args:
        Y : Complex-valued STFT signal with shape (F, C, T)
        inverse_power : Weighting factor with shape (F, T)
        taps (int): Lenghts of correlation window
        delay (int): Delay for the weighting factor

    Returns:
        Correlation matrix of shape (F, taps*C, taps*C)
        Correlation vector of shape (F, taps, C, C)
    """
    assert inverse_power.dim() == 2, inverse_power.dim()
    assert inverse_power.size(0) == Y.size(0), \
        (inverse_power.size(0), Y.size(0))

    F, C, T = Y.size()

    # Y: (F, C, T) -> Psi: (F, C, T, taps)
    Psi = signal_framing(Y, frame_length=taps,
                         frame_step=1)[..., :T - delay - taps + 1, :]
    # Reverse along taps-axis
    Psi = FC.reverse(Psi, dim=-1)
    Psi_conj_norm = \
        Psi.conj() * inverse_power[..., None, delay + taps - 1:, None]

    # (F, C, T, taps) x (F, C, T, taps) -> (F, taps, C, taps, C)
    correlation_matrix = FC.einsum('fdtk,fetl->fkdle', (Psi_conj_norm, Psi))
    # (F, taps, C, taps, C) -> (F, taps * C, taps * C)
    correlation_matrix = correlation_matrix.view(F, taps * C, taps * C)

    # (F, C, T, taps) x (F, C, T) -> (F, taps, C, C)
    correlation_vector = FC.einsum('fdtk,fet->fked',
                                   (Psi_conj_norm, Y[..., delay + taps - 1:]))

    return correlation_matrix, correlation_vector
コード例 #11
0
def get_power_spectral_density_matrix(
        complex_tensor: ComplexTensor) -> ComplexTensor:
    """
    Cross-channel power spectral density (PSD) matrix

    Args:
        complex_tensor: [..., F, C, T]

    Returns
        psd: [..., F, C, C]
    """
    # outer product: [..., C_1, T] x [..., C_2, T] => [..., T, C_1, C_2]
    return FC.einsum("...ct,...et->...tce",
                     [complex_tensor, complex_tensor.conj()])
コード例 #12
0
def apply_crf_filter(cRM_filter: ComplexTensor,
                     mix: ComplexTensor) -> ComplexTensor:
    """
    Apply complex Ratio Filter

    Args:
        cRM_filter: complex Ratio Filter
        mix: mixture

    Returns:
        [B, C, F, T]
    """
    # [B, F, T, Filter_delay] x [B, C, F, Filter_delay,T] => [B, C, F, T]
    es = FC.einsum("bftd, bcfdt -> bcft", [cRM_filter.conj(), mix])
    return es
コード例 #13
0
ファイル: pytorch_wpe.py プロジェクト: nttcslab-sp/dnn_wpe
def perform_filter_operation_v2(Y: ComplexTensor,
                                filter_matrix_conj: ComplexTensor,
                                taps, delay) -> ComplexTensor:
    """perform_filter_operation_v2

    Args:
        Y : Complex-valued STFT signal of shape (F, C, T)
        filter Matrix (F, taps, C, C)
    """
    T = Y.size(-1)
    # Y_tilde: (taps, F, C, T)
    Y_tilde = FC.stack([FC.pad(Y[:, :, :T - delay - i], (delay + i, 0),
                               mode='constant', value=0)
                        for i in range(taps)],
                       dim=0)
    reverb_tail = FC.einsum('fpde,pfdt->fet', (filter_matrix_conj, Y_tilde))
    return Y - reverb_tail
コード例 #14
0
def get_WPD_filter_v2(
    Phi: ComplexTensor,
    Rf: ComplexTensor,
    reference_vector: torch.Tensor,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> ComplexTensor:
    """Return the WPD vector (v2).

       This implementation is more efficient than `get_WPD_filter` as
        it skips unnecessary computation with zeros.

    Args:
        Phi (ComplexTensor): (B, F, C, C)
            is speech PSD.
        Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C)
            is the power normalized spatio-temporal covariance matrix.
        reference_vector (torch.Tensor): (B, C)
            is the reference_vector.
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):

    Returns:
        filter_matrix (ComplexTensor): (B, F, (btaps+1) * C)
    """
    C = reference_vector.shape[-1]
    if diagonal_loading:
        Rf = tik_reg(Rf, reg=diag_eps, eps=eps)
    inv_Rf = Rf.inverse2()
    # (B, F, (btaps+1) * C, C)
    inv_Rf_pruned = inv_Rf[..., :C]
    # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
    numerator = FC.matmul(inv_Rf_pruned, Phi)
    # ws: (..., (btaps+1) * C, C) / (...,) -> (..., (btaps+1) * C, C)
    ws = numerator / (FC.trace(numerator[..., :C, :])[..., None, None] + eps)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
    # (B, F, (btaps+1) * C)
    return beamform_vector
コード例 #15
0
def get_mvdr_vector(
    psd_s: ComplexTensor,
    psd_n: ComplexTensor,
    reference_vector: torch.Tensor,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> ComplexTensor:
    """Return the MVDR (Minimum Variance Distortionless Response) vector:

        h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u

    Reference:
        On optimal frequency-domain multichannel linear filtering
        for noise reduction; M. Souden et al., 2010;
        https://ieeexplore.ieee.org/document/5089420

    Args:
        psd_s (ComplexTensor): speech covariance matrix (..., F, C, C)
        psd_n (ComplexTensor): observation/noise covariance matrix (..., F, C, C)
        reference_vector (torch.Tensor): (..., C)
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (ComplexTensor): (..., F, C)
    """  # noqa: D400
    if diagonal_loading:
        psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps)

    if use_torch_solver:
        numerator = FC.solve(psd_s, psd_n)[0]
    else:
        numerator = FC.matmul(psd_n.inverse2(), psd_s)
    # ws: (..., C, C) / (...,) -> (..., C, C)
    ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
    return beamform_vector
コード例 #16
0
def get_mfmvdr_vector(gammax,
                      Phi,
                      use_torch_solver: bool = True,
                      eps: float = EPS):
    """Compute conventional MFMPDR/MFMVDR filter.

    Args:
        gammax (ComplexTensor): (..., L, N)
        Phi (ComplexTensor): (..., L, N, N)
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        eps (float)
    Returns:
        beamforming_vector (ComplexTensor): (..., L, N)
    """
    # (..., L, N)
    if use_torch_solver:
        numerator = FC.solve(gammax.unsqueeze(-1), Phi)[0].squeeze(-1)
    else:
        numerator = FC.matmul(Phi.inverse2(), gammax.unsqueeze(-1)).squeeze(-1)
    denominator = FC.einsum("...d,...d->...", [gammax.conj(), numerator])
    return numerator / (denominator.real.unsqueeze(-1) + eps)
コード例 #17
0
ファイル: pytorch_wpe.py プロジェクト: nttcslab-sp/dnn_wpe
def perform_filter_operation(Y: ComplexTensor,
                             filter_matrix_conj: ComplexTensor, taps, delay) \
        -> ComplexTensor:
    """perform_filter_operation

    Args:
        Y : Complex-valued STFT signal of shape (F, C, T)
        filter Matrix (F, taps, C, C)
    """
    T = Y.size(-1)
    reverb_tail = ComplexTensor(torch.zeros_like(Y.real),
                                torch.zeros_like(Y.real))
    for tau_minus_delay in range(taps):
        new = FC.einsum('fde,fdt->fet',
                        (filter_matrix_conj[:, tau_minus_delay, :, :],
                         Y[:, :, :T - delay - tau_minus_delay]))
        new = FC.pad(new, (delay + tau_minus_delay, 0),
                     mode='constant', value=0)
        reverb_tail = reverb_tail + new

    return Y - reverb_tail
コード例 #18
0
ファイル: dnn_beamformer.py プロジェクト: yistLin/espnet
    def forward(
        self, data: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]:
        """The forward function

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq

        Args:
            data (ComplexTensor): (B, T, C, F), double precision
            ilens (torch.Tensor): (B,)
        Returns:
            enhanced (ComplexTensor): (B, T, F), double precision
            ilens (torch.Tensor): (B,)
            masks (torch.Tensor): (B, T, C, F)
        """
        def apply_beamforming(data, ilens, psd_speech, psd_n, beamformer_type):
            # u: (B, C)
            if self.ref_channel < 0:
                u, _ = self.ref(psd_speech.float(), ilens)
            else:
                # (optional) Create onehot vector for fixed reference microphone
                u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )),
                                device=data.device)
                u[..., self.ref_channel].fill_(1)

            if beamformer_type in ("mpdr", "mvdr"):
                ws = get_mvdr_vector(psd_speech, psd_n, u.double())
                enhanced = apply_beamforming_vector(ws, data)
            elif beamformer_type == "wpd":
                ws = get_WPD_filter_v2(psd_speech, psd_n, u.double())
                enhanced = perform_WPD_filtering(ws, data, self.bdelay,
                                                 self.btaps)
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    beamformer_type))

            return enhanced, ws

        # data (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)

        # mask: [(B, F, C, T)]
        masks, _ = self.mask(data.float(), ilens)
        assert self.nmask == len(masks)
        # floor masks with self.eps to increase numerical stability
        masks = [torch.clamp(m, min=self.eps) for m in masks]

        if self.num_spk == 1:  # single-speaker case
            if self.use_noise_mask:
                # (mask_speech, mask_noise)
                mask_speech, mask_noise = masks
            else:
                # (mask_speech,)
                mask_speech = masks[0]
                mask_noise = 1 - mask_speech

            psd_speech = get_power_spectral_density_matrix(
                data, mask_speech.double())
            if self.beamformer_type == "mvdr":
                # psd of noise
                psd_n = get_power_spectral_density_matrix(
                    data, mask_noise.double())
            elif self.beamformer_type == "mpdr":
                # psd of observed signal
                psd_n = FC.einsum("...ct,...et->...ce", [data, data.conj()])
            elif self.beamformer_type == "wpd":
                # Calculate power: (..., C, T)
                power_speech = (data.real**2 +
                                data.imag**2) * mask_speech.double()
                # Averaging along the channel axis: (B, F, C, T) -> (B, F, T)
                power_speech = power_speech.mean(dim=-2)
                inverse_power = 1 / torch.clamp(power_speech, min=self.eps)
                # covariance of expanded observed speech
                psd_n = get_covariances(data,
                                        inverse_power,
                                        self.bdelay,
                                        self.btaps,
                                        get_vector=False)
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    self.beamformer_type))

            enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_n,
                                             self.beamformer_type)

            # (..., F, T) -> (..., T, F)
            enhanced = enhanced.transpose(-1, -2)
        else:  # multi-speaker case
            if self.use_noise_mask:
                # (mask_speech1, ..., mask_noise)
                mask_speech = list(masks[:-1])
                mask_noise = masks[-1]
            else:
                # (mask_speech1, ..., mask_speechX)
                mask_speech = list(masks)
                mask_noise = None

            psd_speeches = [
                get_power_spectral_density_matrix(data, mask)
                for mask in mask_speech
            ]
            if self.beamformer_type == "mvdr":
                # psd of noise
                if mask_noise is not None:
                    psd_n = get_power_spectral_density_matrix(data, mask_noise)
            elif self.beamformer_type == "mpdr":
                # psd of observed speech
                psd_n = FC.einsum("...ct,...et->...ce", [data, data.conj()])
            elif self.beamformer_type == "wpd":
                # Calculate power: (..., C, T)
                power = data.real**2 + data.imag**2
                power_speeches = [power * mask for mask in mask_speech]
                # Averaging along the channel axis: (B, F, C, T) -> (B, F, T)
                power_speeches = [ps.mean(dim=-2) for ps in power_speeches]
                inverse_poweres = [
                    1 / torch.clamp(ps, min=self.eps) for ps in power_speeches
                ]
                # covariance of expanded observed speech
                psd_n = [
                    get_covariances(data,
                                    inv_ps,
                                    self.bdelay,
                                    self.btaps,
                                    get_vector=False)
                    for inv_ps in inverse_poweres
                ]
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    self.beamformer_type))

            enhanced = []
            for i in range(self.num_spk):
                psd_speech = psd_speeches.pop(i)
                # treat all other speakers' psd_speech as noises
                if self.beamformer_type == "mvdr":
                    psd_noise = sum(psd_speeches)
                    if mask_noise is not None:
                        psd_noise = psd_noise + psd_n

                    enh, w = apply_beamforming(data, ilens, psd_speech,
                                               psd_noise, self.beamformer_type)
                elif self.beamformer_type == "mpdr":
                    enh, w = apply_beamforming(data, ilens, psd_speech, psd_n,
                                               self.beamformer_type)
                elif self.beamformer_type == "wpd":
                    enh, w = apply_beamforming(data, ilens, psd_speech,
                                               psd_n[i], self.beamformer_type)
                else:
                    raise ValueError(
                        "Not supporting beamformer_type={}".format(
                            self.beamformer_type))
                psd_speeches.insert(i, psd_speech)

                # (..., F, T) -> (..., T, F)
                enh = enh.transpose(-1, -2)
                enhanced.append(enh)

        # (..., F, C, T) -> (..., T, C, F)
        masks = [m.transpose(-1, -3) for m in masks]
        return enhanced, ilens, masks
コード例 #19
0
def get_mvdr_vector_with_rtf(
    psd_n: ComplexTensor,
    psd_speech: ComplexTensor,
    psd_noise: ComplexTensor,
    iterations: int = 3,
    reference_vector: Union[int, torch.Tensor, None] = None,
    normalize_ref_channel: Optional[int] = None,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> ComplexTensor:
    """Return the MVDR (Minimum Variance Distortionless Response) vector
        calculated with RTF:

        h = (Npsd^-1 @ rtf) / (rtf^H @ Npsd^-1 @ rtf)

    Reference:
        On optimal frequency-domain multichannel linear filtering
        for noise reduction; M. Souden et al., 2010;
        https://ieeexplore.ieee.org/document/5089420

    Args:
        psd_n (ComplexTensor): observation/noise covariance matrix (..., F, C, C)
        psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C)
        psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C)
        iterations (int): number of iterations in power method
        reference_vector (torch.Tensor or int): (..., C) or scalar
        normalize_ref_channel (int): reference channel for normalizing the RTF
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (ComplexTensor): (..., F, C)
    """  # noqa: H405, D205, D400
    if diagonal_loading:
        psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

    # (B, F, C, 1)
    rtf = get_rtf(
        psd_speech,
        psd_noise,
        reference_vector,
        iterations=iterations,
        use_torch_solver=use_torch_solver,
    )

    # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1)
    if use_torch_solver:
        numerator = FC.solve(rtf, psd_n)[0].squeeze(-1)
    else:
        numerator = FC.matmul(psd_n.inverse2(), rtf).squeeze(-1)
    denominator = FC.einsum("...d,...d->...",
                            [rtf.squeeze(-1).conj(), numerator])
    if normalize_ref_channel is not None:
        scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj()
        beamforming_vector = numerator * scale / (
            denominator.real.unsqueeze(-1) + eps)
    else:
        beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps)
    return beamforming_vector
コード例 #20
0
def apply_beamforming_vector(beamforming_vector: ComplexTensor,
                             mix: ComplexTensor) -> ComplexTensor:
    # [..., C] x [..., C, T] => [..., T]
    # There's no relationship between frequencies.
    es = FC.einsum("bftc, bfct -> bft", [beamforming_vector.conj(), mix])
    return es
コード例 #21
0
import numpy
import pytest

import torch_complex.functional as F
from torch_complex.tensor import ComplexTensor


def _get_complex_array(*shape):
    return numpy.random.randn(*shape) + 1j + numpy.random.randn(*shape)


@pytest.mark.parametrize('nop,top',
                         [(numpy.concatenate, F.cat), (numpy.stack, F.stack),
                          (lambda x: numpy.einsum('ai,ij,jk->ak', *x),
                           lambda x: F.einsum('ai,ij,jk->ak', x))])
def test_operation(nop, top):
    if top is None:
        top = nop
    n1 = _get_complex_array(10, 10)
    n2 = _get_complex_array(10, 10)
    n3 = _get_complex_array(10, 10)
    t1 = ComplexTensor(n1.copy())
    t2 = ComplexTensor(n2.copy())
    t3 = ComplexTensor(n3.copy())

    x = nop([n1, n2, n3])
    y = top([t1, t2, t3])
    y = y.numpy()
    numpy.testing.assert_allclose(x, y)

コード例 #22
0
ファイル: beamformer.py プロジェクト: zqs01/espnet
def apply_beamforming_vector(beamform_vector: ComplexTensor,
                             mix: ComplexTensor) -> ComplexTensor:
    # (..., C) x (..., C, T) -> (..., T)
    es = FC.einsum('...c,...ct->...t', [beamform_vector.conj(), mix])
    return es
コード例 #23
0
def get_WPD_filter_with_rtf(
    psd_observed_bar: ComplexTensor,
    psd_speech: ComplexTensor,
    psd_noise: ComplexTensor,
    iterations: int = 3,
    reference_vector: Union[int, torch.Tensor, None] = None,
    normalize_ref_channel: Optional[int] = None,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-15,
) -> ComplexTensor:
    """Return the WPD vector calculated with RTF.

        WPD is the Weighted Power minimization Distortionless response
        convolutional beamformer. As follows:

        h = (Rf^-1 @ vbar) / (vbar^H @ R^-1 @ vbar)

    Reference:
        T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer
        for Simultaneous Denoising and Dereverberation," in IEEE Signal
        Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi:
        10.1109/LSP.2019.2911179.
        https://ieeexplore.ieee.org/document/8691481

    Args:
        psd_observed_bar (ComplexTensor): stacked observation covariance matrix
        psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C)
        psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C)
        iterations (int): number of iterations in power method
        reference_vector (torch.Tensor or int): (..., C) or scalar
        normalize_ref_channel (int): reference channel for normalizing the RTF
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (ComplexTensor)r: (..., F, C)
    """
    C = psd_noise.size(-1)
    if diagonal_loading:
        psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

    # (B, F, C, 1)
    rtf = get_rtf(
        psd_speech,
        psd_noise,
        reference_vector,
        iterations=iterations,
        use_torch_solver=use_torch_solver,
    )

    # (B, F, (K+1)*C, 1)
    rtf = FC.pad(rtf, (0, 0, 0, psd_observed_bar.shape[-1] - C), "constant", 0)
    # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1)
    if use_torch_solver:
        numerator = FC.solve(rtf, psd_observed_bar)[0].squeeze(-1)
    else:
        numerator = FC.matmul(psd_observed_bar.inverse2(), rtf).squeeze(-1)
    denominator = FC.einsum("...d,...d->...",
                            [rtf.squeeze(-1).conj(), numerator])
    if normalize_ref_channel is not None:
        scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj()
        beamforming_vector = numerator * scale / (
            denominator.real.unsqueeze(-1) + eps)
    else:
        beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps)
    return beamforming_vector
コード例 #24
0
def online_wpe_step(input_buffer: ComplexTensor,
                    power: torch.Tensor,
                    inv_cov: ComplexTensor = None,
                    filter_taps: ComplexTensor = None,
                    alpha: float = 0.99,
                    taps: int = 10,
                    delay: int = 3):
    """One step of online dereverberation.

    Args:
        input_buffer: (F, C, taps + delay + 1)
        power: Estimate for the current PSD (F, T)
        inv_cov: Current estimate of R^-1
        filter_taps: Current estimate of filter taps (F, taps * C, taps)
        alpha (float): Smoothing factor
        taps (int): Number of filter taps
        delay (int): Delay in frames

    Returns:
        Dereverberated frame of shape (F, D)
        Updated estimate of R^-1
        Updated estimate of the filter taps


    >>> frame_length = 512
    >>> frame_shift = 128
    >>> taps = 6
    >>> delay = 3
    >>> alpha = 0.999
    >>> frequency_bins = frame_length // 2 + 1
    >>> Q = None
    >>> G = None
    >>> unreverbed, Q, G = online_wpe_step(stft, get_power_online(stft), Q, G,
    ...                                    alpha=alpha, taps=taps, delay=delay)

    """
    assert input_buffer.size(-1) == taps + delay + 1, input_buffer.size()
    C = input_buffer.size(-2)

    if inv_cov is None:
        inv_cov = ComplexTensor(
            torch.eye(C * taps, dtype=input_buffer.dtype).expand(
                *input_buffer.size()[:-2], C * taps, C * taps))
    if filter_taps is None:
        filter_taps = ComplexTensor(
            torch.zeros(*input_buffer.size()[:-2],
                        C * taps,
                        C,
                        dtype=input_buffer.dtype))

    window = FC.reverse(input_buffer[..., :-delay - 1], dim=-1)
    # (..., C, T) -> (..., C * T)
    window = window.view(*input_buffer.size()[:-2], -1)
    pred = input_buffer[..., -1] - FC.einsum('...id,...i->...d',
                                             (filter_taps.conj(), window))

    nominator = FC.einsum('...ij,...j->...i', (inv_cov, window))
    denominator = \
        FC.einsum('...i,...i->...', (window.conj(), nominator)) + alpha * power
    kalman_gain = nominator / denominator[..., None]

    inv_cov_k = inv_cov - FC.einsum('...j,...jm,...i->...im',
                                    (window.conj(), inv_cov, kalman_gain))
    inv_cov_k /= alpha

    filter_taps_k = \
        filter_taps + FC.einsum('...i,...m->...im', (kalman_gain, pred.conj()))
    return pred, inv_cov_k, filter_taps_k
コード例 #25
0
import torch_complex.functional as F
from torch_complex.tensor import ComplexTensor


def _get_complex_array(*shape):
    return numpy.random.randn(*shape) + 1j + numpy.random.randn(*shape)


@pytest.mark.parametrize(
    "nop,top",
    [
        (numpy.concatenate, F.cat),
        (numpy.stack, F.stack),
        (
            lambda x: numpy.einsum("ai,ij,jk->ak", *x),
            lambda x: F.einsum("ai,ij,jk->ak", x),
        ),
    ],
)
def test_operation(nop, top):
    if top is None:
        top = nop
    n1 = _get_complex_array(10, 10)
    n2 = _get_complex_array(10, 10)
    n3 = _get_complex_array(10, 10)
    t1 = ComplexTensor(n1.copy())
    t2 = ComplexTensor(n2.copy())
    t3 = ComplexTensor(n3.copy())

    x = nop([n1, n2, n3])
    y = top([t1, t2, t3])
コード例 #26
0
    def forward(
        self,
        data: ComplexTensor,
        ilens: torch.LongTensor,
        powers: Union[List[torch.Tensor], None] = None,
    ) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]:
        """DNN_Beamformer forward function.

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq

        Args:
            data (ComplexTensor): (B, T, C, F)
            ilens (torch.Tensor): (B,)
            powers (List[torch.Tensor] or None): used for wMPDR or WPD (B, F, T)
        Returns:
            enhanced (ComplexTensor): (B, T, F)
            ilens (torch.Tensor): (B,)
            masks (torch.Tensor): (B, T, C, F)
        """
        def apply_beamforming(data,
                              ilens,
                              psd_n,
                              psd_speech,
                              psd_distortion=None):
            """Beamforming with the provided statistics.

            Args:
                data (ComplexTensor): (B, F, C, T)
                ilens (torch.Tensor): (B,)
                psd_n (ComplexTensor):
                    Noise covariance matrix for MVDR (B, F, C, C)
                    Observation covariance matrix for MPDR/wMPDR (B, F, C, C)
                    Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C)
                psd_speech (ComplexTensor): Speech covariance matrix (B, F, C, C)
                psd_distortion (ComplexTensor): Noise covariance matrix (B, F, C, C)
            Return:
                enhanced (ComplexTensor): (B, F, T)
                ws (ComplexTensor): (B, F) or (B, F, (btaps+1)*C)
            """
            # u: (B, C)
            if self.ref_channel < 0:
                u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens)
                u = u.double()
            else:
                if self.beamformer_type.endswith("_souden"):
                    # (optional) Create onehot vector for fixed reference microphone
                    u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )),
                                    device=data.device,
                                    dtype=torch.double)
                    u[..., self.ref_channel].fill_(1)
                else:
                    # for simplifying computation in RTF-based beamforming
                    u = self.ref_channel

            if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"):
                ws = get_mvdr_vector_with_rtf(
                    psd_n.double(),
                    psd_speech.double(),
                    psd_distortion.double(),
                    iterations=self.rtf_iterations,
                    reference_vector=u,
                    normalize_ref_channel=self.ref_channel,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = apply_beamforming_vector(ws, data.double())
            elif self.beamformer_type in ("mpdr_souden", "mvdr_souden",
                                          "wmpdr_souden"):
                ws = get_mvdr_vector(
                    psd_speech.double(),
                    psd_n.double(),
                    u,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = apply_beamforming_vector(ws, data.double())
            elif self.beamformer_type == "wpd":
                ws = get_WPD_filter_with_rtf(
                    psd_n.double(),
                    psd_speech.double(),
                    psd_distortion.double(),
                    iterations=self.rtf_iterations,
                    reference_vector=u,
                    normalize_ref_channel=self.ref_channel,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = perform_WPD_filtering(ws, data.double(),
                                                 self.bdelay, self.btaps)
            elif self.beamformer_type == "wpd_souden":
                ws = get_WPD_filter_v2(
                    psd_speech.double(),
                    psd_n.double(),
                    u,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = perform_WPD_filtering(ws, data.double(),
                                                 self.bdelay, self.btaps)
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    self.beamformer_type))

            return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype)

        # data (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)
        data_d = data.double()

        # mask: [(B, F, C, T)]
        masks, _ = self.mask(data, ilens)
        assert self.nmask == len(masks), len(masks)
        # floor masks to increase numerical stability
        if self.mask_flooring:
            masks = [torch.clamp(m, min=self.flooring_thres) for m in masks]

        if self.num_spk == 1:  # single-speaker case
            if self.use_noise_mask:
                # (mask_speech, mask_noise)
                mask_speech, mask_noise = masks
            else:
                # (mask_speech,)
                mask_speech = masks[0]
                mask_noise = 1 - mask_speech

            if self.beamformer_type.startswith(
                    "wmpdr") or self.beamformer_type.startswith("wpd"):
                if powers is None:
                    power_input = data_d.real**2 + data_d.imag**2
                    # Averaging along the channel axis: (..., C, T) -> (..., T)
                    powers = (power_input * mask_speech.double()).mean(dim=-2)
                else:
                    assert len(powers) == 1, len(powers)
                    powers = powers[0]
                inverse_power = 1 / torch.clamp(powers, min=self.eps)

            psd_speech = get_power_spectral_density_matrix(
                data_d, mask_speech.double())
            if mask_noise is not None and (
                    self.beamformer_type == "mvdr_souden"
                    or not self.beamformer_type.endswith("_souden")):
                # MVDR or other RTF-based formulas
                psd_noise = get_power_spectral_density_matrix(
                    data_d, mask_noise.double())
            if self.beamformer_type == "mvdr":
                enhanced, ws = apply_beamforming(data,
                                                 ilens,
                                                 psd_noise,
                                                 psd_speech,
                                                 psd_distortion=psd_noise)
            elif self.beamformer_type == "mvdr_souden":
                enhanced, ws = apply_beamforming(data, ilens, psd_noise,
                                                 psd_speech)
            elif self.beamformer_type == "mpdr":
                psd_observed = FC.einsum("...ct,...et->...ce",
                                         [data_d, data_d.conj()])
                enhanced, ws = apply_beamforming(data,
                                                 ilens,
                                                 psd_observed,
                                                 psd_speech,
                                                 psd_distortion=psd_noise)
            elif self.beamformer_type == "mpdr_souden":
                psd_observed = FC.einsum("...ct,...et->...ce",
                                         [data_d, data_d.conj()])
                enhanced, ws = apply_beamforming(data, ilens, psd_observed,
                                                 psd_speech)
            elif self.beamformer_type == "wmpdr":
                psd_observed = FC.einsum(
                    "...ct,...et->...ce",
                    [data_d * inverse_power[..., None, :],
                     data_d.conj()],
                )
                enhanced, ws = apply_beamforming(data,
                                                 ilens,
                                                 psd_observed,
                                                 psd_speech,
                                                 psd_distortion=psd_noise)
            elif self.beamformer_type == "wmpdr_souden":
                psd_observed = FC.einsum(
                    "...ct,...et->...ce",
                    [data_d * inverse_power[..., None, :],
                     data_d.conj()],
                )
                enhanced, ws = apply_beamforming(data, ilens, psd_observed,
                                                 psd_speech)
            elif self.beamformer_type == "wpd":
                psd_observed_bar = get_covariances(data_d,
                                                   inverse_power,
                                                   self.bdelay,
                                                   self.btaps,
                                                   get_vector=False)
                enhanced, ws = apply_beamforming(data,
                                                 ilens,
                                                 psd_observed_bar,
                                                 psd_speech,
                                                 psd_distortion=psd_noise)
            elif self.beamformer_type == "wpd_souden":
                psd_observed_bar = get_covariances(data_d,
                                                   inverse_power,
                                                   self.bdelay,
                                                   self.btaps,
                                                   get_vector=False)
                enhanced, ws = apply_beamforming(data, ilens, psd_observed_bar,
                                                 psd_speech)
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    self.beamformer_type))

            # (..., F, T) -> (..., T, F)
            enhanced = enhanced.transpose(-1, -2)
        else:  # multi-speaker case
            if self.use_noise_mask:
                # (mask_speech1, ..., mask_noise)
                mask_speech = list(masks[:-1])
                mask_noise = masks[-1]
            else:
                # (mask_speech1, ..., mask_speechX)
                mask_speech = list(masks)
                mask_noise = None

            if self.beamformer_type.startswith(
                    "wmpdr") or self.beamformer_type.startswith("wpd"):
                if powers is None:
                    power_input = data_d.real**2 + data_d.imag**2
                    # Averaging along the channel axis: (..., C, T) -> (..., T)
                    powers = [(power_input * m.double()).mean(dim=-2)
                              for m in mask_speech]
                else:
                    assert len(powers) == self.num_spk, len(powers)
                inverse_power = [
                    1 / torch.clamp(p, min=self.eps) for p in powers
                ]

            psd_speeches = [
                get_power_spectral_density_matrix(data_d, mask.double())
                for mask in mask_speech
            ]
            if mask_noise is not None and (
                    self.beamformer_type == "mvdr_souden"
                    or not self.beamformer_type.endswith("_souden")):
                # MVDR or other RTF-based formulas
                psd_noise = get_power_spectral_density_matrix(
                    data_d, mask_noise.double())
            if self.beamformer_type in ("mpdr", "mpdr_souden"):
                psd_observed = FC.einsum("...ct,...et->...ce",
                                         [data_d, data_d.conj()])
            elif self.beamformer_type in ("wmpdr", "wmpdr_souden"):
                psd_observed = [
                    FC.einsum(
                        "...ct,...et->...ce",
                        [data_d * inv_p[..., None, :],
                         data_d.conj()],
                    ) for inv_p in inverse_power
                ]
            elif self.beamformer_type in ("wpd", "wpd_souden"):
                psd_observed_bar = [
                    get_covariances(data_d,
                                    inv_p,
                                    self.bdelay,
                                    self.btaps,
                                    get_vector=False)
                    for inv_p in inverse_power
                ]

            enhanced, ws = [], []
            for i in range(self.num_spk):
                psd_speech = psd_speeches.pop(i)
                if (self.beamformer_type == "mvdr_souden"
                        or not self.beamformer_type.endswith("_souden")):
                    psd_noise_i = (psd_noise + sum(psd_speeches) if mask_noise
                                   is not None else sum(psd_speeches))
                # treat all other speakers' psd_speech as noises
                if self.beamformer_type == "mvdr":
                    enh, w = apply_beamforming(data,
                                               ilens,
                                               psd_noise_i,
                                               psd_speech,
                                               psd_distortion=psd_noise_i)
                elif self.beamformer_type == "mvdr_souden":
                    enh, w = apply_beamforming(data, ilens, psd_noise_i,
                                               psd_speech)
                elif self.beamformer_type == "mpdr":
                    enh, w = apply_beamforming(
                        data,
                        ilens,
                        psd_observed,
                        psd_speech,
                        psd_distortion=psd_noise_i,
                    )
                elif self.beamformer_type == "mpdr_souden":
                    enh, w = apply_beamforming(data, ilens, psd_observed,
                                               psd_speech)
                elif self.beamformer_type == "wmpdr":
                    enh, w = apply_beamforming(
                        data,
                        ilens,
                        psd_observed[i],
                        psd_speech,
                        psd_distortion=psd_noise_i,
                    )
                elif self.beamformer_type == "wmpdr_souden":
                    enh, w = apply_beamforming(data, ilens, psd_observed[i],
                                               psd_speech)
                elif self.beamformer_type == "wpd":
                    enh, w = apply_beamforming(
                        data,
                        ilens,
                        psd_observed_bar[i],
                        psd_speech,
                        psd_distortion=psd_noise_i,
                    )
                elif self.beamformer_type == "wpd_souden":
                    enh, w = apply_beamforming(data, ilens,
                                               psd_observed_bar[i], psd_speech)
                else:
                    raise ValueError(
                        "Not supporting beamformer_type={}".format(
                            self.beamformer_type))
                psd_speeches.insert(i, psd_speech)

                # (..., F, T) -> (..., T, F)
                enh = enh.transpose(-1, -2)
                enhanced.append(enh)
                ws.append(w)

        # (..., F, C, T) -> (..., T, C, F)
        masks = [m.transpose(-1, -3) for m in masks]
        return enhanced, ilens, masks