def forward(self, input: ComplexTensor, ilens: torch.Tensor): """Forward. Args: input (ComplexTensor): spectrum [Batch, T, (C,) F] ilens (torch.Tensor): input lengths [Batch] """ if not isinstance(input, ComplexTensor) and ( is_torch_1_9_plus and not torch.is_complex(input)): raise TypeError("Only support complex tensors for stft decoder") bs = input.size(0) if input.dim() == 4: multi_channel = True # input: (Batch, T, C, F) -> (Batch * C, T, F) input = input.transpose(1, 2).reshape(-1, input.size(1), input.size(3)) else: multi_channel = False wav, wav_lens = self.stft.inverse(input, ilens) if multi_channel: # wav: (Batch * C, Nsamples) -> (Batch, Nsamples, C) wav = wav.reshape(bs, -1, wav.size(1)).transpose(1, 2) return wav, wav_lens
def forward(self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0) -> Tuple[torch.Tensor, torch.LongTensor]: """The forward function Args: psd_in (ComplexTensor): (B, F, C, C) ilens (torch.Tensor): (B,) scaling (float): Returns: u (torch.Tensor): (B, C) ilens (torch.Tensor): (B,) """ B, _, C = psd_in.size()[:3] assert psd_in.size(2) == psd_in.size(3), psd_in.size() # psd_in: (B, F, C, C) datatype = torch.bool if is_torch_1_2_plus else torch.uint8 psd = psd_in.masked_fill( torch.eye(C, dtype=datatype, device=psd_in.device), 0) # psd: (B, F, C, C) -> (B, C, F) psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2) # Calculate amplitude psd_feat = (psd.real**2 + psd.imag**2)**0.5 # (B, C, F) -> (B, C, F2) mlp_psd = self.mlp_psd(psd_feat) # (B, C, F2) -> (B, C, 1) -> (B, C) e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1) u = F.softmax(scaling * e, dim=-1) return u, ilens
def trace(a: ComplexTensor) -> ComplexTensor: if LooseVersion(torch.__version__) >= LooseVersion('1.3'): datatype = torch.bool else: datatype = torch.uint8 E = torch.eye(a.real.size(-1), dtype=datatype).expand(*a.size()) if LooseVersion(torch.__version__) >= LooseVersion('1.1'): E = E.type(torch.bool) return a[E].view(*a.size()[:-1]).sum(-1)
def get_covariances( Y: ComplexTensor, inverse_power: torch.Tensor, bdelay: int, btaps: int, get_vector: bool = False, ) -> ComplexTensor: """Calculates the power normalized spatio-temporal covariance matrix of the framed signal. Args: Y : Complext STFT signal with shape (B, F, C, T) inverse_power : Weighting factor with shape (B, F, T) Returns: Correlation matrix of shape (B, F, (btaps+1) * C, (btaps+1) * C) Correlation vector of shape (B, F, btaps + 1, C, C) """ assert inverse_power.dim() == 3, inverse_power.dim() assert inverse_power.size(0) == Y.size(0), (inverse_power.size(0), Y.size(0)) Bs, Fdim, C, T = Y.shape # (B, F, C, T - bdelay - btaps + 1, btaps + 1) Psi = signal_framing(Y, btaps + 1, 1, bdelay, do_padding=False)[ ..., : T - bdelay - btaps + 1, : ] # Reverse along btaps-axis: # [tau, tau-bdelay, tau-bdelay-1, ..., tau-bdelay-frame_length+1] Psi = FC.reverse(Psi, dim=-1) Psi_norm = Psi * inverse_power[..., None, bdelay + btaps - 1 :, None] # let T' = T - bdelay - btaps + 1 # (B, F, C, T', btaps + 1) x (B, F, C, T', btaps + 1) # -> (B, F, btaps + 1, C, btaps + 1, C) covariance_matrix = FC.einsum("bfdtk,bfetl->bfkdle", (Psi, Psi_norm.conj())) # (B, F, btaps + 1, C, btaps + 1, C) # -> (B, F, (btaps + 1) * C, (btaps + 1) * C) covariance_matrix = covariance_matrix.view( Bs, Fdim, (btaps + 1) * C, (btaps + 1) * C ) if get_vector: # (B, F, C, T', btaps + 1) x (B, F, C, T') # --> (B, F, btaps +1, C, C) covariance_vector = FC.einsum( "bfdtk,bfet->bfked", (Psi_norm, Y[..., bdelay + btaps - 1 :].conj()) ) return covariance_matrix, covariance_vector else: return covariance_matrix
def get_mvdr_vector(psd_s: ComplexTensor, psd_n: ComplexTensor, reference_vector: torch.Tensor, eps: float = 1e-15) -> ComplexTensor: """Return the MVDR(Minimum Variance Distortionless Response) vector: h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u Reference: On optimal frequency-domain multichannel linear filtering for noise reduction; M. Souden et al., 2010; https://ieeexplore.ieee.org/document/5089420 Args: psd_s (ComplexTensor): (..., F, C, C) psd_n (ComplexTensor): (..., F, C, C) reference_vector (torch.Tensor): (..., C) eps (float): Returns: beamform_vector (ComplexTensor)r: (..., F, C) """ # Add eps C = psd_n.size(-1) eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device) shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C] eye = eye.view(*shape) psd_n += eps * eye # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) numerator = FC.einsum('...ec,...cd->...ed', [psd_n.inverse(), psd_s]) # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = FC.einsum('...fec,...c->...fe', [ws, reference_vector]) return beamform_vector
def forward( self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]] ) -> Tuple[torch.Tensor, torch.LongTensor]: # (B, T, F) or (B, T, C, F) if x.dim() not in (3, 4): raise ValueError(f"Input dim must be 3 or 4: {x.dim()}") if not torch.is_tensor(ilens): ilens = torch.from_numpy(np.asarray(ilens)).to(x.device) if x.dim() == 4: # h: (B, T, C, F) -> h: (B, T, F) if self.training: # Select 1ch randomly ch = np.random.randint(x.size(2)) h = x[:, :, ch, :] else: # Use the first channel h = x[:, :, 0, :] else: h = x # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F) h = h.real ** 2 + h.imag ** 2 h, _ = self.logmel(h, ilens) if self.stats_file is not None: h, _ = self.global_mvn(h, ilens) if self.apply_uttmvn: h, _ = self.uttmvn(h, ilens) return h, ilens
def test_complex_norm(dim): mat = ComplexTensor(torch.rand(2, 3, 4), torch.rand(2, 3, 4)) mat_th = torch.complex(mat.real, mat.imag) norm = complex_norm(mat, dim=dim, keepdim=True) norm_th = complex_norm(mat_th, dim=dim, keepdim=True) assert (torch.allclose(norm, norm_th) and norm.ndim == mat.ndim and mat.numel() == norm.numel() * mat.size(dim))
def get_filter_matrix_conj(correlation_matrix: ComplexTensor, correlation_vector: ComplexTensor) -> ComplexTensor: """Calculate (conjugate) filter matrix based on correlations for one freq. Args: correlation_matrix : Correlation matrix (F, taps * C, taps * C) correlation_vector : Correlation vector (F, taps, C, C) Returns: filter_matrix_conj (ComplexTensor): (F, taps, C, C) """ F, taps, C, _ = correlation_vector.size() # (F, taps, C1, C2) -> (F, C1, taps, C2) -> (F, C1, taps * C2) correlation_vector = \ correlation_vector.permute(0, 2, 1, 3)\ .contiguous().view(F, C, taps * C) inv_correlation_matrix = correlation_matrix.inverse() # (F, C, taps, C) x (F, taps * C, taps * C) -> (F, C, taps * C) stacked_filter_conj = FC.matmul(correlation_vector, inv_correlation_matrix.transpose(-1, -2)) # (F, C1, taps * C2) -> (F, C1, taps, C2) -> (F, taps, C2, C1) filter_matrix_conj = \ stacked_filter_conj.view(F, C, taps, C).permute(0, 2, 3, 1) return filter_matrix_conj
def forward(self, data: ComplexTensor, ilens: torch.LongTensor) \ -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: """The forward function Notation: B: Batch C: Channel T: Time or Sequence length F: Freq Args: data (ComplexTensor): (B, T, C, F) ilens (torch.Tensor): (B,) Returns: enhanced (ComplexTensor): (B, T, F) ilens (torch.Tensor): (B,) """ # data (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) # mask: (B, F, C, T) (mask_speech, mask_noise), _ = self.mask(data, ilens) psd_speech = get_power_spectral_density_matrix(data, mask_speech) psd_noise = get_power_spectral_density_matrix(data, mask_noise) # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech, ilens) else: # (optional) Create onehot vector for fixed reference microphone u = torch.zeros(*(data.size()[:-3] + (data.size(-2),)), device=data.device) u[..., self.ref_channel].fill_(1) ws = get_mvdr_vector(psd_speech, psd_noise, u) enhanced = apply_beamforming_vector(ws, data) # (..., F, T) -> (..., T, F) enhanced = enhanced.transpose(-1, -2) mask_speech = mask_speech.transpose(-1, -3) return enhanced, ilens, mask_speech
def complex_matrix2real_matrix(c: ComplexTensor) -> torch.Tensor: # NOTE(kamo): # Complex value can be expressed as follows # a + bi => a * x + b y # where # x = |1 0| y = |0 -1| # |0 1|, |1 0| # A complex matrix can be also expressed as # |A -B| # |B A| # and complex vector can be expressed as # |A| # |B| assert c.size(-2) == c.size(-1), c.size() # (∗, m, m) -> (*, 2m, 2m) return torch.cat( [torch.cat([c.real, -c.imag], dim=-1), torch.cat([c.imag, c.real], dim=-1)], dim=-2, )
def forward( self, xs: ComplexTensor, ilens: torch.LongTensor ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]: """The forward function Args: xs: (B, F, C, T) ilens: (B,) Returns: hs (torch.Tensor): The hidden vector (B, F, C, T) masks: A tuple of the masks. (B, F, C, T) ilens: (B,) """ assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0)) _, _, C, input_length = xs.size() # (B, F, C, T) -> (B, C, T, F) xs = xs.permute(0, 2, 3, 1) # Calculate amplitude: (B, C, T, F) -> (B, C, T, F) xs = (xs.real**2 + xs.imag**2)**0.5 # xs: (B, C, T, F) -> xs: (B * C, T, F) xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1)) # ilens: (B,) -> ilens_: (B * C) ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1) # xs: (B * C, T, F) -> xs: (B * C, T, D) xs, _, _ = self.brnn(xs, ilens_) # xs: (B * C, T, D) -> xs: (B, C, T, D) xs = xs.view(-1, C, xs.size(-2), xs.size(-1)) masks = [] for linear in self.linears: # xs: (B, C, T, D) -> mask:(B, C, T, F) mask = linear(xs) if self.nonlinear == "sigmoid": mask = torch.sigmoid(mask) elif self.nonlinear == "relu": mask = torch.relu(mask) elif self.nonlinear == "tanh": mask = torch.tanh(mask) elif self.nonlinear == "crelu": mask = torch.clamp(mask, min=0, max=1) # Zero padding mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0) # (B, C, T, F) -> (B, F, C, T) mask = mask.permute(0, 3, 1, 2) # Take cares of multi gpu cases: If input_length > max(ilens) if mask.size(-1) < input_length: mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0) masks.append(mask) return tuple(masks), ilens
def get_correlations(Y: ComplexTensor, inverse_power: torch.Tensor, taps, delay) -> Tuple[ComplexTensor, ComplexTensor]: """Calculates weighted correlations of a window of length taps Args: Y : Complex-valued STFT signal with shape (F, C, T) inverse_power : Weighting factor with shape (F, T) taps (int): Lenghts of correlation window delay (int): Delay for the weighting factor Returns: Correlation matrix of shape (F, taps*C, taps*C) Correlation vector of shape (F, taps, C, C) """ assert inverse_power.dim() == 2, inverse_power.dim() assert inverse_power.size(0) == Y.size(0), \ (inverse_power.size(0), Y.size(0)) F, C, T = Y.size() # Y: (F, C, T) -> Psi: (F, C, T, taps) Psi = signal_framing(Y, frame_length=taps, frame_step=1)[..., :T - delay - taps + 1, :] # Reverse along taps-axis Psi = FC.reverse(Psi, dim=-1) Psi_conj_norm = \ Psi.conj() * inverse_power[..., None, delay + taps - 1:, None] # (F, C, T, taps) x (F, C, T, taps) -> (F, taps, C, taps, C) correlation_matrix = FC.einsum('fdtk,fetl->fkdle', (Psi_conj_norm, Psi)) # (F, taps, C, taps, C) -> (F, taps * C, taps * C) correlation_matrix = correlation_matrix.view(F, taps * C, taps * C) # (F, C, T, taps) x (F, C, T) -> (F, taps, C, C) correlation_vector = FC.einsum('fdtk,fet->fked', (Psi_conj_norm, Y[..., delay + taps - 1:])) return correlation_matrix, correlation_vector
def get_filter_matrix_conj(correlation_matrix: ComplexTensor, correlation_vector: ComplexTensor, eps: float = 1e-10) -> ComplexTensor: """Calculate (conjugate) filter matrix based on correlations for one freq. Args: correlation_matrix : Correlation matrix (F, taps * C, taps * C) correlation_vector : Correlation vector (F, taps, C, C) eps: Returns: filter_matrix_conj (ComplexTensor): (F, taps, C, C) """ F, taps, C, _ = correlation_vector.size() # (F, taps, C1, C2) -> (F, C1, taps, C2) -> (F, C1, taps * C2) correlation_vector = \ correlation_vector.permute(0, 2, 1, 3)\ .contiguous().view(F, C, taps * C) eye = torch.eye(correlation_matrix.size(-1), dtype=correlation_matrix.dtype, device=correlation_matrix.device) shape = tuple(1 for _ in range(correlation_matrix.dim() - 2)) + \ correlation_matrix.shape[-2:] eye = eye.view(*shape) correlation_matrix += eps * eye inv_correlation_matrix = correlation_matrix.inverse() # (F, C, taps, C) x (F, taps * C, taps * C) -> (F, C, taps * C) stacked_filter_conj = FC.matmul(correlation_vector, inv_correlation_matrix.transpose(-1, -2)) # (F, C1, taps * C2) -> (F, C1, taps, C2) -> (F, taps, C2, C1) filter_matrix_conj = \ stacked_filter_conj.view(F, C, taps, C).permute(0, 2, 3, 1) return filter_matrix_conj
def wpe_one_iteration(Y: ComplexTensor, power: torch.Tensor, taps: int = 10, delay: int = 3, eps: float = 1e-10, inverse_power: bool = True) -> ComplexTensor: """WPE for one iteration Args: Y: Complex valued STFT signal with shape (..., C, T) power: : (..., T) taps: Number of filter taps delay: Delay as a guard interval, such that X does not become zero. eps: inverse_power (bool): Returns: enhanced: (..., C, T) """ assert Y.size()[:-2] == power.size()[:-1] batch_freq_size = Y.size()[:-2] Y = Y.view(-1, *Y.size()[-2:]) power = power.view(-1, power.size()[-1]) if inverse_power: inverse_power = 1 / torch.clamp(power, min=eps) else: inverse_power = power correlation_matrix, correlation_vector = \ get_correlations(Y, inverse_power, taps, delay) filter_matrix_conj = get_filter_matrix_conj(correlation_matrix, correlation_vector) enhanced = perform_filter_operation_v2(Y, filter_matrix_conj, taps, delay) enhanced = enhanced.view(*batch_freq_size, *Y.size()[-2:]) return enhanced
def perform_filter_operation_v2(Y: ComplexTensor, filter_matrix_conj: ComplexTensor, taps, delay) -> ComplexTensor: """perform_filter_operation_v2 Args: Y : Complex-valued STFT signal of shape (F, C, T) filter Matrix (F, taps, C, C) """ T = Y.size(-1) # Y_tilde: (taps, F, C, T) Y_tilde = FC.stack([FC.pad(Y[:, :, :T - delay - i], (delay + i, 0), mode='constant', value=0) for i in range(taps)], dim=0) reverb_tail = FC.einsum('fpde,pfdt->fet', (filter_matrix_conj, Y_tilde)) return Y - reverb_tail
def forward( self, input: torch.Tensor, input_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: # 1. Domain-conversion: e.g. Stft: time -> time-freq input_stft, feats_lens = self.stft(input, input_lengths) assert input_stft.dim() >= 4, input_stft.shape # "2" refers to the real/imag parts of Complex assert input_stft.shape[-1] == 2, input_stft.shape # Change torch.Tensor to ComplexTensor # input_stft: (..., F, 2) -> (..., F) input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1]) # 2. [Option] Speech enhancement if self.frontend is not None: assert isinstance(input_stft, ComplexTensor), type(input_stft) # input_stft: (Batch, Length, [Channel], Freq) input_stft, _, mask = self.frontend(input_stft, feats_lens) # 3. [Multi channel case]: Select a channel if input_stft.dim() == 4: # h: (B, T, C, F) -> h: (B, T, F) if self.training: # Select 1ch randomly ch = np.random.randint(input_stft.size(2)) input_stft = input_stft[:, :, ch, :] else: # Use the first channel input_stft = input_stft[:, :, 0, :] # 4. STFT -> Power spectrum # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F) input_power = input_stft.real ** 2 + input_stft.imag ** 2 # 5. Feature transform e.g. Stft -> Log-Mel-Fbank # input_power: (Batch, [Channel,] Length, Freq) # -> input_feats: (Batch, Length, Dim) input_feats, _ = self.logmel(input_power, feats_lens) return input_feats, feats_lens
def forward(self, xs: ComplexTensor, input_lengths: torch.LongTensor) \ -> torch.Tensor: assert xs.size(0) == input_lengths.size(0), (xs.size(0), input_lengths.size(0)) # xs: (B, C, T, D) C = xs.size(1) if self.feat_type == 'amplitude': # xs: (B, C, T, F) -> (B, C, T, F) xs = (xs.real ** 2 + xs.imag ** 2) ** 0.5 elif self.feat_type == 'power': # xs: (B, C, T, F) -> (B, C, T, F) xs = xs.real ** 2 + xs.imag ** 2 elif self.feat_type == 'log_power': # xs: (B, C, T, F) -> (B, C, T, F) xs = torch.log(xs.real ** 2 + xs.imag ** 2) elif self.feat_type == 'concat': # xs: (B, C, T, F) -> (B, C, T, 2 * F) xs = torch.cat([xs.real, xs.imag], -1) else: raise NotImplementedError(f'Not implemented: {self.feat_type}') if self.model_type in ('blstm', 'lstm'): # xs: (B, C, T, F) -> xs: (B, C, T, D) xs = self.net(xs, input_lengths) elif self.model_type == 'cnn': if self.channel_independent: # xs: (B, C, T, F) -> xs: (B * C, F, T) xs = xs.view(-1, *xs.size()[2:]).transpose(1, 2) # xs: (B * C, F, T) -> xs: (B * C, D, T) xs = self.net(xs) # xs: (B * C, D, T) -> (B, C, T, D) xs = xs.transpose(1, 2).contiguous().view( -1, C, xs.size(2), xs.size(1)) else: # xs: (B, C, T, F) -> xs: (B, C, T, F) xs = self.net(xs) else: raise NotImplementedError(f'Not implemented: {self.model_type}') # xs: (B, C, T, D) -> out:(B, C, T, F) out = self.linear(xs) # Zero padding out = torch.sigmoid(out) out.masked_fill(make_pad_mask(input_lengths, out, length_dim=2), 0) return out
def perform_filter_operation(Y: ComplexTensor, filter_matrix_conj: ComplexTensor, taps, delay) \ -> ComplexTensor: """perform_filter_operation Args: Y : Complex-valued STFT signal of shape (F, C, T) filter Matrix (F, taps, C, C) """ T = Y.size(-1) reverb_tail = ComplexTensor(torch.zeros_like(Y.real), torch.zeros_like(Y.real)) for tau_minus_delay in range(taps): new = FC.einsum('fde,fdt->fet', (filter_matrix_conj[:, tau_minus_delay, :, :], Y[:, :, :T - delay - tau_minus_delay])) new = FC.pad(new, (delay + tau_minus_delay, 0), mode='constant', value=0) reverb_tail = reverb_tail + new return Y - reverb_tail
def tik_reg(mat: ComplexTensor, reg: float = 1e-8, eps: float = 1e-8) -> ComplexTensor: """Perform Tikhonov regularization (only modifying real part). Args: mat (ComplexTensor): input matrix (..., C, C) reg (float): regularization factor eps (float) Returns: ret (ComplexTensor): regularized matrix (..., C, C) """ # Add eps C = mat.size(-1) eye = torch.eye(C, dtype=mat.dtype, device=mat.device) shape = [1 for _ in range(mat.dim() - 2)] + [C, C] eye = eye.view(*shape).repeat(*mat.shape[:-2], 1, 1) with torch.no_grad(): epsilon = FC.trace(mat).real[..., None, None] * reg # in case that correlation_matrix is all-zero epsilon = epsilon + eps mat = mat + epsilon * eye return mat
def trace(a: ComplexTensor) -> ComplexTensor: E = torch.eye(a.real.size(-1), dtype=torch.uint8).expand(*a.size()) return a[E].view(*a.size()[:-1]).sum(-1)
def forward(self, data: ComplexTensor, ilens: torch.LongTensor=None, return_wpe: bool=True) -> Tuple[Optional[ComplexTensor], torch.Tensor]: if ilens is None: ilens = torch.full((data.size(0),), data.size(2), dtype=torch.long, device=data.device) r = -self.rcontext if self.rcontext != 0 else None enhanced = data[:, :, self.lcontext:r, :] if self.lcontext != 0 or self.rcontext != 0: assert all(ilens[0] == i for i in ilens) # Create context window (a.k.a Splicing) if self.model_type in ('blstm', 'lstm'): width = data.size(2) - self.lcontext - self.rcontext # data: (B, C, l + w + r, F) indices = [i + j for i in range(width) for j in range(1 + self.lcontext + self.rcontext)] _y = data[:, :, indices] # data: (B, C, l, (1 + w + r), F) data = _y.view( data.size(0), data.size(1), width, (1 + self.lcontext + self.rcontext) * data.size(3)) ilens = torch.full((data.size(0),), width, dtype=torch.long, device=data.device) del _y for i in range(self.iterations): power = enhanced.real ** 2 + enhanced.imag ** 2 # Calculate power: (B, C, T, Context, F) if i == 0 and self.use_dnn: # mask: (B, C, T, F) mask = self.estimator(data, ilens) if mask.size(2) != power.size(2): assert mask.size(2) == (power.size(2) + self.rcontext + self.lcontext) r = -self.rcontext if self.rcontext != 0 else None mask = mask[:, :, self.lcontext:r, :] if self.normalization: # Normalize along T mask = mask / mask.sum(dim=-2)[..., None] if self.out_type == 'mask': power = power * mask else: power = mask if self.out_type == 'amplitude': power = power ** 2 elif self.out_type == 'log_power': power = power.exp() elif self.out_type == 'power': pass else: raise NotImplementedError(self.out_type) if not return_wpe: return None, power # power: (B, C, T, F) -> _power: (B, F, T) _power = power.mean(dim=1).transpose(-1, -2).contiguous() # data: (B, C, T, F) -> _data: (B, F, C, T) _data = data.permute(0, 3, 1, 2).contiguous() # _enhanced: (B, F, C, T) _enhanced_real = [] _enhanced_imag = [] for d, p, l in zip(_data, _power, ilens): # e: (F, C, T) -> (T, C, F) e = wpe_one_iteration( d[..., :l], p[..., :l], taps=self.taps, delay=self.delay, inverse_power=self.inverse_power).transpose(0, 2) _enhanced_real.append(e.real) _enhanced_imag.append(e.imag) # _enhanced: B x (T, C, F) -> (B, T, C, F) -> (B, F, C, T) _enhanced_real = pad_sequence(_enhanced_real, batch_first=True).transpose(1, 3) _enhanced_imag = pad_sequence(_enhanced_imag, batch_first=True).transpose(1, 3) _enhanced = ComplexTensor(_enhanced_real, _enhanced_imag) # enhanced: (B, F, C, T) -> (B, C, T, F) enhanced = _enhanced.permute(0, 2, 3, 1) # enhanced: (B, C, T, F), power: (B, C, T, F) return enhanced, power
# (..., C, T) * (..., C, T) -> (..., C, T) power = power * mask_speech # Averaging along the channel axis: (B, F, C, T) -> (B, F, T) power = power.mean(dim=-2) # (B, F, T) --> (B * F, T) power = power.view(-1, power.shape[-1]) inverse_power = 1 / torch.clamp(power, min=eps) B, Fdim, C, T = Z.shape # covariance matrix: (B, F, (btaps+1) * C, (btaps+1) * C) covariance_matrix = get_covariances( Z, inverse_power, bdelay, btaps, get_vector=False ) # speech signal PSD: (B, F, C, C) psd_speech = beamformer.get_power_spectral_density_matrix( Z, mask_speech, btaps, normalization=True ) # reference vector: (B, C) ref_channel = 0 u = torch.zeros(*(Z.size()[:-3] + (Z.size(-2),)), device=Z.device) u[..., ref_channel].fill_(1) # (B, F, (btaps + 1) * C) WPD_filter = get_WPD_filter_v2(psd_speech, covariance_matrix, u) # (B, F, T) enhanced = perform_WPD_filtering(WPD_filter, Z, bdelay, btaps)
def online_wpe_step(input_buffer: ComplexTensor, power: torch.Tensor, inv_cov: ComplexTensor = None, filter_taps: ComplexTensor = None, alpha: float = 0.99, taps: int = 10, delay: int = 3): """One step of online dereverberation. Args: input_buffer: (F, C, taps + delay + 1) power: Estimate for the current PSD (F, T) inv_cov: Current estimate of R^-1 filter_taps: Current estimate of filter taps (F, taps * C, taps) alpha (float): Smoothing factor taps (int): Number of filter taps delay (int): Delay in frames Returns: Dereverberated frame of shape (F, D) Updated estimate of R^-1 Updated estimate of the filter taps >>> frame_length = 512 >>> frame_shift = 128 >>> taps = 6 >>> delay = 3 >>> alpha = 0.999 >>> frequency_bins = frame_length // 2 + 1 >>> Q = None >>> G = None >>> unreverbed, Q, G = online_wpe_step(stft, get_power_online(stft), Q, G, ... alpha=alpha, taps=taps, delay=delay) """ assert input_buffer.size(-1) == taps + delay + 1, input_buffer.size() C = input_buffer.size(-2) if inv_cov is None: inv_cov = ComplexTensor( torch.eye(C * taps, dtype=input_buffer.dtype).expand( *input_buffer.size()[:-2], C * taps, C * taps)) if filter_taps is None: filter_taps = ComplexTensor( torch.zeros(*input_buffer.size()[:-2], C * taps, C, dtype=input_buffer.dtype)) window = FC.reverse(input_buffer[..., :-delay - 1], dim=-1) # (..., C, T) -> (..., C * T) window = window.view(*input_buffer.size()[:-2], -1) pred = input_buffer[..., -1] - FC.einsum('...id,...i->...d', (filter_taps.conj(), window)) nominator = FC.einsum('...ij,...j->...i', (inv_cov, window)) denominator = \ FC.einsum('...i,...i->...', (window.conj(), nominator)) + alpha * power kalman_gain = nominator / denominator[..., None] inv_cov_k = inv_cov - FC.einsum('...j,...jm,...i->...im', (window.conj(), inv_cov, kalman_gain)) inv_cov_k /= alpha filter_taps_k = \ filter_taps + FC.einsum('...i,...m->...im', (kalman_gain, pred.conj())) return pred, inv_cov_k, filter_taps_k
def get_WPD_filter_with_rtf( psd_observed_bar: ComplexTensor, psd_speech: ComplexTensor, psd_noise: ComplexTensor, iterations: int = 3, reference_vector: Union[int, torch.Tensor, None] = None, normalize_ref_channel: Optional[int] = None, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-15, ) -> ComplexTensor: """Return the WPD vector calculated with RTF. WPD is the Weighted Power minimization Distortionless response convolutional beamformer. As follows: h = (Rf^-1 @ vbar) / (vbar^H @ R^-1 @ vbar) Reference: T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer for Simultaneous Denoising and Dereverberation," in IEEE Signal Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi: 10.1109/LSP.2019.2911179. https://ieeexplore.ieee.org/document/8691481 Args: psd_observed_bar (ComplexTensor): stacked observation covariance matrix psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C) iterations (int): number of iterations in power method reference_vector (torch.Tensor or int): (..., C) or scalar normalize_ref_channel (int): reference channel for normalizing the RTF use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (ComplexTensor)r: (..., F, C) """ C = psd_noise.size(-1) if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) # (B, F, C, 1) rtf = get_rtf( psd_speech, psd_noise, reference_vector, iterations=iterations, use_torch_solver=use_torch_solver, ) # (B, F, (K+1)*C, 1) rtf = FC.pad(rtf, (0, 0, 0, psd_observed_bar.shape[-1] - C), "constant", 0) # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1) if use_torch_solver: numerator = FC.solve(rtf, psd_observed_bar)[0].squeeze(-1) else: numerator = FC.matmul(psd_observed_bar.inverse2(), rtf).squeeze(-1) denominator = FC.einsum("...d,...d->...", [rtf.squeeze(-1).conj(), numerator]) if normalize_ref_channel is not None: scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj() beamforming_vector = numerator * scale / ( denominator.real.unsqueeze(-1) + eps) else: beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps) return beamforming_vector