def test_bernoulli(): # Switch for multiprocess (0 is faster). multiprocess = 0 # Create categorical data of DATA_NUM_0 zeros and DATA_NUM_1 ones. data = np.transpose(np.array([[0] * DATA_NUM_0 + [1] * DATA_NUM_1])) # Run a single chain for a few iterations. engine = Engine(data, cctypes=['categorical'], distargs=[{ 'k': 2 }], rng=gu.gen_rng(0), multiprocess=0) engine.transition(NUM_ITER, multiprocess=multiprocess) # Simulate from hypothetical row and compute the proportion of ones. sample = engine.simulate(-1, [0], N=NUM_SIM, multiprocess=multiprocess)[0] sum_b = sum(s[0] for s in sample) observed_prob_of_1 = (float(sum_b) / float(NUM_SIM)) true_prob_of_1 = float(DATA_NUM_1) / float(DATA_NUM_0 + DATA_NUM_1) # Check 1% relative match. assert np.allclose(true_prob_of_1, observed_prob_of_1, rtol=.1) # Simulate from observed row as a crash test. sample = engine.simulate(1, [0], N=1, multiprocess=multiprocess) # Ensure normalized unobserved probabilities. p0_uob = engine.logpdf(-1, {0: 0}, multiprocess=multiprocess)[0] p1_uob = engine.logpdf(-1, {0: 1}, multiprocess=multiprocess)[0] assert np.allclose(gu.logsumexp([p0_uob, p1_uob]), 0) # A logpdf query constraining an observed returns an error. with pytest.raises(ValueError): engine.logpdf(1, {0: 0}, multiprocess=multiprocess) with pytest.raises(ValueError): engine.logpdf(1, {0: 1}, multiprocess=multiprocess)
def test_serialize_composite_cgpm(): rng = gu.gen_rng(2) # Generate the data. cctypes, distargs = cu.parse_distargs([ 'categorical(k=3)', # RandomForest 0 'normal', # LinearRegression 1 'categorical(k=3)', # GPMCC 2 'poisson', # GPMCC 3 'normal', # GPMCC 4 'lognormal' # GPMCC 5 ]) T, Zv, Zc = tu.gen_data_table( 35, [.4, .6], [[.33, .33, .34], [.5, .5]], cctypes, distargs, [.2]*len(cctypes), rng=rng) D = np.transpose(T) # Create GPMCC. state = State( D[:,2:], outputs=[2,3,4,5], cctypes=cctypes[2:], distargs=distargs[2:], rng=rng) # Create a Forest. forest = RandomForest( outputs=[0], inputs=[1,2,3,4], distargs={ 'inputs': { 'stattypes': [cctypes[i] for i in [1,2,3,4]], 'statargs': [distargs[i] for i in [1,2,3,4]]}, 'k': distargs[0]['k']}, rng=rng) # Create a Regression. linreg = LinearRegression( outputs=[1], inputs=[3,4,5], distargs={ 'inputs': { 'stattypes': [cctypes[i] for i in [3,4,5]], 'statargs': [distargs[i] for i in [3,4,5]]}}, rng=rng) # Incorporate the data. def incorporate_data(cgpm, rowid, row): cgpm.incorporate( rowid, {i: row[i] for i in cgpm.outputs}, {i: row[i] for i in cgpm.inputs}, ) for rowid, row in enumerate(D): incorporate_data(forest, rowid, row) incorporate_data(linreg, rowid, row) # Compose the CGPMs. # Run state transitions. state.transition(N=10, progress=False) # Compose CGPMs, instructing State to run the transitions. token_forest = state.compose_cgpm(forest) token_linreg = state.compose_cgpm(linreg) state.transition_foreign(N=10, cols=[forest.outputs[0], linreg.outputs[0]]) # Now run the serialization. metadata = state.to_metadata() state2 = State.from_metadata(metadata) # Check that the tokens are in state2. assert token_forest in state2.hooked_cgpms assert token_linreg in state2.hooked_cgpms # The hooked cgpms must be unique objects after serialize/deserialize. assert state.hooked_cgpms[token_forest] != state2.hooked_cgpms[token_forest] assert state.hooked_cgpms[token_linreg] != state2.hooked_cgpms[token_linreg] # Check that the log scores of the hooked cgpms agree. assert np.allclose( state.hooked_cgpms[token_forest].logpdf_score(), state2.hooked_cgpms[token_forest].logpdf_score()) assert np.allclose( state.hooked_cgpms[token_linreg].logpdf_score(), state2.hooked_cgpms[token_linreg].logpdf_score()) # Now run some tests for the engine. e = Engine( D[:,2:], outputs=[2,3,4,5], cctypes=cctypes[2:], distargs=distargs[2:], num_states=2, rng=rng) e.compose_cgpm([forest, forest], multiprocess=1) e.compose_cgpm([linreg, linreg], multiprocess=1) e.transition_foreign(N=1, cols=[forest.outputs[0], linreg.outputs[0]]) e.dependence_probability(0,1) e.simulate(-1, [0,1], {2:1}, multiprocess=0) e.logpdf(-1, {1:1}, {2:1, 0:0}, multiprocess=0) state3 = e.get_state(0) # There is no guarantee that the logpdf score improves with inference, but # it should reduce by more than a few nats. def check_logpdf_delta(before, after): return before < after or (after-before) < 5 check_logpdf_delta( before=state.hooked_cgpms[token_forest].logpdf_score(), after=state3.hooked_cgpms[token_forest].logpdf_score()) check_logpdf_delta( before=state.hooked_cgpms[token_linreg].logpdf_score(), after=state3.hooked_cgpms[token_linreg].logpdf_score())