def perform_WPD_filtering(filter_matrix: ComplexTensor, Y: ComplexTensor, bdelay: int, btaps: int) -> ComplexTensor: """Perform WPD filtering. Args: filter_matrix: Filter matrix (B, F, (btaps + 1) * C) Y : Complex STFT signal with shape (B, F, C, T) Returns: enhanced (ComplexTensor): (B, F, T) """ # (B, F, C, T) --> (B, F, C, T, btaps + 1) Ytilde = signal_framing(Y, btaps + 1, 1, bdelay, do_padding=True, pad_value=0) Ytilde = FC.reverse(Ytilde, dim=-1) Bs, Fdim, C, T = Y.shape # --> (B, F, T, btaps + 1, C) --> (B, F, T, (btaps + 1) * C) Ytilde = Ytilde.permute(0, 1, 3, 4, 2).contiguous().view(Bs, Fdim, T, -1) # (B, F, T, 1) enhanced = FC.einsum("...tc,...c->...t", [Ytilde, filter_matrix.conj()]) return enhanced
def get_covariances( Y: ComplexTensor, inverse_power: torch.Tensor, bdelay: int, btaps: int, get_vector: bool = False, ) -> ComplexTensor: """Calculates the power normalized spatio-temporal covariance matrix of the framed signal. Args: Y : Complext STFT signal with shape (B, F, C, T) inverse_power : Weighting factor with shape (B, F, T) Returns: Correlation matrix of shape (B, F, (btaps+1) * C, (btaps+1) * C) Correlation vector of shape (B, F, btaps + 1, C, C) """ assert inverse_power.dim() == 3, inverse_power.dim() assert inverse_power.size(0) == Y.size(0), (inverse_power.size(0), Y.size(0)) Bs, Fdim, C, T = Y.shape # (B, F, C, T - bdelay - btaps + 1, btaps + 1) Psi = signal_framing(Y, btaps + 1, 1, bdelay, do_padding=False)[ ..., : T - bdelay - btaps + 1, : ] # Reverse along btaps-axis: # [tau, tau-bdelay, tau-bdelay-1, ..., tau-bdelay-frame_length+1] Psi = FC.reverse(Psi, dim=-1) Psi_norm = Psi * inverse_power[..., None, bdelay + btaps - 1 :, None] # let T' = T - bdelay - btaps + 1 # (B, F, C, T', btaps + 1) x (B, F, C, T', btaps + 1) # -> (B, F, btaps + 1, C, btaps + 1, C) covariance_matrix = FC.einsum("bfdtk,bfetl->bfkdle", (Psi, Psi_norm.conj())) # (B, F, btaps + 1, C, btaps + 1, C) # -> (B, F, (btaps + 1) * C, (btaps + 1) * C) covariance_matrix = covariance_matrix.view( Bs, Fdim, (btaps + 1) * C, (btaps + 1) * C ) if get_vector: # (B, F, C, T', btaps + 1) x (B, F, C, T') # --> (B, F, btaps +1, C, C) covariance_vector = FC.einsum( "bfdtk,bfet->bfked", (Psi_norm, Y[..., bdelay + btaps - 1 :].conj()) ) return covariance_matrix, covariance_vector else: return covariance_matrix
def get_correlations(Y: ComplexTensor, inverse_power: torch.Tensor, taps, delay) -> Tuple[ComplexTensor, ComplexTensor]: """Calculates weighted correlations of a window of length taps Args: Y : Complex-valued STFT signal with shape (F, C, T) inverse_power : Weighting factor with shape (F, T) taps (int): Lenghts of correlation window delay (int): Delay for the weighting factor Returns: Correlation matrix of shape (F, taps*C, taps*C) Correlation vector of shape (F, taps, C, C) """ assert inverse_power.dim() == 2, inverse_power.dim() assert inverse_power.size(0) == Y.size(0), \ (inverse_power.size(0), Y.size(0)) F, C, T = Y.size() # Y: (F, C, T) -> Psi: (F, C, T, taps) Psi = signal_framing(Y, frame_length=taps, frame_step=1)[..., :T - delay - taps + 1, :] # Reverse along taps-axis Psi = FC.reverse(Psi, dim=-1) Psi_conj_norm = \ Psi.conj() * inverse_power[..., None, delay + taps - 1:, None] # (F, C, T, taps) x (F, C, T, taps) -> (F, taps, C, taps, C) correlation_matrix = FC.einsum('fdtk,fetl->fkdle', (Psi_conj_norm, Psi)) # (F, taps, C, taps, C) -> (F, taps * C, taps * C) correlation_matrix = correlation_matrix.view(F, taps * C, taps * C) # (F, C, T, taps) x (F, C, T) -> (F, taps, C, C) correlation_vector = FC.einsum('fdtk,fet->fked', (Psi_conj_norm, Y[..., delay + taps - 1:])) return correlation_matrix, correlation_vector
def reverse(a: Union[torch.Tensor, ComplexTensor], dim=0): if isinstance(a, ComplexTensor): return FC.reverse(a, dim=dim) else: return torch.flip(a, dims=(dim, ))
def online_wpe_step(input_buffer: ComplexTensor, power: torch.Tensor, inv_cov: ComplexTensor = None, filter_taps: ComplexTensor = None, alpha: float = 0.99, taps: int = 10, delay: int = 3): """One step of online dereverberation. Args: input_buffer: (F, C, taps + delay + 1) power: Estimate for the current PSD (F, T) inv_cov: Current estimate of R^-1 filter_taps: Current estimate of filter taps (F, taps * C, taps) alpha (float): Smoothing factor taps (int): Number of filter taps delay (int): Delay in frames Returns: Dereverberated frame of shape (F, D) Updated estimate of R^-1 Updated estimate of the filter taps >>> frame_length = 512 >>> frame_shift = 128 >>> taps = 6 >>> delay = 3 >>> alpha = 0.999 >>> frequency_bins = frame_length // 2 + 1 >>> Q = None >>> G = None >>> unreverbed, Q, G = online_wpe_step(stft, get_power_online(stft), Q, G, ... alpha=alpha, taps=taps, delay=delay) """ assert input_buffer.size(-1) == taps + delay + 1, input_buffer.size() C = input_buffer.size(-2) if inv_cov is None: inv_cov = ComplexTensor( torch.eye(C * taps, dtype=input_buffer.dtype).expand( *input_buffer.size()[:-2], C * taps, C * taps)) if filter_taps is None: filter_taps = ComplexTensor( torch.zeros(*input_buffer.size()[:-2], C * taps, C, dtype=input_buffer.dtype)) window = FC.reverse(input_buffer[..., :-delay - 1], dim=-1) # (..., C, T) -> (..., C * T) window = window.view(*input_buffer.size()[:-2], -1) pred = input_buffer[..., -1] - FC.einsum('...id,...i->...d', (filter_taps.conj(), window)) nominator = FC.einsum('...ij,...j->...i', (inv_cov, window)) denominator = \ FC.einsum('...i,...i->...', (window.conj(), nominator)) + alpha * power kalman_gain = nominator / denominator[..., None] inv_cov_k = inv_cov - FC.einsum('...j,...jm,...i->...im', (window.conj(), inv_cov, kalman_gain)) inv_cov_k /= alpha filter_taps_k = \ filter_taps + FC.einsum('...i,...m->...im', (kalman_gain, pred.conj())) return pred, inv_cov_k, filter_taps_k