def test_complex_impl_consistency(): if not is_torch_1_9_plus: return mat_th = torch.complex(torch.from_numpy(mat_np.real), torch.from_numpy(mat_np.imag)) mat_ct = ComplexTensor(torch.from_numpy(mat_np.real), torch.from_numpy(mat_np.imag)) bs = mat_th.shape[0] rank = mat_th.shape[-1] vec_th = torch.complex(torch.rand(bs, rank), torch.rand(bs, rank)).type_as(mat_th) vec_ct = ComplexTensor(vec_th.real, vec_th.imag) for result_th, result_ct in ( (abs(mat_th), abs(mat_ct)), (inverse(mat_th), inverse(mat_ct)), (matmul(mat_th, vec_th.unsqueeze(-1)), matmul(mat_ct, vec_ct.unsqueeze(-1))), (solve(vec_th.unsqueeze(-1), mat_th), solve(vec_ct.unsqueeze(-1), mat_ct)), ( einsum("bec,bc->be", mat_th, vec_th), einsum("bec,bc->be", mat_ct, vec_ct), ), ): np.testing.assert_allclose(result_th.numpy(), result_ct.numpy(), atol=1e-6)
def forward(self, input: torch.Tensor, ilens: torch.Tensor): """Forward. Args: input (torch.Tensor): mixed speech [Batch, Nsample, Channel] ilens (torch.Tensor): input lengths [Batch] Returns: enhanced speech (single-channel): torch.Tensor or List[torch.Tensor] output lengths predcited masks: OrderedDict[ 'dereverb': torch.Tensor(Batch, Frames, Channel, Freq), 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), 'noise1': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # wave -> stft -> magnitude specturm input_spectrum, flens = self.stft(input, ilens) # (Batch, Frames, Freq) or (Batch, Frames, Channels, Freq) input_spectrum = ComplexTensor(input_spectrum[..., 0], input_spectrum[..., 1]) if self.normalize_input: input_spectrum = input_spectrum / abs(input_spectrum).max() enhanced = input_spectrum masks = OrderedDict() if input_spectrum.dim() == 3: # single-channel input if self.use_wpe: # (B, T, F) enhanced, flens, mask_w = self.wpe(input_spectrum.unsqueeze(-2), flens) enhanced = enhanced.squeeze(-2) if mask_w is not None: masks["dereverb"] = mask_w.squeeze(-2) elif input_spectrum.dim() == 4: # multi-channel input # 1. WPE if self.use_wpe: # (B, T, C, F) enhanced, flens, mask_w = self.wpe(input_spectrum, flens) if mask_w is not None: masks["dereverb"] = mask_w # 2. Beamformer if self.use_beamformer: # enhanced: (B, T, C, F) -> (B, T, F) enhanced, flens, masks_b = self.beamformer(enhanced, flens) for spk in range(self.num_spk): masks["spk{}".format(spk + 1)] = masks_b[spk] if len(masks_b) > self.num_spk: masks["noise1"] = masks_b[self.num_spk] else: raise ValueError( "Invalid spectrum dimension: {}".format(input_spectrum.shape) ) # Convert ComplexTensor to torch.Tensor # (B, T, F) -> (B, T, F, 2) if isinstance(enhanced, list): # multi-speaker output enhanced = [torch.stack([enh.real, enh.imag], dim=-1) for enh in enhanced] else: # single-speaker output enhanced = torch.stack([enhanced.real, enhanced.imag], dim=-1).float() return enhanced, flens, masks
def forward( self, input: ComplexTensor, ilens: torch.Tensor ) -> Tuple[List[ComplexTensor], torch.Tensor, OrderedDict]: """Forward. Args: input (ComplexTensor): mixed speech [Batch, Frames, Channel, Freq] ilens (torch.Tensor): input lengths [Batch] Returns: enhanced speech (single-channel): List[ComplexTensor] output lengths other predcited data: OrderedDict[ 'dereverb1': ComplexTensor(Batch, Frames, Channel, Freq), 'mask_dereverb1': torch.Tensor(Batch, Frames, Channel, Freq), 'mask_noise1': torch.Tensor(Batch, Frames, Channel, Freq), 'mask_spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # Shape of input spectrum must be (B, T, F) or (B, T, C, F) assert input.dim() in (3, 4), input.dim() enhanced = input others = OrderedDict() if (self.training and self.loss_type is not None and self.loss_type.startswith("mask")): # Only estimating masks during training for saving memory if self.use_wpe: if input.dim() == 3: mask_w, ilens = self.wpe.predict_mask( input.unsqueeze(-2), ilens) mask_w = mask_w.squeeze(-2) elif input.dim() == 4: mask_w, ilens = self.wpe.predict_mask(input, ilens) if mask_w is not None: if isinstance(enhanced, list): # single-source WPE for spk in range(self.num_spk): others["mask_dereverb{}".format(spk + 1)] = mask_w[spk] else: # multi-source WPE others["mask_dereverb1"] = mask_w if self.use_beamformer and input.dim() == 4: others_b, ilens = self.beamformer.predict_mask(input, ilens) for spk in range(self.num_spk): others["mask_spk{}".format(spk + 1)] = others_b[spk] if len(others_b) > self.num_spk: others["mask_noise1"] = others_b[self.num_spk] return None, ilens, others else: powers = None # Performing both mask estimation and enhancement if input.dim() == 3: # single-channel input (B, T, F) if self.use_wpe: enhanced, ilens, mask_w, powers = self.wpe( input.unsqueeze(-2), ilens) if isinstance(enhanced, list): # single-source WPE enhanced = [enh.squeeze(-2) for enh in enhanced] if mask_w is not None: for spk in range(self.num_spk): key = "dereverb{}".format(spk + 1) others[key] = enhanced[spk] others["mask_" + key] = mask_w[spk].squeeze(-2) else: # multi-source WPE enhanced = enhanced.squeeze(-2) if mask_w is not None: others["dereverb1"] = enhanced others["mask_dereverb1"] = mask_w.squeeze(-2) else: # multi-channel input (B, T, C, F) # 1. WPE if self.use_wpe: enhanced, ilens, mask_w, powers = self.wpe(input, ilens) if mask_w is not None: if isinstance(enhanced, list): # single-source WPE for spk in range(self.num_spk): key = "dereverb{}".format(spk + 1) others[key] = enhanced[spk] others["mask_" + key] = mask_w[spk] else: # multi-source WPE others["dereverb1"] = enhanced others["mask_dereverb1"] = mask_w.squeeze(-2) # 2. Beamformer if self.use_beamformer: if (not self.beamformer.beamformer_type.startswith("wmpdr") or not self.beamformer.beamformer_type.startswith( "wpd") or not self.shared_power or (self.wpe.nmask == 1 and self.num_spk > 1)): powers = None # enhanced: (B, T, C, F) -> (B, T, F) if isinstance(enhanced, list): # outputs of single-source WPE raise NotImplementedError( "Single-source WPE is not supported with beamformer " "in multi-speaker cases.") else: # output of multi-source WPE enhanced, ilens, others_b = self.beamformer( enhanced, ilens, powers=powers) for spk in range(self.num_spk): others["mask_spk{}".format(spk + 1)] = others_b[spk] if len(others_b) > self.num_spk: others["mask_noise1"] = others_b[self.num_spk] if not isinstance(enhanced, list): enhanced = [enhanced] return enhanced, ilens, others
def forward(self, input: torch.Tensor, ilens: torch.Tensor): """Forward. Args: input (torch.Tensor): mixed speech [Batch, Nsample, Channel] ilens (torch.Tensor): input lengths [Batch] Returns: enhanced speech (single-channel): torch.Tensor or List[torch.Tensor] output lengths predcited masks: OrderedDict[ 'dereverb': torch.Tensor(Batch, Frames, Channel, Freq), 'spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'spkn': torch.Tensor(Batch, Frames, Channel, Freq), 'noise1': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # wave -> stft -> magnitude specturm input_spectrum, flens = self.stft(input, ilens) # (Batch, Frames, Freq) or (Batch, Frames, Channels, Freq) input_spectrum = ComplexTensor(input_spectrum[..., 0], input_spectrum[..., 1]) if self.normalize_input: input_spectrum = input_spectrum / abs(input_spectrum).max() # Shape of input spectrum must be (B, T, F) or (B, T, C, F) assert input_spectrum.dim() in (3, 4), input_spectrum.dim() enhanced = input_spectrum masks = OrderedDict() if self.training and self.loss_type.startswith("mask"): # Only estimating masks for training if self.use_wpe: if input_spectrum.dim() == 3: mask_w, flens = self.wpe.predict_mask( input_spectrum.unsqueeze(-2), flens) mask_w = mask_w.squeeze(-2) elif input_spectrum.dim() == 4: if self.use_beamformer: enhanced, flens, mask_w = self.wpe( input_spectrum, flens) else: mask_w, flens = self.wpe.predict_mask( input_spectrum, flens) if mask_w is not None: masks["dereverb"] = mask_w if self.use_beamformer and input_spectrum.dim() == 4: masks_b, flens = self.beamformer.predict_mask(enhanced, flens) for spk in range(self.num_spk): masks["spk{}".format(spk + 1)] = masks_b[spk] if len(masks_b) > self.num_spk: masks["noise1"] = masks_b[self.num_spk] return None, flens, masks else: # Performing both mask estimation and enhancement if input_spectrum.dim() == 3: # single-channel input (B, T, F) if self.use_wpe: enhanced, flens, mask_w = self.wpe( input_spectrum.unsqueeze(-2), flens) enhanced = enhanced.squeeze(-2) if mask_w is not None: masks["dereverb"] = mask_w.squeeze(-2) else: # multi-channel input (B, T, C, F) # 1. WPE if self.use_wpe: enhanced, flens, mask_w = self.wpe(input_spectrum, flens) if mask_w is not None: masks["dereverb"] = mask_w # 2. Beamformer if self.use_beamformer: # enhanced: (B, T, C, F) -> (B, T, F) enhanced, flens, masks_b = self.beamformer(enhanced, flens) for spk in range(self.num_spk): masks["spk{}".format(spk + 1)] = masks_b[spk] if len(masks_b) > self.num_spk: masks["noise1"] = masks_b[self.num_spk] # Convert ComplexTensor to torch.Tensor # (B, T, F) -> (B, T, F, 2) if isinstance(enhanced, list): # multi-speaker output enhanced = [ torch.stack([enh.real, enh.imag], dim=-1) for enh in enhanced ] else: # single-speaker output enhanced = torch.stack([enhanced.real, enhanced.imag], dim=-1).float() return enhanced, flens, masks