Esempio n. 1
0
def test_transition_hypers(cctype):
    name, arg = cctype
    model = cu.cctype_class(name)(outputs=[0],
                                  inputs=None,
                                  distargs=arg,
                                  rng=gu.gen_rng(10))
    D, Zv, Zc = tu.gen_data_table(50, [1], [[.33, .33, .34]], [name], [arg],
                                  [.8],
                                  rng=gu.gen_rng(1))

    hypers_previous = model.get_hypers()
    for rowid, x in enumerate(np.ravel(D)[:25]):
        model.incorporate(rowid, {0: x}, None)
    model.transition_hypers(N=3)
    hypers_new = model.get_hypers()
    assert not all(
        np.allclose(hypers_new[hyper], hypers_previous[hyper])
        for hyper in hypers_new)

    for rowid, x in enumerate(np.ravel(D)[:25]):
        model.incorporate(rowid + 25, {0: x}, None)
    model.transition_hypers(N=3)
    hypers_newer = model.get_hypers()
    assert not all(
        np.allclose(hypers_new[hyper], hypers_newer[hyper])
        for hyper in hypers_newer)
def test_dependence_probability_pairwise():
    cctypes, distargs = cu.parse_distargs(['normal', 'normal', 'normal'])

    T, Zv, _Zc = tu.gen_data_table(10, [.5, .5], [[.25, .25, .5], [.3, .7]],
                                   cctypes,
                                   distargs, [.95] * len(cctypes),
                                   rng=gu.gen_rng(100))

    outputs = [0, 1, 2]
    engine = Engine(T.T,
                    outputs=outputs,
                    cctypes=cctypes,
                    num_states=4,
                    distargs=distargs,
                    Zv={o: z
                        for o, z in zip(outputs, Zv)},
                    rng=gu.gen_rng(0))

    Ds = engine.dependence_probability_pairwise(multiprocess=0)
    assert len(Ds) == engine.num_states()
    assert all(np.shape(D) == (len(outputs), len(outputs)) for D in Ds)
    for D in Ds:
        for col0, col1 in itertools.product(outputs, outputs):
            i0 = outputs.index(col0)
            i1 = outputs.index(col1)
            actual = D[i0, i1]
            expected = Zv[i0] == Zv[i1]
            assert actual == expected

    Ds = engine.dependence_probability_pairwise(colnos=[0, 2], multiprocess=0)
    assert len(Ds) == engine.num_states()
    assert all(np.shape(D) == (2, 2) for D in Ds)
Esempio n. 3
0
def engine():
    # Set up the data generation
    cctypes, distargs = cu.parse_distargs([
        'normal',
        'poisson',
        'bernoulli',
        'categorical(k=4)',
        'lognormal',
        'exponential',
        'beta',
        'geometric',
        'vonmises',
    ])

    T, Zv, Zc = tu.gen_data_table(20, [1], [[.25, .25, .5]],
                                  cctypes,
                                  distargs, [.95] * len(cctypes),
                                  rng=gu.gen_rng(10))

    return Engine(T.T,
                  cctypes=cctypes,
                  distargs=distargs,
                  num_states=4,
                  rng=gu.gen_rng(312),
                  multiprocess=False)
Esempio n. 4
0
def simulate_synthetic(n_samples, cctype, distargs):
    rng = gu.gen_rng(12)
    D, Zv, Zc = tu.gen_data_table(n_samples,
                                  VIEW_WEIGHTS,
                                  CLUSTER_WEIGHTS, [cctype], [distargs],
                                  SEPARATION,
                                  rng=rng)
    rng.shuffle(D[0])
    return np.asarray(D).T
def state():
    rng = gu.gen_rng(5)
    rows = 120
    cctypes = ['normal', 'bernoulli', 'normal']
    G = generate_quadrants(rows, rng)
    B, Zv, Zrv = tu.gen_data_table(rows, [1], [[.5, .5]], ['bernoulli'],
                                   [None], [.95],
                                   rng=rng)
    T = np.column_stack((G, B.T))[:, [0, 2, 1]]
    state = State(T, outputs=[0, 1, 2], cctypes=cctypes, rng=rng)
    state.transition(N=20)
    return state
Esempio n. 6
0
def retrieve_bernoulli_dataset():
    D, Zv, Zc = tu.gen_data_table(n_rows=150,
                                  view_weights=None,
                                  cluster_weights=[
                                      [.5, .5],
                                      [.1, .9],
                                  ],
                                  cctypes=['bernoulli'] * 4,
                                  distargs=[None] * 4,
                                  separation=[0.95] * 4,
                                  view_partition=[0, 0, 1, 1],
                                  rng=gu.gen_rng(12))
    return D
