예제 #1
0
def test_bernoulli():
    # Switch for multiprocess (0 is faster).
    multiprocess = 0

    # Create categorical data of DATA_NUM_0 zeros and DATA_NUM_1 ones.
    data = np.transpose(np.array([[0] * DATA_NUM_0 + [1] * DATA_NUM_1]))

    # Run a single chain for a few iterations.
    engine = Engine(data,
                    cctypes=['categorical'],
                    distargs=[{
                        'k': 2
                    }],
                    rng=gu.gen_rng(0),
                    multiprocess=0)
    engine.transition(NUM_ITER, multiprocess=multiprocess)

    # Simulate from hypothetical row and compute the proportion of ones.
    sample = engine.simulate(-1, [0], N=NUM_SIM, multiprocess=multiprocess)[0]
    sum_b = sum(s[0] for s in sample)
    observed_prob_of_1 = (float(sum_b) / float(NUM_SIM))
    true_prob_of_1 = float(DATA_NUM_1) / float(DATA_NUM_0 + DATA_NUM_1)
    # Check 1% relative match.
    assert np.allclose(true_prob_of_1, observed_prob_of_1, rtol=.1)

    # Simulate from observed row as a crash test.
    sample = engine.simulate(1, [0], N=1, multiprocess=multiprocess)

    # Ensure normalized unobserved probabilities.
    p0_uob = engine.logpdf(-1, {0: 0}, multiprocess=multiprocess)[0]
    p1_uob = engine.logpdf(-1, {0: 1}, multiprocess=multiprocess)[0]
    assert np.allclose(gu.logsumexp([p0_uob, p1_uob]), 0)

    # A logpdf query constraining an observed returns an error.
    with pytest.raises(ValueError):
        engine.logpdf(1, {0: 0}, multiprocess=multiprocess)
    with pytest.raises(ValueError):
        engine.logpdf(1, {0: 1}, multiprocess=multiprocess)
예제 #2
0
def test_serialize_composite_cgpm():
    rng = gu.gen_rng(2)

    # Generate the data.
    cctypes, distargs = cu.parse_distargs([
        'categorical(k=3)',     # RandomForest          0
        'normal',               # LinearRegression      1
        'categorical(k=3)',     # GPMCC                 2
        'poisson',              # GPMCC                 3
        'normal',               # GPMCC                 4
        'lognormal'             # GPMCC                 5
        ])
    T, Zv, Zc = tu.gen_data_table(
        35, [.4, .6], [[.33, .33, .34], [.5, .5]],
        cctypes, distargs, [.2]*len(cctypes), rng=rng)
    D = np.transpose(T)

    # Create GPMCC.
    state = State(
        D[:,2:], outputs=[2,3,4,5], cctypes=cctypes[2:],
        distargs=distargs[2:], rng=rng)

    # Create a Forest.
    forest = RandomForest(
        outputs=[0],
        inputs=[1,2,3,4],
        distargs={
            'inputs': {
                'stattypes': [cctypes[i] for i in [1,2,3,4]],
                'statargs': [distargs[i] for i in [1,2,3,4]]},
            'k': distargs[0]['k']},
        rng=rng)

    # Create a Regression.
    linreg = LinearRegression(
        outputs=[1],
        inputs=[3,4,5],
        distargs={
            'inputs': {
                'stattypes': [cctypes[i] for i in [3,4,5]],
                'statargs': [distargs[i] for i in [3,4,5]]}},
        rng=rng)

    # Incorporate the data.
    def incorporate_data(cgpm, rowid, row):
        cgpm.incorporate(
            rowid,
            {i: row[i] for i in cgpm.outputs},
            {i: row[i] for i in cgpm.inputs},
        )
    for rowid, row in enumerate(D):
        incorporate_data(forest, rowid, row)
        incorporate_data(linreg, rowid, row)

    # Compose the CGPMs.

    # Run state transitions.
    state.transition(N=10, progress=False)
    # Compose CGPMs, instructing State to run the transitions.
    token_forest = state.compose_cgpm(forest)
    token_linreg = state.compose_cgpm(linreg)
    state.transition_foreign(N=10, cols=[forest.outputs[0], linreg.outputs[0]])

    # Now run the serialization.
    metadata = state.to_metadata()
    state2 = State.from_metadata(metadata)

    # Check that the tokens are in state2.
    assert token_forest in state2.hooked_cgpms
    assert token_linreg in state2.hooked_cgpms

    # The hooked cgpms must be unique objects after serialize/deserialize.
    assert state.hooked_cgpms[token_forest] != state2.hooked_cgpms[token_forest]
    assert state.hooked_cgpms[token_linreg] != state2.hooked_cgpms[token_linreg]

    # Check that the log scores of the hooked cgpms agree.
    assert np.allclose(
        state.hooked_cgpms[token_forest].logpdf_score(),
        state2.hooked_cgpms[token_forest].logpdf_score())
    assert np.allclose(
        state.hooked_cgpms[token_linreg].logpdf_score(),
        state2.hooked_cgpms[token_linreg].logpdf_score())

    # Now run some tests for the engine.
    e = Engine(
        D[:,2:], outputs=[2,3,4,5], cctypes=cctypes[2:],
        distargs=distargs[2:], num_states=2, rng=rng)
    e.compose_cgpm([forest, forest], multiprocess=1)
    e.compose_cgpm([linreg, linreg], multiprocess=1)
    e.transition_foreign(N=1, cols=[forest.outputs[0], linreg.outputs[0]])
    e.dependence_probability(0,1)
    e.simulate(-1, [0,1], {2:1}, multiprocess=0)
    e.logpdf(-1, {1:1}, {2:1, 0:0}, multiprocess=0)

    state3 = e.get_state(0)

    # There is no guarantee that the logpdf score improves with inference, but
    # it should reduce by more than a few nats.
    def check_logpdf_delta(before, after):
        return before < after or (after-before) < 5
    check_logpdf_delta(
        before=state.hooked_cgpms[token_forest].logpdf_score(),
        after=state3.hooked_cgpms[token_forest].logpdf_score())
    check_logpdf_delta(
        before=state.hooked_cgpms[token_linreg].logpdf_score(),
        after=state3.hooked_cgpms[token_linreg].logpdf_score())