def apply_beamforming( self, data, ilens, psd_n, psd_speech, psd_distortion=None, rtf_mat=None, spk=0, ): """Beamforming with the provided statistics. Args: data (torch.complex64/ComplexTensor): (B, F, C, T) ilens (torch.Tensor): (B,) psd_n (torch.complex64/ComplexTensor): Noise covariance matrix for MVDR (B, F, C, C) Observation covariance matrix for MPDR/wMPDR (B, F, C, C) Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C) psd_speech (torch.complex64/ComplexTensor): Speech covariance matrix (B, F, C, C) psd_distortion (torch.complex64/ComplexTensor): Noise covariance matrix (B, F, C, C) rtf_mat (torch.complex64/ComplexTensor): RTF matrix (B, F, C, num_spk) spk (int): speaker index Return: enhanced (torch.complex64/ComplexTensor): (B, F, T) ws (torch.complex64/ComplexTensor): (B, F) or (B, F, (btaps+1)*C) """ # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens) u = u.double() else: if self.beamformer_type.endswith("_souden"): # (optional) Create onehot vector for fixed reference microphone u = torch.zeros( *(data.size()[:-3] + (data.size(-2),)), device=data.device, dtype=torch.double ) u[..., self.ref_channel].fill_(1) else: # for simplifying computation in RTF-based beamforming u = self.ref_channel if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"): ws = get_mvdr_vector_with_rtf( to_double(psd_n), to_double(psd_speech), to_double(psd_distortion), iterations=self.rtf_iterations, reference_vector=u, normalize_ref_channel=self.ref_channel, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, to_double(data)) elif self.beamformer_type == "mvdr_tfs": assert isinstance(psd_n, (list, tuple)) ws = [ get_mvdr_vector_with_rtf( to_double(psd_n_i), to_double(psd_speech), to_double(psd_distortion), iterations=self.rtf_iterations, reference_vector=u, normalize_ref_channel=self.ref_channel, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) for psd_n_i in psd_n ] enhanced = stack([apply_beamforming_vector(w, to_double(data)) for w in ws]) with torch.no_grad(): index = enhanced.abs().argmin(dim=0, keepdims=True) enhanced = enhanced.gather(0, index).squeeze(0) ws = stack(ws, dim=0) elif self.beamformer_type in ( "mpdr_souden", "mvdr_souden", "wmpdr_souden", ): ws = get_mvdr_vector( to_double(psd_speech), to_double(psd_n), u, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, to_double(data)) elif self.beamformer_type == "mvdr_tfs_souden": assert isinstance(psd_n, (list, tuple)) ws = [ get_mvdr_vector( to_double(psd_speech), to_double(psd_n_i), u, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) for psd_n_i in psd_n ] enhanced = stack([apply_beamforming_vector(w, to_double(data)) for w in ws]) with torch.no_grad(): index = enhanced.abs().argmin(dim=0, keepdims=True) enhanced = enhanced.gather(0, index).squeeze(0) ws = stack(ws, dim=0) elif self.beamformer_type == "wpd": ws = get_WPD_filter_with_rtf( to_double(psd_n), to_double(psd_speech), to_double(psd_distortion), iterations=self.rtf_iterations, reference_vector=u, normalize_ref_channel=self.ref_channel, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = perform_WPD_filtering( ws, to_double(data), self.bdelay, self.btaps ) elif self.beamformer_type == "wpd_souden": ws = get_WPD_filter_v2( to_double(psd_speech), to_double(psd_n), u, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = perform_WPD_filtering( ws, to_double(data), self.bdelay, self.btaps ) elif self.beamformer_type in ("mwf", "wmwf"): ws = get_mwf_vector( to_double(psd_speech), to_double(psd_n), u, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, to_double(data)) elif self.beamformer_type == "sdw_mwf": ws = get_sdw_mwf_vector( to_double(psd_speech), to_double(psd_n), u, denoising_weight=self.mwf_mu, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, to_double(data)) elif self.beamformer_type == "r1mwf": ws = get_rank1_mwf_vector( to_double(psd_speech), to_double(psd_n), u, denoising_weight=self.mwf_mu, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, to_double(data)) elif self.beamformer_type in ("lcmp", "wlcmp", "lcmv"): ws = get_lcmv_vector_with_rtf( to_double(psd_n), to_double(rtf_mat), reference_vector=spk, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, to_double(data)) elif self.beamformer_type.startswith("gev"): ws = get_gev_vector( to_double(psd_n), to_double(psd_speech), mode="power", diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, to_double(data)) if self.beamformer_type == "gev_ban": gain = blind_analytic_normalization(ws, to_double(psd_n)) enhanced = enhanced * gain.unsqueeze(-1) else: raise ValueError( "Not supporting beamformer_type={}".format(self.beamformer_type) ) return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype)
def forward( self, data: Union[torch.Tensor, ComplexTensor], ilens: torch.LongTensor, powers: Optional[List[torch.Tensor]] = None, oracle_masks: Optional[List[torch.Tensor]] = None, ) -> Tuple[Union[torch.Tensor, ComplexTensor], torch.LongTensor, torch.Tensor]: """DNN_Beamformer forward function. Notation: B: Batch C: Channel T: Time or Sequence length F: Freq Args: data (torch.complex64/ComplexTensor): (B, T, C, F) ilens (torch.Tensor): (B,) powers (List[torch.Tensor] or None): used for wMPDR or WPD (B, F, T) oracle_masks (List[torch.Tensor] or None): oracle masks (B, F, C, T) if not None, oracle_masks will be used instead of self.mask Returns: enhanced (torch.complex64/ComplexTensor): (B, T, F) ilens (torch.Tensor): (B,) masks (torch.Tensor): (B, T, C, F) """ # data (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) data_d = to_double(data) # mask: [(B, F, C, T)] if oracle_masks is not None: masks = oracle_masks else: masks, _ = self.mask(data, ilens) assert self.nmask == len(masks), len(masks) # floor masks to increase numerical stability if self.mask_flooring: masks = [torch.clamp(m, min=self.flooring_thres) for m in masks] if self.num_spk == 1: # single-speaker case if self.use_noise_mask: # (mask_speech, mask_noise) mask_speech, mask_noise = masks else: # (mask_speech,) mask_speech = masks[0] mask_noise = 1 - mask_speech if self.beamformer_type in ("lcmv", "lcmp", "wlcmp"): raise NotImplementedError("Single source is not supported yet") beamformer_stats = prepare_beamformer_stats( data_d, [mask_speech], mask_noise, powers=powers, beamformer_type=self.beamformer_type, bdelay=self.bdelay, btaps=self.btaps, eps=self.eps, ) if self.beamformer_type in ("mvdr", "mpdr", "wmpdr", "wpd"): enhanced, ws = self.apply_beamforming( data, ilens, beamformer_stats["psd_n"], beamformer_stats["psd_speech"], psd_distortion=beamformer_stats["psd_distortion"], ) elif ( self.beamformer_type.endswith("_souden") or self.beamformer_type == "mwf" or self.beamformer_type == "wmwf" or self.beamformer_type == "sdw_mwf" or self.beamformer_type == "r1mwf" or self.beamformer_type.startswith("gev") ): enhanced, ws = self.apply_beamforming( data, ilens, beamformer_stats["psd_n"], beamformer_stats["psd_speech"], ) else: raise ValueError( "Not supporting beamformer_type={}".format(self.beamformer_type) ) # (..., F, T) -> (..., T, F) enhanced = enhanced.transpose(-1, -2) else: # multi-speaker case if self.use_noise_mask: # (mask_speech1, ..., mask_noise) mask_speech = list(masks[:-1]) mask_noise = masks[-1] else: # (mask_speech1, ..., mask_speechX) mask_speech = list(masks) mask_noise = None beamformer_stats = prepare_beamformer_stats( data_d, mask_speech, mask_noise, powers=powers, beamformer_type=self.beamformer_type, bdelay=self.bdelay, btaps=self.btaps, eps=self.eps, ) if self.beamformer_type in ("lcmv", "lcmp", "wlcmp"): rtf_mat = get_rtf_matrix( beamformer_stats["psd_speech"], beamformer_stats["psd_distortion"], diagonal_loading=self.diagonal_loading, ref_channel=self.ref_channel, rtf_iterations=self.rtf_iterations, use_torch_solver=self.use_torch_solver, diag_eps=self.diag_eps, ) enhanced, ws = [], [] for i in range(self.num_spk): # treat all other speakers' psd_speech as noises if self.beamformer_type in ("mvdr", "mvdr_tfs", "wmpdr", "wpd"): enh, w = self.apply_beamforming( data, ilens, beamformer_stats["psd_n"][i], beamformer_stats["psd_speech"][i], psd_distortion=beamformer_stats["psd_distortion"][i], ) elif self.beamformer_type in ( "mvdr_souden", "mvdr_tfs_souden", "wmpdr_souden", "wpd_souden", "wmwf", "sdw_mwf", "r1mwf", "gev", "gev_ban", ): enh, w = self.apply_beamforming( data, ilens, beamformer_stats["psd_n"][i], beamformer_stats["psd_speech"][i], ) elif self.beamformer_type == "mpdr": enh, w = self.apply_beamforming( data, ilens, beamformer_stats["psd_n"], beamformer_stats["psd_speech"][i], psd_distortion=beamformer_stats["psd_distortion"][i], ) elif self.beamformer_type in ("mpdr_souden", "mwf"): enh, w = self.apply_beamforming( data, ilens, beamformer_stats["psd_n"], beamformer_stats["psd_speech"][i], ) elif self.beamformer_type == "lcmp": enh, w = self.apply_beamforming( data, ilens, beamformer_stats["psd_n"], beamformer_stats["psd_speech"][i], rtf_mat=rtf_mat, spk=i, ) elif self.beamformer_type in ("lcmv", "wlcmp"): enh, w = self.apply_beamforming( data, ilens, beamformer_stats["psd_n"][i], beamformer_stats["psd_speech"][i], rtf_mat=rtf_mat, spk=i, ) else: raise ValueError( "Not supporting beamformer_type={}".format(self.beamformer_type) ) # (..., F, T) -> (..., T, F) enh = enh.transpose(-1, -2) enhanced.append(enh) ws.append(w) # (..., F, C, T) -> (..., T, C, F) masks = [m.transpose(-1, -3) for m in masks] return enhanced, ilens, masks
def forward( self, data: Union[torch.Tensor, ComplexTensor], ilens: torch.LongTensor ) -> Tuple[Union[torch.Tensor, ComplexTensor], torch.LongTensor, Union[ torch.Tensor, ComplexTensor], ]: """DNN_WPE forward function. Notation: B: Batch C: Channel T: Time or Sequence length F: Freq or Some dimension of the feature vector Args: data: (B, T, C, F) ilens: (B,) Returns: enhanced (torch.Tensor or List[torch.Tensor]): (B, T, C, F) ilens: (B,) masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F) power (List[torch.Tensor]): (B, F, T) """ # (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) enhanced = [data for i in range(self.nmask)] masks = None power = None for i in range(self.iterations): # Calculate power: (..., C, T) power = [enh.real**2 + enh.imag**2 for enh in enhanced] if i == 0 and self.use_dnn_mask: # mask: (B, F, C, T) masks, _ = self.mask_est(data, ilens) # floor masks to increase numerical stability if self.mask_flooring: masks = [m.clamp(min=self.flooring_thres) for m in masks] if self.normalization: # Normalize along T masks = [m / m.sum(dim=-1, keepdim=True) for m in masks] # (..., C, T) * (..., C, T) -> (..., C, T) power = [p * masks[i] for i, p in enumerate(power)] # Averaging along the channel axis: (..., C, T) -> (..., T) power = [p.mean(dim=-2).clamp(min=self.eps) for p in power] # enhanced: (..., C, T) -> (..., C, T) # NOTE(kamo): Calculate in double precision enhanced = [ wpe_one_iteration( to_double(data.contiguous()), to_double(p), taps=self.taps, delay=self.delay, inverse_power=self.inverse_power, ) for p in power ] enhanced = [ enh.to(dtype=data.dtype).masked_fill( make_pad_mask(ilens, enh.real), 0) for enh in enhanced ] # (B, F, C, T) -> (B, T, C, F) enhanced = [enh.permute(0, 3, 2, 1) for enh in enhanced] if masks is not None: masks = ([m.transpose(-1, -3) for m in masks] if self.nmask > 1 else masks[0].transpose(-1, -3)) if self.nmask == 1: enhanced = enhanced[0] return enhanced, ilens, masks, power
def apply_beamforming(data, ilens, psd_n, psd_speech, psd_distortion=None): """Beamforming with the provided statistics. Args: data (torch.complex64/ComplexTensor): (B, F, C, T) ilens (torch.Tensor): (B,) psd_n (torch.complex64/ComplexTensor): Noise covariance matrix for MVDR (B, F, C, C) Observation covariance matrix for MPDR/wMPDR (B, F, C, C) Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C) psd_speech (torch.complex64/ComplexTensor): Speech covariance matrix (B, F, C, C) psd_distortion (torch.complex64/ComplexTensor): Noise covariance matrix (B, F, C, C) Return: enhanced (torch.complex64/ComplexTensor): (B, F, T) ws (torch.complex64/ComplexTensor): (B, F) or (B, F, (btaps+1)*C) """ # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens) u = u.double() else: if self.beamformer_type.endswith("_souden"): # (optional) Create onehot vector for fixed reference microphone u = torch.zeros( *(data.size()[:-3] + (data.size(-2),)), device=data.device, dtype=torch.double ) u[..., self.ref_channel].fill_(1) else: # for simplifying computation in RTF-based beamforming u = self.ref_channel if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"): ws = get_mvdr_vector_with_rtf( to_double(psd_n), to_double(psd_speech), to_double(psd_distortion), iterations=self.rtf_iterations, reference_vector=u, normalize_ref_channel=self.ref_channel, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, to_double(data)) elif self.beamformer_type in ("mpdr_souden", "mvdr_souden", "wmpdr_souden"): ws = get_mvdr_vector( to_double(psd_speech), to_double(psd_n), u, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, to_double(data)) elif self.beamformer_type == "wpd": ws = get_WPD_filter_with_rtf( to_double(psd_n), to_double(psd_speech), to_double(psd_distortion), iterations=self.rtf_iterations, reference_vector=u, normalize_ref_channel=self.ref_channel, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = perform_WPD_filtering( ws, to_double(data), self.bdelay, self.btaps ) elif self.beamformer_type == "wpd_souden": ws = get_WPD_filter_v2( to_double(psd_speech), to_double(psd_n), u, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = perform_WPD_filtering( ws, to_double(data), self.bdelay, self.btaps ) else: raise ValueError( "Not supporting beamformer_type={}".format(self.beamformer_type) ) return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype)
def forward( self, data: Union[torch.Tensor, ComplexTensor], ilens: torch.LongTensor, powers: Union[List[torch.Tensor], None] = None, ) -> Tuple[Union[torch.Tensor, ComplexTensor], torch.LongTensor, torch.Tensor]: """DNN_Beamformer forward function. Notation: B: Batch C: Channel T: Time or Sequence length F: Freq Args: data (torch.complex64/ComplexTensor): (B, T, C, F) ilens (torch.Tensor): (B,) powers (List[torch.Tensor] or None): used for wMPDR or WPD (B, F, T) Returns: enhanced (torch.complex64/ComplexTensor): (B, T, F) ilens (torch.Tensor): (B,) masks (torch.Tensor): (B, T, C, F) """ def apply_beamforming(data, ilens, psd_n, psd_speech, psd_distortion=None): """Beamforming with the provided statistics. Args: data (torch.complex64/ComplexTensor): (B, F, C, T) ilens (torch.Tensor): (B,) psd_n (torch.complex64/ComplexTensor): Noise covariance matrix for MVDR (B, F, C, C) Observation covariance matrix for MPDR/wMPDR (B, F, C, C) Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C) psd_speech (torch.complex64/ComplexTensor): Speech covariance matrix (B, F, C, C) psd_distortion (torch.complex64/ComplexTensor): Noise covariance matrix (B, F, C, C) Return: enhanced (torch.complex64/ComplexTensor): (B, F, T) ws (torch.complex64/ComplexTensor): (B, F) or (B, F, (btaps+1)*C) """ # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens) u = u.double() else: if self.beamformer_type.endswith("_souden"): # (optional) Create onehot vector for fixed reference microphone u = torch.zeros( *(data.size()[:-3] + (data.size(-2),)), device=data.device, dtype=torch.double ) u[..., self.ref_channel].fill_(1) else: # for simplifying computation in RTF-based beamforming u = self.ref_channel if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"): ws = get_mvdr_vector_with_rtf( to_double(psd_n), to_double(psd_speech), to_double(psd_distortion), iterations=self.rtf_iterations, reference_vector=u, normalize_ref_channel=self.ref_channel, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, to_double(data)) elif self.beamformer_type in ("mpdr_souden", "mvdr_souden", "wmpdr_souden"): ws = get_mvdr_vector( to_double(psd_speech), to_double(psd_n), u, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, to_double(data)) elif self.beamformer_type == "wpd": ws = get_WPD_filter_with_rtf( to_double(psd_n), to_double(psd_speech), to_double(psd_distortion), iterations=self.rtf_iterations, reference_vector=u, normalize_ref_channel=self.ref_channel, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = perform_WPD_filtering( ws, to_double(data), self.bdelay, self.btaps ) elif self.beamformer_type == "wpd_souden": ws = get_WPD_filter_v2( to_double(psd_speech), to_double(psd_n), u, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = perform_WPD_filtering( ws, to_double(data), self.bdelay, self.btaps ) else: raise ValueError( "Not supporting beamformer_type={}".format(self.beamformer_type) ) return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype) if isinstance(data, ComplexTensor): complex_wrapper = FC elif is_torch_1_9_plus and torch.is_complex(data): complex_wrapper = torch else: raise ValueError( "Please update your PyTorch version to 1.8+ for compelx support." ) # data (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) data_d = to_double(data) # mask: [(B, F, C, T)] masks, _ = self.mask(data, ilens) assert self.nmask == len(masks), len(masks) # floor masks to increase numerical stability if self.mask_flooring: masks = [torch.clamp(m, min=self.flooring_thres) for m in masks] if self.num_spk == 1: # single-speaker case if self.use_noise_mask: # (mask_speech, mask_noise) mask_speech, mask_noise = masks else: # (mask_speech,) mask_speech = masks[0] mask_noise = 1 - mask_speech if self.beamformer_type.startswith( "wmpdr" ) or self.beamformer_type.startswith("wpd"): if powers is None: power_input = data_d.real ** 2 + data_d.imag ** 2 # Averaging along the channel axis: (..., C, T) -> (..., T) powers = (power_input * mask_speech.double()).mean(dim=-2) else: assert len(powers) == 1, len(powers) powers = powers[0] inverse_power = 1 / torch.clamp(powers, min=self.eps) psd_speech = get_power_spectral_density_matrix(data_d, mask_speech.double()) if mask_noise is not None and ( self.beamformer_type == "mvdr_souden" or not self.beamformer_type.endswith("_souden") ): # MVDR or other RTF-based formulas psd_noise = get_power_spectral_density_matrix( data_d, mask_noise.double() ) if self.beamformer_type == "mvdr": enhanced, ws = apply_beamforming( data, ilens, psd_noise, psd_speech, psd_distortion=psd_noise ) elif self.beamformer_type == "mvdr_souden": enhanced, ws = apply_beamforming(data, ilens, psd_noise, psd_speech) elif self.beamformer_type == "mpdr": psd_observed = complex_wrapper.einsum( "...ct,...et->...ce", [data_d, data_d.conj()] ) enhanced, ws = apply_beamforming( data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise ) elif self.beamformer_type == "mpdr_souden": psd_observed = complex_wrapper.einsum( "...ct,...et->...ce", [data_d, data_d.conj()] ) enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech) elif self.beamformer_type == "wmpdr": psd_observed = complex_wrapper.einsum( "...ct,...et->...ce", [data_d * inverse_power[..., None, :], data_d.conj()], ) enhanced, ws = apply_beamforming( data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise ) elif self.beamformer_type == "wmpdr_souden": psd_observed = complex_wrapper.einsum( "...ct,...et->...ce", [data_d * inverse_power[..., None, :], data_d.conj()], ) enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech) elif self.beamformer_type == "wpd": psd_observed_bar = get_covariances( data_d, inverse_power, self.bdelay, self.btaps, get_vector=False ) enhanced, ws = apply_beamforming( data, ilens, psd_observed_bar, psd_speech, psd_distortion=psd_noise ) elif self.beamformer_type == "wpd_souden": psd_observed_bar = get_covariances( data_d, inverse_power, self.bdelay, self.btaps, get_vector=False ) enhanced, ws = apply_beamforming( data, ilens, psd_observed_bar, psd_speech ) else: raise ValueError( "Not supporting beamformer_type={}".format(self.beamformer_type) ) # (..., F, T) -> (..., T, F) enhanced = enhanced.transpose(-1, -2) else: # multi-speaker case if self.use_noise_mask: # (mask_speech1, ..., mask_noise) mask_speech = list(masks[:-1]) mask_noise = masks[-1] else: # (mask_speech1, ..., mask_speechX) mask_speech = list(masks) mask_noise = None if self.beamformer_type.startswith( "wmpdr" ) or self.beamformer_type.startswith("wpd"): if powers is None: power_input = data_d.real ** 2 + data_d.imag ** 2 # Averaging along the channel axis: (..., C, T) -> (..., T) powers = [ (power_input * m.double()).mean(dim=-2) for m in mask_speech ] else: assert len(powers) == self.num_spk, len(powers) inverse_power = [1 / torch.clamp(p, min=self.eps) for p in powers] psd_speeches = [ get_power_spectral_density_matrix(data_d, mask.double()) for mask in mask_speech ] if mask_noise is not None and ( self.beamformer_type == "mvdr_souden" or not self.beamformer_type.endswith("_souden") ): # MVDR or other RTF-based formulas psd_noise = get_power_spectral_density_matrix( data_d, mask_noise.double() ) if self.beamformer_type in ("mpdr", "mpdr_souden"): psd_observed = complex_wrapper.einsum( "...ct,...et->...ce", [data_d, data_d.conj()] ) elif self.beamformer_type in ("wmpdr", "wmpdr_souden"): psd_observed = [ complex_wrapper.einsum( "...ct,...et->...ce", [data_d * inv_p[..., None, :], data_d.conj()], ) for inv_p in inverse_power ] elif self.beamformer_type in ("wpd", "wpd_souden"): psd_observed_bar = [ get_covariances( data_d, inv_p, self.bdelay, self.btaps, get_vector=False ) for inv_p in inverse_power ] enhanced, ws = [], [] for i in range(self.num_spk): psd_speech = psd_speeches.pop(i) if ( self.beamformer_type == "mvdr_souden" or not self.beamformer_type.endswith("_souden") ): psd_noise_i = ( psd_noise + sum(psd_speeches) if mask_noise is not None else sum(psd_speeches) ) # treat all other speakers' psd_speech as noises if self.beamformer_type == "mvdr": enh, w = apply_beamforming( data, ilens, psd_noise_i, psd_speech, psd_distortion=psd_noise_i ) elif self.beamformer_type == "mvdr_souden": enh, w = apply_beamforming(data, ilens, psd_noise_i, psd_speech) elif self.beamformer_type == "mpdr": enh, w = apply_beamforming( data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise_i, ) elif self.beamformer_type == "mpdr_souden": enh, w = apply_beamforming(data, ilens, psd_observed, psd_speech) elif self.beamformer_type == "wmpdr": enh, w = apply_beamforming( data, ilens, psd_observed[i], psd_speech, psd_distortion=psd_noise_i, ) elif self.beamformer_type == "wmpdr_souden": enh, w = apply_beamforming(data, ilens, psd_observed[i], psd_speech) elif self.beamformer_type == "wpd": enh, w = apply_beamforming( data, ilens, psd_observed_bar[i], psd_speech, psd_distortion=psd_noise_i, ) elif self.beamformer_type == "wpd_souden": enh, w = apply_beamforming( data, ilens, psd_observed_bar[i], psd_speech ) else: raise ValueError( "Not supporting beamformer_type={}".format(self.beamformer_type) ) psd_speeches.insert(i, psd_speech) # (..., F, T) -> (..., T, F) enh = enh.transpose(-1, -2) enhanced.append(enh) ws.append(w) # (..., F, C, T) -> (..., T, C, F) masks = [m.transpose(-1, -3) for m in masks] return enhanced, ilens, masks
def prepare_beamformer_stats( signal, masks_speech, mask_noise, powers=None, beamformer_type="mvdr", bdelay=3, btaps=5, eps=1e-6, ): """Prepare necessary statistics for constructing the specified beamformer. Args: signal (torch.complex64/ComplexTensor): (..., F, C, T) masks_speech (List[torch.Tensor]): (..., F, C, T) masks for all speech sources mask_noise (torch.Tensor): (..., F, C, T) noise mask powers (List[torch.Tensor]): powers for all speech sources (..., F, T) used for wMPDR or WPD beamformers beamformer_type (str): one of the pre-defined beamformer types bdelay (int): delay factor, used for WPD beamformser btaps (int): number of filter taps, used for WPD beamformser eps (torch.Tensor): tiny constant Returns: beamformer_stats (dict): a dictionary containing all necessary statistics e.g. "psd_n", "psd_speech", "psd_distortion" Note: * When `masks_speech` is a tensor or a single-element list, all returned statistics are tensors; * When `masks_speech` is a multi-element list, some returned statistics can be a list, e.g., "psd_n" for MVDR, "psd_speech" and "psd_distortion". """ from espnet2.enh.layers.dnn_beamformer import BEAMFORMER_TYPES assert beamformer_type in BEAMFORMER_TYPES, "%s is not supported yet" if isinstance(masks_speech, (list, tuple)): masks_speech = [to_double(m) for m in masks_speech] else: masks_speech = [to_double(masks_speech)] num_spk = len(masks_speech) if (beamformer_type.startswith("wmpdr") or beamformer_type.startswith("wpd") or beamformer_type == "wlcmp" or beamformer_type == "wmwf"): if powers is None: power_input = signal.real**2 + signal.imag**2 # Averaging along the channel axis: (..., C, T) -> (..., T) powers = [(power_input * m).mean(dim=-2) for m in masks_speech] else: assert len(powers) == num_spk, (len(powers), num_spk) inverse_powers = [1 / torch.clamp(p, min=eps) for p in powers] psd_speeches = [ get_power_spectral_density_matrix(signal, m) for m in masks_speech ] if (beamformer_type == "mvdr_souden" or beamformer_type == "sdw_mwf" or beamformer_type == "r1mwf" or beamformer_type.startswith("mvdr_tfs") or not beamformer_type.endswith("_souden")): # MVDR or other RTF-based formulas if mask_noise is not None: psd_bg = get_power_spectral_density_matrix(signal, to_double(mask_noise)) if num_spk == 1: assert mask_noise is not None psd_noise = psd_bg else: psd_noise = [] for i in range(num_spk): if beamformer_type.startswith("mvdr_tfs"): # NOTE: psd_noise is a list only for this beamformer psd_noise_i = [ psd for j, psd in enumerate(psd_speeches) if j != i ] else: psd_sum = sum(psd for j, psd in enumerate(psd_speeches) if j != i) psd_noise_i = (psd_bg + psd_sum if mask_noise is not None else psd_sum) psd_noise.append(psd_noise_i) if beamformer_type in ( "mvdr", "mvdr_souden", "mvdr_tfs_souden", "sdw_mwf", "r1mwf", "lcmv", "gev", "gev_ban", ): psd_n = psd_noise elif beamformer_type == "mvdr_tfs": psd_n = psd_noise psd_noise = [sum(psd_noise_i) for psd_noise_i in psd_noise] elif beamformer_type in ("mpdr", "mpdr_souden", "lcmp", "mwf"): psd_n = einsum("...ct,...et->...ce", signal, signal.conj()) elif beamformer_type in ("wmpdr", "wmpdr_souden", "wlcmp", "wmwf"): psd_n = [ einsum( "...ct,...et->...ce", signal * inv_p[..., None, :], signal.conj(), ) for inv_p in inverse_powers ] elif beamformer_type in ("wpd", "wpd_souden"): psd_n = [ get_covariances(signal, inv_p, bdelay, btaps, get_vector=False) for inv_p in inverse_powers ] if num_spk == 1: psd_speeches = psd_speeches[0] if isinstance(psd_n, (list, tuple)): psd_n = psd_n[0] if beamformer_type in ( "mvdr", "mpdr", "wmpdr", "wpd", "lcmp", "wlcmp", "lcmv", "mvdr_tfs", ): return { "psd_n": psd_n, "psd_speech": psd_speeches, "psd_distortion": psd_noise } elif (beamformer_type.endswith("_souden") or beamformer_type.startswith("gev") or beamformer_type == "mwf" or beamformer_type == "wmwf" or beamformer_type == "sdw_mwf" or beamformer_type == "r1mwf"): return {"psd_n": psd_n, "psd_speech": psd_speeches}