Esempio n. 7
0
def retrieve_normal_dataset():
    D, Zv, Zc = tu.gen_data_table(n_rows=150,
                                  view_weights=None,
                                  cluster_weights=[
                                      [.2, .2, .2, .4],
                                      [.3, .2, .5],
                                  ],
                                  cctypes=['normal'] * 6,
                                  distargs=[None] * 6,
                                  separation=[0.95] * 6,
                                  view_partition=[0, 0, 0, 1, 1, 1],
                                  rng=gu.gen_rng(12))
    return D
Esempio n. 8
0
def generate_dataset():
    # Set up the data generation, 20 rows by 8 cols, with some missing values.
    D, Zv, Zc = tu.gen_data_table(20, [1], [[.25, .25, .5]],
                                  cctypes,
                                  distargs, [.95] * len(cctypes),
                                  rng=gu.gen_rng(2))

    # Generate some missing entries in D.
    missing = rng.choice(range(D.shape[1]), size=(D.shape[0], 4), replace=True)
    for i, m in enumerate(missing):
        D[i, m] = np.nan

    T = np.transpose(D)
    return T
Esempio n. 9
0
def test_dependence_probability():
    '''Test that Loom correctly recovers a 2-view dataset.'''
    D, Zv, Zc = tu.gen_data_table(n_rows=150,
                                  view_weights=None,
                                  cluster_weights=[
                                      [.2, .2, .2, .4],
                                      [.3, .2, .5],
                                  ],
                                  cctypes=['normal'] * 6,
                                  distargs=[None] * 6,
                                  separation=[0.95] * 6,
                                  view_partition=[0, 0, 0, 1, 1, 1],
                                  rng=gu.gen_rng(12))

    engine = Engine(
        D.T,
        outputs=[7, 2, 12, 80, 129, 98],
        cctypes=['normal'] * len(D),
        distargs=[None] * 6,
        rng=gu.gen_rng(122),
        num_states=20,
    )

    logscore0 = engine.logpdf_score()
    engine.transition_loom(N=100)
    logscore1 = engine.logpdf_score()
    assert numpy.mean(logscore1) > numpy.mean(logscore0)

    dependence_probability = numpy.mean(
        engine.dependence_probability_pairwise(), axis=0)

    assert dependence_probability[0, 1] > 0.8
    assert dependence_probability[1, 2] > 0.8
    assert dependence_probability[0, 2] > 0.8

    assert dependence_probability[3, 4] > 0.8
    assert dependence_probability[4, 5] > 0.8
    assert dependence_probability[3, 5] > 0.8

    assert dependence_probability[0, 3] < 0.2
    assert dependence_probability[0, 4] < 0.2
    assert dependence_probability[0, 5] < 0.2

    assert dependence_probability[1, 3] < 0.2
    assert dependence_probability[1, 4] < 0.2
    assert dependence_probability[1, 5] < 0.2

    assert dependence_probability[2, 3] < 0.2
    assert dependence_probability[2, 4] < 0.2
    assert dependence_probability[2, 5] < 0.2
Esempio n. 10
0
def populate_crosscat(crosscat, prng):
    X, Zv, Zrv = gen_data_table(
        n_rows=10,
        view_weights=[.4, .6],
        cluster_weights=[[.3,.4,.3],[.5,.5]],
        cctypes=['normal','normal','poisson','normal','normal','categorical'],
        distargs=[None, None, None, None, None, {'k':4}],
        separation=[0.99]*6,
        rng=prng)
    X[0,1] = X[3,1] = float('nan')
    dataset = np.transpose(X)
    for rowid, row in enumerate(dataset):
        observation = {c:v for c,v in enumerate(row)}
        crosscat.observe(rowid, observation)
    return crosscat
Esempio n. 11
0
def state():
    # Set up the data generation
    cctypes, distargs = cu.parse_distargs(
        ['normal', 'poisson', 'bernoulli', 'lognormal', 'beta', 'vonmises'])
    T, Zv, Zc = tu.gen_data_table(30, [1], [[.25, .25, .5]],
                                  cctypes,
                                  distargs, [.95] * len(cctypes),
                                  rng=gu.gen_rng(0))
    T = T.T
    s = State(T,
              cctypes=cctypes,
              distargs=distargs,
              Zv={i: 0
                  for i in xrange(len(cctypes))},
              rng=gu.gen_rng(0))
    return s
