def apply_beamforming(data, ilens, psd_speech, psd_noise): # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech, ilens) else: # (optional) Create onehot vector for fixed reference microphone u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )), device=data.device) u[..., self.ref_channel].fill_(1) ws = get_mvdr_vector(psd_speech, psd_noise, u) enhanced = apply_beamforming_vector(ws, data) return enhanced, ws
def forward(self, data: ComplexTensor, ilens: torch.LongTensor) \ -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: """The forward function Notation: B: Batch C: Channel T: Time or Sequence length F: Freq Args: data (ComplexTensor): (B, T, C, F) ilens (torch.Tensor): (B,) Returns: enhanced (ComplexTensor): (B, T, F) ilens (torch.Tensor): (B,) """ # data (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) # mask: (B, F, C, T) (mask_speech, mask_noise), _ = self.mask(data, ilens) psd_speech = get_power_spectral_density_matrix(data, mask_speech) psd_noise = get_power_spectral_density_matrix(data, mask_noise) # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech, ilens) else: # (optional) Create onehot vector for fixed reference microphone u = torch.zeros(*(data.size()[:-3] + (data.size(-2),)), device=data.device) u[..., self.ref_channel].fill_(1) ws = get_mvdr_vector(psd_speech, psd_noise, u) enhanced = apply_beamforming_vector(ws, data) # (..., F, T) -> (..., T, F) enhanced = enhanced.transpose(-1, -2) mask_speech = mask_speech.transpose(-1, -3) return enhanced, ilens, mask_speech
def apply_beamforming(data, ilens, psd_speech, psd_n, beamformer_type): # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech.float(), ilens) else: # (optional) Create onehot vector for fixed reference microphone u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )), device=data.device) u[..., self.ref_channel].fill_(1) if beamformer_type in ("mpdr", "mvdr"): ws = get_mvdr_vector(psd_speech, psd_n, u.double()) enhanced = apply_beamforming_vector(ws, data) elif beamformer_type == "wpd": ws = get_WPD_filter_v2(psd_speech, psd_n, u.double()) enhanced = perform_WPD_filtering(ws, data, self.bdelay, self.btaps) else: raise ValueError("Not supporting beamformer_type={}".format( beamformer_type)) return enhanced, ws