def test_cmi_multivariate_crash(): X = np.eye(5) cctypes = ['normal'] * 5 s = State(X, Zv={0:0, 1:0, 2:0, 3:1, 4:1}, cctypes=cctypes) s.mutual_information([0,1], [0,1], {2:1}, T=10, N=10) s.mutual_information([0,1], [0,1], {2:None}, T=10, N=10) s.mutual_information([2,4], [0,1,3], {}, T=10, N=10) # Duplicate in 2 query and constraint. with pytest.raises(ValueError): s.mutual_information([2,4], [1,3], {0:1, 2:None}, T=10, N=10) # Duplicate in 3 query. with pytest.raises(ValueError): s.mutual_information([2,3,4], [1,3], {0:None}, T=10, N=10)
def test_cmi_different_views__ci_(): rng = gen_rng(0) T = np.zeros((50,3)) T[:,0] = rng.normal(loc=-5, scale=1, size=50) T[:,1] = rng.normal(loc=2, scale=2, size=50) T[:,2] = rng.normal(loc=12, scale=3, size=50) state = State( T, outputs=[0, 1, 2], cctypes=['normal','normal','normal'], Zv={0:0, 1:1, 2:2}, rng=rng ) state.transition(N=30, kernels=['alpha','view_alphas','column_params','column_hypers','rows']) mi01 = state.mutual_information([0], [1]) mi02 = state.mutual_information([0], [2]) mi12 = state.mutual_information([1], [2]) # Marginal MI all zero. assert np.allclose(mi01, 0) assert np.allclose(mi02, 0) assert np.allclose(mi12, 0) # CMI on variable in other view equal to MI. assert np.allclose(state.mutual_information([0], [1], {2:10}), mi01) assert np.allclose(state.mutual_information([0], [2], {1:0}), mi02) assert np.allclose(state.mutual_information([1], [2], {0:-2}), mi12) assert np.allclose(state.mutual_information([1], [2], {0:None}, T=5), mi12)
def test_cmi_marginal_crash(): X = np.eye(5) cctypes = ['normal'] * 5 s = State(X, Zv={0:0, 1:0, 2:0, 3:1, 4:1}, cctypes=cctypes) # One marginalized constraint variable. s.mutual_information([0], [1], {2:None}, T=10, N=10) # Two marginalized constraint variables. s.mutual_information([0], [1], {2:None, 3:None}, T=10, N=10) # Two marginalized constraint variables and one constrained variable. s.mutual_information([0], [1], {2:None, 3:None, 4:0}, T=10, N=10)