Esempio n. 12
0
def test_errors():
    """Targets loomcat._validate_transition."""
    D, Zv, Zc = tu.gen_data_table(n_rows=150,
                                  view_weights=None,
                                  cluster_weights=[
                                      [.2, .2, .2, .4],
                                      [.3, .2, .5],
                                  ],
                                  cctypes=['normal'] * 6,
                                  distargs=[None] * 6,
                                  separation=[0.95] * 6,
                                  view_partition=[0, 0, 0, 1, 1, 1],
                                  rng=gu.gen_rng(12))

    state = State(
        D.T,
        outputs=range(10, 16),
        cctypes=['normal'] * len(D),
        distargs=[None] * 6,
        rng=gu.gen_rng(122),
    )

    engine = Engine(
        D.T,
        outputs=range(10, 16),
        cctypes=['normal'] * len(D),
        distargs=[None] * 6,
        rng=gu.gen_rng(122),
    )

    def check_errors(cgpm):
        with pytest.raises(ValueError):
            cgpm.transition_loom(N=10, S=5)
        with pytest.raises(ValueError):
            cgpm.transition_loom(N=10, kernels=['alpha'])
        with pytest.raises(ValueError):
            cgpm.transition_loom(N=10, progress=True)
        with pytest.raises(ValueError):
            cgpm.transition_loom(N=10, progress=True)
        with pytest.raises(ValueError):
            cgpm.transition_loom(N=10, checkpoint=2)
        cgpm.transition_loom(N=2)

    check_errors(state)
    check_errors(engine)
Esempio n. 13
0
def test_multiple_stattypes():
    '''Test cgpm statistical types are heuristically converted to Loom types.'''
    cctypes, distargs = cu.parse_distargs([
        'normal', 'poisson', 'bernoulli', 'categorical(k=4)', 'lognormal',
        'exponential', 'beta', 'geometric', 'vonmises'
    ])

    T, Zv, Zc = tu.gen_data_table(200, [1], [[.25, .25, .5]],
                                  cctypes,
                                  distargs, [.95] * len(cctypes),
                                  rng=gu.gen_rng(10))

    engine = Engine(
        T.T,
        cctypes=cctypes,
        distargs=distargs,
        rng=gu.gen_rng(15),
        num_states=16,
    )

    logscore0 = engine.logpdf_score()
    engine.transition_loom(N=5)
    logscore1 = engine.logpdf_score()
    assert numpy.mean(logscore1) > numpy.mean(logscore0)

    # Check serializeation.
    metadata = engine.to_metadata()
    modname = importlib.import_module(metadata['factory'][0])
    builder = getattr(modname, metadata['factory'][1])
    engine2 = builder.from_metadata(metadata)

    # To JSON.
    json_metadata = json.dumps(engine.to_metadata())
    engine3 = builder.from_metadata(json.loads(json_metadata))

    # Assert all states in engine, engine2, and engine3 have same loom_path.
    loom_paths = list(
        itertools.chain.from_iterable([s._loom_path for s in e.states]
                                      for e in [engine, engine2, engine3]))
    assert all(p == loom_paths[0] for p in loom_paths)

    engine2.transition(S=5)
    dependence_probability = engine2.dependence_probability_pairwise()

    assert numpy.all(dependence_probability > 0.85)
Esempio n. 14
0
def generate_real_nominal_data(N, rng=None):
    # Generates a bivariate dataset, where the first variable x is real-valued
    # and the second variable z is nominal with 6 levels. The real variable's
    # mean is determined by the value of z, where there are three means
    # corresponding to levels [(0,1), (2,3), (4,5)].

    if rng is None: rng = gu.gen_rng(0)
    T, Zv, Zc = tu.gen_data_table(
        N, [1], [[.3, .5, .2]], ['normal'], [None], [.95], rng=rng)
    data = np.zeros((N, 2))
    data[:,0] = T[0]
    indicators = [0, 1, 2, 3, 4, 5]
    counts = {0:0, 1:0, 2:0}
    for i in xrange(N):
        k = Zc[0][i]
        data[i,1] = 2*indicators[k] + counts[k] % 2
        counts[k] += 1
    return data, indicators
