def get_rtf_matrix( psd_speeches, psd_noises, diagonal_loading: bool = True, ref_channel: int = 0, rtf_iterations: int = 3, use_torch_solver: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, ): """Calculate the RTF matrix with each column the relative transfer function of the corresponding source. """ # noqa: H405 assert isinstance(psd_speeches, list) and isinstance(psd_noises, list) rtf_mat = cat( [ get_rtf( psd_speeches[spk], tik_reg(psd_n, reg=diag_eps, eps=eps) if diagonal_loading else psd_n, reference_vector=ref_channel, iterations=rtf_iterations, use_torch_solver=use_torch_solver, ) for spk, psd_n in enumerate(psd_noises) ], dim=-1, ) # normalize at the reference channel return rtf_mat / rtf_mat[..., ref_channel, None, :]
def test_cat(dim): if is_torch_1_9_plus: wrappers = [ComplexTensor, torch.complex] modules = [FC, torch] else: wrappers = [ComplexTensor] modules = [FC] for complex_wrapper, complex_module in zip(wrappers, modules): mat1 = complex_wrapper(torch.rand(2, 3, 4), torch.rand(2, 3, 4)) mat2 = complex_wrapper(torch.rand(2, 3, 4), torch.rand(2, 3, 4)) ret = cat([mat1, mat2], dim=dim) ret2 = complex_module.cat([mat1, mat2], dim=dim) assert complex_module.allclose(ret, ret2)