def perform_WPD_filtering( filter_matrix: Union[torch.Tensor, ComplexTensor], Y: Union[torch.Tensor, ComplexTensor], bdelay: int, btaps: int, ) -> Union[torch.Tensor, ComplexTensor]: """Perform WPD filtering. Args: filter_matrix: Filter matrix (B, F, (btaps + 1) * C) Y : Complex STFT signal with shape (B, F, C, T) Returns: enhanced (torch.complex64/ComplexTensor): (B, F, T) """ # (B, F, C, T) --> (B, F, C, T, btaps + 1) Ytilde = signal_framing(Y, btaps + 1, 1, bdelay, do_padding=True, pad_value=0) Ytilde = reverse(Ytilde, dim=-1) Bs, Fdim, C, T = Y.shape # --> (B, F, T, btaps + 1, C) --> (B, F, T, (btaps + 1) * C) Ytilde = Ytilde.permute(0, 1, 3, 4, 2).contiguous().view(Bs, Fdim, T, -1) # (B, F, T, 1) enhanced = einsum("...tc,...c->...t", Ytilde, filter_matrix.conj()) return enhanced
def get_covariances( Y: Union[torch.Tensor, ComplexTensor], inverse_power: torch.Tensor, bdelay: int, btaps: int, get_vector: bool = False, ) -> Union[torch.Tensor, ComplexTensor]: """Calculates the power normalized spatio-temporal covariance matrix of the framed signal. Args: Y : Complex STFT signal with shape (B, F, C, T) inverse_power : Weighting factor with shape (B, F, T) Returns: Correlation matrix: (B, F, (btaps+1) * C, (btaps+1) * C) Correlation vector: (B, F, btaps + 1, C, C) """ # noqa: H405, D205, D400, D401 assert inverse_power.dim() == 3, inverse_power.dim() assert inverse_power.size(0) == Y.size(0), (inverse_power.size(0), Y.size(0)) Bs, Fdim, C, T = Y.shape # (B, F, C, T - bdelay - btaps + 1, btaps + 1) Psi = signal_framing(Y, btaps + 1, 1, bdelay, do_padding=False)[ ..., : T - bdelay - btaps + 1, : ] # Reverse along btaps-axis: # [tau, tau-bdelay, tau-bdelay-1, ..., tau-bdelay-frame_length+1] Psi = reverse(Psi, dim=-1) Psi_norm = Psi * inverse_power[..., None, bdelay + btaps - 1 :, None] # let T' = T - bdelay - btaps + 1 # (B, F, C, T', btaps + 1) x (B, F, C, T', btaps + 1) # -> (B, F, btaps + 1, C, btaps + 1, C) covariance_matrix = einsum("bfdtk,bfetl->bfkdle", Psi, Psi_norm.conj()) # (B, F, btaps + 1, C, btaps + 1, C) # -> (B, F, (btaps + 1) * C, (btaps + 1) * C) covariance_matrix = covariance_matrix.view( Bs, Fdim, (btaps + 1) * C, (btaps + 1) * C ) if get_vector: # (B, F, C, T', btaps + 1) x (B, F, C, T') # --> (B, F, btaps +1, C, C) covariance_vector = einsum( "bfdtk,bfet->bfked", Psi_norm, Y[..., bdelay + btaps - 1 :].conj() ) return covariance_matrix, covariance_vector else: return covariance_matrix
def get_correlations( Y: Union[torch.Tensor, ComplexTensor], inverse_power: torch.Tensor, taps, delay ) -> Tuple[Union[torch.Tensor, ComplexTensor], Union[torch.Tensor, ComplexTensor]]: """Calculates weighted correlations of a window of length taps Args: Y : Complex-valued STFT signal with shape (F, C, T) inverse_power : Weighting factor with shape (F, T) taps (int): Lenghts of correlation window delay (int): Delay for the weighting factor Returns: Correlation matrix of shape (F, taps*C, taps*C) Correlation vector of shape (F, taps, C, C) """ assert inverse_power.dim() == 2, inverse_power.dim() assert inverse_power.size(0) == Y.size(0), (inverse_power.size(0), Y.size(0)) F, C, T = Y.size() # Y: (F, C, T) -> Psi: (F, C, T, taps) Psi = signal_framing(Y, frame_length=taps, frame_step=1)[..., :T - delay - taps + 1, :] # Reverse along taps-axis Psi = reverse(Psi, dim=-1) Psi_conj_norm = Psi.conj() * inverse_power[..., None, delay + taps - 1:, None] # (F, C, T, taps) x (F, C, T, taps) -> (F, taps, C, taps, C) correlation_matrix = einsum("fdtk,fetl->fkdle", Psi_conj_norm, Psi) # (F, taps, C, taps, C) -> (F, taps * C, taps * C) correlation_matrix = correlation_matrix.reshape(F, taps * C, taps * C) # (F, C, T, taps) x (F, C, T) -> (F, taps, C, C) correlation_vector = einsum("fdtk,fet->fked", Psi_conj_norm, Y[..., delay + taps - 1:]) return correlation_matrix, correlation_vector