Exemple #1
0
def test_incorporate_state():
    state = State(
        T[:,:2], cctypes=CCTYPES[:2], distargs=DISTARGS[:2], rng=gu.gen_rng(0))
    state.transition(N=5)

    target = state.views.keys()[0]

    # Incorporate a new dim into view[0].
    state.incorporate_dim(
        T[:,2], outputs=[2], cctype=CCTYPES[2], distargs=DISTARGS[2], v=target)
    assert state.Zv(2) == target
    state.transition(N=1)

    # Incorporate a new dim into view[0] with a non-contiguous output.
    state.incorporate_dim(
        T[:,2], outputs=[10], cctype=CCTYPES[2], distargs=DISTARGS[2], v=target)
    assert state.Zv(10) == target
    state.transition(N=1)

    # Some crash testing queries.
    state.logpdf(-1, {10:1}, constraints={0:2, 1:1})
    state.simulate(-1, [10], constraints={0:2})

    # Incorporating with a duplicated output should raise.
    with pytest.raises(ValueError):
        state.incorporate_dim(
            T[:,2], outputs=[10], cctype=CCTYPES[2], distargs=DISTARGS[2],
            v=target)

    # Multivariate incorporate should raise.
    with pytest.raises(ValueError):
        state.incorporate_dim(
            T[:,2], outputs=[10, 2], cctype=CCTYPES[2],
            distargs=DISTARGS[2], v=target)

    # Missing output should raise.
    with pytest.raises(ValueError):
        state.incorporate_dim(
            T[:,2], outputs=[], cctype=CCTYPES[2],
            distargs=DISTARGS[2], v=target)

    # Wrong number of rows should raise.
    with pytest.raises(ValueError):
        state.incorporate_dim(
            T[:,2][:-1], outputs=[11], cctype=CCTYPES[2],
            distargs=DISTARGS[2], v=target)

    # Inputs should raise.
    with pytest.raises(ValueError):
        state.incorporate_dim(
            T[:,2], outputs=[11], inputs=[2], cctype=CCTYPES[2],
            distargs=DISTARGS[2], v=target)

    # Incorporate dim into a newly created singleton view.
    target = max(state.views)+1
    state.incorporate_dim(
        T[:,3], outputs=[3], cctype=CCTYPES[3],
        distargs=DISTARGS[3], v=target)
    assert state.Zv(3) == target
    state.transition(N=1)

    # Incorporate dim without specifying a view.
    state.incorporate_dim(T[:,4], outputs=[4],
        cctype=CCTYPES[4], distargs=DISTARGS[4])
    state.transition(N=1)

    # Unincorporate first dim.
    previous = state.n_cols()
    state.unincorporate_dim(0)
    assert state.n_cols() == previous-1
    state.transition(N=1)

    # Reincorporate dim without specifying a view.
    state.incorporate_dim(
        T[:,0], outputs=[0], cctype=CCTYPES[0], distargs=DISTARGS[0])
    state.transition(N=1)

    # Incorporate dim into singleton view, remove it, assert destroyed.
    target = max(state.views)+1
    state.incorporate_dim(
        T[:,5], outputs=[5], cctype=CCTYPES[5], distargs=DISTARGS[5],
        v=target)
    previous = len(state.views)
    state.unincorporate_dim(5)
    assert len(state.views) == previous-1
    state.transition(N=1)

    # Reincorporate dim into a singleton view.
    target = max(state.views)+1
    state.incorporate_dim(T[:,5], outputs=[5], cctype=CCTYPES[5],
        distargs=DISTARGS[5], v=target)
    state.transition(N=1)

    # Incorporate the rest of the dims in the default way.
    for i in xrange(6, len(CCTYPES)):
        state.incorporate_dim(
            T[:,i], outputs=[max(state.outputs)+1],
            cctype=CCTYPES[i], distargs=DISTARGS[i])
    state.transition(N=1)

    # Unincorporating non-existent dim should raise.
    with pytest.raises(ValueError):
        state.unincorporate_dim(9999)

    # Unincorporate all the dims, except the last one.
    for o in state.outputs[:-1]:
        state.unincorporate_dim(o)
    assert state.n_cols() == 1
    state.transition(N=1)

    # Unincorporating last dim should raise.
    with pytest.raises(ValueError):
        state.unincorporate_dim(state.outputs[0])
Exemple #2
0
# Plot!
state.plot()

# Run some solid checks on a complex state.
test_crash_simulate_joint(state)
test_crash_logpdf_joint(state)
test_crash_simulate_conditional(state)
test_crash_logpdf_conditional(state)
test_crash_simulate_joint_observed(state)
test_crash_logpdf_joint_observed(state)
test_crash_simulate_conditional_observed(state)
test_crash_logpdf_conditional_observed(state)

# Joint equals chain rule for state 1.
joint = state.logpdf(-1, {0: 1, 1: 2})
chain = state.logpdf(-1, {0: 1}, {1: 2}) + state.logpdf(-1, {1: 2})
assert np.allclose(joint, chain)

if False:
    state2 = State(T.T, cctypes=cctypes, distargs=distargs, rng=gu.gen_rng(12))
    state2.transition(N=10, progress=1)

    # Joint equals chain rule for state 2.
    state2.logpdf(-1, {0: 1, 1: 2})
    state2.logpdf(-1, {0: 1}, {1: 2}) + state2.logpdf(-1, {1: 2})

    # Take the Monte Carlo average of the conditional.
    mc_conditional = np.log(.5) + gu.logsumexp(
        [state.logpdf(-1, {0: 1}, {1: 2}),
         state2.logpdf(-1, {0: 1}, {1: 2})])