def test_naive_bayes_independence_lovecat():
    rng = gu.gen_rng(1)
    D = rng.normal(size=(10, 1))
    T = np.repeat(D, 10, axis=1)
    Ci = list(itertools.combinations(range(10), 2))
    state = State(T, cctypes=['normal'] * 10, Ci=Ci, rng=gu.gen_rng(0))
    state.transition(N=10, progress=0)
    vu.validate_crp_constrained_partition(state.Zv(), [], Ci, {}, {})
    state.transition_lovecat(N=100, progress=0)
    vu.validate_crp_constrained_partition(state.Zv(), [], Ci, {}, {})
def test_complex_independent_relationships():
    rng = gu.gen_rng(1)
    D = rng.normal(size=(10, 1))
    T = np.repeat(D, 10, axis=1)
    Ci = [(2, 8), (0, 3)]
    state = State(T, cctypes=['normal'] * 10, Ci=Ci, rng=rng)
    state.transition(N=10, progress=0)
    vu.validate_crp_constrained_partition(state.Zv(), [], Ci, {}, {})
def test_simple_dependence_constraint(Ci):
    rng = gu.gen_rng(1)
    D = rng.normal(size=(10, 1))
    T = np.repeat(D, 10, axis=1)
    Cd = [(2, 0), (8, 3)]
    state = State(T, cctypes=['normal'] * 10, Ci=Ci, Cd=Cd, rng=rng)
    with pytest.raises(ValueError):
        # Cannot transition columns with dependencies.
        state.transition(N=10, kernels=['columns'], progress=0)
    state.transition(
        N=10,
        kernels=['rows', 'alpha', 'column_hypers', 'alpha', 'view_alphas'],
        progress=False)
    vu.validate_crp_constrained_partition(state.Zv(), Cd, Ci, {}, {})
Exemple #4
0
def test_categorical_forest():
    state = State(
        T, cctypes=CCTYPES, distargs=DISTARGS, rng=gu.gen_rng(1))
    state.transition(N=1, progress=False)
    cat_id = CCTYPES.index('categorical')

    # If cat_id is singleton migrate first.
    if len(state.view_for(cat_id).dims) == 1:
        distargs = DISTARGS[cat_id].copy()
        state.unincorporate_dim(cat_id)
        state.incorporate_dim(
            T[:,cat_id], outputs=[cat_id], cctype='categorical',
            distargs=distargs, v=0)
    state.update_cctype(cat_id, 'random_forest', distargs=distargs)

    bernoulli_id = CCTYPES.index('bernoulli')
    state.incorporate_dim(
        T[:,bernoulli_id], outputs=[191], cctype='bernoulli',
        v=state.Zv(cat_id))
    state.update_cctype(191, 'random_forest', distargs={'k':2})

    # Run valid transitions.
    state.transition(
        N=2, kernels=['rows','column_params','column_hypers'],
        views=[state.Zv(cat_id)], progress=False)

    # Running column transition should raise.
    with pytest.raises(ValueError):
        state.transition(N=1, kernels=['columns'], progress=False)

    # Updating cctype in singleton View should raise.
    distargs = DISTARGS[cat_id].copy()
    state.incorporate_dim(
        T[:,CCTYPES.index('categorical')], outputs=[98],
        cctype='categorical', distargs=distargs, v=max(state.views)+1)
    with pytest.raises(Exception):
        state.update_cctype(98, 'random_forest', distargs=distargs)
def test_independence_inference_quality_lovecat():
    rng = gu.gen_rng(584)
    column_view_1 = rng.normal(loc=0, size=(50, 1))

    column_view_2 = np.concatenate((
        rng.normal(loc=10, size=(25, 1)),
        rng.normal(loc=20, size=(25, 1)),
    ))

    data_view_1 = np.repeat(column_view_1, 4, axis=1)
    data_view_2 = np.repeat(column_view_2, 4, axis=1)
    data = np.column_stack((data_view_1, data_view_2))

    Zv0 = {i: 0 for i in xrange(8)}
    state = State(data, Zv=Zv0, cctypes=['normal'] * 8, rng=gu.gen_rng(10))
    state.transition_lovecat(N=100, progress=1)
    for col in [
            0,
            1,
            2,
            3,
    ]:
        assert state.Zv(col) == state.Zv(0)
    for col in [4, 5, 6, 7]:
        assert state.Zv(col) == state.Zv(4)
    assert state.Zv(0) != state.Zv(4)

    # Get lovecat to merge the dependent columns into one view.
    Cd = [(0, 1), (2, 3), (4, 5), (6, 7)]
    Zv0 = {0: 0, 1: 0, 2: 1, 3: 1, 4: 2, 5: 2, 6: 3, 7: 3}
    state = State(data,
                  Zv=Zv0,
                  cctypes=['normal'] * 8,
                  Cd=Cd,
                  rng=gu.gen_rng(1))
    state.transition_lovecat(N=100, progress=1)
    for col in [
            0,
            1,
            2,
            3,
    ]:
        assert state.Zv(col) == state.Zv(0)
    for col in [4, 5, 6, 7]:
        assert state.Zv(col) == state.Zv(4)
    assert state.Zv(0) != state.Zv(4)
