def test_data_processing_inequality_mc(dist): """ given X - Y - Z: rho(X:Z) <= rho(X:Y) """ rho_xy = maximum_correlation(dist, [[0], [1]]) rho_xz = maximum_correlation(dist, [[0], [2]]) assert rho_xz <= rho_xy + epsilon
def test_maximum_correlation_tensorization(dist1, dist2): """ Test tensorization: rho(X X' : Y Y') = max(rho(X:Y), rho(X', Y')) """ mixed = dist1.__matmul__(dist2) rho_mixed = maximum_correlation(mixed, [[0, 2], [1, 3]]) rho_a = maximum_correlation(dist1, [[0], [1]]) rho_b = maximum_correlation(dist2, [[0], [1]]) assert rho_mixed == pytest.approx(max(rho_a, rho_b), abs=1e-4)
def test_max_correlation_mutual_information(dist): """ (p_min * rho(X:Y))^2 <= (2 ln 2)I(X:Y) """ p_min = dist.marginal([0]).pmf.min() rho = maximum_correlation(dist, [[0], [1]]) i = I(dist, [[0], [1]]) assert (p_min * rho)**2 <= (2 * np.log(2)) * i + epsilon
def test_max_correlation_mutual_information(dist): """ (p_min * rho(X:Y))^2 <= (2 ln 2)I(X:Y) """ p_min = dist.marginal([0]).pmf.min() rho = maximum_correlation(dist, [[0], [1]]) i = I(dist, [[0], [1]]) assert (p_min*rho)**2 <= (2*np.log(2))*i + epsilon
def test_maximum_correlation_failure(rvs): """ Test that maximum_correlation fails with len(rvs) != 2 """ with pytest.raises(ditException): maximum_correlation(dyadic, rvs)
def test_maximum_correlation(dist, rvs, crvs): """ Test against known values """ assert maximum_correlation(dist, rvs, crvs) == pytest.approx(1.0)