def forward( self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]] ) -> Tuple[torch.Tensor, torch.LongTensor]: # (B, T, F) or (B, T, C, F) if x.dim() not in (3, 4): raise ValueError(f"Input dim must be 3 or 4: {x.dim()}") if not torch.is_tensor(ilens): ilens = torch.from_numpy(np.asarray(ilens)).to(x.device) if x.dim() == 4: # h: (B, T, C, F) -> h: (B, T, F) if self.training: # Select 1ch randomly ch = np.random.randint(x.size(2)) h = x[:, :, ch, :] else: # Use the first channel h = x[:, :, 0, :] else: h = x # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F) h = h.real ** 2 + h.imag ** 2 h, _ = self.logmel(h, ilens) if self.stats_file is not None: h, _ = self.global_mvn(h, ilens) if self.apply_uttmvn: h, _ = self.uttmvn(h, ilens) return h, ilens
def get_mvdr_vector(psd_s: ComplexTensor, psd_n: ComplexTensor, reference_vector: torch.Tensor, eps: float = 1e-15) -> ComplexTensor: """Return the MVDR(Minimum Variance Distortionless Response) vector: h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u Reference: On optimal frequency-domain multichannel linear filtering for noise reduction; M. Souden et al., 2010; https://ieeexplore.ieee.org/document/5089420 Args: psd_s (ComplexTensor): (..., F, C, C) psd_n (ComplexTensor): (..., F, C, C) reference_vector (torch.Tensor): (..., C) eps (float): Returns: beamform_vector (ComplexTensor)r: (..., F, C) """ # Add eps C = psd_n.size(-1) eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device) shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C] eye = eye.view(*shape) psd_n += eps * eye # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) numerator = FC.einsum('...ec,...cd->...ed', [psd_n.inverse(), psd_s]) # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (FC.trace(numerator)[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = FC.einsum('...fec,...c->...fe', [ws, reference_vector]) return beamform_vector
def forward(self, input: ComplexTensor, ilens: torch.Tensor): """Forward. Args: input (ComplexTensor): spectrum [Batch, T, (C,) F] ilens (torch.Tensor): input lengths [Batch] """ if not isinstance(input, ComplexTensor) and ( is_torch_1_9_plus and not torch.is_complex(input)): raise TypeError("Only support complex tensors for stft decoder") bs = input.size(0) if input.dim() == 4: multi_channel = True # input: (Batch, T, C, F) -> (Batch * C, T, F) input = input.transpose(1, 2).reshape(-1, input.size(1), input.size(3)) else: multi_channel = False wav, wav_lens = self.stft.inverse(input, ilens) if multi_channel: # wav: (Batch * C, Nsamples) -> (Batch, Nsamples, C) wav = wav.reshape(bs, -1, wav.size(1)).transpose(1, 2) return wav, wav_lens
def forward( self, input: torch.Tensor, input_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: # 1. Domain-conversion: e.g. Stft: time -> time-freq input_stft, feats_lens = self.stft(input, input_lengths) assert input_stft.dim() >= 4, input_stft.shape # "2" refers to the real/imag parts of Complex assert input_stft.shape[-1] == 2, input_stft.shape # Change torch.Tensor to ComplexTensor # input_stft: (..., F, 2) -> (..., F) input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1]) # 2. [Option] Speech enhancement if self.frontend is not None: assert isinstance(input_stft, ComplexTensor), type(input_stft) # input_stft: (Batch, Length, [Channel], Freq) input_stft, _, mask = self.frontend(input_stft, feats_lens) # 3. [Multi channel case]: Select a channel if input_stft.dim() == 4: # h: (B, T, C, F) -> h: (B, T, F) if self.training: # Select 1ch randomly ch = np.random.randint(input_stft.size(2)) input_stft = input_stft[:, :, ch, :] else: # Use the first channel input_stft = input_stft[:, :, 0, :] # 4. STFT -> Power spectrum # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F) input_power = input_stft.real ** 2 + input_stft.imag ** 2 # 5. Feature transform e.g. Stft -> Log-Mel-Fbank # input_power: (Batch, [Channel,] Length, Freq) # -> input_feats: (Batch, Length, Dim) input_feats, _ = self.logmel(input_power, feats_lens) return input_feats, feats_lens
def forward( self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]] ) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]: assert len(x) == len(ilens), (len(x), len(ilens)) # (B, T, F) or (B, T, C, F) if x.dim() not in (3, 4): raise ValueError(f"Input dim must be 3 or 4: {x.dim()}") if not torch.is_tensor(ilens): ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device) mask = None h = x if h.dim() == 4: if self.training: choices = [(False, False)] if not self.use_frontend_for_all else [] if self.use_wpe: choices.append((True, False)) if self.use_beamformer: choices.append((False, True)) use_wpe, use_beamformer = choices[numpy.random.randint( len(choices))] else: use_wpe = self.use_wpe use_beamformer = self.use_beamformer # 1. WPE if use_wpe: # h: (B, T, C, F) -> h: (B, T, C, F) h, ilens, mask = self.wpe(h, ilens) # 2. Beamformer if use_beamformer: # h: (B, T, C, F) -> h: (B, T, F) h, ilens, mask = self.beamformer(h, ilens) return h, ilens, mask
def tik_reg(mat: ComplexTensor, reg: float = 1e-8, eps: float = 1e-8) -> ComplexTensor: """Perform Tikhonov regularization (only modifying real part). Args: mat (ComplexTensor): input matrix (..., C, C) reg (float): regularization factor eps (float) Returns: ret (ComplexTensor): regularized matrix (..., C, C) """ # Add eps C = mat.size(-1) eye = torch.eye(C, dtype=mat.dtype, device=mat.device) shape = [1 for _ in range(mat.dim() - 2)] + [C, C] eye = eye.view(*shape).repeat(*mat.shape[:-2], 1, 1) with torch.no_grad(): epsilon = FC.trace(mat).real[..., None, None] * reg # in case that correlation_matrix is all-zero epsilon = epsilon + eps mat = mat + epsilon * eye return mat
def get_filter_matrix_conj(correlation_matrix: ComplexTensor, correlation_vector: ComplexTensor, eps: float = 1e-10) -> ComplexTensor: """Calculate (conjugate) filter matrix based on correlations for one freq. Args: correlation_matrix : Correlation matrix (F, taps * C, taps * C) correlation_vector : Correlation vector (F, taps, C, C) eps: Returns: filter_matrix_conj (ComplexTensor): (F, taps, C, C) """ F, taps, C, _ = correlation_vector.size() # (F, taps, C1, C2) -> (F, C1, taps, C2) -> (F, C1, taps * C2) correlation_vector = \ correlation_vector.permute(0, 2, 1, 3)\ .contiguous().view(F, C, taps * C) eye = torch.eye(correlation_matrix.size(-1), dtype=correlation_matrix.dtype, device=correlation_matrix.device) shape = tuple(1 for _ in range(correlation_matrix.dim() - 2)) + \ correlation_matrix.shape[-2:] eye = eye.view(*shape) correlation_matrix += eps * eye inv_correlation_matrix = correlation_matrix.inverse() # (F, C, taps, C) x (F, taps * C, taps * C) -> (F, C, taps * C) stacked_filter_conj = FC.matmul(correlation_vector, inv_correlation_matrix.transpose(-1, -2)) # (F, C1, taps * C2) -> (F, C1, taps, C2) -> (F, taps, C2, C1) filter_matrix_conj = \ stacked_filter_conj.view(F, C, taps, C).permute(0, 2, 3, 1) return filter_matrix_conj
def forward(self, input: torch.Tensor, ilens: torch.Tensor): """Forward. Args: input (torch.Tensor): mixed speech [Batch, Nsample, Channel] ilens (torch.Tensor): input lengths [Batch] Returns: enhanced speech (single-channel): torch.Tensor or List[torch.Tensor] output lengths predcited masks: OrderedDict[ 'dereverb': torch.Tensor(Batch, Frames, Channel, Freq), 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), 'noise1': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # wave -> stft -> magnitude specturm input_spectrum, flens = self.stft(input, ilens) # (Batch, Frames, Freq) or (Batch, Frames, Channels, Freq) input_spectrum = ComplexTensor(input_spectrum[..., 0], input_spectrum[..., 1]) if self.normalize_input: input_spectrum = input_spectrum / abs(input_spectrum).max() enhanced = input_spectrum masks = OrderedDict() if input_spectrum.dim() == 3: # single-channel input if self.use_wpe: # (B, T, F) enhanced, flens, mask_w = self.wpe(input_spectrum.unsqueeze(-2), flens) enhanced = enhanced.squeeze(-2) if mask_w is not None: masks["dereverb"] = mask_w.squeeze(-2) elif input_spectrum.dim() == 4: # multi-channel input # 1. WPE if self.use_wpe: # (B, T, C, F) enhanced, flens, mask_w = self.wpe(input_spectrum, flens) if mask_w is not None: masks["dereverb"] = mask_w # 2. Beamformer if self.use_beamformer: # enhanced: (B, T, C, F) -> (B, T, F) enhanced, flens, masks_b = self.beamformer(enhanced, flens) for spk in range(self.num_spk): masks["spk{}".format(spk + 1)] = masks_b[spk] if len(masks_b) > self.num_spk: masks["noise1"] = masks_b[self.num_spk] else: raise ValueError( "Invalid spectrum dimension: {}".format(input_spectrum.shape) ) # Convert ComplexTensor to torch.Tensor # (B, T, F) -> (B, T, F, 2) if isinstance(enhanced, list): # multi-speaker output enhanced = [torch.stack([enh.real, enh.imag], dim=-1) for enh in enhanced] else: # single-speaker output enhanced = torch.stack([enhanced.real, enhanced.imag], dim=-1).float() return enhanced, flens, masks
def forward( self, input: ComplexTensor, ilens: torch.Tensor ) -> Tuple[List[ComplexTensor], torch.Tensor, OrderedDict]: """Forward. Args: input (ComplexTensor): mixed speech [Batch, Frames, Channel, Freq] ilens (torch.Tensor): input lengths [Batch] Returns: enhanced speech (single-channel): List[ComplexTensor] output lengths other predcited data: OrderedDict[ 'dereverb1': ComplexTensor(Batch, Frames, Channel, Freq), 'mask_dereverb1': torch.Tensor(Batch, Frames, Channel, Freq), 'mask_noise1': torch.Tensor(Batch, Frames, Channel, Freq), 'mask_spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # Shape of input spectrum must be (B, T, F) or (B, T, C, F) assert input.dim() in (3, 4), input.dim() enhanced = input others = OrderedDict() if (self.training and self.loss_type is not None and self.loss_type.startswith("mask")): # Only estimating masks during training for saving memory if self.use_wpe: if input.dim() == 3: mask_w, ilens = self.wpe.predict_mask( input.unsqueeze(-2), ilens) mask_w = mask_w.squeeze(-2) elif input.dim() == 4: mask_w, ilens = self.wpe.predict_mask(input, ilens) if mask_w is not None: if isinstance(enhanced, list): # single-source WPE for spk in range(self.num_spk): others["mask_dereverb{}".format(spk + 1)] = mask_w[spk] else: # multi-source WPE others["mask_dereverb1"] = mask_w if self.use_beamformer and input.dim() == 4: others_b, ilens = self.beamformer.predict_mask(input, ilens) for spk in range(self.num_spk): others["mask_spk{}".format(spk + 1)] = others_b[spk] if len(others_b) > self.num_spk: others["mask_noise1"] = others_b[self.num_spk] return None, ilens, others else: powers = None # Performing both mask estimation and enhancement if input.dim() == 3: # single-channel input (B, T, F) if self.use_wpe: enhanced, ilens, mask_w, powers = self.wpe( input.unsqueeze(-2), ilens) if isinstance(enhanced, list): # single-source WPE enhanced = [enh.squeeze(-2) for enh in enhanced] if mask_w is not None: for spk in range(self.num_spk): key = "dereverb{}".format(spk + 1) others[key] = enhanced[spk] others["mask_" + key] = mask_w[spk].squeeze(-2) else: # multi-source WPE enhanced = enhanced.squeeze(-2) if mask_w is not None: others["dereverb1"] = enhanced others["mask_dereverb1"] = mask_w.squeeze(-2) else: # multi-channel input (B, T, C, F) # 1. WPE if self.use_wpe: enhanced, ilens, mask_w, powers = self.wpe(input, ilens) if mask_w is not None: if isinstance(enhanced, list): # single-source WPE for spk in range(self.num_spk): key = "dereverb{}".format(spk + 1) others[key] = enhanced[spk] others["mask_" + key] = mask_w[spk] else: # multi-source WPE others["dereverb1"] = enhanced others["mask_dereverb1"] = mask_w.squeeze(-2) # 2. Beamformer if self.use_beamformer: if (not self.beamformer.beamformer_type.startswith("wmpdr") or not self.beamformer.beamformer_type.startswith( "wpd") or not self.shared_power or (self.wpe.nmask == 1 and self.num_spk > 1)): powers = None # enhanced: (B, T, C, F) -> (B, T, F) if isinstance(enhanced, list): # outputs of single-source WPE raise NotImplementedError( "Single-source WPE is not supported with beamformer " "in multi-speaker cases.") else: # output of multi-source WPE enhanced, ilens, others_b = self.beamformer( enhanced, ilens, powers=powers) for spk in range(self.num_spk): others["mask_spk{}".format(spk + 1)] = others_b[spk] if len(others_b) > self.num_spk: others["mask_noise1"] = others_b[self.num_spk] if not isinstance(enhanced, list): enhanced = [enhanced] return enhanced, ilens, others
def forward(self, input: torch.Tensor, ilens: torch.Tensor): """Forward. Args: input (torch.Tensor): mixed speech [Batch, Nsample, Channel] ilens (torch.Tensor): input lengths [Batch] Returns: enhanced speech (single-channel): torch.Tensor or List[torch.Tensor] output lengths predcited masks: OrderedDict[ 'dereverb': torch.Tensor(Batch, Frames, Channel, Freq), 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), 'noise1': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # wave -> stft -> magnitude specturm input_spectrum, flens = self.stft(input, ilens) # (Batch, Frames, Freq) or (Batch, Frames, Channels, Freq) input_spectrum = ComplexTensor(input_spectrum[..., 0], input_spectrum[..., 1]) if self.normalize_input: input_spectrum = input_spectrum / abs(input_spectrum).max() # Shape of input spectrum must be (B, T, F) or (B, T, C, F) assert input_spectrum.dim() in (3, 4), input_spectrum.dim() enhanced = input_spectrum masks = OrderedDict() if self.training and self.loss_type.startswith("mask"): # Only estimating masks for training if self.use_wpe: if input_spectrum.dim() == 3: mask_w, flens = self.wpe.predict_mask( input_spectrum.unsqueeze(-2), flens) mask_w = mask_w.squeeze(-2) elif input_spectrum.dim() == 4: if self.use_beamformer: enhanced, flens, mask_w = self.wpe( input_spectrum, flens) else: mask_w, flens = self.wpe.predict_mask( input_spectrum, flens) if mask_w is not None: masks["dereverb"] = mask_w if self.use_beamformer and input_spectrum.dim() == 4: masks_b, flens = self.beamformer.predict_mask(enhanced, flens) for spk in range(self.num_spk): masks["spk{}".format(spk + 1)] = masks_b[spk] if len(masks_b) > self.num_spk: masks["noise1"] = masks_b[self.num_spk] return None, flens, masks else: # Performing both mask estimation and enhancement if input_spectrum.dim() == 3: # single-channel input (B, T, F) if self.use_wpe: enhanced, flens, mask_w = self.wpe( input_spectrum.unsqueeze(-2), flens) enhanced = enhanced.squeeze(-2) if mask_w is not None: masks["dereverb"] = mask_w.squeeze(-2) else: # multi-channel input (B, T, C, F) # 1. WPE if self.use_wpe: enhanced, flens, mask_w = self.wpe(input_spectrum, flens) if mask_w is not None: masks["dereverb"] = mask_w # 2. Beamformer if self.use_beamformer: # enhanced: (B, T, C, F) -> (B, T, F) enhanced, flens, masks_b = self.beamformer(enhanced, flens) for spk in range(self.num_spk): masks["spk{}".format(spk + 1)] = masks_b[spk] if len(masks_b) > self.num_spk: masks["noise1"] = masks_b[self.num_spk] # Convert ComplexTensor to torch.Tensor # (B, T, F) -> (B, T, F, 2) if isinstance(enhanced, list): # multi-speaker output enhanced = [ torch.stack([enh.real, enh.imag], dim=-1) for enh in enhanced ] else: # single-speaker output enhanced = torch.stack([enhanced.real, enhanced.imag], dim=-1).float() return enhanced, flens, masks