Esempio n. 15
0
def get_engine():
    cctypes, distargs = cu.parse_distargs(
        ['normal', 'poisson', 'bernoulli', 'lognormal', 'beta', 'vonmises'])
    T, Zv, Zc = tu.gen_data_table(20, [1], [[.25, .25, .5]],
                                  cctypes,
                                  distargs, [.95] * len(cctypes),
                                  rng=gu.gen_rng(0))
    T = T.T
    # Make some nan cells for evidence.
    T[5, 0] = T[5, 1] = T[5, 2] = T[5, 3] = np.nan
    T[8, 4] = np.nan
    engine = Engine(T,
                    cctypes=cctypes,
                    distargs=distargs,
                    num_states=6,
                    rng=gu.gen_rng(0))
    engine.transition(N=2)
    return engine
Esempio n. 16
0
def state():
    cctypes, distargs = cu.parse_distargs(
        ['categorical(k=5)', 'normal', 'poisson', 'bernoulli'])
    T, Zv, Zc = tu.gen_data_table(50, [1], [[.33, .33, .34]],
                                  cctypes,
                                  distargs, [.95] * len(cctypes),
                                  rng=gu.gen_rng(0))
    s = State(T.T,
              cctypes=cctypes,
              distargs=distargs,
              Zv={i: 0
                  for i in xrange(len(cctypes))},
              rng=gu.gen_rng(0))
    s.update_cctype(0, 'random_forest', distargs={'k': 5})
    # XXX Uncomment me for a bug!
    # state.update_cctype(1, 'linear_regression')
    kernels = [
        'rows', 'view_alphas', 'alpha', 'column_params', 'column_hypers'
    ]
    s.transition(N=1, kernels=kernels)
    return s
Esempio n. 17
0
from math import log

import numpy as np

from cgpm.regressions.ols import OrdinaryLeastSquares
from cgpm.utils import config as cu
from cgpm.utils import general as gu
from cgpm.utils import test as tu

cctypes, distargs = cu.parse_distargs([
    'normal', 'categorical(k=3)', 'poisson', 'bernoulli', 'lognormal',
    'exponential', 'geometric', 'vonmises', 'normal'
])

T, Zv, Zc = tu.gen_data_table(100, [1], [[.33, .33, .34]],
                              cctypes,
                              distargs, [.2] * len(cctypes),
                              rng=gu.gen_rng(0))

D = T.T
OLS_DISTARGS = {
    'inputs': {
        'stattypes':
        cctypes[1:],
        'statargs': [{
            'k': 3
        }] + [None] + [{
            'k': 2
        }] + [None, None, None, None, None]
    }
}
OLS_OUTPUTS = [0]
Esempio n. 18
0
import pytest

import matplotlib.pyplot as plt
import numpy as np

from scipy.stats import ks_2samp

from cgpm.crosscat.engine import Engine
from cgpm.utils import general as gu
from cgpm.utils import test as tu

N_SAMPLES = 250

T, Zv, Zc = tu.gen_data_table(N_SAMPLES, [1], [[.3, .5, .2]], ['normal'],
                              [None], [.95],
                              rng=gu.gen_rng(0))

DATA = np.zeros((N_SAMPLES, 2))
DATA[:, 0] = T[0]

INDICATORS = [0, 1, 2, 3, 4, 5]

counts = {0: 0, 1: 0, 2: 0}
for i in xrange(N_SAMPLES):
    k = Zc[0][i]
    DATA[i, 1] = 2 * INDICATORS[k] + counts[k] % 2
    counts[k] += 1


