Пример #1
0
def test_server_logprob_shape(model):
    table = TINY_TABLE
    server = TreeCatServer(model)
    logprobs = server.logprob(table.data)
    N = table.num_rows
    assert logprobs.dtype == np.float32
    assert logprobs.shape == (N, )
    assert np.isfinite(logprobs).all()
Пример #2
0
def test_latent_perplexity(N, V, C, M):
    set_random_seed(make_seed(N, V, C, M))
    model = generate_fake_model(N, V, C, M)
    config = TINY_CONFIG.copy()
    config['model_num_clusters'] = M
    model['config'] = config
    server = TreeCatServer(model)

    perplexity = server.latent_perplexity()
    print(perplexity)
    assert perplexity.shape == (V, )
    assert np.all(1 <= perplexity)
    assert np.all(perplexity <= M)
Пример #3
0
def process_eval_task(task):
    (dataset_path, config, models_dir) = task

    # Load a server with the trained model.
    model_path = os.path.join(models_dir,
                              'model.{}.pkz'.format(serialize_config(config)))
    try:
        model = pickle_load(model_path)
    except (OSError, EOFError):
        return {'config': config}
    print('Eval {}'.format(os.path.basename(model_path)))
    server = TreeCatServer(model)

    # Split data for crossvalidation.
    num_parts = config['model_ensemble_size']
    partid = config['seed']
    assert 0 <= partid < num_parts
    dataset = pickle_load(dataset_path)
    table = dataset['table']
    ragged_index = table.ragged_index
    data = table.data
    mask = split_data(ragged_index, table.num_rows, num_parts, partid)
    training_data = data.copy()
    training_data[mask] = 0
    validation_data = data.copy()
    validation_data[~mask] = 0

    # Compute posterior predictive log probability of held-out data.
    logprob = np.mean(server.logprob(data) - server.logprob(training_data))

    # Compute L1 loss on observed validation data.
    N, R = data.shape
    V = len(ragged_index) - 1
    obs_counts = count_observations(ragged_index, data)
    assert obs_counts.shape == (N, V)
    max_counts = obs_counts.max(axis=0)
    median = server.median(max_counts, training_data)
    observed = (obs_counts == max_counts[np.newaxis, :])
    observed = make_ragged_mask(ragged_index, observed.T).T
    relevant = observed & mask
    validation_data[~relevant] = 0
    median[~relevant] = 0
    l1_loss = 0.5 * np.abs(median - validation_data).sum()
    l1_loss /= relevant.sum() + 0.1

    return {
        'config': config,
        'logprob': logprob,
        'l1_loss': l1_loss,
        'profiling_stats': model.get('profiling_stats', {}),
    }
Пример #4
0
def test_server_marginals(N, V, C, M):
    model = generate_fake_model(N, V, C, M)
    config = TINY_CONFIG.copy()
    config['model_num_clusters'] = M
    model['config'] = config
    server = TreeCatServer(model)

    # Evaluate on random data.
    table = generate_dataset(N, V, C)['table']
    marginals = server.marginals(table.data)
    for v in range(V):
        beg, end = table.ragged_index[v:v + 2]
        totals = marginals[:, beg:end].sum(axis=1)
        assert np.allclose(totals, 1.0)
Пример #5
0
def test_latent_correlation(N, V, C, M):
    set_random_seed(make_seed(N, V, C, M))
    model = generate_fake_model(N, V, C, M)
    config = TINY_CONFIG.copy()
    config['model_num_clusters'] = M
    model['config'] = config
    server = TreeCatServer(model)

    correlation = server.latent_correlation()
    print(correlation)
    assert np.all(0 <= correlation)
    assert np.all(correlation <= 1)
    assert np.allclose(correlation, correlation.T)
    for v in range(V):
        assert correlation[v, :].argmax() == v
        assert correlation[:, v].argmax() == v
Пример #6
0
def serve_files(model_path, config_path, num_samples):
    """INTERNAL Serve from pickled model, config."""
    from treecat.serving import TreeCatServer
    import numpy as np
    model = pickle_load(model_path)
    config = pickle_load(config_path)
    model['config'] = config
    server = TreeCatServer(model)
    counts = np.ones(model['tree'].num_vertices, np.int8)
    samples = server.sample(int(num_samples), counts)
    server.logprob(samples)
    server.median(counts, samples)
    server.latent_correlation()
Пример #7
0
def test_observed_perplexity(N, V, C, M):
    set_random_seed(make_seed(N, V, C, M))
    model = generate_fake_model(N, V, C, M)
    config = TINY_CONFIG.copy()
    config['model_num_clusters'] = M
    model['config'] = config
    server = TreeCatServer(model)

    for count in [1, 2, 3]:
        if count > 1 and C > 2:
            continue  # NotImplementedError.
        counts = 1
        perplexity = server.observed_perplexity(counts)
        print(perplexity)
        assert perplexity.shape == (V, )
        assert np.all(1 <= perplexity)
        assert np.all(perplexity <= count * C)
