def get_rtf( psd_speech, psd_noise, mode="power", reference_vector: Union[int, torch.Tensor] = 0, iterations: int = 3, use_torch_solver: bool = True, ): """Calculate the relative transfer function (RTF) Algorithm of power method: 1) rtf = reference_vector 2) for i in range(iterations): rtf = (psd_noise^-1 @ psd_speech) @ rtf rtf = rtf / ||rtf||_2 # this normalization can be skipped 3) rtf = psd_noise @ rtf 4) rtf = rtf / rtf[..., ref_channel, :] Note: 4) Normalization at the reference channel is not performed here. Args: psd_speech (torch.complex64/ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (torch.complex64/ComplexTensor): noise covariance matrix (..., F, C, C) mode (str): one of ("power", "evd") "power": power method "evd": eigenvalue decomposition reference_vector (torch.Tensor or int): (..., C) or scalar iterations (int): number of iterations in power method use_torch_solver (bool): Whether to use `solve` instead of `inverse` Returns: rtf (torch.complex64/ComplexTensor): (..., F, C, 1) """ if mode == "power": if use_torch_solver: phi = solve(psd_speech, psd_noise) else: phi = matmul(inverse(psd_noise), psd_speech) rtf = ( phi[..., reference_vector, None] if isinstance(reference_vector, int) else matmul(phi, reference_vector[..., None, :, None]) ) for _ in range(iterations - 2): rtf = matmul(phi, rtf) # rtf = rtf / complex_norm(rtf, dim=-1, keepdim=True) rtf = matmul(psd_speech, rtf) elif mode == "evd": assert ( is_torch_1_9_plus and is_torch_complex_tensor(psd_speech) and is_torch_complex_tensor(psd_noise) ) e_vec = generalized_eigenvalue_decomposition(psd_speech, psd_noise)[1] rtf = matmul(psd_noise, e_vec[..., -1, None]) else: raise ValueError("Unknown mode: %s" % mode) return rtf
def get_adjacent(spec, filter_length: int = 5): """Zero-pad and unfold stft, i.e., add zeros to the beginning so that, using the multi-frame signal model, there will be as many output frames as input frames. Args: spec (torch.complex64/ComplexTensor): input spectrum (B, F, T) filter_length (int): length for frame extension Returns: ret (torch.complex64/ComplexTensor): output spectrum (B, F, T, filter_length) """ # noqa: D400 if isinstance(spec, ComplexTensor): pad_func = FC.pad elif is_torch_complex_tensor(spec): pad_func = torch.nn.functional.pad else: raise ValueError( "Please update your PyTorch version to 1.9+ for complex support.") return (pad_func(spec, pad=[filter_length - 1, 0]).unfold(dim=-1, size=filter_length, step=1).contiguous())
def get_WPD_filter_with_rtf( psd_observed_bar: Union[torch.Tensor, ComplexTensor], psd_speech: Union[torch.Tensor, ComplexTensor], psd_noise: Union[torch.Tensor, ComplexTensor], iterations: int = 3, reference_vector: Union[int, torch.Tensor, None] = None, normalize_ref_channel: Optional[int] = None, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-15, ) -> Union[torch.Tensor, ComplexTensor]: """Return the WPD vector calculated with RTF. WPD is the Weighted Power minimization Distortionless response convolutional beamformer. As follows: h = (Rf^-1 @ vbar) / (vbar^H @ R^-1 @ vbar) Reference: T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer for Simultaneous Denoising and Dereverberation," in IEEE Signal Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi: 10.1109/LSP.2019.2911179. https://ieeexplore.ieee.org/document/8691481 Args: psd_observed_bar (torch.complex64/ComplexTensor): stacked observation covariance matrix psd_speech (torch.complex64/ComplexTensor): speech covariance matrix (..., F, C, C) psd_noise (torch.complex64/ComplexTensor): noise covariance matrix (..., F, C, C) iterations (int): number of iterations in power method reference_vector (torch.Tensor or int): (..., C) or scalar normalize_ref_channel (int): reference channel for normalizing the RTF use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (torch.complex64/ComplexTensor)r: (..., F, C) """ if isinstance(psd_speech, ComplexTensor): pad_func = FC.pad elif is_torch_complex_tensor(psd_speech): pad_func = torch.nn.functional.pad else: raise ValueError( "Please update your PyTorch version to 1.9+ for complex support.") C = psd_noise.size(-1) if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) # (B, F, C, 1) rtf = get_rtf( psd_speech, psd_noise, reference_vector, iterations=iterations, use_torch_solver=use_torch_solver, ) # (B, F, (K+1)*C, 1) rtf = pad_func(rtf, (0, 0, 0, psd_observed_bar.shape[-1] - C), "constant", 0) # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1) if use_torch_solver: numerator = solve(rtf, psd_observed_bar).squeeze(-1) else: numerator = matmul(inverse(psd_observed_bar), rtf).squeeze(-1) denominator = einsum("...d,...d->...", rtf.squeeze(-1).conj(), numerator) if normalize_ref_channel is not None: scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj() beamforming_vector = numerator * scale / ( denominator.real.unsqueeze(-1) + eps) else: beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps) return beamforming_vector
def signal_framing( signal: Union[torch.Tensor, ComplexTensor], frame_length: int, frame_step: int, bdelay: int, do_padding: bool = False, pad_value: int = 0, indices: List = None, ) -> Union[torch.Tensor, ComplexTensor]: """Expand `signal` into several frames, with each frame of length `frame_length`. Args: signal : (..., T) frame_length: length of each segment frame_step: step for selecting frames bdelay: delay for WPD do_padding: whether or not to pad the input signal at the beginning of the time dimension pad_value: value to fill in the padding Returns: torch.Tensor: if do_padding: (..., T, frame_length) else: (..., T - bdelay - frame_length + 2, frame_length) """ if isinstance(signal, ComplexTensor): complex_wrapper = ComplexTensor pad_func = FC.pad elif is_torch_complex_tensor(signal): complex_wrapper = torch.complex pad_func = torch.nn.functional.pad else: pad_func = torch.nn.functional.pad frame_length2 = frame_length - 1 # pad to the right at the last dimension of `signal` (time dimension) if do_padding: # (..., T) --> (..., T + bdelay + frame_length - 2) signal = pad_func(signal, (bdelay + frame_length2 - 1, 0), "constant", pad_value) do_padding = False if indices is None: # [[ 0, 1, ..., frame_length2 - 1, frame_length2 - 1 + bdelay ], # [ 1, 2, ..., frame_length2, frame_length2 + bdelay ], # [ 2, 3, ..., frame_length2 + 1, frame_length2 + 1 + bdelay ], # ... # [ T-bdelay-frame_length2, ..., T-1-bdelay, T-1 ]] indices = [[ *range(i, i + frame_length2), i + frame_length2 + bdelay - 1 ] for i in range(0, signal.shape[-1] - frame_length2 - bdelay + 1, frame_step)] if is_complex(signal): real = signal_framing( signal.real, frame_length, frame_step, bdelay, do_padding, pad_value, indices, ) imag = signal_framing( signal.imag, frame_length, frame_step, bdelay, do_padding, pad_value, indices, ) return complex_wrapper(real, imag) else: # (..., T - bdelay - frame_length + 2, frame_length) signal = signal[..., indices] return signal
def get_gev_vector( psd_noise: Union[torch.Tensor, ComplexTensor], psd_speech: Union[torch.Tensor, ComplexTensor], mode="power", reference_vector: Union[int, torch.Tensor] = 0, iterations: int = 3, use_torch_solver: bool = True, diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ) -> Union[torch.Tensor, ComplexTensor]: """Return the generalized eigenvalue (GEV) beamformer vector: psd_speech @ h = lambda * psd_noise @ h Reference: Blind acoustic beamforming based on generalized eigenvalue decomposition; E. Warsitz and R. Haeb-Umbach, 2007. Args: psd_noise (torch.complex64/ComplexTensor): noise covariance matrix (..., F, C, C) psd_speech (torch.complex64/ComplexTensor): speech covariance matrix (..., F, C, C) mode (str): one of ("power", "evd") "power": power method "evd": eigenvalue decomposition (only for torch builtin complex tensors) reference_vector (torch.Tensor or int): (..., C) or scalar iterations (int): number of iterations in power method use_torch_solver (bool): Whether to use `solve` instead of `inverse` diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n diag_eps (float): eps (float): Returns: beamform_vector (torch.complex64/ComplexTensor): (..., F, C) """ # noqa: H405, D205, D400 if diagonal_loading: psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) if mode == "power": if use_torch_solver: phi = solve(psd_speech, psd_noise) else: phi = matmul(inverse(psd_noise), psd_speech) e_vec = (phi[..., reference_vector, None] if isinstance(reference_vector, int) else matmul( phi, reference_vector[..., None, :, None])) for _ in range(iterations - 1): e_vec = matmul(phi, e_vec) # e_vec = e_vec / complex_norm(e_vec, dim=-1, keepdim=True) e_vec = e_vec.squeeze(-1) elif mode == "evd": assert (is_torch_1_9_plus and is_torch_complex_tensor(psd_speech) and is_torch_complex_tensor(psd_noise)) # e_vec = generalized_eigenvalue_decomposition(psd_speech, psd_noise)[1][...,-1] e_vec = psd_noise.new_zeros(psd_noise.shape[:-1]) for f in range(psd_noise.shape[-3]): try: e_vec[..., f, :] = generalized_eigenvalue_decomposition( psd_speech[..., f, :, :], psd_noise[..., f, :, :])[1][..., -1] except RuntimeError: # port from github.com/fgnt/nn-gev/blob/master/fgnt/beamforming.py#L106 print( "GEV beamformer: LinAlg error for frequency {}".format(f), flush=True, ) C = psd_noise.size(-1) e_vec[..., f, :] = (psd_noise.new_ones(e_vec[..., f, :].shape) / FC.trace(psd_noise[..., f, :, :]) * C) else: raise ValueError("Unknown mode: %s" % mode) beamforming_vector = e_vec / complex_norm(e_vec, dim=-1, keepdim=True) beamforming_vector = gev_phase_correction(beamforming_vector) return beamforming_vector