예제 #1
0
def test_conditional_corrcoeff(corr):
    """
    Test whether the conditional correlation coefficient is computed correctly.
    """
    d = MultivariateNormal(torch.tensor([0.6, 5.0]),
                           torch.tensor([[0.1, corr], [corr, 10.0]]))
    estimated_corr = conditional_corrcoeff(
        density=d,
        condition=torch.ones(1, 2),
        limits=torch.tensor([[-2.0, 3.0], [-70, 90]]),
        resolution=500,
    )[0, 1]

    assert torch.abs(corr - estimated_corr) < 1e-3
예제 #2
0
def test_average_cond_coeff_matrix():
    d = MultivariateNormal(
        torch.tensor([10.0, 5, 1]),
        torch.tensor([[100.0, 30.0, 0], [30.0, 10.0, 0], [0, 0, 1.0]]),
    )
    cond_mat = conditional_corrcoeff(
        density=d,
        condition=torch.zeros(1, 3),
        limits=torch.tensor([[-60.0, 60.0], [-20, 20], [-7, 7]]),
        resolution=500,
    )
    corr_dim12 = torch.sqrt(torch.tensor(30.0**2 / 100.0 / 10.0))
    gt_matrix = torch.tensor([[1.0, corr_dim12, 0.0], [corr_dim12, 1.0, 0.0],
                              [0.0, 0.0, 1.0]])

    assert (torch.abs(gt_matrix - cond_mat) < 1e-3).all()