コード例 #1
0
def get_rtf_matrix(
    psd_speeches,
    psd_noises,
    diagonal_loading: bool = True,
    ref_channel: int = 0,
    rtf_iterations: int = 3,
    use_torch_solver: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
):
    """Calculate the RTF matrix with each column the relative transfer function
    of the corresponding source.
    """  # noqa: H405
    assert isinstance(psd_speeches, list) and isinstance(psd_noises, list)
    rtf_mat = cat(
        [
            get_rtf(
                psd_speeches[spk],
                tik_reg(psd_n, reg=diag_eps, eps=eps)
                if diagonal_loading else psd_n,
                reference_vector=ref_channel,
                iterations=rtf_iterations,
                use_torch_solver=use_torch_solver,
            ) for spk, psd_n in enumerate(psd_noises)
        ],
        dim=-1,
    )
    # normalize at the reference channel
    return rtf_mat / rtf_mat[..., ref_channel, None, :]
コード例 #2
0
def test_cat(dim):
    if is_torch_1_9_plus:
        wrappers = [ComplexTensor, torch.complex]
        modules = [FC, torch]
    else:
        wrappers = [ComplexTensor]
        modules = [FC]

    for complex_wrapper, complex_module in zip(wrappers, modules):
        mat1 = complex_wrapper(torch.rand(2, 3, 4), torch.rand(2, 3, 4))
        mat2 = complex_wrapper(torch.rand(2, 3, 4), torch.rand(2, 3, 4))
        ret = cat([mat1, mat2], dim=dim)
        ret2 = complex_module.cat([mat1, mat2], dim=dim)
        assert complex_module.allclose(ret, ret2)