Exemplo n.º 1
0
def perform_WPD_filtering(filter_matrix: ComplexTensor, Y: ComplexTensor,
                          bdelay: int, btaps: int) -> ComplexTensor:
    """Perform WPD filtering.

    Args:
        filter_matrix: Filter matrix (B, F, (btaps + 1) * C)
        Y : Complex STFT signal with shape (B, F, C, T)

    Returns:
        enhanced (ComplexTensor): (B, F, T)
    """
    # (B, F, C, T) --> (B, F, C, T, btaps + 1)
    Ytilde = signal_framing(Y,
                            btaps + 1,
                            1,
                            bdelay,
                            do_padding=True,
                            pad_value=0)
    Ytilde = FC.reverse(Ytilde, dim=-1)

    Bs, Fdim, C, T = Y.shape
    # --> (B, F, T, btaps + 1, C) --> (B, F, T, (btaps + 1) * C)
    Ytilde = Ytilde.permute(0, 1, 3, 4, 2).contiguous().view(Bs, Fdim, T, -1)
    # (B, F, T, 1)
    enhanced = FC.einsum("...tc,...c->...t", [Ytilde, filter_matrix.conj()])
    return enhanced
Exemplo n.º 2
0
def get_covariances(
    Y: ComplexTensor,
    inverse_power: torch.Tensor,
    bdelay: int,
    btaps: int,
    get_vector: bool = False,
) -> ComplexTensor:
    """Calculates the power normalized spatio-temporal

     covariance matrix of the framed signal.

    Args:
        Y : Complext STFT signal with shape (B, F, C, T)
        inverse_power : Weighting factor with shape (B, F, T)

    Returns:
        Correlation matrix of shape (B, F, (btaps+1) * C, (btaps+1) * C)
        Correlation vector of shape (B, F, btaps + 1, C, C)
    """
    assert inverse_power.dim() == 3, inverse_power.dim()
    assert inverse_power.size(0) == Y.size(0), (inverse_power.size(0), Y.size(0))

    Bs, Fdim, C, T = Y.shape

    # (B, F, C, T - bdelay - btaps + 1, btaps + 1)
    Psi = signal_framing(Y, btaps + 1, 1, bdelay, do_padding=False)[
        ..., : T - bdelay - btaps + 1, :
    ]
    # Reverse along btaps-axis:
    # [tau, tau-bdelay, tau-bdelay-1, ..., tau-bdelay-frame_length+1]
    Psi = FC.reverse(Psi, dim=-1)
    Psi_norm = Psi * inverse_power[..., None, bdelay + btaps - 1 :, None]

    # let T' = T - bdelay - btaps + 1
    # (B, F, C, T', btaps + 1) x (B, F, C, T', btaps + 1)
    #  -> (B, F, btaps + 1, C, btaps + 1, C)
    covariance_matrix = FC.einsum("bfdtk,bfetl->bfkdle", (Psi, Psi_norm.conj()))

    # (B, F, btaps + 1, C, btaps + 1, C)
    #   -> (B, F, (btaps + 1) * C, (btaps + 1) * C)
    covariance_matrix = covariance_matrix.view(
        Bs, Fdim, (btaps + 1) * C, (btaps + 1) * C
    )

    if get_vector:
        # (B, F, C, T', btaps + 1) x (B, F, C, T')
        #    --> (B, F, btaps +1, C, C)
        covariance_vector = FC.einsum(
            "bfdtk,bfet->bfked", (Psi_norm, Y[..., bdelay + btaps - 1 :].conj())
        )
        return covariance_matrix, covariance_vector
    else:
        return covariance_matrix
