示例#1
0
def test_add_catplot():
    pytest.importorskip('seaborn')
    X = generate_data(n_samples=100)
    upset = UpSet(X)
    # smoke test
    upset.add_catplot('violin')
    fig = matplotlib.figure.Figure()
    upset.plot(fig)

    # can't provide value with Series
    with pytest.raises(ValueError):
        upset.add_catplot('violin', value='foo')

    # check the above add_catplot did not break the state
    upset.plot(fig)

    X = generate_data(n_samples=100)
    X.name = 'foo'
    X = X.to_frame()
    upset = UpSet(X, sum_over=False)
    # must provide value with DataFrame
    with pytest.raises(ValueError):
        upset.add_catplot('violin')
    upset.add_catplot('violin', value='foo')
    with pytest.raises(ValueError):
        # not a known column
        upset.add_catplot('violin', value='bar')
    upset.plot(fig)

    # invalid plot kind raises error when plotting
    upset.add_catplot('foobar', value='foo')
    with pytest.raises(AttributeError):
        upset.plot(fig)
示例#2
0
def test_not_aggregated(sort_by, sort_sets_by):
    # FIXME: this is not testing if aggregation used is count or sum
    kw = {'sort_by': sort_by, 'sort_sets_by': sort_sets_by}
    Xagg = generate_data(aggregated=True)
    intersections1, totals1 = _process_data(Xagg, **kw)
    Xunagg = generate_data()
    Xunagg.loc[:] = 1
    intersections2, totals2 = _process_data(Xunagg, **kw)
    assert_series_equal(intersections1, intersections2, check_dtype=False)
    assert_series_equal(totals1, totals2, check_dtype=False)
示例#3
0
def test_not_aggregated(sort_by, sort_sets_by):
    # FIXME: this is not testing if aggregation used is count or sum
    kw = {'sort_by': sort_by, 'sort_sets_by': sort_sets_by, 'sum_over': None}
    Xagg = generate_data(aggregated=True)
    df1, intersections1, totals1 = _process_data(Xagg, **kw)
    Xunagg = generate_data()
    Xunagg.loc[:] = 1
    df2, intersections2, totals2 = _process_data(Xunagg, **kw)
    assert_series_equal(intersections1, intersections2,
                        check_dtype=False)
    assert_series_equal(totals1, totals2, check_dtype=False)
    assert set(df1.columns) == {'_value', '_bin'}
    assert set(df2.columns) == {'_value', '_bin'}
    assert len(df2) == len(Xunagg)
    assert df2['_bin'].nunique() == len(intersections2)
示例#4
0
def test_plot_smoke_test(kw):
    fig = matplotlib.figure.Figure()
    X = generate_data(n_samples=100)
    plot(X, fig, **kw)
    fig.savefig(io.BytesIO(), format='png')

    # Also check fig is optional
    n_nums = len(plt.get_fignums())
    plot(X, **kw)
    assert len(plt.get_fignums()) - n_nums == 1
    assert plt.gcf().axes
示例#5
0
def test_vertical():
    X = generate_data(n_samples=100)

    fig = matplotlib.figure.Figure()
    UpSet(X, orientation='horizontal').make_grid(fig)
    horz_height = fig.get_figheight()
    horz_width = fig.get_figwidth()
    assert horz_height < horz_width

    fig = matplotlib.figure.Figure()
    UpSet(X, orientation='vertical').make_grid(fig)
    vert_height = fig.get_figheight()
    vert_width = fig.get_figwidth()
    assert horz_width / horz_height > vert_width / vert_height

    # TODO: test axes positions, plot order, bar orientation
    pass
示例#6
0
def test_show_counts(orientation):
    fig = matplotlib.figure.Figure()
    X = generate_data(n_samples=100)
    plot(X, fig)
    n_artists_no_sizes = _count_descendants(fig)

    fig = matplotlib.figure.Figure()
    plot(X, fig, show_counts=True)
    n_artists_yes_sizes = _count_descendants(fig)
    assert n_artists_yes_sizes - n_artists_no_sizes > 6

    fig = matplotlib.figure.Figure()
    plot(X, fig, show_counts='%0.2g')
    assert n_artists_yes_sizes == _count_descendants(fig)

    with pytest.raises(ValueError):
        fig = matplotlib.figure.Figure()
        plot(X, fig, show_counts='%0.2h')
示例#7
0
def test_element_size():
    X = generate_data(n_samples=100)
    figsizes = []
    for element_size in range(10, 50, 5):
        fig = matplotlib.figure.Figure()
        UpSet(X, element_size=element_size).make_grid(fig)
        figsizes.append((fig.get_figwidth(), fig.get_figheight()))

    figwidths, figheights = zip(*figsizes)
    # Absolute width increases
    assert np.all(np.diff(figwidths) > 0)
    aspect = np.divide(figwidths, figheights)
    # Font size stays constant, so aspect ratio decreases
    assert np.all(np.diff(aspect) < 0)
    # But doesn't decrease by much
    assert np.all(aspect[:-1] / aspect[1:] < 1.1)

    fig = matplotlib.figure.Figure()
    figsize_before = fig.get_figwidth(), fig.get_figheight()
    UpSet(X, element_size=None).make_grid(fig)
    figsize_after = fig.get_figwidth(), fig.get_figheight()
    assert figsize_before == figsize_after
示例#8
0
def test_param_validation(kw):
    X = generate_data(n_samples=100)
    with pytest.raises(ValueError):
        UpSet(X, **kw)
示例#9
0
import matplotlib.figure
import matplotlib.pyplot as plt

from upsetplot import plot
from upsetplot import UpSet
from upsetplot import generate_data
from upsetplot.plotting import _process_data


def is_ascending(seq):
    # return np.all(np.diff(seq) >= 0)
    return sorted(seq) == list(seq)


@pytest.mark.parametrize('X', [
    generate_data(aggregated=True),
    generate_data(aggregated=True).iloc[1:-2],
])
@pytest.mark.parametrize('sort_by', ['cardinality', 'degree'])
@pytest.mark.parametrize('sort_sets_by', [None, 'cardinality'])
def test_process_data_series(X, sort_by, sort_sets_by):
    with pytest.raises(ValueError, match='sum_over is not applicable'):
        _process_data(X, sort_by=sort_by, sort_sets_by=sort_sets_by,
                      sum_over=False)

    df, intersections, totals = _process_data(X,
                                              sort_by=sort_by,
                                              sort_sets_by=sort_sets_by,
                                              sum_over=None)
    assert intersections.name == 'value'
    X_reordered = (X
示例#10
0
"""
====================
Vertical orientation
====================

This illustrates the effect of orientation='vertical'.
"""

from matplotlib import pyplot as plt
from upsetplot import generate_data, plot

example = generate_data(aggregated=True)
plot(example, orientation='vertical')
plt.suptitle('A vertical plot')
plt.show()

plot(example, orientation='vertical', show_counts='%d')
plt.suptitle('A vertical plot with counts shown')
plt.show()
示例#11
0
def test_generate_data_warning():
    with pytest.warns(DeprecationWarning):
        generate_data()