@pytest.fixture(scope='module')
Esempio n. 19
0
def test_dependence_probability():
    cctypes, distargs = cu.parse_distargs(
        ['normal', 'poisson', 'bernoulli', 'lognormal', 'beta', 'vonmises'])

    T, Zv, Zc = tu.gen_data_table(100, [.5, .5], [[.25, .25, .5], [.3, .7]],
                                  cctypes,
                                  distargs, [.95] * len(cctypes),
                                  rng=gu.gen_rng(100))

    T = T.T
    outputs = range(0, 12, 2)

    # Test for direct dependence for state and engine.
    s = State(T,
              outputs=outputs,
              cctypes=cctypes,
              distargs=distargs,
              Zv={o: z
                  for o, z in zip(outputs, Zv)},
              rng=gu.gen_rng(0))

    e = Engine(T,
               outputs=outputs,
               cctypes=cctypes,
               distargs=distargs,
               Zv={o: z
                   for o, z in zip(outputs, Zv)},
               rng=gu.gen_rng(0))

    for C in [s, e]:
        for col0, col1 in itertools.product(outputs, outputs):
            i0 = outputs.index(col0)
            i1 = outputs.index(col1)
            assert (compute_depprob(C.dependence_probability(
                col0, col1)) == (Zv[i0] == Zv[i1]))

    # Hook some cgpms into state.

    # XXX What if Zv has only one unique value? Hopefully not with this rng!
    uniques = list(set(Zv))
    parent_1 = [o for i, o in enumerate(outputs) if Zv[i] == uniques[0]]
    parent_2 = [o for i, o in enumerate(outputs) if Zv[i] == uniques[1]]

    c1 = BareBonesCGpm(outputs=[1821, 154], inputs=[parent_1[0]])
    c2 = BareBonesCGpm(outputs=[1721], inputs=[parent_2[0]])
    c3 = BareBonesCGpm(outputs=[9721], inputs=[parent_2[1]])
    c4 = BareBonesCGpm(outputs=[74], inputs=[9721])

    for i, C in enumerate([s, e]):
        C.compose_cgpm(c1 if i == 0 else [c1])
        C.compose_cgpm(c2 if i == 0 else [c2])
        C.compose_cgpm(c3 if i == 0 else [c3])
        C.compose_cgpm(c4 if i == 0 else [c4])

        # Between hooked cgpms and state parents.
        for p in parent_1:
            assert compute_depprob(C.dependence_probability(1821, p)) == 1
            assert compute_depprob(C.dependence_probability(154, p)) == 1
            assert compute_depprob(C.dependence_probability(1721, p)) == 0
            assert compute_depprob(C.dependence_probability(9721, p)) == 0
            assert compute_depprob(C.dependence_probability(74, p)) == 0
        for p in parent_2:
            assert compute_depprob(C.dependence_probability(1821, p)) == 0
            assert compute_depprob(C.dependence_probability(154, p)) == 0
            assert compute_depprob(C.dependence_probability(1721, p)) == 1
            assert compute_depprob(C.dependence_probability(9721, p)) == 1
            assert compute_depprob(C.dependence_probability(74, p)) == 1

        # Between hooked cgpm.
        assert compute_depprob(C.dependence_probability(9721, 1721)) == 1
        assert compute_depprob(C.dependence_probability(1821, 154)) == 1
        assert compute_depprob(C.dependence_probability(74, 9721)) == 1
        assert compute_depprob(C.dependence_probability(74, 1721)) == 1

        assert compute_depprob(C.dependence_probability(1821, 1721)) == 0
        assert compute_depprob(C.dependence_probability(1821, 74)) == 0
        assert compute_depprob(C.dependence_probability(154, 74)) == 0
