Beispiel #1
0
def test_server_conditional_gof(N, V, C, M):
    set_random_seed(make_seed(N, V, C, M, 1))
    model = generate_fake_model(N, V, C, M)
    config = TINY_CONFIG.copy()
    config['model_num_clusters'] = M
    model['config'] = config
    server = TreeCatServer(model)
    validate_gof(N, V, C, M, server, conditional=True)
Beispiel #2
0
def test_latent_perplexity(N, V, C, M):
    set_random_seed(make_seed(N, V, C, M))
    model = generate_fake_model(N, V, C, M)
    config = TINY_CONFIG.copy()
    config['model_num_clusters'] = M
    model['config'] = config
    server = TreeCatServer(model)

    perplexity = server.latent_perplexity()
    print(perplexity)
    assert perplexity.shape == (V, )
    assert np.all(1 <= perplexity)
    assert np.all(perplexity <= M)
Beispiel #3
0
def test_server_marginals(N, V, C, M):
    model = generate_fake_model(N, V, C, M)
    config = TINY_CONFIG.copy()
    config['model_num_clusters'] = M
    model['config'] = config
    server = TreeCatServer(model)

    # Evaluate on random data.
    table = generate_dataset(N, V, C)['table']
    marginals = server.marginals(table.data)
    for v in range(V):
        beg, end = table.ragged_index[v:v + 2]
        totals = marginals[:, beg:end].sum(axis=1)
        assert np.allclose(totals, 1.0)
Beispiel #4
0
def test_latent_correlation(N, V, C, M):
    set_random_seed(make_seed(N, V, C, M))
    model = generate_fake_model(N, V, C, M)
    config = TINY_CONFIG.copy()
    config['model_num_clusters'] = M
    model['config'] = config
    server = TreeCatServer(model)

    correlation = server.latent_correlation()
    print(correlation)
    assert np.all(0 <= correlation)
    assert np.all(correlation <= 1)
    assert np.allclose(correlation, correlation.T)
    for v in range(V):
        assert correlation[v, :].argmax() == v
        assert correlation[:, v].argmax() == v
Beispiel #5
0
def test_observed_perplexity(N, V, C, M):
    set_random_seed(make_seed(N, V, C, M))
    model = generate_fake_model(N, V, C, M)
    config = TINY_CONFIG.copy()
    config['model_num_clusters'] = M
    model['config'] = config
    server = TreeCatServer(model)

    for count in [1, 2, 3]:
        if count > 1 and C > 2:
            continue  # NotImplementedError.
        counts = 1
        perplexity = server.observed_perplexity(counts)
        print(perplexity)
        assert perplexity.shape == (V, )
        assert np.all(1 <= perplexity)
        assert np.all(perplexity <= count * C)
Beispiel #6
0
def test_server_median(N, V, C, M):
    model = generate_fake_model(N, V, C, M)
    config = TINY_CONFIG.copy()
    config['model_num_clusters'] = M
    model['config'] = config
    server = TreeCatServer(model)

    # Evaluate on random data.
    counts = np.random.randint(10, size=[V], dtype=np.int8)
    table = generate_dataset(N, V, C)['table']
    median = server.median(counts, table.data)
    assert median.shape == table.data.shape
    assert median.dtype == np.int8
    for v in range(V):
        beg, end = table.ragged_index[v:v + 2]
        totals = median[:, beg:end].sum(axis=1)
        assert np.all(totals == counts[v])
Beispiel #7
0
def test_server_logprob_normalized(N, V, C, M):
    model = generate_fake_model(N, V, C, M)
    config = TINY_CONFIG.copy()
    config['model_num_clusters'] = M
    model['config'] = config
    server = TreeCatServer(model)

    # The total probability of all categorical rows should be 1.
    ragged_index = model['suffstats']['ragged_index']
    factors = []
    for v in range(V):
        C = ragged_index[v + 1] - ragged_index[v]
        factors.append([one_hot(c, C) for c in range(C)])
    data = np.array(
        [np.concatenate(columns) for columns in itertools.product(*factors)],
        dtype=np.int8)
    logprobs = server.logprob(data)
    logtotal = np.logaddexp.reduce(logprobs)
    assert logtotal == pytest.approx(0.0, abs=1e-5)
Beispiel #8
0
def test_e2e(model_type):
    with tempdir() as dirname:
        data_csv = os.path.join(TESTDATA, 'tiny_data.csv')
        config = TINY_CONFIG.copy()

        print('Guess schema.')
        types_csv = os.path.join(dirname, 'types.csv')
        values_csv = os.path.join(dirname, 'values.csv')
        guess_schema(data_csv, types_csv, values_csv)

        print('Load schema')
        groups_csv = os.path.join(TESTDATA, 'tiny_groups.csv')
        schema = load_schema(types_csv, values_csv, groups_csv)
        ragged_index = schema['ragged_index']
        tree_prior = schema['tree_prior']

        print('Load data')
        data = load_data(schema, data_csv)
        feature_types = [TY_MULTINOMIAL] * len(schema['feature_names'])
        table = Table(feature_types, ragged_index, data)
        dataset = {
            'schema': schema,
            'table': table,
        }

        print('Train model')
        if model_type == 'single':
            model = train_model(table, tree_prior, config)
        elif model_type == 'ensemble':
            model = train_ensemble(table, tree_prior, config)
        else:
            raise ValueError(model_type)

        print('Serve model')
        server = serve_model(dataset, model)

        print('Query model')
        evidence = {'genre': 'drama'}
        server.logprob([evidence])
        samples = server.sample(100)
        server.logprob(samples)
        samples = server.sample(100, evidence)
        server.logprob(samples)
        try:
            median = server.median([evidence])
            server.logprob(median)
        except NotImplementedError:
            warn('{} median not implemented'.format(model_type))
            pass
        try:
            mode = server.mode([evidence])
            server.logprob(mode)
        except NotImplementedError:
            warn('{} mode not implemented'.format(model_type))
            pass

        print('Examine latent structure')
        server.feature_density()
        server.observed_perplexity()
        server.latent_perplexity()
        server.latent_correlation()
        server.estimate_tree()
        server.sample_tree(10)

        print('Plotting latent structure')
        plot_circular(server)