コード例 #1
0
ファイル: test_cmi.py プロジェクト: leighhopcroft/cgpm
def test_entropy_bernoulli_bivariate__ci_():
    rng = gen_rng(10)

    # Generate a bivariate Bernoulli dataset.
    PX = [.3, .7]
    PY = [[.2, .8], [.6, .4]]
    TX = rng.choice([0, 1], p=PX, size=250)
    TY = np.zeros(shape=len(TX))
    TY[TX == 0] = rng.choice([0, 1], p=PY[0], size=len(TX[TX == 0]))
    TY[TX == 1] = rng.choice([0, 1], p=PY[0], size=len(TX[TX == 1]))
    T = np.column_stack((TY, TX))

    engine = Engine(
        T,
        cctypes=['categorical', 'categorical'],
        distargs=[{
            'k': 2
        }, {
            'k': 2
        }],
        num_states=64,
        rng=rng,
    )

    engine.transition_lovecat(N=200)

    # exact computation
    entropy_exact = (-PX[0] * PY[0][0] * np.log(PX[0] * PY[0][0]) -
                     PX[0] * PY[0][1] * np.log(PX[0] * PY[0][1]) -
                     PX[1] * PY[1][0] * np.log(PX[1] * PY[1][0]) -
                     PX[1] * PY[1][1] * np.log(PX[1] * PY[1][1]))

    # logpdf computation
    logps = engine.logpdf_bulk([-1, -1, -1, -1], [{
        0: 0,
        1: 0
    }, {
        0: 0,
        1: 1
    }, {
        0: 1,
        1: 0
    }, {
        0: 1,
        1: 1
    }])
    entropy_logpdf = [-np.sum(np.exp(logp) * logp) for logp in logps]

    # mutual_information computation.
    entropy_mi = engine.mutual_information([0, 1], [0, 1], N=1000)

    # Punt CLT analysis and go for a small tolerance.
    assert np.allclose(entropy_exact, entropy_logpdf, atol=.15)
    assert np.allclose(entropy_exact, entropy_mi, atol=.15)
    assert np.allclose(entropy_logpdf, entropy_mi, atol=.1)
コード例 #2
0
ファイル: test_cmi.py プロジェクト: wilsondy/cgpm
def test_entropy_bernoulli_univariate__ci_():
    rng = gen_rng(10)

    # Generate a univariate Bernoulli dataset.
    T = rng.choice([0,1], p=[.3,.7], size=250).reshape(-1,1)

    engine = Engine(T, cctypes=['bernoulli'], rng=rng, num_states=16)
    engine.transition(S=15)

    # exact computation.
    entropy_exact = - (.3*np.log(.3) + .7*np.log(.7))

    # logpdf computation.
    logps = engine.logpdf_bulk([-1,-1], [{0:0}, {0:1}])
    entropy_logpdf = [-np.sum(np.exp(logp)*logp) for logp in logps]

    # mutual_information computation.
    entropy_mi = engine.mutual_information([0], [0], N=1000)

    # Punt CLT analysis and go for 1 dp.
    assert np.allclose(entropy_exact, entropy_logpdf, atol=.1)
    assert np.allclose(entropy_exact, entropy_mi, atol=.1)
    assert np.allclose(entropy_logpdf, entropy_mi, atol=.05)