def test_crosscat_two_component_nominal__ci_(): prng = get_prng(10) integration = pytest.config.getoption('--integration') # Build CGPM with adversarial initialization. crosscat = Product([ FlexibleRowMixture( cgpm_row_divide=CRP([-1], [], rng=prng), cgpm_components_base=Product([ Normal([0], [], rng=prng), ], rng=prng), rng=prng), FlexibleRowMixture( cgpm_row_divide=CRP([-2], [], rng=prng), cgpm_components_base=Product([ Normal([1], [], rng=prng), Categorical([50], [], distargs={'k':4}, rng=prng), ], rng=prng), rng=prng), ], rng=prng,) # Fetch data and add a nominal variable. data_xy = make_bivariate_two_clusters(prng) data_z = np.zeros(len(data_xy)) data_z[:15] = 0 data_z[15:30] = 1 data_z[30:45] = 2 data_z[45:60] = 3 data = np.column_stack((data_xy, data_z)) # Observe. for rowid, row in enumerate(data): crosscat.observe(rowid, {0: row[0], 1: row[1], 50:row[2]}) # Run inference. synthesizer = GibbsCrossCat(crosscat) synthesizer.transition(N=(50 if integration else 1), progress=False) synthesizer.transition(N=(100 if integration else 1), kernels=['hypers_distributions','hypers_row_divide'], progress=False) # Assert views are merged into one. assert not integration or len(synthesizer.crosscat.cgpms) == 1 crp_output = synthesizer.crosscat.cgpms[0].cgpm_row_divide.outputs[0] # Check joint samples for all nominals. samples = synthesizer.crosscat.simulate(None, [crp_output,0,1,50], N=250) not integration or check_sampled_data(samples, [0, 7], 3, 110) # Check joint samples for nominals [0, 2]. samples_a = [s for s in samples if s[50] in [0,2]] not integration or check_sampled_data(samples_a, [0, 7], 3, 45) # Check joint samples for nominals [1, 3]. samples_b = [s for s in samples if s[50] in [1,3]] not integration or check_sampled_data(samples_b, [0, 7], 3, 45) # Check conditional samples in correct quadrants. means = {0:0, 1:0, 2:7, 3:7} for z in [0, 1, 2, 3]: samples = synthesizer.crosscat.simulate(None, [0, 1], {50:z}, N=100) not integration or check_sampled_data(samples, [means[z]], 3, 90)
def custom_program(crosscat): from cgpm2.transition_crosscat import GibbsCrossCat synthesizer = GibbsCrossCat(crosscat) synthesizer.transition(N=N, kernels=[ 'hypers_distributions', 'hypers_row_divide', 'hypers_column_divide', 'row_assignments', ], progress=False) return synthesizer.crosscat
def custom_program(crosscat): from cgpm2.transition_crosscat import GibbsCrossCat synthesizer = GibbsCrossCat(crosscat) synthesizer.transition(N=1, progress=False) return synthesizer.crosscat