예제 #1
0
def test_labels_have_the_correct_number_of_lines(n_lines):
    n_samples, slots, dims = 100, 4, 2
    labels = combigen.generate_labels(
        n_samples=n_samples,
        slots=slots,
        dims=dims,
        n_lines=n_lines)
    # Ensure the sums are exactly what we expect them to be at each combination
    assert labels.sum() == n_lines*slots*n_samples
예제 #2
0
def test_label_statistics_correspond_to_correct_element_indices(idx):
    line_stats = np.zeros(10)
    line_stats[idx] = 1
    label = combigen.generate_labels(slots=1, n_lines=1, line_stats=line_stats)
    for i in range(5):
        for j in range(2):
            if i == idx%5 and j==idx//5:
                assert label[0, 0, i, j] == 1
            else:
                assert label[0, 0, i, j] == 0
예제 #3
0
def test_nonuniformly_sampled_labels_return_the_correct_statistics():
    slots = 100
    line_stats = np.array([1,2,3,4,5])

    # Create an array containing all the sums of labels with `dims` dimensions. 
    y_tests = np.array(
        [combigen.generate_labels(
            n_samples=1, 
            slots=slots, 
            dims=1, 
            n_lines=1, 
            line_stats=line_stats,).sum(axis=1)
         for _ in range(100)])

    # Separate out the means of sums
    means = y_tests.mean(axis=0).reshape((5))

    # Assert the different pair-wise sums are as expected of the stats defined
    # above.
    for l, m in zip(line_stats, means):
        assert np.isclose(means[0]*l, m, atol=slots*0.05)
예제 #4
0
def test_generate_labels_returns_correct_shapes(shapes):
    assert np.array_equal(
        combigen.generate_labels(*shapes).shape,
        shapes)
예제 #5
0
def test_combigen_heatmap_runs_without_errors_for_different_input_lengths(
        length):
    """Lifted directly from nb0.1 c382."""
    cgh.heatmap(cg.generate_labels(length))