def forward( self, data: ComplexTensor, ilens: torch.LongTensor ) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]: """The forward function Notation: B: Batch C: Channel T: Time or Sequence length F: Freq Args: data (ComplexTensor): (B, T, C, F), double precision ilens (torch.Tensor): (B,) Returns: enhanced (ComplexTensor): (B, T, F), double precision ilens (torch.Tensor): (B,) masks (torch.Tensor): (B, T, C, F) """ def apply_beamforming(data, ilens, psd_speech, psd_n, beamformer_type): # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens) else: # (optional) Create onehot vector for fixed reference microphone u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )), device=data.device) u[..., self.ref_channel].fill_(1) if beamformer_type in ("mpdr", "mvdr"): ws = get_mvdr_vector(psd_speech.double(), psd_n.double(), u.double()) enhanced = apply_beamforming_vector(ws, data.double()) elif beamformer_type == "wpd": ws = get_WPD_filter_v2(psd_speech.double(), psd_n.double(), u.double()) enhanced = perform_WPD_filtering(ws, data.double(), self.bdelay, self.btaps) else: raise ValueError("Not supporting beamformer_type={}".format( beamformer_type)) return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype) # data (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) # mask: [(B, F, C, T)] masks, _ = self.mask(data, ilens) assert self.nmask == len(masks) # floor masks with self.eps to increase numerical stability masks = [torch.clamp(m, min=self.eps) for m in masks] if self.num_spk == 1: # single-speaker case if self.use_noise_mask: # (mask_speech, mask_noise) mask_speech, mask_noise = masks else: # (mask_speech,) mask_speech = masks[0] mask_noise = 1 - mask_speech data_d = data.double() psd_speech = get_power_spectral_density_matrix( data_d, mask_speech.double()) if self.beamformer_type == "mvdr": # psd of noise psd_n = get_power_spectral_density_matrix( data_d, mask_noise.double()) elif self.beamformer_type == "mpdr": # psd of observed signal psd_n = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) elif self.beamformer_type == "wpd": # Calculate power: (..., C, T) power_speech = (data_d.real**2 + data_d.imag**2) * mask_speech.double() # Averaging along the channel axis: (B, F, C, T) -> (B, F, T) power_speech = power_speech.mean(dim=-2) inverse_power = 1 / torch.clamp(power_speech, min=self.eps) # covariance of expanded observed speech psd_n = get_covariances(data_d, inverse_power, self.bdelay, self.btaps, get_vector=False) else: raise ValueError("Not supporting beamformer_type={}".format( self.beamformer_type)) enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_n, self.beamformer_type) # (..., F, T) -> (..., T, F) enhanced = enhanced.transpose(-1, -2) else: # multi-speaker case if self.use_noise_mask: # (mask_speech1, ..., mask_noise) mask_speech = list(masks[:-1]) mask_noise = masks[-1] else: # (mask_speech1, ..., mask_speechX) mask_speech = list(masks) mask_noise = None psd_speeches = [ get_power_spectral_density_matrix(data, mask) for mask in mask_speech ] if self.beamformer_type == "mvdr": # psd of noise if mask_noise is not None: psd_n = get_power_spectral_density_matrix(data, mask_noise) elif self.beamformer_type == "mpdr": # psd of observed speech psd_n = FC.einsum("...ct,...et->...ce", [data, data.conj()]) elif self.beamformer_type == "wpd": # Calculate power: (..., C, T) power = data.real**2 + data.imag**2 power_speeches = [power * mask for mask in mask_speech] # Averaging along the channel axis: (B, F, C, T) -> (B, F, T) power_speeches = [ps.mean(dim=-2) for ps in power_speeches] inverse_poweres = [ 1 / torch.clamp(ps, min=self.eps) for ps in power_speeches ] # covariance of expanded observed speech psd_n = [ get_covariances(data, inv_ps, self.bdelay, self.btaps, get_vector=False) for inv_ps in inverse_poweres ] else: raise ValueError("Not supporting beamformer_type={}".format( self.beamformer_type)) enhanced = [] for i in range(self.num_spk): psd_speech = psd_speeches.pop(i) # treat all other speakers' psd_speech as noises if self.beamformer_type == "mvdr": psd_noise = sum(psd_speeches) if mask_noise is not None: psd_noise = psd_noise + psd_n enh, w = apply_beamforming(data, ilens, psd_speech, psd_noise, self.beamformer_type) elif self.beamformer_type == "mpdr": enh, w = apply_beamforming(data, ilens, psd_speech, psd_n, self.beamformer_type) elif self.beamformer_type == "wpd": enh, w = apply_beamforming(data, ilens, psd_speech, psd_n[i], self.beamformer_type) else: raise ValueError( "Not supporting beamformer_type={}".format( self.beamformer_type)) psd_speeches.insert(i, psd_speech) # (..., F, T) -> (..., T, F) enh = enh.transpose(-1, -2) enhanced.append(enh) # (..., F, C, T) -> (..., T, C, F) masks = [m.transpose(-1, -3) for m in masks] return enhanced, ilens, masks
def forward( self, data: ComplexTensor, ilens: torch.LongTensor, powers: Union[List[torch.Tensor], None] = None, ) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]: """DNN_Beamformer forward function. Notation: B: Batch C: Channel T: Time or Sequence length F: Freq Args: data (ComplexTensor): (B, T, C, F) ilens (torch.Tensor): (B,) powers (List[torch.Tensor] or None): used for wMPDR or WPD (B, F, T) Returns: enhanced (ComplexTensor): (B, T, F) ilens (torch.Tensor): (B,) masks (torch.Tensor): (B, T, C, F) """ def apply_beamforming(data, ilens, psd_n, psd_speech, psd_distortion=None): """Beamforming with the provided statistics. Args: data (ComplexTensor): (B, F, C, T) ilens (torch.Tensor): (B,) psd_n (ComplexTensor): Noise covariance matrix for MVDR (B, F, C, C) Observation covariance matrix for MPDR/wMPDR (B, F, C, C) Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C) psd_speech (ComplexTensor): Speech covariance matrix (B, F, C, C) psd_distortion (ComplexTensor): Noise covariance matrix (B, F, C, C) Return: enhanced (ComplexTensor): (B, F, T) ws (ComplexTensor): (B, F) or (B, F, (btaps+1)*C) """ # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens) u = u.double() else: if self.beamformer_type.endswith("_souden"): # (optional) Create onehot vector for fixed reference microphone u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )), device=data.device, dtype=torch.double) u[..., self.ref_channel].fill_(1) else: # for simplifying computation in RTF-based beamforming u = self.ref_channel if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"): ws = get_mvdr_vector_with_rtf( psd_n.double(), psd_speech.double(), psd_distortion.double(), iterations=self.rtf_iterations, reference_vector=u, normalize_ref_channel=self.ref_channel, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, data.double()) elif self.beamformer_type in ("mpdr_souden", "mvdr_souden", "wmpdr_souden"): ws = get_mvdr_vector( psd_speech.double(), psd_n.double(), u, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, data.double()) elif self.beamformer_type == "wpd": ws = get_WPD_filter_with_rtf( psd_n.double(), psd_speech.double(), psd_distortion.double(), iterations=self.rtf_iterations, reference_vector=u, normalize_ref_channel=self.ref_channel, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = perform_WPD_filtering(ws, data.double(), self.bdelay, self.btaps) elif self.beamformer_type == "wpd_souden": ws = get_WPD_filter_v2( psd_speech.double(), psd_n.double(), u, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = perform_WPD_filtering(ws, data.double(), self.bdelay, self.btaps) else: raise ValueError("Not supporting beamformer_type={}".format( self.beamformer_type)) return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype) # data (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) data_d = data.double() # mask: [(B, F, C, T)] masks, _ = self.mask(data, ilens) assert self.nmask == len(masks), len(masks) # floor masks to increase numerical stability if self.mask_flooring: masks = [torch.clamp(m, min=self.flooring_thres) for m in masks] if self.num_spk == 1: # single-speaker case if self.use_noise_mask: # (mask_speech, mask_noise) mask_speech, mask_noise = masks else: # (mask_speech,) mask_speech = masks[0] mask_noise = 1 - mask_speech if self.beamformer_type.startswith( "wmpdr") or self.beamformer_type.startswith("wpd"): if powers is None: power_input = data_d.real**2 + data_d.imag**2 # Averaging along the channel axis: (..., C, T) -> (..., T) powers = (power_input * mask_speech.double()).mean(dim=-2) else: assert len(powers) == 1, len(powers) powers = powers[0] inverse_power = 1 / torch.clamp(powers, min=self.eps) psd_speech = get_power_spectral_density_matrix( data_d, mask_speech.double()) if mask_noise is not None and ( self.beamformer_type == "mvdr_souden" or not self.beamformer_type.endswith("_souden")): # MVDR or other RTF-based formulas psd_noise = get_power_spectral_density_matrix( data_d, mask_noise.double()) if self.beamformer_type == "mvdr": enhanced, ws = apply_beamforming(data, ilens, psd_noise, psd_speech, psd_distortion=psd_noise) elif self.beamformer_type == "mvdr_souden": enhanced, ws = apply_beamforming(data, ilens, psd_noise, psd_speech) elif self.beamformer_type == "mpdr": psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise) elif self.beamformer_type == "mpdr_souden": psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech) elif self.beamformer_type == "wmpdr": psd_observed = FC.einsum( "...ct,...et->...ce", [data_d * inverse_power[..., None, :], data_d.conj()], ) enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise) elif self.beamformer_type == "wmpdr_souden": psd_observed = FC.einsum( "...ct,...et->...ce", [data_d * inverse_power[..., None, :], data_d.conj()], ) enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech) elif self.beamformer_type == "wpd": psd_observed_bar = get_covariances(data_d, inverse_power, self.bdelay, self.btaps, get_vector=False) enhanced, ws = apply_beamforming(data, ilens, psd_observed_bar, psd_speech, psd_distortion=psd_noise) elif self.beamformer_type == "wpd_souden": psd_observed_bar = get_covariances(data_d, inverse_power, self.bdelay, self.btaps, get_vector=False) enhanced, ws = apply_beamforming(data, ilens, psd_observed_bar, psd_speech) else: raise ValueError("Not supporting beamformer_type={}".format( self.beamformer_type)) # (..., F, T) -> (..., T, F) enhanced = enhanced.transpose(-1, -2) else: # multi-speaker case if self.use_noise_mask: # (mask_speech1, ..., mask_noise) mask_speech = list(masks[:-1]) mask_noise = masks[-1] else: # (mask_speech1, ..., mask_speechX) mask_speech = list(masks) mask_noise = None if self.beamformer_type.startswith( "wmpdr") or self.beamformer_type.startswith("wpd"): if powers is None: power_input = data_d.real**2 + data_d.imag**2 # Averaging along the channel axis: (..., C, T) -> (..., T) powers = [(power_input * m.double()).mean(dim=-2) for m in mask_speech] else: assert len(powers) == self.num_spk, len(powers) inverse_power = [ 1 / torch.clamp(p, min=self.eps) for p in powers ] psd_speeches = [ get_power_spectral_density_matrix(data_d, mask.double()) for mask in mask_speech ] if mask_noise is not None and ( self.beamformer_type == "mvdr_souden" or not self.beamformer_type.endswith("_souden")): # MVDR or other RTF-based formulas psd_noise = get_power_spectral_density_matrix( data_d, mask_noise.double()) if self.beamformer_type in ("mpdr", "mpdr_souden"): psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) elif self.beamformer_type in ("wmpdr", "wmpdr_souden"): psd_observed = [ FC.einsum( "...ct,...et->...ce", [data_d * inv_p[..., None, :], data_d.conj()], ) for inv_p in inverse_power ] elif self.beamformer_type in ("wpd", "wpd_souden"): psd_observed_bar = [ get_covariances(data_d, inv_p, self.bdelay, self.btaps, get_vector=False) for inv_p in inverse_power ] enhanced, ws = [], [] for i in range(self.num_spk): psd_speech = psd_speeches.pop(i) if (self.beamformer_type == "mvdr_souden" or not self.beamformer_type.endswith("_souden")): psd_noise_i = (psd_noise + sum(psd_speeches) if mask_noise is not None else sum(psd_speeches)) # treat all other speakers' psd_speech as noises if self.beamformer_type == "mvdr": enh, w = apply_beamforming(data, ilens, psd_noise_i, psd_speech, psd_distortion=psd_noise_i) elif self.beamformer_type == "mvdr_souden": enh, w = apply_beamforming(data, ilens, psd_noise_i, psd_speech) elif self.beamformer_type == "mpdr": enh, w = apply_beamforming( data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise_i, ) elif self.beamformer_type == "mpdr_souden": enh, w = apply_beamforming(data, ilens, psd_observed, psd_speech) elif self.beamformer_type == "wmpdr": enh, w = apply_beamforming( data, ilens, psd_observed[i], psd_speech, psd_distortion=psd_noise_i, ) elif self.beamformer_type == "wmpdr_souden": enh, w = apply_beamforming(data, ilens, psd_observed[i], psd_speech) elif self.beamformer_type == "wpd": enh, w = apply_beamforming( data, ilens, psd_observed_bar[i], psd_speech, psd_distortion=psd_noise_i, ) elif self.beamformer_type == "wpd_souden": enh, w = apply_beamforming(data, ilens, psd_observed_bar[i], psd_speech) else: raise ValueError( "Not supporting beamformer_type={}".format( self.beamformer_type)) psd_speeches.insert(i, psd_speech) # (..., F, T) -> (..., T, F) enh = enh.transpose(-1, -2) enhanced.append(enh) ws.append(w) # (..., F, C, T) -> (..., T, C, F) masks = [m.transpose(-1, -3) for m in masks] return enhanced, ilens, masks