def test_complex_norm(dim): mat = ComplexTensor(torch.rand(2, 3, 4), torch.rand(2, 3, 4)) mat_th = torch.complex(mat.real, mat.imag) norm = complex_norm(mat, dim=dim, keepdim=True) norm_th = complex_norm(mat_th, dim=dim, keepdim=True) assert (torch.allclose(norm, norm_th) and norm.ndim == mat.ndim and mat.numel() == norm.numel() * mat.size(dim))
def forward(self, ref, inf) -> torch.Tensor: """time-frequency absolute coherence loss. Reference: Independent Vector Analysis with Deep Neural Network Source Priors; Li et al 2020; https://arxiv.org/abs/2008.11273 Args: ref: (Batch, T, F) or (Batch, T, C, F) inf: (Batch, T, F) or (Batch, T, C, F) Returns: loss: (Batch,) """ assert ref.shape == inf.shape, (ref.shape, inf.shape) if is_complex(ref) and is_complex(inf): # sqrt( E[|inf|^2] * E[|ref|^2] ) denom = ( complex_norm(ref, dim=1) * complex_norm(inf, dim=1) / ref.size(1) + EPS ) coh = (inf * ref.conj()).mean(dim=1).abs() / denom if ref.dim() == 3: coh_loss = 1.0 - coh.mean(dim=1) elif ref.dim() == 4: coh_loss = 1.0 - coh.mean(dim=[1, 2]) else: raise ValueError( "Invalid input shape: ref={}, inf={}".format(ref.shape, inf.shape) ) else: raise ValueError("`ref` and `inf` must be complex tensors.") return coh_loss
def get_gev_vector( psd_noise: Union[torch.Tensor, ComplexTensor], psd_speech: Union[torch.Tensor, ComplexTensor], mode="power", reference_vector: Union[int, torch.Tensor] = 0, iterations: int = 3, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> Union[torch.Tensor, ComplexTensor]: """Return the generalized eigenvalue (GEV) beamformer vector: psd_speech @ h = lambda * psd_noise @ h Reference: Blind acoustic beamforming based on generalized eigenvalue decomposition; E. Warsitz and R. Haeb-Umbach, 2007. Args: psd_noise (torch.complex64/ComplexTensor): noise covariance matrix (..., F, C, C) psd_speech (torch.complex64/ComplexTensor): speech covariance matrix (..., F, C, C) mode (str): one of ("power", "evd") "power": power method "evd": eigenvalue decomposition (only for torch builtin complex tensors) reference_vector (torch.Tensor or int): (..., C) or scalar iterations (int): number of iterations in power method use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (torch.complex64/ComplexTensor): (..., F, C) """ # noqa: H405, D205, D400 if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) if mode == "power": if use_torch_solver: phi = solve(psd_speech, psd_noise) else: phi = matmul(inverse(psd_noise), psd_speech) e_vec = (phi[..., reference_vector, None] if isinstance(reference_vector, int) else matmul( phi, reference_vector[..., None, :, None])) for _ in range(iterations - 1): e_vec = matmul(phi, e_vec) # e_vec = e_vec / complex_norm(e_vec, dim=-1, keepdim=True) e_vec = e_vec.squeeze(-1) elif mode == "evd": assert (is_torch_1_9_plus and is_torch_complex_tensor(psd_speech) and is_torch_complex_tensor(psd_noise)) # e_vec = generalized_eigenvalue_decomposition(psd_speech, psd_noise)[1][...,-1] e_vec = psd_noise.new_zeros(psd_noise.shape[:-1]) for f in range(psd_noise.shape[-3]): try: e_vec[..., f, :] = generalized_eigenvalue_decomposition( psd_speech[..., f, :, :], psd_noise[..., f, :, :])[1][..., -1] except RuntimeError: # port from github.com/fgnt/nn-gev/blob/master/fgnt/beamforming.py#L106 print( "GEV beamformer: LinAlg error for frequency {}".format(f), flush=True, ) C = psd_noise.size(-1) e_vec[..., f, :] = (psd_noise.new_ones(e_vec[..., f, :].shape) / FC.trace(psd_noise[..., f, :, :]) * C) else: raise ValueError("Unknown mode: %s" % mode) beamforming_vector = e_vec / complex_norm(e_vec, dim=-1, keepdim=True) beamforming_vector = gev_phase_correction(beamforming_vector) return beamforming_vector