Пример #8
0
def test_server_conditional_gof(N, V, C, M):
    set_random_seed(make_seed(N, V, C, M, 1))
    model = generate_fake_model(N, V, C, M)
    config = TINY_CONFIG.copy()
    config['model_num_clusters'] = M
    model['config'] = config
    server = TreeCatServer(model)
    validate_gof(N, V, C, M, server, conditional=True)
Пример #9
0
def test_server_median(N, V, C, M):
    model = generate_fake_model(N, V, C, M)
    config = TINY_CONFIG.copy()
    config['model_num_clusters'] = M
    model['config'] = config
    server = TreeCatServer(model)

    # Evaluate on random data.
    counts = np.random.randint(10, size=[V], dtype=np.int8)
    table = generate_dataset(N, V, C)['table']
    median = server.median(counts, table.data)
    assert median.shape == table.data.shape
    assert median.dtype == np.int8
    for v in range(V):
        beg, end = table.ragged_index[v:v + 2]
        totals = median[:, beg:end].sum(axis=1)
        assert np.all(totals == counts[v])
Пример #10
0
def test_server_logprob_normalized(N, V, C, M):
    model = generate_fake_model(N, V, C, M)
    config = TINY_CONFIG.copy()
    config['model_num_clusters'] = M
    model['config'] = config
    server = TreeCatServer(model)

    # The total probability of all categorical rows should be 1.
    ragged_index = model['suffstats']['ragged_index']
    factors = []
    for v in range(V):
        C = ragged_index[v + 1] - ragged_index[v]
        factors.append([one_hot(c, C) for c in range(C)])
    data = np.array(
        [np.concatenate(columns) for columns in itertools.product(*factors)],
        dtype=np.int8)
    logprobs = server.logprob(data)
    logtotal = np.logaddexp.reduce(logprobs)
    assert logtotal == pytest.approx(0.0, abs=1e-5)
Пример #11
0
def test_server_sample_shape(model):
    server = TreeCatServer(model)
    validate_sample_shape(TINY_TABLE, server)
Пример #12
0
def generate_clean_dataset(tree, num_rows, num_cats):
    """Generate a dataset whose structure should be easy to learn.

    This generates a highly correlated uniformly distributed dataset with
    given tree structure. This is useful to test that structure learning can
    recover a known structure.

    Args:
        tree: A TreeStructure instance.
        num_rows: The number of rows in the generated dataset.
        num_cats: The number of categories in the geneated categorical dataset.
            This will also be used for the number of latent classes.

    Returns:
        A dict with key 'table' and value a Table object.
    """
    assert isinstance(tree, TreeStructure)
    V = tree.num_vertices
    E = V - 1
    K = V * (V - 1) // 2
    C = num_cats
    M = num_cats
    config = make_config(model_num_clusters=M)
    ragged_index = np.arange(0, C * (V + 1), C, np.int32)
    ragged_index.flags.writeable = False

    # Create sufficient statistics that are ideal for structure learning:
    # Correlation should be high enough that (vertex,vertex) correlation can be
    # detected, but low enough that multi-hop correlation can be distinguished
    # from single-hop correlation.
    # Observations should have very low error rate.
    edge_precision = 1
    feat_precision = 100
    vert_ss = np.zeros((V, M), dtype=np.int32)
    edge_ss = np.zeros((E, M, M), dtype=np.int32)
    feat_ss = np.zeros((V * C, M), dtype=np.int32)
    meas_ss = np.zeros([V, M], np.int32)
    vert_ss[...] = edge_precision
    meas_ss[...] = feat_precision
    for e, v1, v2 in tree.tree_grid.T:
        edge_ss[e, :, :] = edge_precision * np.eye(M, dtype=np.int32)
    for v in range(V):
        beg, end = ragged_index[v:v + 2]
        feat_ss[beg:end, :] = feat_precision * np.eye(M, dtype=np.int32)
    model = {
        'config': config,
        'tree': tree,
        'edge_logits': np.zeros(K, np.float32),
        'suffstats': {
            'ragged_index': ragged_index,
            'vert_ss': vert_ss,
            'edge_ss': edge_ss,
            'feat_ss': feat_ss,
            'meas_ss': meas_ss,
        },
    }
    server = TreeCatServer(model)
    data = server.sample(num_rows, counts=np.ones(V, np.int8))
    data.flags.writeable = False
    feature_types = [TY_MULTINOMIAL] * V
    table = Table(feature_types, ragged_index, data)
    return {'table': table}