def get_mvdr_vector(psd_s: ComplexTensor, psd_n: ComplexTensor, reference_vector: torch.Tensor, eps: float = 1e-15) -> ComplexTensor: """Return the MVDR(Minimum Variance Distortionless Response) vector: h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u Reference: On optimal frequency-domain multichannel linear filtering for noise reduction; M. Souden et al., 2010; https://ieeexplore.ieee.org/document/5089420 Args: psd_s (ComplexTensor): (..., F, C, C) psd_n (ComplexTensor): (..., F, C, C) reference_vector (torch.Tensor): (..., C) eps (float): Returns: beamform_vector (ComplexTensor)r: (..., F, C) """ # Add eps C = psd_n.size(-1) eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device) shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C] eye = eye.view(*shape) psd_n += eps * eye # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) numerator = FC.einsum('...ec,...cd->...ed', [psd_n.inverse(), psd_s]) # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = FC.einsum('...fec,...c->...fe', [ws, reference_vector]) return beamform_vector
def get_WPD_filter( Phi: ComplexTensor, Rf: ComplexTensor, reference_vector: torch.Tensor, eps: float = 1e-15, ) -> ComplexTensor: """Return the WPD vector. WPD is the Weighted Power minimization Distortionless response convolutional beamformer. As follows: h = (Rf^-1 @ Phi_{xx}) / tr[(Rf^-1) @ Phi_{xx}] @ u Reference: T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer for Simultaneous Denoising and Dereverberation," in IEEE Signal Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi: 10.1109/LSP.2019.2911179. https://ieeexplore.ieee.org/document/8691481 Args: Phi (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the PSD of zero-padded speech [x^T(t,f) 0 ... 0]^T. Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the power normalized spatio-temporal covariance matrix. reference_vector (torch.Tensor): (B, (btaps+1) * C) is the reference_vector. eps (float): Returns: filter_matrix (ComplexTensor): (B, F, (btaps + 1) * C) """ try: inv_Rf = inv(Rf) except Exception: try: reg_coeff_tensor = ( ComplexTensor(torch.rand_like(Rf.real), torch.rand_like(Rf.real)) * 1e-4 ) Rf = Rf / 10e4 Phi = Phi / 10e4 Rf += reg_coeff_tensor inv_Rf = inv(Rf) except Exception: reg_coeff_tensor = ( ComplexTensor(torch.rand_like(Rf.real), torch.rand_like(Rf.real)) * 1e-1 ) Rf = Rf / 10e10 Phi = Phi / 10e10 Rf += reg_coeff_tensor inv_Rf = inv(Rf) # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) numerator = FC.einsum("...ec,...cd->...ed", [inv_Rf, Phi]) # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) # (B, F, (btaps + 1) * C) return beamform_vector
def get_covariances( Y: ComplexTensor, inverse_power: torch.Tensor, bdelay: int, btaps: int, get_vector: bool = False, ) -> ComplexTensor: """Calculates the power normalized spatio-temporal covariance matrix of the framed signal. Args: Y : Complext STFT signal with shape (B, F, C, T) inverse_power : Weighting factor with shape (B, F, T) Returns: Correlation matrix of shape (B, F, (btaps+1) * C, (btaps+1) * C) Correlation vector of shape (B, F, btaps + 1, C, C) """ assert inverse_power.dim() == 3, inverse_power.dim() assert inverse_power.size(0) == Y.size(0), (inverse_power.size(0), Y.size(0)) Bs, Fdim, C, T = Y.shape # (B, F, C, T - bdelay - btaps + 1, btaps + 1) Psi = signal_framing(Y, btaps + 1, 1, bdelay, do_padding=False)[ ..., : T - bdelay - btaps + 1, : ] # Reverse along btaps-axis: # [tau, tau-bdelay, tau-bdelay-1, ..., tau-bdelay-frame_length+1] Psi = FC.reverse(Psi, dim=-1) Psi_norm = Psi * inverse_power[..., None, bdelay + btaps - 1 :, None] # let T' = T - bdelay - btaps + 1 # (B, F, C, T', btaps + 1) x (B, F, C, T', btaps + 1) # -> (B, F, btaps + 1, C, btaps + 1, C) covariance_matrix = FC.einsum("bfdtk,bfetl->bfkdle", (Psi, Psi_norm.conj())) # (B, F, btaps + 1, C, btaps + 1, C) # -> (B, F, (btaps + 1) * C, (btaps + 1) * C) covariance_matrix = covariance_matrix.view( Bs, Fdim, (btaps + 1) * C, (btaps + 1) * C ) if get_vector: # (B, F, C, T', btaps + 1) x (B, F, C, T') # --> (B, F, btaps +1, C, C) covariance_vector = FC.einsum( "bfdtk,bfet->bfked", (Psi_norm, Y[..., bdelay + btaps - 1 :].conj()) ) return covariance_matrix, covariance_vector else: return covariance_matrix
def filter_minimum_gain_like(G_min, w, y, alpha=None, k: float = 10.0, eps: float = EPS): """Approximate a minimum gain operation. speech_estimate = alpha w^H y + (1 - alpha) G_min Y, where alpha = 1 / (1 + exp(-2 k x)), x = w^H y - G_min Y Args: G_min (float): minimum gain w (ComplexTensor): filter coefficients (..., L, N) y (ComplexTensor): buffered and stacked input (..., L, N) alpha: mixing factor k (float): scaling in tanh-like function esp (float) Returns: output (ComplexTensor): minimum gain-filtered output alpha (float): optional """ # (..., L) filtered_input = FC.einsum("...d,...d->...", [w.conj(), y]) # (..., L) Y = y[..., -1] return minimum_gain_like(G_min, Y, filtered_input, alpha, k, eps)
def einsum(equation, *operands): # NOTE: Do not mix ComplexTensor and torch.complex in the input! # NOTE (wangyou): Until PyTorch 1.9.0, torch.einsum does not support # mixed input with complex and real tensors. if len(operands) == 1: if isinstance(operands[0], (tuple, list)): operands = operands[0] complex_module = FC if isinstance(operands[0], ComplexTensor) else torch return complex_module.einsum(equation, *operands) elif len(operands) != 2: op0 = operands[0] same_type = all(op.dtype == op0.dtype for op in operands[1:]) if same_type: _einsum = FC.einsum if isinstance(op0, ComplexTensor) else torch.einsum return _einsum(equation, *operands) else: raise ValueError("0 or More than 2 operands are not supported.") a, b = operands if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor): return FC.einsum(equation, a, b) elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)): if not torch.is_complex(a): o_real = torch.einsum(equation, a, b.real) o_imag = torch.einsum(equation, a, b.imag) return torch.complex(o_real, o_imag) elif not torch.is_complex(b): o_real = torch.einsum(equation, a.real, b) o_imag = torch.einsum(equation, a.imag, b) return torch.complex(o_real, o_imag) else: return torch.einsum(equation, a, b) else: return torch.einsum(equation, a, b)
def get_power_spectral_density_matrix(xs: ComplexTensor, mask: torch.Tensor, normalization=True, eps: float = 1e-15) -> ComplexTensor: """Return cross-channel power spectral density (PSD) matrix Args: xs (ComplexTensor): (..., F, C, T) mask (torch.Tensor): (..., F, C, T) normalization (bool): eps (float): Returns psd (ComplexTensor): (..., F, C, C) """ # outer product: (..., C_1, T) x (..., C_2, T) -> (..., T, C, C_2) psd_Y = FC.einsum('...ct,...et->...tce', [xs, xs.conj()]) # Averaging mask along C: (..., C, T) -> (..., T) mask = mask.mean(dim=-2) # Normalized mask along T: (..., T) if normalization: # If assuming the tensor is padded with zero, the summation along # the time axis is same regardless of the padding length. mask = mask / (mask.sum(dim=-1, keepdim=True) + eps) # psd: (..., T, C, C) psd = psd_Y * mask[..., None, None] # (..., T, C, C) -> (..., C, C) psd = psd.sum(dim=-3) return psd
def perform_WPD_filtering(filter_matrix: ComplexTensor, Y: ComplexTensor, bdelay: int, btaps: int) -> ComplexTensor: """Perform WPD filtering. Args: filter_matrix: Filter matrix (B, F, (btaps + 1) * C) Y : Complex STFT signal with shape (B, F, C, T) Returns: enhanced (ComplexTensor): (B, F, T) """ # (B, F, C, T) --> (B, F, C, T, btaps + 1) Ytilde = signal_framing(Y, btaps + 1, 1, bdelay, do_padding=True, pad_value=0) Ytilde = FC.reverse(Ytilde, dim=-1) Bs, Fdim, C, T = Y.shape # --> (B, F, T, btaps + 1, C) --> (B, F, T, (btaps + 1) * C) Ytilde = Ytilde.permute(0, 1, 3, 4, 2).contiguous().view(Bs, Fdim, T, -1) # (B, F, T, 1) enhanced = FC.einsum("...tc,...c->...t", [Ytilde, filter_matrix.conj()]) return enhanced
def get_WPD_filter( Phi: ComplexTensor, Rf: ComplexTensor, reference_vector: torch.Tensor, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> ComplexTensor: """Return the WPD vector. WPD is the Weighted Power minimization Distortionless response convolutional beamformer. As follows: h = (Rf^-1 @ Phi_{xx}) / tr[(Rf^-1) @ Phi_{xx}] @ u Reference: T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer for Simultaneous Denoising and Dereverberation," in IEEE Signal Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi: 10.1109/LSP.2019.2911179. https://ieeexplore.ieee.org/document/8691481 Args: Phi (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the PSD of zero-padded speech [x^T(t,f) 0 ... 0]^T. Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the power normalized spatio-temporal covariance matrix. reference_vector (torch.Tensor): (B, (btaps+1) * C) is the reference_vector. use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: filter_matrix (ComplexTensor): (B, F, (btaps + 1) * C) """ if diagonal_loading: Rf = tik_reg(Rf, reg=diag_eps, eps=eps) # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) if use_torch_solver: numerator = FC.solve(Phi, Rf)[0] else: numerator = FC.matmul(Rf.inverse2(), Phi) # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) # (B, F, (btaps + 1) * C) return beamform_vector
def test_get_rtf(ch): stft = Stft( n_fft=8, win_length=None, hop_length=2, center=True, window="hann", normalized=False, onesided=True, ) torch.random.manual_seed(0) x = random_speech[..., :ch] n = torch.rand(2, 16, ch, dtype=torch.double) ilens = torch.LongTensor([16, 12]) # (B, T, C, F) -> (B, F, C, T) X = ComplexTensor(*torch.unbind(stft(x, ilens)[0], dim=-1)).transpose(-1, -3) N = ComplexTensor(*torch.unbind(stft(n, ilens)[0], dim=-1)).transpose(-1, -3) # (B, F, C, C) Phi_X = FC.einsum("...ct,...et->...ce", [X, X.conj()]) Phi_N = FC.einsum("...ct,...et->...ce", [N, N.conj()]) # (B, F, C, 1) rtf = get_rtf(Phi_X, Phi_N, reference_vector=0, iterations=20) if is_torch_1_1_plus: rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True).values + 1e-15) else: rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True)[0] + 1e-15) # rtf \approx Phi_N MaxEigVec(Phi_N^-1 @ Phi_X) if is_torch_1_1_plus: # torch.solve is required, which is only available after pytorch 1.1.0+ mat = FC.solve(Phi_X, Phi_N)[0] max_eigenvec = FC.solve(rtf, Phi_N)[0] else: mat = FC.matmul(Phi_N.inverse2(), Phi_X) max_eigenvec = FC.matmul(Phi_N.inverse2(), rtf) factor = FC.matmul(mat, max_eigenvec) assert FC.allclose( FC.matmul(max_eigenvec, factor.transpose(-1, -2)), FC.matmul(factor, max_eigenvec.transpose(-1, -2)), )
def get_correlations(Y: ComplexTensor, inverse_power: torch.Tensor, taps, delay) -> Tuple[ComplexTensor, ComplexTensor]: """Calculates weighted correlations of a window of length taps Args: Y : Complex-valued STFT signal with shape (F, C, T) inverse_power : Weighting factor with shape (F, T) taps (int): Lenghts of correlation window delay (int): Delay for the weighting factor Returns: Correlation matrix of shape (F, taps*C, taps*C) Correlation vector of shape (F, taps, C, C) """ assert inverse_power.dim() == 2, inverse_power.dim() assert inverse_power.size(0) == Y.size(0), \ (inverse_power.size(0), Y.size(0)) F, C, T = Y.size() # Y: (F, C, T) -> Psi: (F, C, T, taps) Psi = signal_framing(Y, frame_length=taps, frame_step=1)[..., :T - delay - taps + 1, :] # Reverse along taps-axis Psi = FC.reverse(Psi, dim=-1) Psi_conj_norm = \ Psi.conj() * inverse_power[..., None, delay + taps - 1:, None] # (F, C, T, taps) x (F, C, T, taps) -> (F, taps, C, taps, C) correlation_matrix = FC.einsum('fdtk,fetl->fkdle', (Psi_conj_norm, Psi)) # (F, taps, C, taps, C) -> (F, taps * C, taps * C) correlation_matrix = correlation_matrix.view(F, taps * C, taps * C) # (F, C, T, taps) x (F, C, T) -> (F, taps, C, C) correlation_vector = FC.einsum('fdtk,fet->fked', (Psi_conj_norm, Y[..., delay + taps - 1:])) return correlation_matrix, correlation_vector
def get_power_spectral_density_matrix( complex_tensor: ComplexTensor) -> ComplexTensor: """ Cross-channel power spectral density (PSD) matrix Args: complex_tensor: [..., F, C, T] Returns psd: [..., F, C, C] """ # outer product: [..., C_1, T] x [..., C_2, T] => [..., T, C_1, C_2] return FC.einsum("...ct,...et->...tce", [complex_tensor, complex_tensor.conj()])
def apply_crf_filter(cRM_filter: ComplexTensor, mix: ComplexTensor) -> ComplexTensor: """ Apply complex Ratio Filter Args: cRM_filter: complex Ratio Filter mix: mixture Returns: [B, C, F, T] """ # [B, F, T, Filter_delay] x [B, C, F, Filter_delay,T] => [B, C, F, T] es = FC.einsum("bftd, bcfdt -> bcft", [cRM_filter.conj(), mix]) return es
def perform_filter_operation_v2(Y: ComplexTensor, filter_matrix_conj: ComplexTensor, taps, delay) -> ComplexTensor: """perform_filter_operation_v2 Args: Y : Complex-valued STFT signal of shape (F, C, T) filter Matrix (F, taps, C, C) """ T = Y.size(-1) # Y_tilde: (taps, F, C, T) Y_tilde = FC.stack([FC.pad(Y[:, :, :T - delay - i], (delay + i, 0), mode='constant', value=0) for i in range(taps)], dim=0) reverb_tail = FC.einsum('fpde,pfdt->fet', (filter_matrix_conj, Y_tilde)) return Y - reverb_tail
def get_WPD_filter_v2( Phi: ComplexTensor, Rf: ComplexTensor, reference_vector: torch.Tensor, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> ComplexTensor: """Return the WPD vector (v2). This implementation is more efficient than `get_WPD_filter` as it skips unnecessary computation with zeros. Args: Phi (ComplexTensor): (B, F, C, C) is speech PSD. Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the power normalized spatio-temporal covariance matrix. reference_vector (torch.Tensor): (B, C) is the reference_vector. diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: filter_matrix (ComplexTensor): (B, F, (btaps+1) * C) """ C = reference_vector.shape[-1] if diagonal_loading: Rf = tik_reg(Rf, reg=diag_eps, eps=eps) inv_Rf = Rf.inverse2() # (B, F, (btaps+1) * C, C) inv_Rf_pruned = inv_Rf[..., :C] # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) numerator = FC.matmul(inv_Rf_pruned, Phi) # ws: (..., (btaps+1) * C, C) / (...,) -> (..., (btaps+1) * C, C) ws = numerator / (FC.trace(numerator[..., :C, :])[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) # (B, F, (btaps+1) * C) return beamform_vector
def get_mvdr_vector( psd_s: ComplexTensor, psd_n: ComplexTensor, reference_vector: torch.Tensor, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> ComplexTensor: """Return the MVDR (Minimum Variance Distortionless Response) vector: h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u Reference: On optimal frequency-domain multichannel linear filtering for noise reduction; M. Souden et al., 2010; https://ieeexplore.ieee.org/document/5089420 Args: psd_s (ComplexTensor): speech covariance matrix (..., F, C, C) psd_n (ComplexTensor): observation/noise covariance matrix (..., F, C, C) reference_vector (torch.Tensor): (..., C) use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (ComplexTensor): (..., F, C) """ # noqa: D400 if diagonal_loading: psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps) if use_torch_solver: numerator = FC.solve(psd_s, psd_n)[0] else: numerator = FC.matmul(psd_n.inverse2(), psd_s) # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) return beamform_vector
def get_mfmvdr_vector(gammax, Phi, use_torch_solver: bool = True, eps: float = EPS): """Compute conventional MFMPDR/MFMVDR filter. Args: gammax (ComplexTensor): (..., L, N) Phi (ComplexTensor): (..., L, N, N) use_torch_solver (bool): Whether to use `solve` instead of `inverse` eps (float) Returns: beamforming_vector (ComplexTensor): (..., L, N) """ # (..., L, N) if use_torch_solver: numerator = FC.solve(gammax.unsqueeze(-1), Phi)[0].squeeze(-1) else: numerator = FC.matmul(Phi.inverse2(), gammax.unsqueeze(-1)).squeeze(-1) denominator = FC.einsum("...d,...d->...", [gammax.conj(), numerator]) return numerator / (denominator.real.unsqueeze(-1) + eps)
def perform_filter_operation(Y: ComplexTensor, filter_matrix_conj: ComplexTensor, taps, delay) \ -> ComplexTensor: """perform_filter_operation Args: Y : Complex-valued STFT signal of shape (F, C, T) filter Matrix (F, taps, C, C) """ T = Y.size(-1) reverb_tail = ComplexTensor(torch.zeros_like(Y.real), torch.zeros_like(Y.real)) for tau_minus_delay in range(taps): new = FC.einsum('fde,fdt->fet', (filter_matrix_conj[:, tau_minus_delay, :, :], Y[:, :, :T - delay - tau_minus_delay])) new = FC.pad(new, (delay + tau_minus_delay, 0), mode='constant', value=0) reverb_tail = reverb_tail + new return Y - reverb_tail
def forward( self, data: ComplexTensor, ilens: torch.LongTensor ) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]: """The forward function Notation: B: Batch C: Channel T: Time or Sequence length F: Freq Args: data (ComplexTensor): (B, T, C, F), double precision ilens (torch.Tensor): (B,) Returns: enhanced (ComplexTensor): (B, T, F), double precision ilens (torch.Tensor): (B,) masks (torch.Tensor): (B, T, C, F) """ def apply_beamforming(data, ilens, psd_speech, psd_n, beamformer_type): # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech.float(), ilens) else: # (optional) Create onehot vector for fixed reference microphone u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )), device=data.device) u[..., self.ref_channel].fill_(1) if beamformer_type in ("mpdr", "mvdr"): ws = get_mvdr_vector(psd_speech, psd_n, u.double()) enhanced = apply_beamforming_vector(ws, data) elif beamformer_type == "wpd": ws = get_WPD_filter_v2(psd_speech, psd_n, u.double()) enhanced = perform_WPD_filtering(ws, data, self.bdelay, self.btaps) else: raise ValueError("Not supporting beamformer_type={}".format( beamformer_type)) return enhanced, ws # data (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) # mask: [(B, F, C, T)] masks, _ = self.mask(data.float(), ilens) assert self.nmask == len(masks) # floor masks with self.eps to increase numerical stability masks = [torch.clamp(m, min=self.eps) for m in masks] if self.num_spk == 1: # single-speaker case if self.use_noise_mask: # (mask_speech, mask_noise) mask_speech, mask_noise = masks else: # (mask_speech,) mask_speech = masks[0] mask_noise = 1 - mask_speech psd_speech = get_power_spectral_density_matrix( data, mask_speech.double()) if self.beamformer_type == "mvdr": # psd of noise psd_n = get_power_spectral_density_matrix( data, mask_noise.double()) elif self.beamformer_type == "mpdr": # psd of observed signal psd_n = FC.einsum("...ct,...et->...ce", [data, data.conj()]) elif self.beamformer_type == "wpd": # Calculate power: (..., C, T) power_speech = (data.real**2 + data.imag**2) * mask_speech.double() # Averaging along the channel axis: (B, F, C, T) -> (B, F, T) power_speech = power_speech.mean(dim=-2) inverse_power = 1 / torch.clamp(power_speech, min=self.eps) # covariance of expanded observed speech psd_n = get_covariances(data, inverse_power, self.bdelay, self.btaps, get_vector=False) else: raise ValueError("Not supporting beamformer_type={}".format( self.beamformer_type)) enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_n, self.beamformer_type) # (..., F, T) -> (..., T, F) enhanced = enhanced.transpose(-1, -2) else: # multi-speaker case if self.use_noise_mask: # (mask_speech1, ..., mask_noise) mask_speech = list(masks[:-1]) mask_noise = masks[-1] else: # (mask_speech1, ..., mask_speechX) mask_speech = list(masks) mask_noise = None psd_speeches = [ get_power_spectral_density_matrix(data, mask) for mask in mask_speech ] if self.beamformer_type == "mvdr": # psd of noise if mask_noise is not None: psd_n = get_power_spectral_density_matrix(data, mask_noise) elif self.beamformer_type == "mpdr": # psd of observed speech psd_n = FC.einsum("...ct,...et->...ce", [data, data.conj()]) elif self.beamformer_type == "wpd": # Calculate power: (..., C, T) power = data.real**2 + data.imag**2 power_speeches = [power * mask for mask in mask_speech] # Averaging along the channel axis: (B, F, C, T) -> (B, F, T) power_speeches = [ps.mean(dim=-2) for ps in power_speeches] inverse_poweres = [ 1 / torch.clamp(ps, min=self.eps) for ps in power_speeches ] # covariance of expanded observed speech psd_n = [ get_covariances(data, inv_ps, self.bdelay, self.btaps, get_vector=False) for inv_ps in inverse_poweres ] else: raise ValueError("Not supporting beamformer_type={}".format( self.beamformer_type)) enhanced = [] for i in range(self.num_spk): psd_speech = psd_speeches.pop(i) # treat all other speakers' psd_speech as noises if self.beamformer_type == "mvdr": psd_noise = sum(psd_speeches) if mask_noise is not None: psd_noise = psd_noise + psd_n enh, w = apply_beamforming(data, ilens, psd_speech, psd_noise, self.beamformer_type) elif self.beamformer_type == "mpdr": enh, w = apply_beamforming(data, ilens, psd_speech, psd_n, self.beamformer_type) elif self.beamformer_type == "wpd": enh, w = apply_beamforming(data, ilens, psd_speech, psd_n[i], self.beamformer_type) else: raise ValueError( "Not supporting beamformer_type={}".format( self.beamformer_type)) psd_speeches.insert(i, psd_speech) # (..., F, T) -> (..., T, F) enh = enh.transpose(-1, -2) enhanced.append(enh) # (..., F, C, T) -> (..., T, C, F) masks = [m.transpose(-1, -3) for m in masks] return enhanced, ilens, masks
def get_mvdr_vector_with_rtf( psd_n: ComplexTensor, psd_speech: ComplexTensor, psd_noise: ComplexTensor, iterations: int = 3, reference_vector: Union[int, torch.Tensor, None] = None, normalize_ref_channel: Optional[int] = None, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> ComplexTensor: """Return the MVDR (Minimum Variance Distortionless Response) vector calculated with RTF: h = (Npsd^-1 @ rtf) / (rtf^H @ Npsd^-1 @ rtf) Reference: On optimal frequency-domain multichannel linear filtering for noise reduction; M. Souden et al., 2010; https://ieeexplore.ieee.org/document/5089420 Args: psd_n (ComplexTensor): observation/noise covariance matrix (..., F, C, C) psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C) iterations (int): number of iterations in power method reference_vector (torch.Tensor or int): (..., C) or scalar normalize_ref_channel (int): reference channel for normalizing the RTF use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (ComplexTensor): (..., F, C) """ # noqa: H405, D205, D400 if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) # (B, F, C, 1) rtf = get_rtf( psd_speech, psd_noise, reference_vector, iterations=iterations, use_torch_solver=use_torch_solver, ) # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1) if use_torch_solver: numerator = FC.solve(rtf, psd_n)[0].squeeze(-1) else: numerator = FC.matmul(psd_n.inverse2(), rtf).squeeze(-1) denominator = FC.einsum("...d,...d->...", [rtf.squeeze(-1).conj(), numerator]) if normalize_ref_channel is not None: scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj() beamforming_vector = numerator * scale / ( denominator.real.unsqueeze(-1) + eps) else: beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps) return beamforming_vector
def apply_beamforming_vector(beamforming_vector: ComplexTensor, mix: ComplexTensor) -> ComplexTensor: # [..., C] x [..., C, T] => [..., T] # There's no relationship between frequencies. es = FC.einsum("bftc, bfct -> bft", [beamforming_vector.conj(), mix]) return es
import numpy import pytest import torch_complex.functional as F from torch_complex.tensor import ComplexTensor def _get_complex_array(*shape): return numpy.random.randn(*shape) + 1j + numpy.random.randn(*shape) @pytest.mark.parametrize('nop,top', [(numpy.concatenate, F.cat), (numpy.stack, F.stack), (lambda x: numpy.einsum('ai,ij,jk->ak', *x), lambda x: F.einsum('ai,ij,jk->ak', x))]) def test_operation(nop, top): if top is None: top = nop n1 = _get_complex_array(10, 10) n2 = _get_complex_array(10, 10) n3 = _get_complex_array(10, 10) t1 = ComplexTensor(n1.copy()) t2 = ComplexTensor(n2.copy()) t3 = ComplexTensor(n3.copy()) x = nop([n1, n2, n3]) y = top([t1, t2, t3]) y = y.numpy() numpy.testing.assert_allclose(x, y)
def apply_beamforming_vector(beamform_vector: ComplexTensor, mix: ComplexTensor) -> ComplexTensor: # (..., C) x (..., C, T) -> (..., T) es = FC.einsum('...c,...ct->...t', [beamform_vector.conj(), mix]) return es
def get_WPD_filter_with_rtf( psd_observed_bar: ComplexTensor, psd_speech: ComplexTensor, psd_noise: ComplexTensor, iterations: int = 3, reference_vector: Union[int, torch.Tensor, None] = None, normalize_ref_channel: Optional[int] = None, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-15, ) -> ComplexTensor: """Return the WPD vector calculated with RTF. WPD is the Weighted Power minimization Distortionless response convolutional beamformer. As follows: h = (Rf^-1 @ vbar) / (vbar^H @ R^-1 @ vbar) Reference: T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer for Simultaneous Denoising and Dereverberation," in IEEE Signal Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi: 10.1109/LSP.2019.2911179. https://ieeexplore.ieee.org/document/8691481 Args: psd_observed_bar (ComplexTensor): stacked observation covariance matrix psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C) iterations (int): number of iterations in power method reference_vector (torch.Tensor or int): (..., C) or scalar normalize_ref_channel (int): reference channel for normalizing the RTF use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (ComplexTensor)r: (..., F, C) """ C = psd_noise.size(-1) if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) # (B, F, C, 1) rtf = get_rtf( psd_speech, psd_noise, reference_vector, iterations=iterations, use_torch_solver=use_torch_solver, ) # (B, F, (K+1)*C, 1) rtf = FC.pad(rtf, (0, 0, 0, psd_observed_bar.shape[-1] - C), "constant", 0) # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1) if use_torch_solver: numerator = FC.solve(rtf, psd_observed_bar)[0].squeeze(-1) else: numerator = FC.matmul(psd_observed_bar.inverse2(), rtf).squeeze(-1) denominator = FC.einsum("...d,...d->...", [rtf.squeeze(-1).conj(), numerator]) if normalize_ref_channel is not None: scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj() beamforming_vector = numerator * scale / ( denominator.real.unsqueeze(-1) + eps) else: beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps) return beamforming_vector
def online_wpe_step(input_buffer: ComplexTensor, power: torch.Tensor, inv_cov: ComplexTensor = None, filter_taps: ComplexTensor = None, alpha: float = 0.99, taps: int = 10, delay: int = 3): """One step of online dereverberation. Args: input_buffer: (F, C, taps + delay + 1) power: Estimate for the current PSD (F, T) inv_cov: Current estimate of R^-1 filter_taps: Current estimate of filter taps (F, taps * C, taps) alpha (float): Smoothing factor taps (int): Number of filter taps delay (int): Delay in frames Returns: Dereverberated frame of shape (F, D) Updated estimate of R^-1 Updated estimate of the filter taps >>> frame_length = 512 >>> frame_shift = 128 >>> taps = 6 >>> delay = 3 >>> alpha = 0.999 >>> frequency_bins = frame_length // 2 + 1 >>> Q = None >>> G = None >>> unreverbed, Q, G = online_wpe_step(stft, get_power_online(stft), Q, G, ... alpha=alpha, taps=taps, delay=delay) """ assert input_buffer.size(-1) == taps + delay + 1, input_buffer.size() C = input_buffer.size(-2) if inv_cov is None: inv_cov = ComplexTensor( torch.eye(C * taps, dtype=input_buffer.dtype).expand( *input_buffer.size()[:-2], C * taps, C * taps)) if filter_taps is None: filter_taps = ComplexTensor( torch.zeros(*input_buffer.size()[:-2], C * taps, C, dtype=input_buffer.dtype)) window = FC.reverse(input_buffer[..., :-delay - 1], dim=-1) # (..., C, T) -> (..., C * T) window = window.view(*input_buffer.size()[:-2], -1) pred = input_buffer[..., -1] - FC.einsum('...id,...i->...d', (filter_taps.conj(), window)) nominator = FC.einsum('...ij,...j->...i', (inv_cov, window)) denominator = \ FC.einsum('...i,...i->...', (window.conj(), nominator)) + alpha * power kalman_gain = nominator / denominator[..., None] inv_cov_k = inv_cov - FC.einsum('...j,...jm,...i->...im', (window.conj(), inv_cov, kalman_gain)) inv_cov_k /= alpha filter_taps_k = \ filter_taps + FC.einsum('...i,...m->...im', (kalman_gain, pred.conj())) return pred, inv_cov_k, filter_taps_k
import torch_complex.functional as F from torch_complex.tensor import ComplexTensor def _get_complex_array(*shape): return numpy.random.randn(*shape) + 1j + numpy.random.randn(*shape) @pytest.mark.parametrize( "nop,top", [ (numpy.concatenate, F.cat), (numpy.stack, F.stack), ( lambda x: numpy.einsum("ai,ij,jk->ak", *x), lambda x: F.einsum("ai,ij,jk->ak", x), ), ], ) def test_operation(nop, top): if top is None: top = nop n1 = _get_complex_array(10, 10) n2 = _get_complex_array(10, 10) n3 = _get_complex_array(10, 10) t1 = ComplexTensor(n1.copy()) t2 = ComplexTensor(n2.copy()) t3 = ComplexTensor(n3.copy()) x = nop([n1, n2, n3]) y = top([t1, t2, t3])
def forward( self, data: ComplexTensor, ilens: torch.LongTensor, powers: Union[List[torch.Tensor], None] = None, ) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]: """DNN_Beamformer forward function. Notation: B: Batch C: Channel T: Time or Sequence length F: Freq Args: data (ComplexTensor): (B, T, C, F) ilens (torch.Tensor): (B,) powers (List[torch.Tensor] or None): used for wMPDR or WPD (B, F, T) Returns: enhanced (ComplexTensor): (B, T, F) ilens (torch.Tensor): (B,) masks (torch.Tensor): (B, T, C, F) """ def apply_beamforming(data, ilens, psd_n, psd_speech, psd_distortion=None): """Beamforming with the provided statistics. Args: data (ComplexTensor): (B, F, C, T) ilens (torch.Tensor): (B,) psd_n (ComplexTensor): Noise covariance matrix for MVDR (B, F, C, C) Observation covariance matrix for MPDR/wMPDR (B, F, C, C) Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C) psd_speech (ComplexTensor): Speech covariance matrix (B, F, C, C) psd_distortion (ComplexTensor): Noise covariance matrix (B, F, C, C) Return: enhanced (ComplexTensor): (B, F, T) ws (ComplexTensor): (B, F) or (B, F, (btaps+1)*C) """ # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens) u = u.double() else: if self.beamformer_type.endswith("_souden"): # (optional) Create onehot vector for fixed reference microphone u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )), device=data.device, dtype=torch.double) u[..., self.ref_channel].fill_(1) else: # for simplifying computation in RTF-based beamforming u = self.ref_channel if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"): ws = get_mvdr_vector_with_rtf( psd_n.double(), psd_speech.double(), psd_distortion.double(), iterations=self.rtf_iterations, reference_vector=u, normalize_ref_channel=self.ref_channel, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, data.double()) elif self.beamformer_type in ("mpdr_souden", "mvdr_souden", "wmpdr_souden"): ws = get_mvdr_vector( psd_speech.double(), psd_n.double(), u, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, data.double()) elif self.beamformer_type == "wpd": ws = get_WPD_filter_with_rtf( psd_n.double(), psd_speech.double(), psd_distortion.double(), iterations=self.rtf_iterations, reference_vector=u, normalize_ref_channel=self.ref_channel, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = perform_WPD_filtering(ws, data.double(), self.bdelay, self.btaps) elif self.beamformer_type == "wpd_souden": ws = get_WPD_filter_v2( psd_speech.double(), psd_n.double(), u, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = perform_WPD_filtering(ws, data.double(), self.bdelay, self.btaps) else: raise ValueError("Not supporting beamformer_type={}".format( self.beamformer_type)) return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype) # data (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) data_d = data.double() # mask: [(B, F, C, T)] masks, _ = self.mask(data, ilens) assert self.nmask == len(masks), len(masks) # floor masks to increase numerical stability if self.mask_flooring: masks = [torch.clamp(m, min=self.flooring_thres) for m in masks] if self.num_spk == 1: # single-speaker case if self.use_noise_mask: # (mask_speech, mask_noise) mask_speech, mask_noise = masks else: # (mask_speech,) mask_speech = masks[0] mask_noise = 1 - mask_speech if self.beamformer_type.startswith( "wmpdr") or self.beamformer_type.startswith("wpd"): if powers is None: power_input = data_d.real**2 + data_d.imag**2 # Averaging along the channel axis: (..., C, T) -> (..., T) powers = (power_input * mask_speech.double()).mean(dim=-2) else: assert len(powers) == 1, len(powers) powers = powers[0] inverse_power = 1 / torch.clamp(powers, min=self.eps) psd_speech = get_power_spectral_density_matrix( data_d, mask_speech.double()) if mask_noise is not None and ( self.beamformer_type == "mvdr_souden" or not self.beamformer_type.endswith("_souden")): # MVDR or other RTF-based formulas psd_noise = get_power_spectral_density_matrix( data_d, mask_noise.double()) if self.beamformer_type == "mvdr": enhanced, ws = apply_beamforming(data, ilens, psd_noise, psd_speech, psd_distortion=psd_noise) elif self.beamformer_type == "mvdr_souden": enhanced, ws = apply_beamforming(data, ilens, psd_noise, psd_speech) elif self.beamformer_type == "mpdr": psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise) elif self.beamformer_type == "mpdr_souden": psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech) elif self.beamformer_type == "wmpdr": psd_observed = FC.einsum( "...ct,...et->...ce", [data_d * inverse_power[..., None, :], data_d.conj()], ) enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise) elif self.beamformer_type == "wmpdr_souden": psd_observed = FC.einsum( "...ct,...et->...ce", [data_d * inverse_power[..., None, :], data_d.conj()], ) enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech) elif self.beamformer_type == "wpd": psd_observed_bar = get_covariances(data_d, inverse_power, self.bdelay, self.btaps, get_vector=False) enhanced, ws = apply_beamforming(data, ilens, psd_observed_bar, psd_speech, psd_distortion=psd_noise) elif self.beamformer_type == "wpd_souden": psd_observed_bar = get_covariances(data_d, inverse_power, self.bdelay, self.btaps, get_vector=False) enhanced, ws = apply_beamforming(data, ilens, psd_observed_bar, psd_speech) else: raise ValueError("Not supporting beamformer_type={}".format( self.beamformer_type)) # (..., F, T) -> (..., T, F) enhanced = enhanced.transpose(-1, -2) else: # multi-speaker case if self.use_noise_mask: # (mask_speech1, ..., mask_noise) mask_speech = list(masks[:-1]) mask_noise = masks[-1] else: # (mask_speech1, ..., mask_speechX) mask_speech = list(masks) mask_noise = None if self.beamformer_type.startswith( "wmpdr") or self.beamformer_type.startswith("wpd"): if powers is None: power_input = data_d.real**2 + data_d.imag**2 # Averaging along the channel axis: (..., C, T) -> (..., T) powers = [(power_input * m.double()).mean(dim=-2) for m in mask_speech] else: assert len(powers) == self.num_spk, len(powers) inverse_power = [ 1 / torch.clamp(p, min=self.eps) for p in powers ] psd_speeches = [ get_power_spectral_density_matrix(data_d, mask.double()) for mask in mask_speech ] if mask_noise is not None and ( self.beamformer_type == "mvdr_souden" or not self.beamformer_type.endswith("_souden")): # MVDR or other RTF-based formulas psd_noise = get_power_spectral_density_matrix( data_d, mask_noise.double()) if self.beamformer_type in ("mpdr", "mpdr_souden"): psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) elif self.beamformer_type in ("wmpdr", "wmpdr_souden"): psd_observed = [ FC.einsum( "...ct,...et->...ce", [data_d * inv_p[..., None, :], data_d.conj()], ) for inv_p in inverse_power ] elif self.beamformer_type in ("wpd", "wpd_souden"): psd_observed_bar = [ get_covariances(data_d, inv_p, self.bdelay, self.btaps, get_vector=False) for inv_p in inverse_power ] enhanced, ws = [], [] for i in range(self.num_spk): psd_speech = psd_speeches.pop(i) if (self.beamformer_type == "mvdr_souden" or not self.beamformer_type.endswith("_souden")): psd_noise_i = (psd_noise + sum(psd_speeches) if mask_noise is not None else sum(psd_speeches)) # treat all other speakers' psd_speech as noises if self.beamformer_type == "mvdr": enh, w = apply_beamforming(data, ilens, psd_noise_i, psd_speech, psd_distortion=psd_noise_i) elif self.beamformer_type == "mvdr_souden": enh, w = apply_beamforming(data, ilens, psd_noise_i, psd_speech) elif self.beamformer_type == "mpdr": enh, w = apply_beamforming( data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise_i, ) elif self.beamformer_type == "mpdr_souden": enh, w = apply_beamforming(data, ilens, psd_observed, psd_speech) elif self.beamformer_type == "wmpdr": enh, w = apply_beamforming( data, ilens, psd_observed[i], psd_speech, psd_distortion=psd_noise_i, ) elif self.beamformer_type == "wmpdr_souden": enh, w = apply_beamforming(data, ilens, psd_observed[i], psd_speech) elif self.beamformer_type == "wpd": enh, w = apply_beamforming( data, ilens, psd_observed_bar[i], psd_speech, psd_distortion=psd_noise_i, ) elif self.beamformer_type == "wpd_souden": enh, w = apply_beamforming(data, ilens, psd_observed_bar[i], psd_speech) else: raise ValueError( "Not supporting beamformer_type={}".format( self.beamformer_type)) psd_speeches.insert(i, psd_speech) # (..., F, T) -> (..., T, F) enh = enh.transpose(-1, -2) enhanced.append(enh) ws.append(w) # (..., F, C, T) -> (..., T, C, F) masks = [m.transpose(-1, -3) for m in masks] return enhanced, ilens, masks