Exemple #6
0
def test_incorporate_state():
    state = State(
        T[:,:2], cctypes=CCTYPES[:2], distargs=DISTARGS[:2], rng=gu.gen_rng(0))
    state.transition(N=5)

    target = state.views.keys()[0]

    # Incorporate a new dim into view[0].
    state.incorporate_dim(
        T[:,2], outputs=[2], cctype=CCTYPES[2], distargs=DISTARGS[2], v=target)
    assert state.Zv(2) == target
    state.transition(N=1)

    # Incorporate a new dim into view[0] with a non-contiguous output.
    state.incorporate_dim(
        T[:,2], outputs=[10], cctype=CCTYPES[2], distargs=DISTARGS[2], v=target)
    assert state.Zv(10) == target
    state.transition(N=1)

    # Some crash testing queries.
    state.logpdf(-1, {10:1}, constraints={0:2, 1:1})
    state.simulate(-1, [10], constraints={0:2})

    # Incorporating with a duplicated output should raise.
    with pytest.raises(ValueError):
        state.incorporate_dim(
            T[:,2], outputs=[10], cctype=CCTYPES[2], distargs=DISTARGS[2],
            v=target)

    # Multivariate incorporate should raise.
    with pytest.raises(ValueError):
        state.incorporate_dim(
            T[:,2], outputs=[10, 2], cctype=CCTYPES[2],
            distargs=DISTARGS[2], v=target)

    # Missing output should raise.
    with pytest.raises(ValueError):
        state.incorporate_dim(
            T[:,2], outputs=[], cctype=CCTYPES[2],
            distargs=DISTARGS[2], v=target)

    # Wrong number of rows should raise.
    with pytest.raises(ValueError):
        state.incorporate_dim(
            T[:,2][:-1], outputs=[11], cctype=CCTYPES[2],
            distargs=DISTARGS[2], v=target)

    # Inputs should raise.
    with pytest.raises(ValueError):
        state.incorporate_dim(
            T[:,2], outputs=[11], inputs=[2], cctype=CCTYPES[2],
            distargs=DISTARGS[2], v=target)

    # Incorporate dim into a newly created singleton view.
    target = max(state.views)+1
    state.incorporate_dim(
        T[:,3], outputs=[3], cctype=CCTYPES[3],
        distargs=DISTARGS[3], v=target)
    assert state.Zv(3) == target
    state.transition(N=1)

    # Incorporate dim without specifying a view.
    state.incorporate_dim(T[:,4], outputs=[4],
        cctype=CCTYPES[4], distargs=DISTARGS[4])
    state.transition(N=1)

    # Unincorporate first dim.
    previous = state.n_cols()
    state.unincorporate_dim(0)
    assert state.n_cols() == previous-1
    state.transition(N=1)

    # Reincorporate dim without specifying a view.
    state.incorporate_dim(
        T[:,0], outputs=[0], cctype=CCTYPES[0], distargs=DISTARGS[0])
    state.transition(N=1)

    # Incorporate dim into singleton view, remove it, assert destroyed.
    target = max(state.views)+1
    state.incorporate_dim(
        T[:,5], outputs=[5], cctype=CCTYPES[5], distargs=DISTARGS[5],
        v=target)
    previous = len(state.views)
    state.unincorporate_dim(5)
    assert len(state.views) == previous-1
    state.transition(N=1)

    # Reincorporate dim into a singleton view.
    target = max(state.views)+1
    state.incorporate_dim(T[:,5], outputs=[5], cctype=CCTYPES[5],
        distargs=DISTARGS[5], v=target)
    state.transition(N=1)

    # Incorporate the rest of the dims in the default way.
    for i in xrange(6, len(CCTYPES)):
        state.incorporate_dim(
            T[:,i], outputs=[max(state.outputs)+1],
            cctype=CCTYPES[i], distargs=DISTARGS[i])
    state.transition(N=1)

    # Unincorporating non-existent dim should raise.
    with pytest.raises(ValueError):
        state.unincorporate_dim(9999)

    # Unincorporate all the dims, except the last one.
    for o in state.outputs[:-1]:
        state.unincorporate_dim(o)
    assert state.n_cols() == 1
    state.transition(N=1)

    # Unincorporating last dim should raise.
    with pytest.raises(ValueError):
        state.unincorporate_dim(state.outputs[0])