def trace(a: ComplexTensor) -> ComplexTensor: E = torch.eye(a.real.size(-1), dtype=torch.uint8).expand(*a.size()) return a[E].view(*a.size()[:-1]).sum(-1)
def _compute_loss( self, speech_mix, speech_lengths, speech_ref, dereverb_speech_ref=None, noise_ref=None, cal_loss=True, ): """Compute loss according to self.loss_type. Args: speech_mix: (Batch, samples) or (Batch, samples, channels) speech_lengths: (Batch,), default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py speech_ref: (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels) dereverb_speech_ref: (Batch, N, samples) or (Batch, num_speaker, samples, channels) noise_ref: (Batch, num_noise_type, samples) or (Batch, num_speaker, samples, channels) cal_loss: whether to calculate enh loss, defualt is True Returns: loss: (torch.Tensor) speech enhancement loss speech_pre: (List[torch.Tensor] or List[ComplexTensor]) enhanced speech or spectrum(s) others: (OrderedDict) estimated masks or None output_lengths: (Batch,) perm: () best permutation """ feature_mix, flens = self.encoder(speech_mix, speech_lengths) feature_pre, flens, others = self.separator(feature_mix, flens) if self.loss_type not in ["si_snr", "ci_sdr"]: spectrum_mix = feature_mix spectrum_pre = feature_pre # predict separated speech and masks if self.stft_consistency: # pseudo STFT -> time-domain -> STFT (compute loss) tmp_t_domain = [ self.decoder(sp, speech_lengths)[0] for sp in spectrum_pre ] spectrum_pre = [ self.encoder(sp, speech_lengths)[0] for sp in tmp_t_domain ] pass if spectrum_pre is not None and not isinstance( spectrum_pre[0], ComplexTensor ): spectrum_pre = [ ComplexTensor(*torch.unbind(sp, dim=-1)) for sp in spectrum_pre ] if not cal_loss: loss, perm = None, None return loss, spectrum_pre, others, flens, perm # prepare reference speech and reference spectrum speech_ref = torch.unbind(speech_ref, dim=1) # List[ComplexTensor(Batch, T, F)] or List[ComplexTensor(Batch, T, C, F)] spectrum_ref = [self.encoder(sr, speech_lengths)[0] for sr in speech_ref] # compute TF masking loss if self.loss_type == "magnitude": # compute loss on magnitude spectrum assert spectrum_pre is not None magnitude_pre = [abs(ps + 1e-15) for ps in spectrum_pre] if spectrum_ref[0].dim() > magnitude_pre[0].dim(): # only select one channel as the reference magnitude_ref = [ abs(sr[..., self.ref_channel, :]) for sr in spectrum_ref ] else: magnitude_ref = [abs(sr) for sr in spectrum_ref] tf_loss, perm = self._permutation_loss( magnitude_ref, magnitude_pre, self.tf_mse_loss ) elif self.loss_type.startswith("spectrum"): # compute loss on complex spectrum if self.loss_type == "spectrum": loss_func = self.tf_mse_loss elif self.loss_type == "spectrum_log": loss_func = self.tf_log_mse_loss else: raise ValueError("Unsupported loss type: %s" % self.loss_type) assert spectrum_pre is not None if spectrum_ref[0].dim() > spectrum_pre[0].dim(): # only select one channel as the reference spectrum_ref = [sr[..., self.ref_channel, :] for sr in spectrum_ref] tf_loss, perm = self._permutation_loss( spectrum_ref, spectrum_pre, loss_func ) elif self.loss_type.startswith("mask"): if self.loss_type == "mask_mse": loss_func = self.tf_mse_loss else: raise ValueError("Unsupported loss type: %s" % self.loss_type) assert others is not None mask_pre_ = [ others["mask_spk{}".format(spk + 1)] for spk in range(self.num_spk) ] # prepare ideal masks mask_ref = self._create_mask_label( spectrum_mix, spectrum_ref, mask_type=self.mask_type ) # compute TF masking loss tf_loss, perm = self._permutation_loss(mask_ref, mask_pre_, loss_func) if "mask_dereverb1" in others: if dereverb_speech_ref is None: raise ValueError( "No dereverberated reference for training!\n" 'Please specify "--use_dereverb_ref true" in run.sh' ) mask_wpe_pre = [ others["mask_dereverb{}".format(spk + 1)] for spk in range(self.num_spk) if "mask_dereverb{}".format(spk + 1) in others ] assert len(mask_wpe_pre) == dereverb_speech_ref.size(1), ( len(mask_wpe_pre), dereverb_speech_ref.size(1), ) dereverb_speech_ref = torch.unbind(dereverb_speech_ref, dim=1) dereverb_spectrum_ref = [ self.encoder(dr, speech_lengths)[0] for dr in dereverb_speech_ref ] dereverb_mask_ref = self._create_mask_label( spectrum_mix, dereverb_spectrum_ref, mask_type=self.mask_type ) tf_dereverb_loss, perm_d = self._permutation_loss( dereverb_mask_ref, mask_wpe_pre, loss_func ) tf_loss = tf_loss + tf_dereverb_loss if "mask_noise1" in others: if noise_ref is None: raise ValueError( "No noise reference for training!\n" 'Please specify "--use_noise_ref true" in run.sh' ) noise_ref = torch.unbind(noise_ref, dim=1) noise_spectrum_ref = [ self.encoder(nr, speech_lengths)[0] for nr in noise_ref ] noise_mask_ref = self._create_mask_label( spectrum_mix, noise_spectrum_ref, mask_type=self.mask_type ) mask_noise_pre = [ others["mask_noise{}".format(n + 1)] for n in range(self.num_noise_type) ] tf_noise_loss, perm_n = self._permutation_loss( noise_mask_ref, mask_noise_pre, loss_func ) tf_loss = tf_loss + tf_noise_loss else: raise ValueError("Unsupported loss type: %s" % self.loss_type) loss = tf_loss return loss, spectrum_pre, others, flens, perm else: speech_pre = [self.decoder(ps, speech_lengths)[0] for ps in feature_pre] if not cal_loss: loss, perm = None, None return loss, speech_pre, None, speech_lengths, perm # speech_pre: list[(batch, sample)] assert speech_pre[0].dim() == 2, speech_pre[0].dim() if speech_ref.dim() == 4: # For si_snr loss of multi-channel input, # only select one channel as the reference speech_ref = speech_ref[..., self.ref_channel] speech_ref = torch.unbind(speech_ref, dim=1) if self.loss_type == "si_snr": # compute si-snr loss loss, perm = self._permutation_loss( speech_ref, speech_pre, self.si_snr_loss_zeromean ) elif self.loss_type == "ci_sdr": # compute ci-snr loss loss, perm = self._permutation_loss( speech_ref, speech_pre, self.ci_sdr_loss ) else: raise ValueError("Unsupported loss type: %s" % self.loss_type) return loss, speech_pre, None, speech_lengths, perm
def forward(self, data: ComplexTensor, ilens: torch.LongTensor) \ -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: """The forward function Notation: B: Batch C: Channel T: Time or Sequence length F: Freq Args: data (ComplexTensor): (B, T, C, F) ilens (torch.Tensor): (B,) Returns: enhanced (ComplexTensor): (B, T, F) ilens (torch.Tensor): (B,) """ def apply_beamforming(data, ilens, psd_speech, psd_noise): # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech, ilens) else: # (optional) Create onehot vector for fixed reference microphone u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )), device=data.device) u[..., self.ref_channel].fill_(1) ws = get_mvdr_vector(psd_speech, psd_noise, u) enhanced = apply_beamforming_vector(ws, data) return enhanced, ws # data (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) # mask: (B, F, C, T) masks, _ = self.mask(data, ilens) assert self.nmask == len(masks) if self.nmask == 2: # (mask_speech, mask_noise) mask_speech, mask_noise = masks psd_speech = get_power_spectral_density_matrix(data, mask_speech) psd_noise = get_power_spectral_density_matrix(data, mask_noise) enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_noise) # (..., F, T) -> (..., T, F) enhanced = enhanced.transpose(-1, -2) mask_speech = mask_speech.transpose(-1, -3) else: # multi-speaker case: (mask_speech1, ..., mask_noise) mask_speech = list(masks[:-1]) mask_noise = masks[-1] psd_speeches = [ get_power_spectral_density_matrix(data, mask) for mask in mask_speech ] psd_noise = get_power_spectral_density_matrix(data, mask_noise) enhanced = [] ws = [] for i in range(self.nmask - 1): psd_speech = psd_speeches.pop(i) # treat all other speakers' psd_speech as noises enh, w = apply_beamforming(data, ilens, psd_speech, sum(psd_speeches) + psd_noise) psd_speeches.insert(i, psd_speech) # (..., F, T) -> (..., T, F) enh = enh.transpose(-1, -2) mask_speech[i] = mask_speech[i].transpose(-1, -3) enhanced.append(enh) ws.append(w) return enhanced, ilens, mask_speech
High_noise, Low_noise = Decomposition(y_.squeeze(0), 0.10) # dncnn_data = High_noise[:,0].unsqueeze(0) decompose_data = torch.cat([y_, High_noise, Low_noise], dim=1) # x = torch.cat([High_origin.unsqueeze(1), Low_origin.unsqueeze(1)], dim=1) model = model.cpu() decom_output = decompose_model(decompose_data.float()).squeeze( 0) # inference dncnn_output = model(High_noise[:, 0].unsqueeze(1)).squeeze(0) # dncnn_high, dncnn_low = Decomposition(dncnn_output, 0.10) # x_ = output[0].cpu().detach().numpy().astype(np.float32) output = ComplexTensor(dncnn_output[1] + decom_output[3], decom_output[2] + decom_output[4]).abs() # output = ComplexTensor(dncnn_high[0,0] + decom_output[3], decom_output[2] + decom_output[4]).abs() x_ = output.cpu().detach().numpy().astype(np.float32) # x_ = ComplexTensor(output[:, 0] + output[:, 2], Low_noise[:, 0] + Low_noise[:, 1]).abs().squeeze(0) # x_ = ComplexTensor(output[:, 0] + output[:, 2], output[:, 1] + output[:, 3]).abs().squeeze(0) # x_ = torch.add(output[:,0], output[:,1]).squeeze(0) # x_ = x_.cpu().detach().numpy().astype(np.float32) # x_ = x_.view(y.shape[0], y.shape[1]) # x_ = x_.cpu() # x_ = x_.detach().numpy().astype(np.float32) elapsed_time = time.time() - start_time
def test_gev_phase_correction(): mat = ComplexTensor(torch.rand(2, 3, 4), torch.rand(2, 3, 4)) mat_th = torch.complex(mat.real, mat.imag) norm = gev_phase_correction(mat) norm_th = gev_phase_correction(mat_th) assert np.allclose(norm.numpy(), norm_th.numpy())
# lowfreq_input = torch.cat([y_, High_noise[:,0].unsqueeze(1), High_noise[:,1].unsqueeze(1), Low_noise[:,0].unsqueeze(1), Low_noise[:,1].unsqueeze(1)], dim=1) lowfreq_input = torch.cat([ High_noise[:, 1].unsqueeze(1), Low_noise[:, 0].unsqueeze(1), Low_noise[:, 1].unsqueeze(1) ], dim=1) output_dncnn = model_dncnn(dncnn_input.float()).squeeze( 0) # inference lowfreq_output = model_lowfreq(lowfreq_input) # x_ = output[0].cpu().detach().numpy().astype(np.float32) # output = ComplexTensor(High_origin[0,0].cuda() + output[3], output[2] + output[4]).abs() output = ComplexTensor( output_dncnn[0] + lowfreq_output[0, 1], lowfreq_output[0, 0] + lowfreq_output[0, 2]).abs() # output = ComplexTensor(output_dncnn[0] + Low_origin[0, 0], # High_origin[0, 1] + Low_origin[0, 1]).abs() x_ = output.cpu().detach().numpy().astype(np.float32) plt.figure() plt.imshow(x, cmap='jet') plt.show() plt.close() plt.figure() plt.imshow(output.detach().numpy(), cmap='jet') plt.show() plt.close()
def main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--in-scp', type=str, required=True) parser.add_argument('--clean-scp', type=str, help='Decode using oracle clean power') parser.add_argument('--out-dir', type=str, required=True) parser.add_argument('--model-state', type=str) parser.add_argument('--model-config', type=str) parser.add_argument('--stft-file', type=str, default='./stft.json') parser.add_argument('--ngpu', type=int, default=1) parser.add_argument('--ref-channels', type=str2int_tuple, default=None) parser.add_argument('--online', type=strtobool, default=False) parser.add_argument('--taps', type=int, default=5) parser.add_argument('--delay', type=int, default=3) args = parser.parse_args() devcice = 'cuda' if args.ngpu > 1 else 'cpu' if args.model_config is not None: with open(args.model_config) as f: model_config = json.load(f) model_config.update(use_dnn=True) _ = model_config.pop('width') norm_scale = model_config.pop('norm_scale') model = DNN_WPE(**model_config) if args.model_state is not None: model.load_state_dict(torch.load(args.model_state)) else: model = None norm_scale = False reader = SoundScpReader(args.in_scp) writer = SoundScpWriter(args.out_dir, 'wav') if args.clean_scp is not None: clean_reader = SoundScpReader(args.clean_scp) else: clean_reader = None stft_func = Stft(args.stft_file) for key in tqdm(reader): # inp: (T, C) rate, inp = reader[key] if inp.ndim == 1: inp = inp[:, None] if args.ref_channels is not None: inp = inp[:, args.ref_channels] # Scaling int to [-1, 1] inp = inp.astype(np.float32) / (np.iinfo(inp.dtype).max - 1) if norm_scale: scale = np.abs(inp).mean() inp /= scale else: scale = 1. # inp: (T, C) -> inp_stft: (C, F, T) inp_stft = stft_func(inp.T) if clean_reader is not None: _, clean = clean_reader[key] if clean.ndim == 1: clean = clean[:, None] # clean: (T, C) -> clean_stft: (C, F, T) clean_stft = stft_func(clean.T) power = (clean_stft.real**clean_stft.imag**2).mean(0) elif model is not None: # To torch(C, F, T) -> (1, C, T, F) inp_stft_th = ComplexTensor(inp_stft.transpose( 0, 2, 1)[None]).to(devcice) with torch.no_grad(): _, power = model(inp_stft_th, return_wpe=False) # power: (1, C, T, F) -> (F, C, T) power = power[0].permute(2, 0, 1) # To numpy: (F, C, T) -> (F, T) power = power.cpu().numpy().mean(1) else: power = None # enh_stft: (F, C, T) if not args.online: enh_stft = wpe( inp_stft.transpose(1, 0, 2), power=power, taps=args.taps, delay=args.delay, iterations=1 if model is not None else 3, ) else: enh_stft = online_wpe(inp_stft.transpose(1, 0, 2), power=power, taps=args.taps, delay=args.delay) # enh_stft: (F, C, T) -> (C, F, T) enh_stft = enh_stft.transpose(1, 0, 2) enh_stft = enh_stft[0] # enh_stft: (C, F, T) -> enh: (T, C) enh = stft_func.istft(enh_stft).T # Truncate enh = enh[:inp.shape[0]] if norm_scale: enh *= scale # Rescaling [-1, 1] to int16 enh = (enh * (np.iinfo(np.int16).max - 1)).astype(np.int16) writer[key] = (rate, enh)
def forward(self, data: ComplexTensor, ilens: torch.LongTensor=None, return_wpe: bool=True) -> Tuple[Optional[ComplexTensor], torch.Tensor]: if ilens is None: ilens = torch.full((data.size(0),), data.size(2), dtype=torch.long, device=data.device) r = -self.rcontext if self.rcontext != 0 else None enhanced = data[:, :, self.lcontext:r, :] if self.lcontext != 0 or self.rcontext != 0: assert all(ilens[0] == i for i in ilens) # Create context window (a.k.a Splicing) if self.model_type in ('blstm', 'lstm'): width = data.size(2) - self.lcontext - self.rcontext # data: (B, C, l + w + r, F) indices = [i + j for i in range(width) for j in range(1 + self.lcontext + self.rcontext)] _y = data[:, :, indices] # data: (B, C, l, (1 + w + r), F) data = _y.view( data.size(0), data.size(1), width, (1 + self.lcontext + self.rcontext) * data.size(3)) ilens = torch.full((data.size(0),), width, dtype=torch.long, device=data.device) del _y for i in range(self.iterations): power = enhanced.real ** 2 + enhanced.imag ** 2 # Calculate power: (B, C, T, Context, F) if i == 0 and self.use_dnn: # mask: (B, C, T, F) mask = self.estimator(data, ilens) if mask.size(2) != power.size(2): assert mask.size(2) == (power.size(2) + self.rcontext + self.lcontext) r = -self.rcontext if self.rcontext != 0 else None mask = mask[:, :, self.lcontext:r, :] if self.normalization: # Normalize along T mask = mask / mask.sum(dim=-2)[..., None] if self.out_type == 'mask': power = power * mask else: power = mask if self.out_type == 'amplitude': power = power ** 2 elif self.out_type == 'log_power': power = power.exp() elif self.out_type == 'power': pass else: raise NotImplementedError(self.out_type) if not return_wpe: return None, power # power: (B, C, T, F) -> _power: (B, F, T) _power = power.mean(dim=1).transpose(-1, -2).contiguous() # data: (B, C, T, F) -> _data: (B, F, C, T) _data = data.permute(0, 3, 1, 2).contiguous() # _enhanced: (B, F, C, T) _enhanced_real = [] _enhanced_imag = [] for d, p, l in zip(_data, _power, ilens): # e: (F, C, T) -> (T, C, F) e = wpe_one_iteration( d[..., :l], p[..., :l], taps=self.taps, delay=self.delay, inverse_power=self.inverse_power).transpose(0, 2) _enhanced_real.append(e.real) _enhanced_imag.append(e.imag) # _enhanced: B x (T, C, F) -> (B, T, C, F) -> (B, F, C, T) _enhanced_real = pad_sequence(_enhanced_real, batch_first=True).transpose(1, 3) _enhanced_imag = pad_sequence(_enhanced_imag, batch_first=True).transpose(1, 3) _enhanced = ComplexTensor(_enhanced_real, _enhanced_imag) # enhanced: (B, F, C, T) -> (B, C, T, F) enhanced = _enhanced.permute(0, 2, 3, 1) # enhanced: (B, C, T, F), power: (B, C, T, F) return enhanced, power
# y_ = torch.cat([y_,High_noise, Low_noise], dim=1) # x = torch.cat([High_origin.unsqueeze(1), Low_origin.unsqueeze(1)], dim=1) dncnn_input = High_noise[:,0].unsqueeze(1).cuda() lowfreq_input = torch.cat([High_noise[:,1].unsqueeze(1), Low_noise[:,0].unsqueeze(1), Low_noise[:,1].unsqueeze(1)], dim=1) dncnn_input = dncnn_input.cuda() lowfreq_input = lowfreq_input.cuda() output_dncnn = model_dncnn(dncnn_input.cuda().float()).squeeze(0) # inference lowfreq_output = model_lowfreq(lowfreq_input).unsqueeze(0) # x_ = output[0].cpu().detach().numpy().astype(np.float32) # output = ComplexTensor(High_origin[0,0].cuda() + output[3], output[2] + output[4]).abs() output = ComplexTensor(output_dncnn[0] + lowfreq_output[1], lowfreq_output[0] + lowfreq_output[2]).abs() x_ = output.cpu().detach().numpy().astype(np.float32) # x_ = ComplexTensor(output[:, 0] + output[:, 2], Low_noise[:, 0] + Low_noise[:, 1]).abs().squeeze(0) # x_ = ComplexTensor(output[:, 0] + output[:, 2], output[:, 1] + output[:, 3]).abs().squeeze(0) # x_ = torch.add(output[:,0], output[:,1]).squeeze(0) # x_ = x_.cpu().detach().numpy().astype(np.float32) # x_ = x_.view(y.shape[0], y.shape[1]) # x_ = x_.cpu() # x_ = x_.detach().numpy().astype(np.float32) torch.cuda.synchronize() elapsed_time = time.time() - start_time
def get_WPD_filter_with_rtf( psd_observed_bar: ComplexTensor, psd_speech: ComplexTensor, psd_noise: ComplexTensor, iterations: int = 3, reference_vector: Union[int, torch.Tensor, None] = None, normalize_ref_channel: Optional[int] = None, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-15, ) -> ComplexTensor: """Return the WPD vector calculated with RTF. WPD is the Weighted Power minimization Distortionless response convolutional beamformer. As follows: h = (Rf^-1 @ vbar) / (vbar^H @ R^-1 @ vbar) Reference: T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer for Simultaneous Denoising and Dereverberation," in IEEE Signal Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi: 10.1109/LSP.2019.2911179. https://ieeexplore.ieee.org/document/8691481 Args: psd_observed_bar (ComplexTensor): stacked observation covariance matrix psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C) iterations (int): number of iterations in power method reference_vector (torch.Tensor or int): (..., C) or scalar normalize_ref_channel (int): reference channel for normalizing the RTF use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (ComplexTensor)r: (..., F, C) """ C = psd_noise.size(-1) if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) # (B, F, C, 1) rtf = get_rtf( psd_speech, psd_noise, reference_vector, iterations=iterations, use_torch_solver=use_torch_solver, ) # (B, F, (K+1)*C, 1) rtf = FC.pad(rtf, (0, 0, 0, psd_observed_bar.shape[-1] - C), "constant", 0) # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1) if use_torch_solver: numerator = FC.solve(rtf, psd_observed_bar)[0].squeeze(-1) else: numerator = FC.matmul(psd_observed_bar.inverse2(), rtf).squeeze(-1) denominator = FC.einsum("...d,...d->...", [rtf.squeeze(-1).conj(), numerator]) if normalize_ref_channel is not None: scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj() beamforming_vector = numerator * scale / ( denominator.real.unsqueeze(-1) + eps) else: beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps) return beamforming_vector
def forward( self, data: ComplexTensor, ilens: torch.LongTensor ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: """DNN_WPE forward function. Notation: B: Batch C: Channel T: Time or Sequence length F: Freq or Some dimension of the feature vector Args: data: (B, T, C, F) ilens: (B,) Returns: enhanced (torch.Tensor or List[torch.Tensor]): (B, T, C, F) ilens: (B,) masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F) power (List[torch.Tensor]): (B, F, T) """ # (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) enhanced = [data for i in range(self.nmask)] masks = None power = None for i in range(self.iterations): # Calculate power: (..., C, T) power = [enh.real**2 + enh.imag**2 for enh in enhanced] if i == 0 and self.use_dnn_mask: # mask: (B, F, C, T) masks, _ = self.mask_est(data, ilens) # floor masks to increase numerical stability if self.mask_flooring: masks = [m.clamp(min=self.flooring_thres) for m in masks] if self.normalization: # Normalize along T masks = [m / m.sum(dim=-1, keepdim=True) for m in masks] # (..., C, T) * (..., C, T) -> (..., C, T) power = [p * masks[i] for i, p in enumerate(power)] # Averaging along the channel axis: (..., C, T) -> (..., T) power = [p.mean(dim=-2).clamp(min=self.eps) for p in power] # enhanced: (..., C, T) -> (..., C, T) # NOTE(kamo): Calculate in double precision enhanced = [ wpe_one_iteration( data.contiguous().double(), p.double(), taps=self.taps, delay=self.delay, inverse_power=self.inverse_power, ) for p in power ] enhanced = [ enh.to(dtype=data.dtype).masked_fill( make_pad_mask(ilens, enh.real), 0) for enh in enhanced ] # (B, F, C, T) -> (B, T, C, F) enhanced = [enh.permute(0, 3, 2, 1) for enh in enhanced] if masks is not None: masks = ([m.transpose(-1, -3) for m in masks] if self.nmask > 1 else masks[0].transpose(-1, -3)) if self.nmask == 1: enhanced = enhanced[0] return enhanced, ilens, masks, power
def get_mvdr_vector_with_rtf( psd_n: ComplexTensor, psd_speech: ComplexTensor, psd_noise: ComplexTensor, iterations: int = 3, reference_vector: Union[int, torch.Tensor, None] = None, normalize_ref_channel: Optional[int] = None, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> ComplexTensor: """Return the MVDR (Minimum Variance Distortionless Response) vector calculated with RTF: h = (Npsd^-1 @ rtf) / (rtf^H @ Npsd^-1 @ rtf) Reference: On optimal frequency-domain multichannel linear filtering for noise reduction; M. Souden et al., 2010; https://ieeexplore.ieee.org/document/5089420 Args: psd_n (ComplexTensor): observation/noise covariance matrix (..., F, C, C) psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C) iterations (int): number of iterations in power method reference_vector (torch.Tensor or int): (..., C) or scalar normalize_ref_channel (int): reference channel for normalizing the RTF use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (ComplexTensor): (..., F, C) """ # noqa: H405, D205, D400 if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) # (B, F, C, 1) rtf = get_rtf( psd_speech, psd_noise, reference_vector, iterations=iterations, use_torch_solver=use_torch_solver, ) # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1) if use_torch_solver: numerator = FC.solve(rtf, psd_n)[0].squeeze(-1) else: numerator = FC.matmul(psd_n.inverse2(), rtf).squeeze(-1) denominator = FC.einsum("...d,...d->...", [rtf.squeeze(-1).conj(), numerator]) if normalize_ref_channel is not None: scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj() beamforming_vector = numerator * scale / ( denominator.real.unsqueeze(-1) + eps) else: beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps) return beamforming_vector
def forward( self, input: ComplexTensor, ilens: torch.Tensor ) -> Tuple[List[ComplexTensor], torch.Tensor, OrderedDict]: """Forward. Args: input (ComplexTensor): mixed speech [Batch, Frames, Channel, Freq] ilens (torch.Tensor): input lengths [Batch] Returns: enhanced speech (single-channel): List[ComplexTensor] output lengths other predcited data: OrderedDict[ 'dereverb1': ComplexTensor(Batch, Frames, Channel, Freq), 'mask_dereverb1': torch.Tensor(Batch, Frames, Channel, Freq), 'mask_noise1': torch.Tensor(Batch, Frames, Channel, Freq), 'mask_spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # Shape of input spectrum must be (B, T, F) or (B, T, C, F) assert input.dim() in (3, 4), input.dim() enhanced = input others = OrderedDict() if (self.training and self.loss_type is not None and self.loss_type.startswith("mask")): # Only estimating masks during training for saving memory if self.use_wpe: if input.dim() == 3: mask_w, ilens = self.wpe.predict_mask( input.unsqueeze(-2), ilens) mask_w = mask_w.squeeze(-2) elif input.dim() == 4: mask_w, ilens = self.wpe.predict_mask(input, ilens) if mask_w is not None: if isinstance(enhanced, list): # single-source WPE for spk in range(self.num_spk): others["mask_dereverb{}".format(spk + 1)] = mask_w[spk] else: # multi-source WPE others["mask_dereverb1"] = mask_w if self.use_beamformer and input.dim() == 4: others_b, ilens = self.beamformer.predict_mask(input, ilens) for spk in range(self.num_spk): others["mask_spk{}".format(spk + 1)] = others_b[spk] if len(others_b) > self.num_spk: others["mask_noise1"] = others_b[self.num_spk] return None, ilens, others else: powers = None # Performing both mask estimation and enhancement if input.dim() == 3: # single-channel input (B, T, F) if self.use_wpe: enhanced, ilens, mask_w, powers = self.wpe( input.unsqueeze(-2), ilens) if isinstance(enhanced, list): # single-source WPE enhanced = [enh.squeeze(-2) for enh in enhanced] if mask_w is not None: for spk in range(self.num_spk): key = "dereverb{}".format(spk + 1) others[key] = enhanced[spk] others["mask_" + key] = mask_w[spk].squeeze(-2) else: # multi-source WPE enhanced = enhanced.squeeze(-2) if mask_w is not None: others["dereverb1"] = enhanced others["mask_dereverb1"] = mask_w.squeeze(-2) else: # multi-channel input (B, T, C, F) # 1. WPE if self.use_wpe: enhanced, ilens, mask_w, powers = self.wpe(input, ilens) if mask_w is not None: if isinstance(enhanced, list): # single-source WPE for spk in range(self.num_spk): key = "dereverb{}".format(spk + 1) others[key] = enhanced[spk] others["mask_" + key] = mask_w[spk] else: # multi-source WPE others["dereverb1"] = enhanced others["mask_dereverb1"] = mask_w.squeeze(-2) # 2. Beamformer if self.use_beamformer: if (not self.beamformer.beamformer_type.startswith("wmpdr") or not self.beamformer.beamformer_type.startswith( "wpd") or not self.shared_power or (self.wpe.nmask == 1 and self.num_spk > 1)): powers = None # enhanced: (B, T, C, F) -> (B, T, F) if isinstance(enhanced, list): # outputs of single-source WPE raise NotImplementedError( "Single-source WPE is not supported with beamformer " "in multi-speaker cases.") else: # output of multi-source WPE enhanced, ilens, others_b = self.beamformer( enhanced, ilens, powers=powers) for spk in range(self.num_spk): others["mask_spk{}".format(spk + 1)] = others_b[spk] if len(others_b) > self.num_spk: others["mask_noise1"] = others_b[self.num_spk] if not isinstance(enhanced, list): enhanced = [enhanced] return enhanced, ilens, others
batch_x, batch_y = batch_yx[:, 0].cuda(), batch_yx[:, 1].cuda() temp = batch_x High_noise, Low_noise = Decomposition(batch_y, 0.125) High_origin, Low_origin = Decomposition(batch_x, 0.125) # batch_y = batch_y.unsqueeze(1) # batch_x = batch_x.unsqueeze(1) # batch_y = torch.cat([High_noise, Low_noise], dim=1) batch_x = torch.cat([High_origin, Low_origin], dim=1) output = model(batch_y.cuda()) loss_hf = criterion(output.cpu(), batch_x) output_hf = ComplexTensor(output[:, 0], output[:, 1]).abs() output_lf = ComplexTensor(output[:, 2], output[:, 3]).abs() final_output = ComplexTensor(output[:, 0] + output[:, 2], output[:, 1] + output[:, 3]).abs() High_noise = ComplexTensor(batch_y[:, 0], batch_y[:, 1]).abs() High_origin = ComplexTensor(batch_x[:, 0], batch_x[:, 1]).abs() Low_noise = ComplexTensor(batch_y[:, 2], batch_y[:, 3]).abs() Low_origin = ComplexTensor(batch_x[:, 2], batch_x[:, 3]).abs() fig = plt.figure() gs = GridSpec(nrows=2, ncols=4) highfreq1 = fig.add_subplot(gs[0, 0])
def apply_beamforming_vector(beamform_vector: ComplexTensor, mix: ComplexTensor) -> ComplexTensor: # (..., C) x (..., C, T) -> (..., T) es = FC.einsum('...c,...ct->...t', [beamform_vector.conj(), mix]) return es
torch.cuda.synchronize() start_time = time.time() # 0.14 High_noise, Low_noise = Decomposition(y_.squeeze(0), 0.125) High_noise = High_noise.cuda() Low_noise = Low_noise.cuda() y_ = y_.cuda() y_ = torch.cat([y_, High_noise, Low_noise], dim=1) # x = torch.cat([High_origin.unsqueeze(1), Low_origin.unsqueeze(1)], dim=1) output = decompose_model(y_.cuda().float()).squeeze( 0) # inference # x_ = output[0].cpu().detach().numpy().astype(np.float32) output = ComplexTensor(output[1] + output[3], output[2] + output[4]).abs() x_ = output.cpu().detach().numpy().astype(np.float32) # x_ = ComplexTensor(output[:, 0] + output[:, 2], Low_noise[:, 0] + Low_noise[:, 1]).abs().squeeze(0) # x_ = ComplexTensor(output[:, 0] + output[:, 2], output[:, 1] + output[:, 3]).abs().squeeze(0) # x_ = torch.add(output[:,0], output[:,1]).squeeze(0) # x_ = x_.cpu().detach().numpy().astype(np.float32) # x_ = x_.view(y.shape[0], y.shape[1]) # x_ = x_.cpu() # x_ = x_.detach().numpy().astype(np.float32) torch.cuda.synchronize() elapsed_time = time.time() - start_time psnr_x_ = compare_psnr(x, x_)
def test_conformer_separator_forward_backward_complex( input_dim, num_spk, adim, aheads, layers, linear_units, positionwise_layer_type, positionwise_conv_kernel_size, normalize_before, concat_after, dropout_rate, input_layer, positional_dropout_rate, attention_dropout_rate, nonlinear, conformer_pos_enc_layer_type, conformer_self_attn_layer_type, conformer_activation_type, use_macaron_style_in_conformer, use_cnn_in_conformer, conformer_enc_kernel_size, padding_idx, ): model = ConformerSeparator( input_dim=input_dim, num_spk=num_spk, adim=adim, aheads=aheads, layers=layers, linear_units=linear_units, dropout_rate=dropout_rate, positional_dropout_rate=positional_dropout_rate, attention_dropout_rate=attention_dropout_rate, input_layer=input_layer, normalize_before=normalize_before, concat_after=concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, use_macaron_style_in_conformer=use_macaron_style_in_conformer, nonlinear=nonlinear, conformer_pos_enc_layer_type=conformer_pos_enc_layer_type, conformer_self_attn_layer_type=conformer_self_attn_layer_type, conformer_activation_type=conformer_activation_type, use_cnn_in_conformer=use_cnn_in_conformer, conformer_enc_kernel_size=conformer_enc_kernel_size, padding_idx=padding_idx, ) model.train() real = torch.rand(2, 10, input_dim) imag = torch.rand(2, 10, input_dim) x = ComplexTensor(real, imag) x_lens = torch.tensor([10, 8], dtype=torch.long) masked, flens, others = model(x, ilens=x_lens) assert isinstance(masked[0], ComplexTensor) assert len(masked) == num_spk masked[0].abs().mean().backward()
# --> (B, F, T, btaps + 1, C) --> (B, F, T, (btaps + 1) * C) Ytilde = Ytilde.permute(0, 1, 3, 4, 2).contiguous().view(Bs, Fdim, T, -1) # (B, F, T, 1) enhanced = FC.einsum("...tc,...c->...t", [Ytilde, filter_matrix.conj()]) return enhanced if __name__ == "__main__": ############################################ # Example # ############################################ eps = 1e-10 btaps = 5 bdelay = 3 # pretend to be some STFT: (B, F, C, T) Z = ComplexTensor(torch.rand(4, 256, 2, 518), torch.rand(4, 256, 2, 518)) # Calculate power: (B, F, C, T) power = Z.real ** 2 + Z.imag ** 2 # pretend to be some mask mask_speech = torch.ones_like(Z.real) # (..., C, T) * (..., C, T) -> (..., C, T) power = power * mask_speech # Averaging along the channel axis: (B, F, C, T) -> (B, F, T) power = power.mean(dim=-2) # (B, F, T) --> (B * F, T) power = power.view(-1, power.shape[-1]) inverse_power = 1 / torch.clamp(power, min=eps) B, Fdim, C, T = Z.shape
def test_trace(): t = ComplexTensor(_get_complex_array(10, 10)) x = numpy.trace(t.numpy()) y = F.trace(t).numpy() numpy.testing.assert_allclose(x, y)
def get_WPD_filter_v2( Phi: ComplexTensor, Rf: ComplexTensor, reference_vector: torch.Tensor, eps: float = 1e-15, ) -> ComplexTensor: """Return the WPD vector with filter v2. WPD is the Weighted Power minimization Distortionless response convolutional beamformer. As follows: h = (Rf^-1 @ Phi_{xx}) @ u / tr[(Rf^-1) @ Phi_{xx}] This implementaion is more efficient than `get_WPD_filter` as it skips unnecessary computation with zeros. Reference: T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer for Simultaneous Denoising and Dereverberation," in IEEE Signal Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi: 10.1109/LSP.2019.2911179. https://ieeexplore.ieee.org/document/8691481 Args: Phi (ComplexTensor): (B, F, C, C) is speech PSD. Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) is the power normalized spatio-temporal covariance matrix. reference_vector (torch.Tensor): (B, C) is the reference_vector. eps (float): Returns: filter_matrix (ComplexTensor): (B, F, (btaps+1) * C) """ C = reference_vector.shape[-1] try: inv_Rf = inv(Rf) except Exception: try: reg_coeff_tensor = ( ComplexTensor(torch.rand_like(Rf.real), torch.rand_like(Rf.real)) * 1e-4 ) Rf = Rf / 10e4 Phi = Phi / 10e4 Rf += reg_coeff_tensor inv_Rf = inv(Rf) except Exception: reg_coeff_tensor = ( ComplexTensor(torch.rand_like(Rf.real), torch.rand_like(Rf.real)) * 1e-1 ) Rf = Rf / 10e10 Phi = Phi / 10e10 Rf += reg_coeff_tensor inv_Rf = inv(Rf) # (B, F, (btaps+1) * C, (btaps+1) * C) --> (B, F, (btaps+1) * C, C) inv_Rf_pruned = inv_Rf[..., :C] # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) numerator = FC.einsum("...ec,...cd->...ed", [inv_Rf_pruned, Phi]) # ws: (..., (btaps+1) * C, C) / (...,) -> (..., (btaps+1) * C, C) ws = numerator / (FC.trace(numerator[..., :C, :])[..., None, None] + eps) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) # (B, F, (btaps+1) * C) return beamform_vector
def forward_enh( self, speech_mix: torch.Tensor, speech_mix_lengths: torch.Tensor = None, resort_pre: bool = True, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech_mix: (Batch, samples) or (Batch, samples, channels) speech_ref: (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels) speech_mix_lengths: (Batch,), default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py """ # clean speech signal of each speaker speech_ref = [ kwargs["speech_ref{}".format(spk + 1)] for spk in range(self.num_spk) ] # (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels) speech_ref = torch.stack(speech_ref, dim=1) if "noise_ref1" in kwargs: # noise signal (optional, required when using # frontend models with beamformering) noise_ref = [ kwargs["noise_ref{}".format(n + 1)] for n in range(self.num_noise_type) ] # (Batch, num_noise_type, samples) or # (Batch, num_noise_type, samples, channels) noise_ref = torch.stack(noise_ref, dim=1) else: noise_ref = None # dereverberated noisy signal # (optional, only used for frontend models with WPE) dereverb_speech_ref = kwargs.get("dereverb_ref", None) batch_size = speech_mix.shape[0] speech_lengths = (speech_mix_lengths if speech_mix_lengths is not None else torch.ones(batch_size).int() * speech_mix.shape[1]) assert speech_lengths.dim() == 1, speech_lengths.shape # Check that batch_size is unified assert speech_mix.shape[0] == speech_ref.shape[ 0] == speech_lengths.shape[0], ( speech_mix.shape, speech_ref.shape, speech_lengths.shape, ) # for data-parallel speech_ref = speech_ref[:, :, :speech_lengths.max()] speech_mix = speech_mix[:, :speech_lengths.max()] if self.loss_type != "si_snr": # prepare reference speech and reference spectrum speech_ref = torch.unbind(speech_ref, dim=1) spectrum_ref = [self.enh_model.stft(sr)[0] for sr in speech_ref] # List[ComplexTensor(Batch, T, F)] or List[ComplexTensor(Batch, T, C, F)] spectrum_ref = [ ComplexTensor(sr[..., 0], sr[..., 1]) for sr in spectrum_ref ] spectrum_mix = self.enh_model.stft(speech_mix)[0] spectrum_mix = ComplexTensor(spectrum_mix[..., 0], spectrum_mix[..., 1]) # predict separated speech and masks spectrum_pre, tf_length, mask_pre = self.enh_model( speech_mix, speech_lengths) # TODO(Chenda), Shall we add options for computing loss on # the masked spectrum? # compute TF masking loss if self.loss_type == "magnitude": # compute loss on magnitude spectrum magnitude_pre = [abs(ps) for ps in spectrum_pre] magnitude_ref = [abs(sr) for sr in spectrum_ref] tf_loss, perm = self._permutation_loss(magnitude_ref, magnitude_pre, self.tf_mse_loss) elif self.loss_type == "spectrum": # compute loss on complex spectrum tf_loss, perm = self._permutation_loss(spectrum_ref, spectrum_pre, self.tf_mse_loss) elif self.loss_type.startswith("mask"): if self.loss_type == "mask_mse": loss_func = self.tf_mse_loss else: raise ValueError("Unsupported loss type: %s" % self.loss_type) assert mask_pre is not None mask_pre_ = [ mask_pre["spk{}".format(spk + 1)] for spk in range(self.num_spk) ] # prepare ideal masks mask_ref = self._create_mask_label(spectrum_mix, spectrum_ref, mask_type=self.mask_type) # compute TF masking loss tf_loss, perm = self._permutation_loss(mask_ref, mask_pre_, loss_func) if "dereverb" in mask_pre: if dereverb_speech_ref is None: raise ValueError( "No dereverberated reference for training!\n" 'Please specify "--use_dereverb_ref true" in run.sh' ) dereverb_spectrum_ref = self.enh_model.stft( dereverb_speech_ref)[0] dereverb_spectrum_ref = ComplexTensor( dereverb_spectrum_ref[..., 0], dereverb_spectrum_ref[..., 1]) # ComplexTensor(B, T, F) or ComplexTensor(B, T, C, F) dereverb_mask_ref = self._create_mask_label( spectrum_mix, [dereverb_spectrum_ref], mask_type=self.mask_type)[0] tf_loss = (tf_loss + loss_func( dereverb_mask_ref, mask_pre["dereverb"]).mean()) if "noise1" in mask_pre: if noise_ref is None: raise ValueError( "No noise reference for training!\n" 'Please specify "--use_noise_ref true" in run.sh') noise_ref = torch.unbind(noise_ref, dim=1) noise_spectrum_ref = [ self.enh_model.stft(nr)[0] for nr in noise_ref ] noise_spectrum_ref = [ ComplexTensor(nr[..., 0], nr[..., 1]) for nr in noise_spectrum_ref ] noise_mask_ref = self._create_mask_label( spectrum_mix, noise_spectrum_ref, mask_type=self.mask_type) mask_noise_pre = [ mask_pre["noise{}".format(n + 1)] for n in range(self.num_noise_type) ] tf_noise_loss, perm_n = self._permutation_loss( noise_mask_ref, mask_noise_pre, loss_func) tf_loss = tf_loss + tf_noise_loss else: raise ValueError("Unsupported loss type: %s" % self.loss_type) if spectrum_pre is None and self.loss_type == "mask": # Need the wav prediction in training # TODO(Jing): should coordinate with the enh/nets/***, this is ugly now. self.enh_model.training = False speech_pre, *__ = self.enh_model.forward_rawwav( speech_mix, speech_lengths) self.enh_model.training = self.training else: speech_pre, *__ = self.enh_model.forward_rawwav( speech_mix, speech_lengths) loss = tf_loss else: if speech_ref.dim() == 4: # For si_snr loss of multi-channel input, # only select one channel as the reference speech_ref = speech_ref[..., self.ref_channel] speech_pre, speech_lengths, *__ = self.enh_model.forward_rawwav( speech_mix, speech_lengths) # speech_pre: list[(batch, sample)] assert speech_pre[0].dim() == 2, speech_pre[0].dim() speech_ref = torch.unbind(speech_ref, dim=1) # compute si-snr loss si_snr_loss, perm = self._permutation_loss( speech_ref, speech_pre, self.si_snr_loss_zeromean) loss = si_snr_loss if resort_pre: # speech_pre : list[(bs,T)] of spk # perm : list[(num_spk)] of batch speech_pre_list = [] for batch_idx, p in enumerate(perm): batch_list = [] for spk_idx in p: batch_list.append(speech_pre[spk_idx][batch_idx]) # spk,T speech_pre_list.append(torch.stack(batch_list, dim=0)) speech_pre = torch.stack(speech_pre_list, dim=0) # bs,num_spk,T else: speech_pre = torch.stack(speech_pre, dim=1) # bs,num_spk,T return loss, perm, speech_pre
def signal_framing( signal: Union[torch.Tensor, ComplexTensor], frame_length: int, frame_step: int, bdelay: int, do_padding: bool = False, pad_value: int = 0, indices: List = None, ) -> Union[torch.Tensor, ComplexTensor]: """Expand `signal` into several frames, with each frame of length `frame_length`. Args: signal : (..., T) frame_length: length of each segment frame_step: step for selecting frames bdelay: delay for WPD do_padding: whether or not to pad the input signal at the beginning of the time dimension pad_value: value to fill in the padding Returns: torch.Tensor: if do_padding: (..., T, frame_length) else: (..., T - bdelay - frame_length + 2, frame_length) """ if indices is None: frame_length2 = frame_length - 1 # pad to the right at the last dimension of `signal` (time dimension) if do_padding: # (..., T) --> (..., T + bdelay + frame_length - 2) signal = FC.pad( signal, (bdelay + frame_length2 - 1, 0), "constant", pad_value ) # indices: # [[ 0, 1, ..., frame_length2 - 1, frame_length2 - 1 + bdelay ], # [ 1, 2, ..., frame_length2, frame_length2 + bdelay ], # [ 2, 3, ..., frame_length2 + 1, frame_length2 + 1 + bdelay ], # ... # [ T-bdelay-frame_length2, ..., T-1-bdelay, T-1 ] indices = [ [*range(i, i + frame_length2), i + frame_length2 + bdelay - 1] for i in range(0, signal.shape[-1] - frame_length2 - bdelay + 1, frame_step) ] if isinstance(signal, ComplexTensor): real = signal_framing( signal.real, frame_length, frame_step, bdelay, do_padding, pad_value, indices, ) imag = signal_framing( signal.imag, frame_length, frame_step, bdelay, do_padding, pad_value, indices, ) return ComplexTensor(real, imag) else: # (..., T - bdelay - frame_length + 2, frame_length) signal = signal[..., indices] # signal[..., :-1] = -signal[..., :-1] return signal
def forward(self, input: torch.Tensor, ilens: torch.Tensor): """Forward. Args: input (torch.Tensor): mixed speech [Batch, Nsample, Channel] ilens (torch.Tensor): input lengths [Batch] Returns: enhanced speech (single-channel): torch.Tensor or List[torch.Tensor] output lengths predcited masks: OrderedDict[ 'dereverb': torch.Tensor(Batch, Frames, Channel, Freq), 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), 'noise1': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # wave -> stft -> magnitude specturm input_spectrum, flens = self.stft(input, ilens) # (Batch, Frames, Freq) or (Batch, Frames, Channels, Freq) input_spectrum = ComplexTensor(input_spectrum[..., 0], input_spectrum[..., 1]) if self.normalize_input: input_spectrum = input_spectrum / abs(input_spectrum).max() enhanced = input_spectrum masks = OrderedDict() if input_spectrum.dim() == 3: # single-channel input if self.use_wpe: # (B, T, F) enhanced, flens, mask_w = self.wpe(input_spectrum.unsqueeze(-2), flens) enhanced = enhanced.squeeze(-2) if mask_w is not None: masks["dereverb"] = mask_w.squeeze(-2) elif input_spectrum.dim() == 4: # multi-channel input # 1. WPE if self.use_wpe: # (B, T, C, F) enhanced, flens, mask_w = self.wpe(input_spectrum, flens) if mask_w is not None: masks["dereverb"] = mask_w # 2. Beamformer if self.use_beamformer: # enhanced: (B, T, C, F) -> (B, T, F) enhanced, flens, masks_b = self.beamformer(enhanced, flens) for spk in range(self.num_spk): masks["spk{}".format(spk + 1)] = masks_b[spk] if len(masks_b) > self.num_spk: masks["noise1"] = masks_b[self.num_spk] else: raise ValueError( "Invalid spectrum dimension: {}".format(input_spectrum.shape) ) # Convert ComplexTensor to torch.Tensor # (B, T, F) -> (B, T, F, 2) if isinstance(enhanced, list): # multi-speaker output enhanced = [torch.stack([enh.real, enh.imag], dim=-1) for enh in enhanced] else: # single-speaker output enhanced = torch.stack([enhanced.real, enhanced.imag], dim=-1).float() return enhanced, flens, masks
def online_wpe_step(input_buffer: ComplexTensor, power: torch.Tensor, inv_cov: ComplexTensor = None, filter_taps: ComplexTensor = None, alpha: float = 0.99, taps: int = 10, delay: int = 3): """One step of online dereverberation. Args: input_buffer: (F, C, taps + delay + 1) power: Estimate for the current PSD (F, T) inv_cov: Current estimate of R^-1 filter_taps: Current estimate of filter taps (F, taps * C, taps) alpha (float): Smoothing factor taps (int): Number of filter taps delay (int): Delay in frames Returns: Dereverberated frame of shape (F, D) Updated estimate of R^-1 Updated estimate of the filter taps >>> frame_length = 512 >>> frame_shift = 128 >>> taps = 6 >>> delay = 3 >>> alpha = 0.999 >>> frequency_bins = frame_length // 2 + 1 >>> Q = None >>> G = None >>> unreverbed, Q, G = online_wpe_step(stft, get_power_online(stft), Q, G, ... alpha=alpha, taps=taps, delay=delay) """ assert input_buffer.size(-1) == taps + delay + 1, input_buffer.size() C = input_buffer.size(-2) if inv_cov is None: inv_cov = ComplexTensor( torch.eye(C * taps, dtype=input_buffer.dtype).expand( *input_buffer.size()[:-2], C * taps, C * taps)) if filter_taps is None: filter_taps = ComplexTensor( torch.zeros(*input_buffer.size()[:-2], C * taps, C, dtype=input_buffer.dtype)) window = FC.reverse(input_buffer[..., :-delay - 1], dim=-1) # (..., C, T) -> (..., C * T) window = window.view(*input_buffer.size()[:-2], -1) pred = input_buffer[..., -1] - FC.einsum('...id,...i->...d', (filter_taps.conj(), window)) nominator = FC.einsum('...ij,...j->...i', (inv_cov, window)) denominator = \ FC.einsum('...i,...i->...', (window.conj(), nominator)) + alpha * power kalman_gain = nominator / denominator[..., None] inv_cov_k = inv_cov - FC.einsum('...j,...jm,...i->...im', (window.conj(), inv_cov, kalman_gain)) inv_cov_k /= alpha filter_taps_k = \ filter_taps + FC.einsum('...i,...m->...im', (kalman_gain, pred.conj())) return pred, inv_cov_k, filter_taps_k
def forward( self, speech_mix: torch.Tensor, speech_mix_lengths: torch.Tensor = None, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech_mix: (Batch, samples) or (Batch, samples, channels) speech_ref: (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels) speech_mix_lengths: (Batch,), default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py """ # clean speech signal of each speaker speech_ref = [ kwargs["speech_ref{}".format(spk + 1)] for spk in range(self.num_spk) ] # (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels) speech_ref = torch.stack(speech_ref, dim=1) if "noise_ref1" in kwargs: # noise signal (optional, required when using # frontend models with beamformering) noise_ref = [ kwargs["noise_ref{}".format(n + 1)] for n in range(self.num_noise_type) ] # (Batch, num_noise_type, samples) or # (Batch, num_noise_type, samples, channels) noise_ref = torch.stack(noise_ref, dim=1) else: noise_ref = None # dereverberated noisy signal # (optional, only used for frontend models with WPE) dereverb_speech_ref = kwargs.get("dereverb_ref", None) batch_size = speech_mix.shape[0] speech_lengths = (speech_mix_lengths if speech_mix_lengths is not None else torch.ones(batch_size).int() * speech_mix.shape[1]) assert speech_lengths.dim() == 1, speech_lengths.shape # Check that batch_size is unified assert speech_mix.shape[0] == speech_ref.shape[ 0] == speech_lengths.shape[0], ( speech_mix.shape, speech_ref.shape, speech_lengths.shape, ) batch_size = speech_mix.shape[0] # for data-parallel speech_ref = speech_ref[:, :, :speech_lengths.max()] speech_mix = speech_mix[:, :speech_lengths.max()] if self.loss_type != "si_snr": # prepare reference speech and reference spectrum speech_ref = torch.unbind(speech_ref, dim=1) spectrum_ref = [self.enh_model.stft(sr)[0] for sr in speech_ref] # List[ComplexTensor(Batch, T, F)] or List[ComplexTensor(Batch, T, C, F)] spectrum_ref = [ ComplexTensor(sr[..., 0], sr[..., 1]) for sr in spectrum_ref ] spectrum_mix = self.enh_model.stft(speech_mix)[0] spectrum_mix = ComplexTensor(spectrum_mix[..., 0], spectrum_mix[..., 1]) # predict separated speech and masks spectrum_pre, tf_length, mask_pre = self.enh_model( speech_mix, speech_lengths) # compute TF masking loss if self.loss_type == "magnitude": # compute loss on magnitude spectrum magnitude_pre = [abs(ps) for ps in spectrum_pre] magnitude_ref = [abs(sr) for sr in spectrum_ref] tf_loss, perm = self._permutation_loss(magnitude_ref, magnitude_pre, self.tf_mse_loss) elif self.loss_type == "spectrum": # compute loss on complex spectrum tf_loss, perm = self._permutation_loss(spectrum_ref, spectrum_pre, self.tf_mse_loss) elif self.loss_type.startswith("mask"): if self.loss_type == "mask_mse": loss_func = self.tf_mse_loss else: raise ValueError("Unsupported loss type: %s" % self.loss_type) assert mask_pre is not None mask_pre_ = [ mask_pre["spk{}".format(spk + 1)] for spk in range(self.num_spk) ] # prepare ideal masks mask_ref = self._create_mask_label(spectrum_mix, spectrum_ref, mask_type=self.mask_type) # compute TF masking loss tf_loss, perm = self._permutation_loss(mask_ref, mask_pre_, loss_func) if "dereverb" in mask_pre: if dereverb_speech_ref is None: raise ValueError( "No dereverberated reference for training!\n" 'Please specify "--use_dereverb_ref true" in run.sh' ) dereverb_spectrum_ref = self.enh_model.stft( dereverb_speech_ref)[0] dereverb_spectrum_ref = ComplexTensor( dereverb_spectrum_ref[..., 0], dereverb_spectrum_ref[..., 1]) # ComplexTensor(B, T, F) or ComplexTensor(B, T, C, F) dereverb_mask_ref = self._create_mask_label( spectrum_mix, [dereverb_spectrum_ref], mask_type=self.mask_type)[0] tf_loss = (tf_loss + loss_func( dereverb_mask_ref, mask_pre["dereverb"]).mean()) if "noise1" in mask_pre: if noise_ref is None: raise ValueError( "No noise reference for training!\n" 'Please specify "--use_noise_ref true" in run.sh') noise_ref = torch.unbind(noise_ref, dim=1) noise_spectrum_ref = [ self.enh_model.stft(nr)[0] for nr in noise_ref ] noise_spectrum_ref = [ ComplexTensor(nr[..., 0], nr[..., 1]) for nr in noise_spectrum_ref ] noise_mask_ref = self._create_mask_label( spectrum_mix, noise_spectrum_ref, mask_type=self.mask_type) mask_noise_pre = [ mask_pre["noise{}".format(n + 1)] for n in range(self.num_noise_type) ] tf_noise_loss, perm_n = self._permutation_loss( noise_mask_ref, mask_noise_pre, loss_func) tf_loss = tf_loss + tf_noise_loss else: raise ValueError("Unsupported loss type: %s" % self.loss_type) if self.training: si_snr = None else: speech_pre = [ self.enh_model.stft.inverse(ps, speech_lengths)[0] for ps in spectrum_pre ] if speech_ref[0].dim() == 3: # For si_snr loss, only select one channel as the reference speech_ref = [ sr[..., self.ref_channel] for sr in speech_ref ] # compute si-snr loss si_snr_loss, perm = self._permutation_loss(speech_ref, speech_pre, self.si_snr_loss, perm=perm) si_snr = -si_snr_loss.detach() loss = tf_loss stats = dict( si_snr=si_snr, loss=loss.detach(), ) else: if speech_ref.dim() == 4: # For si_snr loss of multi-channel input, # only select one channel as the reference speech_ref = speech_ref[..., self.ref_channel] speech_pre, speech_lengths, *__ = self.enh_model.forward_rawwav( speech_mix, speech_lengths) # speech_pre: list[(batch, sample)] assert speech_pre[0].dim() == 2, speech_pre[0].dim() speech_ref = torch.unbind(speech_ref, dim=1) # compute si-snr loss si_snr_loss, perm = self._permutation_loss( speech_ref, speech_pre, self.si_snr_loss_zeromean) si_snr = -si_snr_loss loss = si_snr_loss stats = dict(si_snr=si_snr.detach(), loss=loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
High_origin, Low_origin = Decomposition(batch_x, 0.80) High_noise, Low_noise = Decomposition(batch_y, 0.80) # batch_y = batch_y.unsqueeze(1) # batch_x = batch_x.unsqueeze(1) # # batch_y = torch.cat([High_noise.unsqueeze(1), Low_noise.unsqueeze(1)], dim=1) # batch_x = torch.cat([High_origin.unsqueeze(1), Low_origin.unsqueeze(1)], dim=1) output_hf = model_hf(High_noise.cuda()) output_lf = model_lf(Low_noise.cuda()) loss_hf = criterion(output_hf.cpu(), High_origin.unsqueeze(1)) loss_lf = criterion(output_lf.cpu(), Low_origin.unsqueeze(1)) output = ComplexTensor(output_hf[:, 0] + output_lf[:, 0], output_hf[:, 1] + output_lf[:, 1]).abs() output_hf = ComplexTensor(output_hf[:, 0], output_hf[:, 1]).abs() output_lf = ComplexTensor(output_lf[:, 0], output_lf[:, 1]).abs() High_noise = ComplexTensor(High_noise[:, 0], High_noise[:, 1]).abs() High_origin = ComplexTensor(High_origin[:, 0], High_origin[:, 1]).abs() Low_noise = ComplexTensor(Low_noise[:, 0], Low_noise[:, 1]).abs() Low_origin = ComplexTensor(Low_origin[:, 0], Low_origin[:, 1]).abs() fig = plt.figure() gs = GridSpec(nrows=2, ncols=4)
def forward( self, data: ComplexTensor, ilens: torch.LongTensor ) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]: """The forward function Notation: B: Batch C: Channel T: Time or Sequence length F: Freq Args: data (ComplexTensor): (B, T, C, F), double precision ilens (torch.Tensor): (B,) Returns: enhanced (ComplexTensor): (B, T, F), double precision ilens (torch.Tensor): (B,) masks (torch.Tensor): (B, T, C, F) """ def apply_beamforming(data, ilens, psd_speech, psd_n, beamformer_type): # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech.float(), ilens) else: # (optional) Create onehot vector for fixed reference microphone u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )), device=data.device) u[..., self.ref_channel].fill_(1) if beamformer_type in ("mpdr", "mvdr"): ws = get_mvdr_vector(psd_speech, psd_n, u.double()) enhanced = apply_beamforming_vector(ws, data) elif beamformer_type == "wpd": ws = get_WPD_filter_v2(psd_speech, psd_n, u.double()) enhanced = perform_WPD_filtering(ws, data, self.bdelay, self.btaps) else: raise ValueError("Not supporting beamformer_type={}".format( beamformer_type)) return enhanced, ws # data (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) # mask: [(B, F, C, T)] masks, _ = self.mask(data.float(), ilens) assert self.nmask == len(masks) # floor masks with self.eps to increase numerical stability masks = [torch.clamp(m, min=self.eps) for m in masks] if self.num_spk == 1: # single-speaker case if self.use_noise_mask: # (mask_speech, mask_noise) mask_speech, mask_noise = masks else: # (mask_speech,) mask_speech = masks[0] mask_noise = 1 - mask_speech psd_speech = get_power_spectral_density_matrix( data, mask_speech.double()) if self.beamformer_type == "mvdr": # psd of noise psd_n = get_power_spectral_density_matrix( data, mask_noise.double()) elif self.beamformer_type == "mpdr": # psd of observed signal psd_n = FC.einsum("...ct,...et->...ce", [data, data.conj()]) elif self.beamformer_type == "wpd": # Calculate power: (..., C, T) power_speech = (data.real**2 + data.imag**2) * mask_speech.double() # Averaging along the channel axis: (B, F, C, T) -> (B, F, T) power_speech = power_speech.mean(dim=-2) inverse_power = 1 / torch.clamp(power_speech, min=self.eps) # covariance of expanded observed speech psd_n = get_covariances(data, inverse_power, self.bdelay, self.btaps, get_vector=False) else: raise ValueError("Not supporting beamformer_type={}".format( self.beamformer_type)) enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_n, self.beamformer_type) # (..., F, T) -> (..., T, F) enhanced = enhanced.transpose(-1, -2) else: # multi-speaker case if self.use_noise_mask: # (mask_speech1, ..., mask_noise) mask_speech = list(masks[:-1]) mask_noise = masks[-1] else: # (mask_speech1, ..., mask_speechX) mask_speech = list(masks) mask_noise = None psd_speeches = [ get_power_spectral_density_matrix(data, mask) for mask in mask_speech ] if self.beamformer_type == "mvdr": # psd of noise if mask_noise is not None: psd_n = get_power_spectral_density_matrix(data, mask_noise) elif self.beamformer_type == "mpdr": # psd of observed speech psd_n = FC.einsum("...ct,...et->...ce", [data, data.conj()]) elif self.beamformer_type == "wpd": # Calculate power: (..., C, T) power = data.real**2 + data.imag**2 power_speeches = [power * mask for mask in mask_speech] # Averaging along the channel axis: (B, F, C, T) -> (B, F, T) power_speeches = [ps.mean(dim=-2) for ps in power_speeches] inverse_poweres = [ 1 / torch.clamp(ps, min=self.eps) for ps in power_speeches ] # covariance of expanded observed speech psd_n = [ get_covariances(data, inv_ps, self.bdelay, self.btaps, get_vector=False) for inv_ps in inverse_poweres ] else: raise ValueError("Not supporting beamformer_type={}".format( self.beamformer_type)) enhanced = [] for i in range(self.num_spk): psd_speech = psd_speeches.pop(i) # treat all other speakers' psd_speech as noises if self.beamformer_type == "mvdr": psd_noise = sum(psd_speeches) if mask_noise is not None: psd_noise = psd_noise + psd_n enh, w = apply_beamforming(data, ilens, psd_speech, psd_noise, self.beamformer_type) elif self.beamformer_type == "mpdr": enh, w = apply_beamforming(data, ilens, psd_speech, psd_n, self.beamformer_type) elif self.beamformer_type == "wpd": enh, w = apply_beamforming(data, ilens, psd_speech, psd_n[i], self.beamformer_type) else: raise ValueError( "Not supporting beamformer_type={}".format( self.beamformer_type)) psd_speeches.insert(i, psd_speech) # (..., F, T) -> (..., T, F) enh = enh.transpose(-1, -2) enhanced.append(enh) # (..., F, C, T) -> (..., T, C, F) masks = [m.transpose(-1, -3) for m in masks] return enhanced, ilens, masks
def composition(high, low, end=0.2, start=0.0): output = ComplexTensor(high[:,1]+low[:, 1], low[:,0]+ low[:, 2]).abs() return output
def to_torch_tensor(x): """Change to torch.Tensor or ComplexTensor from numpy.ndarray. Args: x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict. Returns: Tensor or ComplexTensor: Type converted inputs. Examples: >>> xs = np.ones(3, dtype=np.float32) >>> xs = to_torch_tensor(xs) tensor([1., 1., 1.]) >>> xs = torch.ones(3, 4, 5) >>> assert to_torch_tensor(xs) is xs >>> xs = {'real': xs, 'imag': xs} >>> to_torch_tensor(xs) ComplexTensor( Real: tensor([1., 1., 1.]) Imag; tensor([1., 1., 1.]) ) """ # If numpy, change to torch tensor if isinstance(x, np.ndarray): if x.dtype.kind == "c": # Dynamically importing because torch_complex requires python3 from torch_complex.tensor import ComplexTensor return ComplexTensor(x) else: return torch.from_numpy(x) # If {'real': ..., 'imag': ...}, convert to ComplexTensor elif isinstance(x, dict): # Dynamically importing because torch_complex requires python3 from torch_complex.tensor import ComplexTensor if "real" not in x or "imag" not in x: raise ValueError("has 'real' and 'imag' keys: {}".format(list(x))) # Relative importing because of using python3 syntax return ComplexTensor(x["real"], x["imag"]) # If torch.Tensor, as it is elif isinstance(x, torch.Tensor): return x else: error = ("x must be numpy.ndarray, torch.Tensor or a dict like " "{{'real': torch.Tensor, 'imag': torch.Tensor}}, " "but got {}".format(type(x))) try: from torch_complex.tensor import ComplexTensor except Exception: # If PY2 raise ValueError(error) else: # If PY3 if isinstance(x, ComplexTensor): return x else: raise ValueError(error)
def forward( self, data: ComplexTensor, ilens: torch.LongTensor, powers: Union[List[torch.Tensor], None] = None, ) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]: """DNN_Beamformer forward function. Notation: B: Batch C: Channel T: Time or Sequence length F: Freq Args: data (ComplexTensor): (B, T, C, F) ilens (torch.Tensor): (B,) powers (List[torch.Tensor] or None): used for wMPDR or WPD (B, F, T) Returns: enhanced (ComplexTensor): (B, T, F) ilens (torch.Tensor): (B,) masks (torch.Tensor): (B, T, C, F) """ def apply_beamforming(data, ilens, psd_n, psd_speech, psd_distortion=None): """Beamforming with the provided statistics. Args: data (ComplexTensor): (B, F, C, T) ilens (torch.Tensor): (B,) psd_n (ComplexTensor): Noise covariance matrix for MVDR (B, F, C, C) Observation covariance matrix for MPDR/wMPDR (B, F, C, C) Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C) psd_speech (ComplexTensor): Speech covariance matrix (B, F, C, C) psd_distortion (ComplexTensor): Noise covariance matrix (B, F, C, C) Return: enhanced (ComplexTensor): (B, F, T) ws (ComplexTensor): (B, F) or (B, F, (btaps+1)*C) """ # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens) u = u.double() else: if self.beamformer_type.endswith("_souden"): # (optional) Create onehot vector for fixed reference microphone u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )), device=data.device, dtype=torch.double) u[..., self.ref_channel].fill_(1) else: # for simplifying computation in RTF-based beamforming u = self.ref_channel if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"): ws = get_mvdr_vector_with_rtf( psd_n.double(), psd_speech.double(), psd_distortion.double(), iterations=self.rtf_iterations, reference_vector=u, normalize_ref_channel=self.ref_channel, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, data.double()) elif self.beamformer_type in ("mpdr_souden", "mvdr_souden", "wmpdr_souden"): ws = get_mvdr_vector( psd_speech.double(), psd_n.double(), u, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, data.double()) elif self.beamformer_type == "wpd": ws = get_WPD_filter_with_rtf( psd_n.double(), psd_speech.double(), psd_distortion.double(), iterations=self.rtf_iterations, reference_vector=u, normalize_ref_channel=self.ref_channel, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = perform_WPD_filtering(ws, data.double(), self.bdelay, self.btaps) elif self.beamformer_type == "wpd_souden": ws = get_WPD_filter_v2( psd_speech.double(), psd_n.double(), u, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = perform_WPD_filtering(ws, data.double(), self.bdelay, self.btaps) else: raise ValueError("Not supporting beamformer_type={}".format( self.beamformer_type)) return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype) # data (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) data_d = data.double() # mask: [(B, F, C, T)] masks, _ = self.mask(data, ilens) assert self.nmask == len(masks), len(masks) # floor masks to increase numerical stability if self.mask_flooring: masks = [torch.clamp(m, min=self.flooring_thres) for m in masks] if self.num_spk == 1: # single-speaker case if self.use_noise_mask: # (mask_speech, mask_noise) mask_speech, mask_noise = masks else: # (mask_speech,) mask_speech = masks[0] mask_noise = 1 - mask_speech if self.beamformer_type.startswith( "wmpdr") or self.beamformer_type.startswith("wpd"): if powers is None: power_input = data_d.real**2 + data_d.imag**2 # Averaging along the channel axis: (..., C, T) -> (..., T) powers = (power_input * mask_speech.double()).mean(dim=-2) else: assert len(powers) == 1, len(powers) powers = powers[0] inverse_power = 1 / torch.clamp(powers, min=self.eps) psd_speech = get_power_spectral_density_matrix( data_d, mask_speech.double()) if mask_noise is not None and ( self.beamformer_type == "mvdr_souden" or not self.beamformer_type.endswith("_souden")): # MVDR or other RTF-based formulas psd_noise = get_power_spectral_density_matrix( data_d, mask_noise.double()) if self.beamformer_type == "mvdr": enhanced, ws = apply_beamforming(data, ilens, psd_noise, psd_speech, psd_distortion=psd_noise) elif self.beamformer_type == "mvdr_souden": enhanced, ws = apply_beamforming(data, ilens, psd_noise, psd_speech) elif self.beamformer_type == "mpdr": psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise) elif self.beamformer_type == "mpdr_souden": psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech) elif self.beamformer_type == "wmpdr": psd_observed = FC.einsum( "...ct,...et->...ce", [data_d * inverse_power[..., None, :], data_d.conj()], ) enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise) elif self.beamformer_type == "wmpdr_souden": psd_observed = FC.einsum( "...ct,...et->...ce", [data_d * inverse_power[..., None, :], data_d.conj()], ) enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech) elif self.beamformer_type == "wpd": psd_observed_bar = get_covariances(data_d, inverse_power, self.bdelay, self.btaps, get_vector=False) enhanced, ws = apply_beamforming(data, ilens, psd_observed_bar, psd_speech, psd_distortion=psd_noise) elif self.beamformer_type == "wpd_souden": psd_observed_bar = get_covariances(data_d, inverse_power, self.bdelay, self.btaps, get_vector=False) enhanced, ws = apply_beamforming(data, ilens, psd_observed_bar, psd_speech) else: raise ValueError("Not supporting beamformer_type={}".format( self.beamformer_type)) # (..., F, T) -> (..., T, F) enhanced = enhanced.transpose(-1, -2) else: # multi-speaker case if self.use_noise_mask: # (mask_speech1, ..., mask_noise) mask_speech = list(masks[:-1]) mask_noise = masks[-1] else: # (mask_speech1, ..., mask_speechX) mask_speech = list(masks) mask_noise = None if self.beamformer_type.startswith( "wmpdr") or self.beamformer_type.startswith("wpd"): if powers is None: power_input = data_d.real**2 + data_d.imag**2 # Averaging along the channel axis: (..., C, T) -> (..., T) powers = [(power_input * m.double()).mean(dim=-2) for m in mask_speech] else: assert len(powers) == self.num_spk, len(powers) inverse_power = [ 1 / torch.clamp(p, min=self.eps) for p in powers ] psd_speeches = [ get_power_spectral_density_matrix(data_d, mask.double()) for mask in mask_speech ] if mask_noise is not None and ( self.beamformer_type == "mvdr_souden" or not self.beamformer_type.endswith("_souden")): # MVDR or other RTF-based formulas psd_noise = get_power_spectral_density_matrix( data_d, mask_noise.double()) if self.beamformer_type in ("mpdr", "mpdr_souden"): psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) elif self.beamformer_type in ("wmpdr", "wmpdr_souden"): psd_observed = [ FC.einsum( "...ct,...et->...ce", [data_d * inv_p[..., None, :], data_d.conj()], ) for inv_p in inverse_power ] elif self.beamformer_type in ("wpd", "wpd_souden"): psd_observed_bar = [ get_covariances(data_d, inv_p, self.bdelay, self.btaps, get_vector=False) for inv_p in inverse_power ] enhanced, ws = [], [] for i in range(self.num_spk): psd_speech = psd_speeches.pop(i) if (self.beamformer_type == "mvdr_souden" or not self.beamformer_type.endswith("_souden")): psd_noise_i = (psd_noise + sum(psd_speeches) if mask_noise is not None else sum(psd_speeches)) # treat all other speakers' psd_speech as noises if self.beamformer_type == "mvdr": enh, w = apply_beamforming(data, ilens, psd_noise_i, psd_speech, psd_distortion=psd_noise_i) elif self.beamformer_type == "mvdr_souden": enh, w = apply_beamforming(data, ilens, psd_noise_i, psd_speech) elif self.beamformer_type == "mpdr": enh, w = apply_beamforming( data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise_i, ) elif self.beamformer_type == "mpdr_souden": enh, w = apply_beamforming(data, ilens, psd_observed, psd_speech) elif self.beamformer_type == "wmpdr": enh, w = apply_beamforming( data, ilens, psd_observed[i], psd_speech, psd_distortion=psd_noise_i, ) elif self.beamformer_type == "wmpdr_souden": enh, w = apply_beamforming(data, ilens, psd_observed[i], psd_speech) elif self.beamformer_type == "wpd": enh, w = apply_beamforming( data, ilens, psd_observed_bar[i], psd_speech, psd_distortion=psd_noise_i, ) elif self.beamformer_type == "wpd_souden": enh, w = apply_beamforming(data, ilens, psd_observed_bar[i], psd_speech) else: raise ValueError( "Not supporting beamformer_type={}".format( self.beamformer_type)) psd_speeches.insert(i, psd_speech) # (..., F, T) -> (..., T, F) enh = enh.transpose(-1, -2) enhanced.append(enh) ws.append(w) # (..., F, C, T) -> (..., T, C, F) masks = [m.transpose(-1, -3) for m in masks] return enhanced, ilens, masks