def test_conditional_corrcoeff(corr): """ Test whether the conditional correlation coefficient is computed correctly. """ d = MultivariateNormal(torch.tensor([0.6, 5.0]), torch.tensor([[0.1, corr], [corr, 10.0]])) estimated_corr = conditional_corrcoeff( density=d, condition=torch.ones(1, 2), limits=torch.tensor([[-2.0, 3.0], [-70, 90]]), resolution=500, )[0, 1] assert torch.abs(corr - estimated_corr) < 1e-3
def test_average_cond_coeff_matrix(): d = MultivariateNormal( torch.tensor([10.0, 5, 1]), torch.tensor([[100.0, 30.0, 0], [30.0, 10.0, 0], [0, 0, 1.0]]), ) cond_mat = conditional_corrcoeff( density=d, condition=torch.zeros(1, 3), limits=torch.tensor([[-60.0, 60.0], [-20, 20], [-7, 7]]), resolution=500, ) corr_dim12 = torch.sqrt(torch.tensor(30.0**2 / 100.0 / 10.0)) gt_matrix = torch.tensor([[1.0, corr_dim12, 0.0], [corr_dim12, 1.0, 0.0], [0.0, 0.0, 1.0]]) assert (torch.abs(gt_matrix - cond_mat) < 1e-3).all()