def perform_WPD_filtering(filter_matrix: ComplexTensor, Y: ComplexTensor, bdelay: int, btaps: int) -> ComplexTensor: """Perform WPD filtering. Args: filter_matrix: Filter matrix (B, F, (btaps + 1) * C) Y : Complex STFT signal with shape (B, F, C, T) Returns: enhanced (ComplexTensor): (B, F, T) """ # (B, F, C, T) --> (B, F, C, T, btaps + 1) Ytilde = signal_framing(Y, btaps + 1, 1, bdelay, do_padding=True, pad_value=0) Ytilde = FC.reverse(Ytilde, dim=-1) Bs, Fdim, C, T = Y.shape # --> (B, F, T, btaps + 1, C) --> (B, F, T, (btaps + 1) * C) Ytilde = Ytilde.permute(0, 1, 3, 4, 2).contiguous().view(Bs, Fdim, T, -1) # (B, F, T, 1) enhanced = FC.einsum("...tc,...c->...t", [Ytilde, filter_matrix.conj()]) return enhanced
def get_power_spectral_density_matrix(xs: ComplexTensor, mask: torch.Tensor, normalization=True, eps: float = 1e-15) -> ComplexTensor: """Return cross-channel power spectral density (PSD) matrix Args: xs (ComplexTensor): (..., F, C, T) mask (torch.Tensor): (..., F, C, T) normalization (bool): eps (float): Returns psd (ComplexTensor): (..., F, C, C) """ # outer product: (..., C_1, T) x (..., C_2, T) -> (..., T, C, C_2) psd_Y = FC.einsum('...ct,...et->...tce', [xs, xs.conj()]) # Averaging mask along C: (..., C, T) -> (..., T) mask = mask.mean(dim=-2) # Normalized mask along T: (..., T) if normalization: # If assuming the tensor is padded with zero, the summation along # the time axis is same regardless of the padding length. mask = mask / (mask.sum(dim=-1, keepdim=True) + eps) # psd: (..., T, C, C) psd = psd_Y * mask[..., None, None] # (..., T, C, C) -> (..., C, C) psd = psd.sum(dim=-3) return psd
def test_inv(ch): torch.manual_seed(100) X = ComplexTensor(torch.rand(2, 3, ch, ch), torch.rand(2, 3, ch, ch)) X = X + X.conj().transpose(-1, -2) assert FC.allclose(ComplexTensor(np.linalg.inv(X.numpy())), inv(X), atol=1e-4)
def test_get_rtf(ch): stft = Stft( n_fft=8, win_length=None, hop_length=2, center=True, window="hann", normalized=False, onesided=True, ) torch.random.manual_seed(0) x = random_speech[..., :ch] n = torch.rand(2, 16, ch, dtype=torch.double) ilens = torch.LongTensor([16, 12]) # (B, T, C, F) -> (B, F, C, T) X = ComplexTensor(*torch.unbind(stft(x, ilens)[0], dim=-1)).transpose(-1, -3) N = ComplexTensor(*torch.unbind(stft(n, ilens)[0], dim=-1)).transpose(-1, -3) # (B, F, C, C) Phi_X = FC.einsum("...ct,...et->...ce", [X, X.conj()]) Phi_N = FC.einsum("...ct,...et->...ce", [N, N.conj()]) # (B, F, C, 1) rtf = get_rtf(Phi_X, Phi_N, reference_vector=0, iterations=20) if is_torch_1_1_plus: rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True).values + 1e-15) else: rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True)[0] + 1e-15) # rtf \approx Phi_N MaxEigVec(Phi_N^-1 @ Phi_X) if is_torch_1_1_plus: # torch.solve is required, which is only available after pytorch 1.1.0+ mat = FC.solve(Phi_X, Phi_N)[0] max_eigenvec = FC.solve(rtf, Phi_N)[0] else: mat = FC.matmul(Phi_N.inverse2(), Phi_X) max_eigenvec = FC.matmul(Phi_N.inverse2(), rtf) factor = FC.matmul(mat, max_eigenvec) assert FC.allclose( FC.matmul(max_eigenvec, factor.transpose(-1, -2)), FC.matmul(factor, max_eigenvec.transpose(-1, -2)), )
def forward( self, data: ComplexTensor, ilens: torch.LongTensor ) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]: """The forward function Notation: B: Batch C: Channel T: Time or Sequence length F: Freq Args: data (ComplexTensor): (B, T, C, F), double precision ilens (torch.Tensor): (B,) Returns: enhanced (ComplexTensor): (B, T, F), double precision ilens (torch.Tensor): (B,) masks (torch.Tensor): (B, T, C, F) """ def apply_beamforming(data, ilens, psd_speech, psd_n, beamformer_type): # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech.float(), ilens) else: # (optional) Create onehot vector for fixed reference microphone u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )), device=data.device) u[..., self.ref_channel].fill_(1) if beamformer_type in ("mpdr", "mvdr"): ws = get_mvdr_vector(psd_speech, psd_n, u.double()) enhanced = apply_beamforming_vector(ws, data) elif beamformer_type == "wpd": ws = get_WPD_filter_v2(psd_speech, psd_n, u.double()) enhanced = perform_WPD_filtering(ws, data, self.bdelay, self.btaps) else: raise ValueError("Not supporting beamformer_type={}".format( beamformer_type)) return enhanced, ws # data (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) # mask: [(B, F, C, T)] masks, _ = self.mask(data.float(), ilens) assert self.nmask == len(masks) # floor masks with self.eps to increase numerical stability masks = [torch.clamp(m, min=self.eps) for m in masks] if self.num_spk == 1: # single-speaker case if self.use_noise_mask: # (mask_speech, mask_noise) mask_speech, mask_noise = masks else: # (mask_speech,) mask_speech = masks[0] mask_noise = 1 - mask_speech psd_speech = get_power_spectral_density_matrix( data, mask_speech.double()) if self.beamformer_type == "mvdr": # psd of noise psd_n = get_power_spectral_density_matrix( data, mask_noise.double()) elif self.beamformer_type == "mpdr": # psd of observed signal psd_n = FC.einsum("...ct,...et->...ce", [data, data.conj()]) elif self.beamformer_type == "wpd": # Calculate power: (..., C, T) power_speech = (data.real**2 + data.imag**2) * mask_speech.double() # Averaging along the channel axis: (B, F, C, T) -> (B, F, T) power_speech = power_speech.mean(dim=-2) inverse_power = 1 / torch.clamp(power_speech, min=self.eps) # covariance of expanded observed speech psd_n = get_covariances(data, inverse_power, self.bdelay, self.btaps, get_vector=False) else: raise ValueError("Not supporting beamformer_type={}".format( self.beamformer_type)) enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_n, self.beamformer_type) # (..., F, T) -> (..., T, F) enhanced = enhanced.transpose(-1, -2) else: # multi-speaker case if self.use_noise_mask: # (mask_speech1, ..., mask_noise) mask_speech = list(masks[:-1]) mask_noise = masks[-1] else: # (mask_speech1, ..., mask_speechX) mask_speech = list(masks) mask_noise = None psd_speeches = [ get_power_spectral_density_matrix(data, mask) for mask in mask_speech ] if self.beamformer_type == "mvdr": # psd of noise if mask_noise is not None: psd_n = get_power_spectral_density_matrix(data, mask_noise) elif self.beamformer_type == "mpdr": # psd of observed speech psd_n = FC.einsum("...ct,...et->...ce", [data, data.conj()]) elif self.beamformer_type == "wpd": # Calculate power: (..., C, T) power = data.real**2 + data.imag**2 power_speeches = [power * mask for mask in mask_speech] # Averaging along the channel axis: (B, F, C, T) -> (B, F, T) power_speeches = [ps.mean(dim=-2) for ps in power_speeches] inverse_poweres = [ 1 / torch.clamp(ps, min=self.eps) for ps in power_speeches ] # covariance of expanded observed speech psd_n = [ get_covariances(data, inv_ps, self.bdelay, self.btaps, get_vector=False) for inv_ps in inverse_poweres ] else: raise ValueError("Not supporting beamformer_type={}".format( self.beamformer_type)) enhanced = [] for i in range(self.num_spk): psd_speech = psd_speeches.pop(i) # treat all other speakers' psd_speech as noises if self.beamformer_type == "mvdr": psd_noise = sum(psd_speeches) if mask_noise is not None: psd_noise = psd_noise + psd_n enh, w = apply_beamforming(data, ilens, psd_speech, psd_noise, self.beamformer_type) elif self.beamformer_type == "mpdr": enh, w = apply_beamforming(data, ilens, psd_speech, psd_n, self.beamformer_type) elif self.beamformer_type == "wpd": enh, w = apply_beamforming(data, ilens, psd_speech, psd_n[i], self.beamformer_type) else: raise ValueError( "Not supporting beamformer_type={}".format( self.beamformer_type)) psd_speeches.insert(i, psd_speech) # (..., F, T) -> (..., T, F) enh = enh.transpose(-1, -2) enhanced.append(enh) # (..., F, C, T) -> (..., T, C, F) masks = [m.transpose(-1, -3) for m in masks] return enhanced, ilens, masks
def apply_beamforming_vector(beamform_vector: ComplexTensor, mix: ComplexTensor) -> ComplexTensor: # (..., C) x (..., C, T) -> (..., T) es = FC.einsum('...c,...ct->...t', [beamform_vector.conj(), mix]) return es
def online_wpe_step(input_buffer: ComplexTensor, power: torch.Tensor, inv_cov: ComplexTensor = None, filter_taps: ComplexTensor = None, alpha: float = 0.99, taps: int = 10, delay: int = 3): """One step of online dereverberation. Args: input_buffer: (F, C, taps + delay + 1) power: Estimate for the current PSD (F, T) inv_cov: Current estimate of R^-1 filter_taps: Current estimate of filter taps (F, taps * C, taps) alpha (float): Smoothing factor taps (int): Number of filter taps delay (int): Delay in frames Returns: Dereverberated frame of shape (F, D) Updated estimate of R^-1 Updated estimate of the filter taps >>> frame_length = 512 >>> frame_shift = 128 >>> taps = 6 >>> delay = 3 >>> alpha = 0.999 >>> frequency_bins = frame_length // 2 + 1 >>> Q = None >>> G = None >>> unreverbed, Q, G = online_wpe_step(stft, get_power_online(stft), Q, G, ... alpha=alpha, taps=taps, delay=delay) """ assert input_buffer.size(-1) == taps + delay + 1, input_buffer.size() C = input_buffer.size(-2) if inv_cov is None: inv_cov = ComplexTensor( torch.eye(C * taps, dtype=input_buffer.dtype).expand( *input_buffer.size()[:-2], C * taps, C * taps)) if filter_taps is None: filter_taps = ComplexTensor( torch.zeros(*input_buffer.size()[:-2], C * taps, C, dtype=input_buffer.dtype)) window = FC.reverse(input_buffer[..., :-delay - 1], dim=-1) # (..., C, T) -> (..., C * T) window = window.view(*input_buffer.size()[:-2], -1) pred = input_buffer[..., -1] - FC.einsum('...id,...i->...d', (filter_taps.conj(), window)) nominator = FC.einsum('...ij,...j->...i', (inv_cov, window)) denominator = \ FC.einsum('...i,...i->...', (window.conj(), nominator)) + alpha * power kalman_gain = nominator / denominator[..., None] inv_cov_k = inv_cov - FC.einsum('...j,...jm,...i->...im', (window.conj(), inv_cov, kalman_gain)) inv_cov_k /= alpha filter_taps_k = \ filter_taps + FC.einsum('...i,...m->...im', (kalman_gain, pred.conj())) return pred, inv_cov_k, filter_taps_k
def wpe_step_v3(Y, inverse_power, taps=10, delay=3, statistics_mode='full', solver='torch_complex.solve'): """ Tested with 1.7.0.dev20200807 Properties (Compared to lower versions): - faster - less memory for backward - (less peak memory)? Looks so. Difficult to profile. Args: Y: (..., channel, frames) inverse_power: taps: delay: statistics_mode: solver: Returns: """ if statistics_mode == 'full': s = Ellipsis elif statistics_mode == 'valid': raise NotImplementedError(statistics_mode) s = (Ellipsis, slice(delay + taps - 1, None)) else: raise ValueError(statistics_mode) if isinstance(Y, np.ndarray): Y = ComplexTensor(Y) Y = Y.to(inverse_power.device) Y_tilde = build_y_tilde(Y, taps, delay) # Torch does not keep the non contignous property for tensors with for # negation (i.e. ComplexTensor.conj changes the sign of imag). Y_conj = Y.conj() Y_tilde_conj = build_y_tilde(Y_conj, taps, delay) # Y_tilde_conj = Y_tilde.conj() # This code is faster, but with backward graph the memory consumption is to # high. (Pytorch is at the moment not intelligent enough) # Y_tilde_inverse_power = Y_tilde * inverse_power[..., None, :] # R = Y_tilde_inverse_power[s] @ transpose(Y_tilde_conj[s]) # P = Y_tilde_inverse_power[s] @ transpose(Y_conj[s]) def get_correlation(m, Y1, Y2): real = torch.einsum('...t,...dt,...et->...de', m, Y1.real, Y2.real) - torch.einsum( '...t,...dt,...et->...de', m, Y1.imag, Y2.imag) imag = torch.einsum('...t,...dt,...et->...de', m, Y1.real, Y2.imag) + torch.einsum( '...t,...dt,...et->...de', m, Y1.imag, Y2.real) return ComplexTensor(real, imag) # R_conj = torch_complex.functional.einsum( # '...t,...dt,...et->...de', inverse_power, Y_tilde_conj, Y_tilde) R_conj = get_correlation(inverse_power, Y_tilde_conj, Y_tilde) # # print('wpe rss before P', ByteSize(process.memory_info().rss)) # P_conj = torch_complex.functional.einsum( # '...t,...dt,...et->...de', # inverse_power, Y_tilde_conj, Y # ) P_conj = get_correlation(inverse_power, Y_tilde_conj, Y) G_conj = _solve(R=R_conj, P=P_conj, solver=solver) # Matmul converts the non contignous Y_tilde to contignous, hence use einsum # Einsum does not work on the gpu with non contignous, hence use torch.utils.checkpoint.checkpoint # X = Y - torch_complex.functional.einsum('...ij,...ik->...jk', G_conj, Y_tilde) X = ComplexTensor( Y.real - torch.einsum('...ij,...ik->...jk', G_conj.real, Y_tilde.real) + torch.einsum('...ij,...ik->...jk', G_conj.imag, Y_tilde.imag), Y.imag - torch.einsum('...ij,...ik->...jk', G_conj.real, Y_tilde.imag) - torch.einsum('...ij,...ik->...jk', G_conj.imag, Y_tilde.real), ) return X
def wpe_step_v2(Y, inverse_power, taps=10, delay=3, statistics_mode='full', solver='torch_complex.solve'): """ Args: Y: (..., channel, frames) inverse_power: taps: delay: statistics_mode: solver: Returns: """ if statistics_mode == 'full': s = Ellipsis elif statistics_mode == 'valid': raise NotImplementedError(statistics_mode) s = (Ellipsis, slice(delay + taps - 1, None)) else: raise ValueError(statistics_mode) if isinstance(Y, np.ndarray): Y = ComplexTensor(Y) Y = Y.to(inverse_power.device) Y_tilde = build_y_tilde(Y, taps, delay) # Torch does not keep the non contignous property for tensors with for # negation (i.e. ComplexTensor.conj changes the sign of imag). Y_conj = Y.conj() Y_tilde_conj = build_y_tilde(Y_conj, taps, delay) # Y_tilde_conj = Y_tilde.conj() # This code is faster, but with backward graph the memory consumption is to # high. (Pytorch is at the moment not intelligent enough) # Y_tilde_inverse_power = Y_tilde * inverse_power[..., None, :] # R = Y_tilde_inverse_power[s] @ hermite(Y_tilde[s]) # P = Y_tilde_inverse_power[s] @ hermite(Y[s]) import torch.utils.checkpoint # remove when https://github.com/pytorch/pytorch/issues/42418 # has a solution. # This may be very expencive, because the calculation of R dominates the # execution time of WPE def get_R(inverse_power, Y_tilde_real, Y_tilde_imag): Y_tilde_real = Y_tilde_real.contiguous() Y_tilde_imag = Y_tilde_imag.contiguous() Y_tilde = ComplexTensor(Y_tilde_real, Y_tilde_imag) Y_tilde_conj = ComplexTensor(Y_tilde_real, -Y_tilde_imag) R = torch_complex.functional.einsum('...t,...dt,...et->...de', inverse_power, Y_tilde, Y_tilde_conj) return R.real, R.imag R = ComplexTensor(*torch.utils.checkpoint.checkpoint( get_R, inverse_power, Y_tilde.real, Y_tilde.imag)) # print('wpe rss before P', ByteSize(process.memory_info().rss)) P = torch_complex.functional.einsum('...t,...dt,...et->...de', inverse_power, Y_tilde, Y_conj) G = _solve(R=R, P=P, solver=solver) # remove when https://github.com/pytorch/pytorch/issues/42418 # has a solution. def contiguous_einsum(equation, *operands): def foo(*operands): assert len(operands) % 2 == 0, len(operands) operands = [ ComplexTensor(real.contiguous(), imag.contiguous()) for real, imag in zip(operands[::2], operands[1::2]) ] ret = torch_complex.functional.einsum(equation, operands) return ret.real, ret.imag operands = [part for o in operands for part in [o.real, o.imag]] real, imag = torch.utils.checkpoint.checkpoint(foo, *operands) return ComplexTensor(real, imag) # Matmul cannot handle the non contignous Y_tilde, hence use einsum # Einsum does not work on the gpu with non contignous, hence use torch.utils.checkpoint.checkpoint X = Y - contiguous_einsum('...ij,...ik->...jk', G.conj(), Y_tilde) return X