def test_engine_composition(): from cgpm.crosscat.engine import Engine X = np.asarray([ [1, 2, 0, 1], [1, 1, 0, 0], ]) engine = Engine(X[:, [3]], outputs=[3], cctypes=['normal'], num_states=2) cgpm = VsCGpm( outputs=[0, 1], inputs=[3], source=source_abstract, ) for i, row in enumerate(X): cgpm.incorporate(i, {0: row[0], 1: row[1]}, {3: row[3]}) cgpm.transition(N=2) engine.compose_cgpm([cgpm, cgpm], multiprocess=True)
def test_serialize_composite_cgpm(): rng = gu.gen_rng(2) # Generate the data. cctypes, distargs = cu.parse_distargs([ 'categorical(k=3)', # RandomForest 0 'normal', # LinearRegression 1 'categorical(k=3)', # GPMCC 2 'poisson', # GPMCC 3 'normal', # GPMCC 4 'lognormal' # GPMCC 5 ]) T, Zv, Zc = tu.gen_data_table( 35, [.4, .6], [[.33, .33, .34], [.5, .5]], cctypes, distargs, [.2]*len(cctypes), rng=rng) D = np.transpose(T) # Create GPMCC. state = State( D[:,2:], outputs=[2,3,4,5], cctypes=cctypes[2:], distargs=distargs[2:], rng=rng) # Create a Forest. forest = RandomForest( outputs=[0], inputs=[1,2,3,4], distargs={ 'inputs': { 'stattypes': [cctypes[i] for i in [1,2,3,4]], 'statargs': [distargs[i] for i in [1,2,3,4]]}, 'k': distargs[0]['k']}, rng=rng) # Create a Regression. linreg = LinearRegression( outputs=[1], inputs=[3,4,5], distargs={ 'inputs': { 'stattypes': [cctypes[i] for i in [3,4,5]], 'statargs': [distargs[i] for i in [3,4,5]]}}, rng=rng) # Incorporate the data. def incorporate_data(cgpm, rowid, row): cgpm.incorporate( rowid, {i: row[i] for i in cgpm.outputs}, {i: row[i] for i in cgpm.inputs}, ) for rowid, row in enumerate(D): incorporate_data(forest, rowid, row) incorporate_data(linreg, rowid, row) # Compose the CGPMs. # Run state transitions. state.transition(N=10, progress=False) # Compose CGPMs, instructing State to run the transitions. token_forest = state.compose_cgpm(forest) token_linreg = state.compose_cgpm(linreg) state.transition_foreign(N=10, cols=[forest.outputs[0], linreg.outputs[0]]) # Now run the serialization. metadata = state.to_metadata() state2 = State.from_metadata(metadata) # Check that the tokens are in state2. assert token_forest in state2.hooked_cgpms assert token_linreg in state2.hooked_cgpms # The hooked cgpms must be unique objects after serialize/deserialize. assert state.hooked_cgpms[token_forest] != state2.hooked_cgpms[token_forest] assert state.hooked_cgpms[token_linreg] != state2.hooked_cgpms[token_linreg] # Check that the log scores of the hooked cgpms agree. assert np.allclose( state.hooked_cgpms[token_forest].logpdf_score(), state2.hooked_cgpms[token_forest].logpdf_score()) assert np.allclose( state.hooked_cgpms[token_linreg].logpdf_score(), state2.hooked_cgpms[token_linreg].logpdf_score()) # Now run some tests for the engine. e = Engine( D[:,2:], outputs=[2,3,4,5], cctypes=cctypes[2:], distargs=distargs[2:], num_states=2, rng=rng) e.compose_cgpm([forest, forest], multiprocess=1) e.compose_cgpm([linreg, linreg], multiprocess=1) e.transition_foreign(N=1, cols=[forest.outputs[0], linreg.outputs[0]]) e.dependence_probability(0,1) e.simulate(-1, [0,1], {2:1}, multiprocess=0) e.logpdf(-1, {1:1}, {2:1, 0:0}, multiprocess=0) state3 = e.get_state(0) # There is no guarantee that the logpdf score improves with inference, but # it should reduce by more than a few nats. def check_logpdf_delta(before, after): return before < after or (after-before) < 5 check_logpdf_delta( before=state.hooked_cgpms[token_forest].logpdf_score(), after=state3.hooked_cgpms[token_forest].logpdf_score()) check_logpdf_delta( before=state.hooked_cgpms[token_linreg].logpdf_score(), after=state3.hooked_cgpms[token_linreg].logpdf_score())