Exemplo n.º 3
0
def get_correlations(Y: ComplexTensor, inverse_power: torch.Tensor, taps,
                     delay) -> Tuple[ComplexTensor, ComplexTensor]:
    """Calculates weighted correlations of a window of length taps

    Args:
        Y : Complex-valued STFT signal with shape (F, C, T)
        inverse_power : Weighting factor with shape (F, T)
        taps (int): Lenghts of correlation window
        delay (int): Delay for the weighting factor

    Returns:
        Correlation matrix of shape (F, taps*C, taps*C)
        Correlation vector of shape (F, taps, C, C)
    """
    assert inverse_power.dim() == 2, inverse_power.dim()
    assert inverse_power.size(0) == Y.size(0), \
        (inverse_power.size(0), Y.size(0))

    F, C, T = Y.size()

    # Y: (F, C, T) -> Psi: (F, C, T, taps)
    Psi = signal_framing(Y, frame_length=taps,
                         frame_step=1)[..., :T - delay - taps + 1, :]
    # Reverse along taps-axis
    Psi = FC.reverse(Psi, dim=-1)
    Psi_conj_norm = \
        Psi.conj() * inverse_power[..., None, delay + taps - 1:, None]

    # (F, C, T, taps) x (F, C, T, taps) -> (F, taps, C, taps, C)
    correlation_matrix = FC.einsum('fdtk,fetl->fkdle', (Psi_conj_norm, Psi))
    # (F, taps, C, taps, C) -> (F, taps * C, taps * C)
    correlation_matrix = correlation_matrix.view(F, taps * C, taps * C)

    # (F, C, T, taps) x (F, C, T) -> (F, taps, C, C)
    correlation_vector = FC.einsum('fdtk,fet->fked',
                                   (Psi_conj_norm, Y[..., delay + taps - 1:]))

    return correlation_matrix, correlation_vector
Exemplo n.º 4
0
def reverse(a: Union[torch.Tensor, ComplexTensor], dim=0):
    if isinstance(a, ComplexTensor):
        return FC.reverse(a, dim=dim)
    else:
        return torch.flip(a, dims=(dim, ))
Exemplo n.º 5
0
def online_wpe_step(input_buffer: ComplexTensor,
                    power: torch.Tensor,
                    inv_cov: ComplexTensor = None,
                    filter_taps: ComplexTensor = None,
                    alpha: float = 0.99,
                    taps: int = 10,
                    delay: int = 3):
    """One step of online dereverberation.

    Args:
        input_buffer: (F, C, taps + delay + 1)
        power: Estimate for the current PSD (F, T)
        inv_cov: Current estimate of R^-1
        filter_taps: Current estimate of filter taps (F, taps * C, taps)
        alpha (float): Smoothing factor
        taps (int): Number of filter taps
        delay (int): Delay in frames

    Returns:
        Dereverberated frame of shape (F, D)
        Updated estimate of R^-1
        Updated estimate of the filter taps


    >>> frame_length = 512
    >>> frame_shift = 128
    >>> taps = 6
    >>> delay = 3
    >>> alpha = 0.999
    >>> frequency_bins = frame_length // 2 + 1
    >>> Q = None
    >>> G = None
    >>> unreverbed, Q, G = online_wpe_step(stft, get_power_online(stft), Q, G,
    ...                                    alpha=alpha, taps=taps, delay=delay)

    """
    assert input_buffer.size(-1) == taps + delay + 1, input_buffer.size()
    C = input_buffer.size(-2)

    if inv_cov is None:
        inv_cov = ComplexTensor(
            torch.eye(C * taps, dtype=input_buffer.dtype).expand(
                *input_buffer.size()[:-2], C * taps, C * taps))
    if filter_taps is None:
        filter_taps = ComplexTensor(
            torch.zeros(*input_buffer.size()[:-2],
                        C * taps,
                        C,
                        dtype=input_buffer.dtype))

    window = FC.reverse(input_buffer[..., :-delay - 1], dim=-1)
    # (..., C, T) -> (..., C * T)
    window = window.view(*input_buffer.size()[:-2], -1)
    pred = input_buffer[..., -1] - FC.einsum('...id,...i->...d',
                                             (filter_taps.conj(), window))

    nominator = FC.einsum('...ij,...j->...i', (inv_cov, window))
    denominator = \
        FC.einsum('...i,...i->...', (window.conj(), nominator)) + alpha * power
    kalman_gain = nominator / denominator[..., None]

    inv_cov_k = inv_cov - FC.einsum('...j,...jm,...i->...im',
                                    (window.conj(), inv_cov, kalman_gain))
    inv_cov_k /= alpha

    filter_taps_k = \
        filter_taps + FC.einsum('...i,...m->...im', (kalman_gain, pred.conj()))
    return pred, inv_cov_k, filter_taps_k