Esempio n. 20
0
def test_serialize_composite_cgpm():
    rng = gu.gen_rng(2)

    # Generate the data.
    cctypes, distargs = cu.parse_distargs([
        'categorical(k=3)',     # RandomForest          0
        'normal',               # LinearRegression      1
        'categorical(k=3)',     # GPMCC                 2
        'poisson',              # GPMCC                 3
        'normal',               # GPMCC                 4
        'lognormal'             # GPMCC                 5
        ])
    T, Zv, Zc = tu.gen_data_table(
        35, [.4, .6], [[.33, .33, .34], [.5, .5]],
        cctypes, distargs, [.2]*len(cctypes), rng=rng)
    D = np.transpose(T)

    # Create GPMCC.
    state = State(
        D[:,2:], outputs=[2,3,4,5], cctypes=cctypes[2:],
        distargs=distargs[2:], rng=rng)

    # Create a Forest.
    forest = RandomForest(
        outputs=[0],
        inputs=[1,2,3,4],
        distargs={
            'inputs': {
                'stattypes': [cctypes[i] for i in [1,2,3,4]],
                'statargs': [distargs[i] for i in [1,2,3,4]]},
            'k': distargs[0]['k']},
        rng=rng)

    # Create a Regression.
    linreg = LinearRegression(
        outputs=[1],
        inputs=[3,4,5],
        distargs={
            'inputs': {
                'stattypes': [cctypes[i] for i in [3,4,5]],
                'statargs': [distargs[i] for i in [3,4,5]]}},
        rng=rng)

    # Incorporate the data.
    def incorporate_data(cgpm, rowid, row):
        cgpm.incorporate(
            rowid,
            {i: row[i] for i in cgpm.outputs},
            {i: row[i] for i in cgpm.inputs},
        )
    for rowid, row in enumerate(D):
        incorporate_data(forest, rowid, row)
        incorporate_data(linreg, rowid, row)

    # Compose the CGPMs.

    # Run state transitions.
    state.transition(N=10, progress=False)
    # Compose CGPMs, instructing State to run the transitions.
    token_forest = state.compose_cgpm(forest)
    token_linreg = state.compose_cgpm(linreg)
    state.transition_foreign(N=10, cols=[forest.outputs[0], linreg.outputs[0]])

    # Now run the serialization.
    metadata = state.to_metadata()
    state2 = State.from_metadata(metadata)

    # Check that the tokens are in state2.
    assert token_forest in state2.hooked_cgpms
    assert token_linreg in state2.hooked_cgpms

    # The hooked cgpms must be unique objects after serialize/deserialize.
    assert state.hooked_cgpms[token_forest] != state2.hooked_cgpms[token_forest]
    assert state.hooked_cgpms[token_linreg] != state2.hooked_cgpms[token_linreg]

    # Check that the log scores of the hooked cgpms agree.
    assert np.allclose(
        state.hooked_cgpms[token_forest].logpdf_score(),
        state2.hooked_cgpms[token_forest].logpdf_score())
    assert np.allclose(
        state.hooked_cgpms[token_linreg].logpdf_score(),
        state2.hooked_cgpms[token_linreg].logpdf_score())

    # Now run some tests for the engine.
    e = Engine(
        D[:,2:], outputs=[2,3,4,5], cctypes=cctypes[2:],
        distargs=distargs[2:], num_states=2, rng=rng)
    e.compose_cgpm([forest, forest], multiprocess=1)
    e.compose_cgpm([linreg, linreg], multiprocess=1)
    e.transition_foreign(N=1, cols=[forest.outputs[0], linreg.outputs[0]])
    e.dependence_probability(0,1)
    e.simulate(-1, [0,1], {2:1}, multiprocess=0)
    e.logpdf(-1, {1:1}, {2:1, 0:0}, multiprocess=0)

    state3 = e.get_state(0)

    # There is no guarantee that the logpdf score improves with inference, but
    # it should reduce by more than a few nats.
    def check_logpdf_delta(before, after):
        return before < after or (after-before) < 5
    check_logpdf_delta(
        before=state.hooked_cgpms[token_forest].logpdf_score(),
        after=state3.hooked_cgpms[token_forest].logpdf_score())
    check_logpdf_delta(
        before=state.hooked_cgpms[token_linreg].logpdf_score(),
        after=state3.hooked_cgpms[token_linreg].logpdf_score())
Esempio n. 21
0
from cgpm.utils import general as gu
from cgpm.utils import test as tu


CCTYPES, DISTARGS = cu.parse_distargs([
    'normal',        # 0
    'poisson',       # 1
    'bernoulli',     # 2
    'lognormal',     # 3
    'exponential',   # 4
    'geometric',     # 5
    'vonmises'])     # 6


T, Zv, Zc = tu.gen_data_table(
    10, [1], [[.33, .33, .34]], CCTYPES, DISTARGS,
    [.95]*len(CCTYPES), rng=gu.gen_rng(0))
T = T.T


def test_incorporate_engine():
    engine = Engine(
        T[:,:2],
        cctypes=CCTYPES[:2],
        distargs=DISTARGS[:2],
        num_states=4,
        rng=gu.gen_rng(0),
    )
    engine.transition(N=5)

    # Incorporate a new dim into with a non-contiguous output.
Esempio n. 22
0
import pytest

from cgpm.crosscat.state import State
from cgpm.utils import config as cu
from cgpm.utils import general as gu
from cgpm.utils import general as gu
from cgpm.utils import test as tu

# Set up the data generation
cctypes, distargs = cu.parse_distargs([
    'normal', 'poisson', 'bernoulli', 'categorical(k=4)', 'lognormal',
    'exponential', 'beta', 'geometric', 'vonmises'
])

T, Zv, Zc = tu.gen_data_table(200, [1], [[.25, .25, .5]],
                              cctypes,
                              distargs, [.95] * len(cctypes),
                              rng=gu.gen_rng(10))

state = State(T.T, cctypes=cctypes, distargs=distargs, rng=gu.gen_rng(312))
state.transition(N=10, progress=1)


def test_crash_simulate_joint(state):
    state.simulate(-1, [0, 1, 2, 3, 4, 5, 6, 7, 8], N=10)


def test_crash_logpdf_joint(state):
    state.logpdf(-1, {
        0: 1,
        1: 2,
        2: 1,