コード例 #1
0
ファイル: test_vscgpm.py プロジェクト: leighhopcroft/cgpm
def test_engine_composition():
    from cgpm.crosscat.engine import Engine

    X = np.asarray([
        [1, 2, 0, 1],
        [1, 1, 0, 0],
    ])
    engine = Engine(X[:, [3]], outputs=[3], cctypes=['normal'], num_states=2)
    cgpm = VsCGpm(
        outputs=[0, 1],
        inputs=[3],
        source=source_abstract,
    )

    for i, row in enumerate(X):
        cgpm.incorporate(i, {0: row[0], 1: row[1]}, {3: row[3]})

    cgpm.transition(N=2)
    engine.compose_cgpm([cgpm, cgpm], multiprocess=True)
コード例 #2
0
ファイル: test_serialize.py プロジェクト: wilsondy/cgpm
def test_serialize_composite_cgpm():
    rng = gu.gen_rng(2)

    # Generate the data.
    cctypes, distargs = cu.parse_distargs([
        'categorical(k=3)',     # RandomForest          0
        'normal',               # LinearRegression      1
        'categorical(k=3)',     # GPMCC                 2
        'poisson',              # GPMCC                 3
        'normal',               # GPMCC                 4
        'lognormal'             # GPMCC                 5
        ])
    T, Zv, Zc = tu.gen_data_table(
        35, [.4, .6], [[.33, .33, .34], [.5, .5]],
        cctypes, distargs, [.2]*len(cctypes), rng=rng)
    D = np.transpose(T)

    # Create GPMCC.
    state = State(
        D[:,2:], outputs=[2,3,4,5], cctypes=cctypes[2:],
        distargs=distargs[2:], rng=rng)

    # Create a Forest.
    forest = RandomForest(
        outputs=[0],
        inputs=[1,2,3,4],
        distargs={
            'inputs': {
                'stattypes': [cctypes[i] for i in [1,2,3,4]],
                'statargs': [distargs[i] for i in [1,2,3,4]]},
            'k': distargs[0]['k']},
        rng=rng)

    # Create a Regression.
    linreg = LinearRegression(
        outputs=[1],
        inputs=[3,4,5],
        distargs={
            'inputs': {
                'stattypes': [cctypes[i] for i in [3,4,5]],
                'statargs': [distargs[i] for i in [3,4,5]]}},
        rng=rng)

    # Incorporate the data.
    def incorporate_data(cgpm, rowid, row):
        cgpm.incorporate(
            rowid,
            {i: row[i] for i in cgpm.outputs},
            {i: row[i] for i in cgpm.inputs},
        )
    for rowid, row in enumerate(D):
        incorporate_data(forest, rowid, row)
        incorporate_data(linreg, rowid, row)

    # Compose the CGPMs.

    # Run state transitions.
    state.transition(N=10, progress=False)
    # Compose CGPMs, instructing State to run the transitions.
    token_forest = state.compose_cgpm(forest)
    token_linreg = state.compose_cgpm(linreg)
    state.transition_foreign(N=10, cols=[forest.outputs[0], linreg.outputs[0]])

    # Now run the serialization.
    metadata = state.to_metadata()
    state2 = State.from_metadata(metadata)

    # Check that the tokens are in state2.
    assert token_forest in state2.hooked_cgpms
    assert token_linreg in state2.hooked_cgpms

    # The hooked cgpms must be unique objects after serialize/deserialize.
    assert state.hooked_cgpms[token_forest] != state2.hooked_cgpms[token_forest]
    assert state.hooked_cgpms[token_linreg] != state2.hooked_cgpms[token_linreg]

    # Check that the log scores of the hooked cgpms agree.
    assert np.allclose(
        state.hooked_cgpms[token_forest].logpdf_score(),
        state2.hooked_cgpms[token_forest].logpdf_score())
    assert np.allclose(
        state.hooked_cgpms[token_linreg].logpdf_score(),
        state2.hooked_cgpms[token_linreg].logpdf_score())

    # Now run some tests for the engine.
    e = Engine(
        D[:,2:], outputs=[2,3,4,5], cctypes=cctypes[2:],
        distargs=distargs[2:], num_states=2, rng=rng)
    e.compose_cgpm([forest, forest], multiprocess=1)
    e.compose_cgpm([linreg, linreg], multiprocess=1)
    e.transition_foreign(N=1, cols=[forest.outputs[0], linreg.outputs[0]])
    e.dependence_probability(0,1)
    e.simulate(-1, [0,1], {2:1}, multiprocess=0)
    e.logpdf(-1, {1:1}, {2:1, 0:0}, multiprocess=0)

    state3 = e.get_state(0)

    # There is no guarantee that the logpdf score improves with inference, but
    # it should reduce by more than a few nats.
    def check_logpdf_delta(before, after):
        return before < after or (after-before) < 5
    check_logpdf_delta(
        before=state.hooked_cgpms[token_forest].logpdf_score(),
        after=state3.hooked_cgpms[token_forest].logpdf_score())
    check_logpdf_delta(
        before=state.hooked_cgpms[token_linreg].logpdf_score(),
        after=state3.hooked_cgpms[token_linreg].logpdf_score())