def get_filter_matrix_conj(correlation_matrix: ComplexTensor, correlation_vector: ComplexTensor) -> ComplexTensor: """Calculate (conjugate) filter matrix based on correlations for one freq. Args: correlation_matrix : Correlation matrix (F, taps * C, taps * C) correlation_vector : Correlation vector (F, taps, C, C) Returns: filter_matrix_conj (ComplexTensor): (F, taps, C, C) """ F, taps, C, _ = correlation_vector.size() # (F, taps, C1, C2) -> (F, C1, taps, C2) -> (F, C1, taps * C2) correlation_vector = \ correlation_vector.permute(0, 2, 1, 3)\ .contiguous().view(F, C, taps * C) inv_correlation_matrix = correlation_matrix.inverse() # (F, C, taps, C) x (F, taps * C, taps * C) -> (F, C, taps * C) stacked_filter_conj = FC.matmul(correlation_vector, inv_correlation_matrix.transpose(-1, -2)) # (F, C1, taps * C2) -> (F, C1, taps, C2) -> (F, taps, C2, C1) filter_matrix_conj = \ stacked_filter_conj.view(F, C, taps, C).permute(0, 2, 3, 1) return filter_matrix_conj
def forward( self, xs: ComplexTensor, ilens: torch.LongTensor ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]: """The forward function Args: xs: (B, F, C, T) ilens: (B,) Returns: hs (torch.Tensor): The hidden vector (B, F, C, T) masks: A tuple of the masks. (B, F, C, T) ilens: (B,) """ assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0)) _, _, C, input_length = xs.size() # (B, F, C, T) -> (B, C, T, F) xs = xs.permute(0, 2, 3, 1) # Calculate amplitude: (B, C, T, F) -> (B, C, T, F) xs = (xs.real**2 + xs.imag**2)**0.5 # xs: (B, C, T, F) -> xs: (B * C, T, F) xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1)) # ilens: (B,) -> ilens_: (B * C) ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1) # xs: (B * C, T, F) -> xs: (B * C, T, D) xs, _, _ = self.brnn(xs, ilens_) # xs: (B * C, T, D) -> xs: (B, C, T, D) xs = xs.view(-1, C, xs.size(-2), xs.size(-1)) masks = [] for linear in self.linears: # xs: (B, C, T, D) -> mask:(B, C, T, F) mask = linear(xs) if self.nonlinear == "sigmoid": mask = torch.sigmoid(mask) elif self.nonlinear == "relu": mask = torch.relu(mask) elif self.nonlinear == "tanh": mask = torch.tanh(mask) elif self.nonlinear == "crelu": mask = torch.clamp(mask, min=0, max=1) # Zero padding mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0) # (B, C, T, F) -> (B, F, C, T) mask = mask.permute(0, 3, 1, 2) # Take cares of multi gpu cases: If input_length > max(ilens) if mask.size(-1) < input_length: mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0) masks.append(mask) return tuple(masks), ilens
def forward( self, data: ComplexTensor, ilens: torch.LongTensor ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: """The forward function Notation: B: Batch C: Channel T: Time or Sequence length F: Freq or Some dimension of the feature vector Args: data: (B, C, T, F), double precision ilens: (B,) Returns: data: (B, C, T, F), double precision ilens: (B,) """ # (B, T, C, F) -> (B, F, C, T) enhanced = data = data.permute(0, 3, 2, 1) mask = None for i in range(self.iterations): # Calculate power: (..., C, T) power = enhanced.real**2 + enhanced.imag**2 if i == 0 and self.use_dnn_mask: # mask: (B, F, C, T) (mask, ), _ = self.mask_est(enhanced, ilens) if self.normalization: # Normalize along T mask = mask / mask.sum(dim=-1)[..., None] # (..., C, T) * (..., C, T) -> (..., C, T) power = power * mask # Averaging along the channel axis: (..., C, T) -> (..., T) power = power.mean(dim=-2) # enhanced: (..., C, T) -> (..., C, T) # NOTE(kamo): Calculate in double precision enhanced = wpe_one_iteration( data.contiguous().double(), power.double(), taps=self.taps, delay=self.delay, inverse_power=self.inverse_power, ) enhanced = enhanced.type(data.dtype) enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0) # (B, F, C, T) -> (B, T, C, F) enhanced = enhanced.permute(0, 3, 2, 1) if mask is not None: mask = mask.transpose(-1, -3) return enhanced, ilens, mask
def predict_mask( self, data: ComplexTensor, ilens: torch.LongTensor ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]: """Predict masks for beamforming Args: data (ComplexTensor): (B, T, C, F), double precision ilens (torch.Tensor): (B,) Returns: masks (torch.Tensor): (B, T, C, F) ilens (torch.Tensor): (B,) """ masks, _ = self.mask(data.permute(0, 3, 2, 1).float(), ilens) # (B, F, C, T) -> (B, T, C, F) masks = [m.transpose(-1, -3) for m in masks] return masks, ilens
def forward(self, data: ComplexTensor, ilens: torch.LongTensor) \ -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: """The forward function Notation: B: Batch C: Channel T: Time or Sequence length F: Freq Args: data (ComplexTensor): (B, T, C, F) ilens (torch.Tensor): (B,) Returns: enhanced (ComplexTensor): (B, T, F) ilens (torch.Tensor): (B,) """ # data (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) # mask: (B, F, C, T) (mask_speech, mask_noise), _ = self.mask(data, ilens) psd_speech = get_power_spectral_density_matrix(data, mask_speech) psd_noise = get_power_spectral_density_matrix(data, mask_noise) # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech, ilens) else: # (optional) Create onehot vector for fixed reference microphone u = torch.zeros(*(data.size()[:-3] + (data.size(-2),)), device=data.device) u[..., self.ref_channel].fill_(1) ws = get_mvdr_vector(psd_speech, psd_noise, u) enhanced = apply_beamforming_vector(ws, data) # (..., F, T) -> (..., T, F) enhanced = enhanced.transpose(-1, -2) mask_speech = mask_speech.transpose(-1, -3) return enhanced, ilens, mask_speech
def predict_mask( self, data: ComplexTensor, ilens: torch.LongTensor) -> Tuple[torch.Tensor, torch.LongTensor]: """Predict mask for WPE dereverberation Args: data (ComplexTensor): (B, T, C, F), double precision ilens (torch.Tensor): (B,) Returns: masks (torch.Tensor): (B, T, C, F) ilens (torch.Tensor): (B,) """ if self.use_dnn_mask: (mask, ), ilens = self.mask_est( data.permute(0, 3, 2, 1).float(), ilens) # (B, F, C, T) -> (B, T, C, F) mask = mask.transpose(-1, -3) else: mask = None return mask, ilens
def get_filter_matrix_conj(correlation_matrix: ComplexTensor, correlation_vector: ComplexTensor, eps: float = 1e-10) -> ComplexTensor: """Calculate (conjugate) filter matrix based on correlations for one freq. Args: correlation_matrix : Correlation matrix (F, taps * C, taps * C) correlation_vector : Correlation vector (F, taps, C, C) eps: Returns: filter_matrix_conj (ComplexTensor): (F, taps, C, C) """ F, taps, C, _ = correlation_vector.size() # (F, taps, C1, C2) -> (F, C1, taps, C2) -> (F, C1, taps * C2) correlation_vector = \ correlation_vector.permute(0, 2, 1, 3)\ .contiguous().view(F, C, taps * C) eye = torch.eye(correlation_matrix.size(-1), dtype=correlation_matrix.dtype, device=correlation_matrix.device) shape = tuple(1 for _ in range(correlation_matrix.dim() - 2)) + \ correlation_matrix.shape[-2:] eye = eye.view(*shape) correlation_matrix += eps * eye inv_correlation_matrix = correlation_matrix.inverse() # (F, C, taps, C) x (F, taps * C, taps * C) -> (F, C, taps * C) stacked_filter_conj = FC.matmul(correlation_vector, inv_correlation_matrix.transpose(-1, -2)) # (F, C1, taps * C2) -> (F, C1, taps, C2) -> (F, taps, C2, C1) filter_matrix_conj = \ stacked_filter_conj.view(F, C, taps, C).permute(0, 2, 3, 1) return filter_matrix_conj
def forward(self, data: ComplexTensor, ilens: torch.LongTensor) \ -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: """The forward function Notation: B: Batch C: Channel T: Time or Sequence length F: Freq Args: data (ComplexTensor): (B, T, C, F) ilens (torch.Tensor): (B,) Returns: enhanced (ComplexTensor): (B, T, F) ilens (torch.Tensor): (B,) """ def apply_beamforming(data, ilens, psd_speech, psd_noise): # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech, ilens) else: # (optional) Create onehot vector for fixed reference microphone u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )), device=data.device) u[..., self.ref_channel].fill_(1) ws = get_mvdr_vector(psd_speech, psd_noise, u) enhanced = apply_beamforming_vector(ws, data) return enhanced, ws # data (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) # mask: (B, F, C, T) masks, _ = self.mask(data, ilens) assert self.nmask == len(masks) if self.nmask == 2: # (mask_speech, mask_noise) mask_speech, mask_noise = masks psd_speech = get_power_spectral_density_matrix(data, mask_speech) psd_noise = get_power_spectral_density_matrix(data, mask_noise) enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_noise) # (..., F, T) -> (..., T, F) enhanced = enhanced.transpose(-1, -2) mask_speech = mask_speech.transpose(-1, -3) else: # multi-speaker case: (mask_speech1, ..., mask_noise) mask_speech = list(masks[:-1]) mask_noise = masks[-1] psd_speeches = [ get_power_spectral_density_matrix(data, mask) for mask in mask_speech ] psd_noise = get_power_spectral_density_matrix(data, mask_noise) enhanced = [] ws = [] for i in range(self.nmask - 1): psd_speech = psd_speeches.pop(i) # treat all other speakers' psd_speech as noises enh, w = apply_beamforming(data, ilens, psd_speech, sum(psd_speeches) + psd_noise) psd_speeches.insert(i, psd_speech) # (..., F, T) -> (..., T, F) enh = enh.transpose(-1, -2) mask_speech[i] = mask_speech[i].transpose(-1, -3) enhanced.append(enh) ws.append(w) return enhanced, ilens, mask_speech
def forward( self, data: ComplexTensor, ilens: torch.LongTensor ) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]: """The forward function Notation: B: Batch C: Channel T: Time or Sequence length F: Freq Args: data (ComplexTensor): (B, T, C, F), double precision ilens (torch.Tensor): (B,) Returns: enhanced (ComplexTensor): (B, T, F), double precision ilens (torch.Tensor): (B,) masks (torch.Tensor): (B, T, C, F) """ def apply_beamforming(data, ilens, psd_speech, psd_n, beamformer_type): # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech.float(), ilens) else: # (optional) Create onehot vector for fixed reference microphone u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )), device=data.device) u[..., self.ref_channel].fill_(1) if beamformer_type in ("mpdr", "mvdr"): ws = get_mvdr_vector(psd_speech, psd_n, u.double()) enhanced = apply_beamforming_vector(ws, data) elif beamformer_type == "wpd": ws = get_WPD_filter_v2(psd_speech, psd_n, u.double()) enhanced = perform_WPD_filtering(ws, data, self.bdelay, self.btaps) else: raise ValueError("Not supporting beamformer_type={}".format( beamformer_type)) return enhanced, ws # data (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) # mask: [(B, F, C, T)] masks, _ = self.mask(data.float(), ilens) assert self.nmask == len(masks) # floor masks with self.eps to increase numerical stability masks = [torch.clamp(m, min=self.eps) for m in masks] if self.num_spk == 1: # single-speaker case if self.use_noise_mask: # (mask_speech, mask_noise) mask_speech, mask_noise = masks else: # (mask_speech,) mask_speech = masks[0] mask_noise = 1 - mask_speech psd_speech = get_power_spectral_density_matrix( data, mask_speech.double()) if self.beamformer_type == "mvdr": # psd of noise psd_n = get_power_spectral_density_matrix( data, mask_noise.double()) elif self.beamformer_type == "mpdr": # psd of observed signal psd_n = FC.einsum("...ct,...et->...ce", [data, data.conj()]) elif self.beamformer_type == "wpd": # Calculate power: (..., C, T) power_speech = (data.real**2 + data.imag**2) * mask_speech.double() # Averaging along the channel axis: (B, F, C, T) -> (B, F, T) power_speech = power_speech.mean(dim=-2) inverse_power = 1 / torch.clamp(power_speech, min=self.eps) # covariance of expanded observed speech psd_n = get_covariances(data, inverse_power, self.bdelay, self.btaps, get_vector=False) else: raise ValueError("Not supporting beamformer_type={}".format( self.beamformer_type)) enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_n, self.beamformer_type) # (..., F, T) -> (..., T, F) enhanced = enhanced.transpose(-1, -2) else: # multi-speaker case if self.use_noise_mask: # (mask_speech1, ..., mask_noise) mask_speech = list(masks[:-1]) mask_noise = masks[-1] else: # (mask_speech1, ..., mask_speechX) mask_speech = list(masks) mask_noise = None psd_speeches = [ get_power_spectral_density_matrix(data, mask) for mask in mask_speech ] if self.beamformer_type == "mvdr": # psd of noise if mask_noise is not None: psd_n = get_power_spectral_density_matrix(data, mask_noise) elif self.beamformer_type == "mpdr": # psd of observed speech psd_n = FC.einsum("...ct,...et->...ce", [data, data.conj()]) elif self.beamformer_type == "wpd": # Calculate power: (..., C, T) power = data.real**2 + data.imag**2 power_speeches = [power * mask for mask in mask_speech] # Averaging along the channel axis: (B, F, C, T) -> (B, F, T) power_speeches = [ps.mean(dim=-2) for ps in power_speeches] inverse_poweres = [ 1 / torch.clamp(ps, min=self.eps) for ps in power_speeches ] # covariance of expanded observed speech psd_n = [ get_covariances(data, inv_ps, self.bdelay, self.btaps, get_vector=False) for inv_ps in inverse_poweres ] else: raise ValueError("Not supporting beamformer_type={}".format( self.beamformer_type)) enhanced = [] for i in range(self.num_spk): psd_speech = psd_speeches.pop(i) # treat all other speakers' psd_speech as noises if self.beamformer_type == "mvdr": psd_noise = sum(psd_speeches) if mask_noise is not None: psd_noise = psd_noise + psd_n enh, w = apply_beamforming(data, ilens, psd_speech, psd_noise, self.beamformer_type) elif self.beamformer_type == "mpdr": enh, w = apply_beamforming(data, ilens, psd_speech, psd_n, self.beamformer_type) elif self.beamformer_type == "wpd": enh, w = apply_beamforming(data, ilens, psd_speech, psd_n[i], self.beamformer_type) else: raise ValueError( "Not supporting beamformer_type={}".format( self.beamformer_type)) psd_speeches.insert(i, psd_speech) # (..., F, T) -> (..., T, F) enh = enh.transpose(-1, -2) enhanced.append(enh) # (..., F, C, T) -> (..., T, C, F) masks = [m.transpose(-1, -3) for m in masks] return enhanced, ilens, masks
def forward(self, data: ComplexTensor, ilens: torch.LongTensor=None, return_wpe: bool=True) -> Tuple[Optional[ComplexTensor], torch.Tensor]: if ilens is None: ilens = torch.full((data.size(0),), data.size(2), dtype=torch.long, device=data.device) r = -self.rcontext if self.rcontext != 0 else None enhanced = data[:, :, self.lcontext:r, :] if self.lcontext != 0 or self.rcontext != 0: assert all(ilens[0] == i for i in ilens) # Create context window (a.k.a Splicing) if self.model_type in ('blstm', 'lstm'): width = data.size(2) - self.lcontext - self.rcontext # data: (B, C, l + w + r, F) indices = [i + j for i in range(width) for j in range(1 + self.lcontext + self.rcontext)] _y = data[:, :, indices] # data: (B, C, l, (1 + w + r), F) data = _y.view( data.size(0), data.size(1), width, (1 + self.lcontext + self.rcontext) * data.size(3)) ilens = torch.full((data.size(0),), width, dtype=torch.long, device=data.device) del _y for i in range(self.iterations): power = enhanced.real ** 2 + enhanced.imag ** 2 # Calculate power: (B, C, T, Context, F) if i == 0 and self.use_dnn: # mask: (B, C, T, F) mask = self.estimator(data, ilens) if mask.size(2) != power.size(2): assert mask.size(2) == (power.size(2) + self.rcontext + self.lcontext) r = -self.rcontext if self.rcontext != 0 else None mask = mask[:, :, self.lcontext:r, :] if self.normalization: # Normalize along T mask = mask / mask.sum(dim=-2)[..., None] if self.out_type == 'mask': power = power * mask else: power = mask if self.out_type == 'amplitude': power = power ** 2 elif self.out_type == 'log_power': power = power.exp() elif self.out_type == 'power': pass else: raise NotImplementedError(self.out_type) if not return_wpe: return None, power # power: (B, C, T, F) -> _power: (B, F, T) _power = power.mean(dim=1).transpose(-1, -2).contiguous() # data: (B, C, T, F) -> _data: (B, F, C, T) _data = data.permute(0, 3, 1, 2).contiguous() # _enhanced: (B, F, C, T) _enhanced_real = [] _enhanced_imag = [] for d, p, l in zip(_data, _power, ilens): # e: (F, C, T) -> (T, C, F) e = wpe_one_iteration( d[..., :l], p[..., :l], taps=self.taps, delay=self.delay, inverse_power=self.inverse_power).transpose(0, 2) _enhanced_real.append(e.real) _enhanced_imag.append(e.imag) # _enhanced: B x (T, C, F) -> (B, T, C, F) -> (B, F, C, T) _enhanced_real = pad_sequence(_enhanced_real, batch_first=True).transpose(1, 3) _enhanced_imag = pad_sequence(_enhanced_imag, batch_first=True).transpose(1, 3) _enhanced = ComplexTensor(_enhanced_real, _enhanced_imag) # enhanced: (B, F, C, T) -> (B, C, T, F) enhanced = _enhanced.permute(0, 2, 3, 1) # enhanced: (B, C, T, F), power: (B, C, T, F) return enhanced, power
def forward( self, data: ComplexTensor, ilens: torch.LongTensor ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: """DNN_WPE forward function. Notation: B: Batch C: Channel T: Time or Sequence length F: Freq or Some dimension of the feature vector Args: data: (B, T, C, F) ilens: (B,) Returns: enhanced (torch.Tensor or List[torch.Tensor]): (B, T, C, F) ilens: (B,) masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F) power (List[torch.Tensor]): (B, F, T) """ # (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) enhanced = [data for i in range(self.nmask)] masks = None power = None for i in range(self.iterations): # Calculate power: (..., C, T) power = [enh.real**2 + enh.imag**2 for enh in enhanced] if i == 0 and self.use_dnn_mask: # mask: (B, F, C, T) masks, _ = self.mask_est(data, ilens) # floor masks to increase numerical stability if self.mask_flooring: masks = [m.clamp(min=self.flooring_thres) for m in masks] if self.normalization: # Normalize along T masks = [m / m.sum(dim=-1, keepdim=True) for m in masks] # (..., C, T) * (..., C, T) -> (..., C, T) power = [p * masks[i] for i, p in enumerate(power)] # Averaging along the channel axis: (..., C, T) -> (..., T) power = [p.mean(dim=-2).clamp(min=self.eps) for p in power] # enhanced: (..., C, T) -> (..., C, T) # NOTE(kamo): Calculate in double precision enhanced = [ wpe_one_iteration( data.contiguous().double(), p.double(), taps=self.taps, delay=self.delay, inverse_power=self.inverse_power, ) for p in power ] enhanced = [ enh.to(dtype=data.dtype).masked_fill( make_pad_mask(ilens, enh.real), 0) for enh in enhanced ] # (B, F, C, T) -> (B, T, C, F) enhanced = [enh.permute(0, 3, 2, 1) for enh in enhanced] if masks is not None: masks = ([m.transpose(-1, -3) for m in masks] if self.nmask > 1 else masks[0].transpose(-1, -3)) if self.nmask == 1: enhanced = enhanced[0] return enhanced, ilens, masks, power
def forward( self, data: ComplexTensor, ilens: torch.LongTensor, powers: Union[List[torch.Tensor], None] = None, ) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]: """DNN_Beamformer forward function. Notation: B: Batch C: Channel T: Time or Sequence length F: Freq Args: data (ComplexTensor): (B, T, C, F) ilens (torch.Tensor): (B,) powers (List[torch.Tensor] or None): used for wMPDR or WPD (B, F, T) Returns: enhanced (ComplexTensor): (B, T, F) ilens (torch.Tensor): (B,) masks (torch.Tensor): (B, T, C, F) """ def apply_beamforming(data, ilens, psd_n, psd_speech, psd_distortion=None): """Beamforming with the provided statistics. Args: data (ComplexTensor): (B, F, C, T) ilens (torch.Tensor): (B,) psd_n (ComplexTensor): Noise covariance matrix for MVDR (B, F, C, C) Observation covariance matrix for MPDR/wMPDR (B, F, C, C) Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C) psd_speech (ComplexTensor): Speech covariance matrix (B, F, C, C) psd_distortion (ComplexTensor): Noise covariance matrix (B, F, C, C) Return: enhanced (ComplexTensor): (B, F, T) ws (ComplexTensor): (B, F) or (B, F, (btaps+1)*C) """ # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens) u = u.double() else: if self.beamformer_type.endswith("_souden"): # (optional) Create onehot vector for fixed reference microphone u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )), device=data.device, dtype=torch.double) u[..., self.ref_channel].fill_(1) else: # for simplifying computation in RTF-based beamforming u = self.ref_channel if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"): ws = get_mvdr_vector_with_rtf( psd_n.double(), psd_speech.double(), psd_distortion.double(), iterations=self.rtf_iterations, reference_vector=u, normalize_ref_channel=self.ref_channel, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, data.double()) elif self.beamformer_type in ("mpdr_souden", "mvdr_souden", "wmpdr_souden"): ws = get_mvdr_vector( psd_speech.double(), psd_n.double(), u, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, data.double()) elif self.beamformer_type == "wpd": ws = get_WPD_filter_with_rtf( psd_n.double(), psd_speech.double(), psd_distortion.double(), iterations=self.rtf_iterations, reference_vector=u, normalize_ref_channel=self.ref_channel, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = perform_WPD_filtering(ws, data.double(), self.bdelay, self.btaps) elif self.beamformer_type == "wpd_souden": ws = get_WPD_filter_v2( psd_speech.double(), psd_n.double(), u, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = perform_WPD_filtering(ws, data.double(), self.bdelay, self.btaps) else: raise ValueError("Not supporting beamformer_type={}".format( self.beamformer_type)) return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype) # data (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) data_d = data.double() # mask: [(B, F, C, T)] masks, _ = self.mask(data, ilens) assert self.nmask == len(masks), len(masks) # floor masks to increase numerical stability if self.mask_flooring: masks = [torch.clamp(m, min=self.flooring_thres) for m in masks] if self.num_spk == 1: # single-speaker case if self.use_noise_mask: # (mask_speech, mask_noise) mask_speech, mask_noise = masks else: # (mask_speech,) mask_speech = masks[0] mask_noise = 1 - mask_speech if self.beamformer_type.startswith( "wmpdr") or self.beamformer_type.startswith("wpd"): if powers is None: power_input = data_d.real**2 + data_d.imag**2 # Averaging along the channel axis: (..., C, T) -> (..., T) powers = (power_input * mask_speech.double()).mean(dim=-2) else: assert len(powers) == 1, len(powers) powers = powers[0] inverse_power = 1 / torch.clamp(powers, min=self.eps) psd_speech = get_power_spectral_density_matrix( data_d, mask_speech.double()) if mask_noise is not None and ( self.beamformer_type == "mvdr_souden" or not self.beamformer_type.endswith("_souden")): # MVDR or other RTF-based formulas psd_noise = get_power_spectral_density_matrix( data_d, mask_noise.double()) if self.beamformer_type == "mvdr": enhanced, ws = apply_beamforming(data, ilens, psd_noise, psd_speech, psd_distortion=psd_noise) elif self.beamformer_type == "mvdr_souden": enhanced, ws = apply_beamforming(data, ilens, psd_noise, psd_speech) elif self.beamformer_type == "mpdr": psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise) elif self.beamformer_type == "mpdr_souden": psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech) elif self.beamformer_type == "wmpdr": psd_observed = FC.einsum( "...ct,...et->...ce", [data_d * inverse_power[..., None, :], data_d.conj()], ) enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise) elif self.beamformer_type == "wmpdr_souden": psd_observed = FC.einsum( "...ct,...et->...ce", [data_d * inverse_power[..., None, :], data_d.conj()], ) enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech) elif self.beamformer_type == "wpd": psd_observed_bar = get_covariances(data_d, inverse_power, self.bdelay, self.btaps, get_vector=False) enhanced, ws = apply_beamforming(data, ilens, psd_observed_bar, psd_speech, psd_distortion=psd_noise) elif self.beamformer_type == "wpd_souden": psd_observed_bar = get_covariances(data_d, inverse_power, self.bdelay, self.btaps, get_vector=False) enhanced, ws = apply_beamforming(data, ilens, psd_observed_bar, psd_speech) else: raise ValueError("Not supporting beamformer_type={}".format( self.beamformer_type)) # (..., F, T) -> (..., T, F) enhanced = enhanced.transpose(-1, -2) else: # multi-speaker case if self.use_noise_mask: # (mask_speech1, ..., mask_noise) mask_speech = list(masks[:-1]) mask_noise = masks[-1] else: # (mask_speech1, ..., mask_speechX) mask_speech = list(masks) mask_noise = None if self.beamformer_type.startswith( "wmpdr") or self.beamformer_type.startswith("wpd"): if powers is None: power_input = data_d.real**2 + data_d.imag**2 # Averaging along the channel axis: (..., C, T) -> (..., T) powers = [(power_input * m.double()).mean(dim=-2) for m in mask_speech] else: assert len(powers) == self.num_spk, len(powers) inverse_power = [ 1 / torch.clamp(p, min=self.eps) for p in powers ] psd_speeches = [ get_power_spectral_density_matrix(data_d, mask.double()) for mask in mask_speech ] if mask_noise is not None and ( self.beamformer_type == "mvdr_souden" or not self.beamformer_type.endswith("_souden")): # MVDR or other RTF-based formulas psd_noise = get_power_spectral_density_matrix( data_d, mask_noise.double()) if self.beamformer_type in ("mpdr", "mpdr_souden"): psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) elif self.beamformer_type in ("wmpdr", "wmpdr_souden"): psd_observed = [ FC.einsum( "...ct,...et->...ce", [data_d * inv_p[..., None, :], data_d.conj()], ) for inv_p in inverse_power ] elif self.beamformer_type in ("wpd", "wpd_souden"): psd_observed_bar = [ get_covariances(data_d, inv_p, self.bdelay, self.btaps, get_vector=False) for inv_p in inverse_power ] enhanced, ws = [], [] for i in range(self.num_spk): psd_speech = psd_speeches.pop(i) if (self.beamformer_type == "mvdr_souden" or not self.beamformer_type.endswith("_souden")): psd_noise_i = (psd_noise + sum(psd_speeches) if mask_noise is not None else sum(psd_speeches)) # treat all other speakers' psd_speech as noises if self.beamformer_type == "mvdr": enh, w = apply_beamforming(data, ilens, psd_noise_i, psd_speech, psd_distortion=psd_noise_i) elif self.beamformer_type == "mvdr_souden": enh, w = apply_beamforming(data, ilens, psd_noise_i, psd_speech) elif self.beamformer_type == "mpdr": enh, w = apply_beamforming( data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise_i, ) elif self.beamformer_type == "mpdr_souden": enh, w = apply_beamforming(data, ilens, psd_observed, psd_speech) elif self.beamformer_type == "wmpdr": enh, w = apply_beamforming( data, ilens, psd_observed[i], psd_speech, psd_distortion=psd_noise_i, ) elif self.beamformer_type == "wmpdr_souden": enh, w = apply_beamforming(data, ilens, psd_observed[i], psd_speech) elif self.beamformer_type == "wpd": enh, w = apply_beamforming( data, ilens, psd_observed_bar[i], psd_speech, psd_distortion=psd_noise_i, ) elif self.beamformer_type == "wpd_souden": enh, w = apply_beamforming(data, ilens, psd_observed_bar[i], psd_speech) else: raise ValueError( "Not supporting beamformer_type={}".format( self.beamformer_type)) psd_speeches.insert(i, psd_speech) # (..., F, T) -> (..., T, F) enh = enh.transpose(-1, -2) enhanced.append(enh) ws.append(w) # (..., F, C, T) -> (..., T, C, F) masks = [m.transpose(-1, -3) for m in masks] return enhanced, ilens, masks
def forward(self, xs: ComplexTensor, ts: ComplexTensor, ilens: torch.LongTensor, loss_types: Union[str, Sequence[str]] = 'power_mse', ref_channel: int = 0) -> Dict[str, torch.Tensor]: # xs: (B, C, T, F), ts: (B, T, F) if isinstance(loss_types, str): loss_types = [loss_types] # ys: (B, C, T, F), power: (B, C, T, F) for loss_type in loss_types: if 'dnnwpe' in loss_type: return_wpe = True break else: return_wpe = False ys, power = self.model(xs, ilens, return_wpe=return_wpe) r = -self.rcontext if self.rcontext != 0 else None xs = xs[:, :, self.lcontext:r, :] if ys is not None: assert xs.shape == ys.shape, (xs.shape, ys.shape) assert xs.shape == power.shape, (xs.shape, power.shape) uts = None uys = None upower = None ys_time = None ts_time = None xs_time = None loss_dict = OrderedDict() for loss_type in loss_types: if loss_type == 'dnnwpe_power_mse': if uys is None: uys = FC.cat(unpad(ys, ilens, length_dim=2), dim=1) if uts is None: uts = FC.cat(unpad(ts, ilens, length_dim=2), dim=1) _ys = uys.real**2 + uys.imag**2 _ts = uts.real**2 + uts.imag**2 _ys = _ys.log() _ts = _ts.log() loss = mse_loss(_ys, _ts) elif loss_type == 'dnnwpe_mse': if uys is None: uys = FC.cat(unpad(ys, ilens, length_dim=2), dim=1) if uts is None: uts = FC.cat(unpad(ts, ilens, length_dim=2), dim=1) _ys = torch.cat([uys.real, uys.imag], dim=-1) _ts = torch.cat([uts.real, uts.imag], dim=-1) loss = mse_loss(_ys, _ts) elif loss_type == 'power_mse': if upower is None: upower = torch.cat(unpad(power, ilens, length_dim=2), dim=1) if uts is None: uts = FC.cat(unpad(ts, ilens, length_dim=2), dim=1) _ts = uts.real**2 + uts.imag**2 _upower = upower _ts = _ts loss = mse_loss(_upower, _ts) # For evaluation as not differentiable elif loss_type == 'dnnwpe_stoi': # Use the first channel only to make faster calculation if ys_time is None: # _ys: List[torch.Tensor]: B x [C, T, F] _ys = unpad(ys, ilens, length_dim=2) # ys_time: List[np.ndarray]: B x [T] ys_time = [ self.stft_func.istft(_y[0].cpu().numpy().T) for _y in _ys ] if ts_time is None: # _ts: List[torch.Tensor]: B x [C, T, F] _ts = unpad(ts, ilens, length_dim=2) # ts_time: List[np.ndarray]: B x [T] ts_time = [ self.stft_func.istft(_t[0].cpu().numpy().T) for _t in _ts ] _losses = [] for _y, _t in zip(ys_time, ts_time): # Single channel only _losses.append(stoi(_t, _y, self.stft_func.fs)) loss = torch.tensor(numpy.mean(_losses)) # For evaluation as not differentiable elif loss_type == 'dnnwpe_pesq': # Use the first channel only to make faster calculation if ys_time is None: # _ys: List[torch.Tensor]: B x [C, T, F] _ys = unpad(ys, ilens, length_dim=2) # ys_time: List[np.ndarray]: B x [T] ys_time = [ self.stft_func.istft(_y[0].cpu().numpy().T) for _y in _ys ] if ts_time is None: # _ts: List[torch.Tensor]: B x [C, T, F] _ts = unpad(ts, ilens, length_dim=2) # ts_time: List[np.ndarray]: B x [T] ts_time = [ self.stft_func.istft(_t[0].cpu().numpy().T) for _t in _ts ] _fns = [] # PESQ via subprocess can be parallerize by threading e = ThreadPoolExecutor(self.pesq_nworker) for _y, _t in zip(ys_time, ts_time): _y *= numpy.iinfo(numpy.int16).max - 1 _y = _y.astype(numpy.int16) _t *= numpy.iinfo(numpy.int16).max - 1 _t = _t.astype(numpy.int16) fn = e.submit(calc_pesq, _t, _y, self.stft_func.fs) _fns.append(fn) _losses = [] for fn in _fns: v = fn.result() _losses.append(v) loss = torch.tensor(numpy.mean(_losses)) # For evaluation as not differentiable elif loss_type == 'unprocessed_pesq': # Use the first channel only to make faster calculation if xs_time is None: # _ys: List[torch.Tensor]: B x [C, T, F] _xs = unpad(xs, ilens, length_dim=2) # ys_time: List[np.ndarray]: B x [T] xs_time = [ self.stft_func.istft(_x[0].cpu().numpy().T) for _x in _xs ] if ts_time is None: # _ts: List[torch.Tensor]: B x [C, T, F] _ts = unpad(ts, ilens, length_dim=2) # ts_time: List[np.ndarray]: B x [T] ts_time = [ self.stft_func.istft(_t[0].cpu().numpy().T) for _t in _ts ] _fns = [] # PESQ via subprocess can be parallerize by threading e = ThreadPoolExecutor(self.pesq_nworker) for _x, _t in zip(xs_time, ts_time): _x = _x * numpy.iinfo(numpy.int16).max - 1 _x = _x.astype(numpy.int16) _t = _t * numpy.iinfo(numpy.int16).max - 1 _t = _t.astype(numpy.int16) fn = e.submit(calc_pesq, _t, _x, self.stft_func.fs) _fns.append(fn) _losses = [] for fn in _fns: v = fn.result() _losses.append(v) loss = torch.tensor(numpy.mean(_losses)) elif loss_type == 'wpe_pesq': with torch.no_grad(): # (B, C, T, F) -> (B, F, C, T) _xs = xs.permute(0, 3, 1, 2).contiguous() # _ys: (B, F, C, T) _ys = wpe(_xs, 5, 3, 3)[:, :, ref_channel] _ys = unpad(_ys, ilens, length_dim=2) ys_time = [ self.stft_func.istft(_y.cpu().numpy()) for _y in _ys ] if ts_time is None: # _ts: List[torch.Tensor]: B x [C, T, F] _ts = unpad(ts, ilens, length_dim=2) # ts_time: List[np.ndarray]: B x [T] ts_time = [ self.stft_func.istft(_t[0].cpu().numpy().T) for _t in _ts ] _fns = [] # PESQ via subprocess can be parallerize by threading e = ThreadPoolExecutor(self.pesq_nworker) for _y, _t in zip(ys_time, ts_time): _y *= numpy.iinfo(numpy.int16).max - 1 _y = _y.astype(numpy.int16) _t *= numpy.iinfo(numpy.int16).max - 1 _t = _t.astype(numpy.int16) fn = e.submit(calc_pesq, _t, _y, self.stft_func.fs) _fns.append(fn) _losses = [] for fn in _fns: v = fn.result() _losses.append(v) loss = torch.tensor(numpy.mean(_losses)) elif loss_type == 'wpe_mse': # Note: No updated parameters existing # 96328786.2853478 if uts is None: uts = FC.cat(unpad(ts, ilens, length_dim=2), dim=1) with torch.no_grad(): # (B, C, T, F) -> (B, F, C, T) _xs = xs.permute(0, 3, 1, 2) # _ys: (B, F, C, T) -> (B, T, F) _ys = wpe(_xs, 5, 3, 3)[:, :, ref_channel].transpose(1, 2) _uys = FC.cat(unpad(_ys, ilens, length_dim=1), dim=0) _ys = _uys.real**2 + _uys.imag**2 _ts = uts.real**2 + uts.imag**2 _ts = _ts[ref_channel] loss = mse_loss(_ys, _ts) else: raise NotImplementedError(f'loss_type={loss_type}') # Don't return scalar loss_dict[loss_type] = loss[None] return loss_dict