コード例 #1
0
def test_plot_individual_conditional_expectation():
    """
    Tests ICE plotting.

    Tests
    :func:`fatf.vis.feature_influence.plot_individual_conditional_expectation`
    function.
    """
    feature_name = 'some feature'
    class_index = 1
    class_name = 'middle'

    figure, axis = fvfi.plot_individual_conditional_expectation(
        FAKE_ICE_ARRAY, FAKE_LINESPACE, class_index, feature_name, class_name)

    assert isinstance(figure, plt.Figure)
    p_title, p_x_label, p_x_range, p_y_label, p_y_range = futv.get_plot_data(
        axis)
    # ...check title
    assert p_title == 'Individual Conditional Expectation'
    # ...check x range
    assert np.array_equal(p_x_range, [FAKE_LINESPACE[0], FAKE_LINESPACE[-1]])
    # ...check x label
    assert p_x_label == feature_name
    # ...check y range
    assert np.array_equal(p_y_range, [-0.05, 1.05])
    # ...check y label
    assert p_y_label == '{} class probability'.format(class_name)

    # Test the line
    assert len(axis.collections) == 1
    l_data, l_colour, l_alpha, l_label, l_width = futv.get_line_data(
        axis.collections[0], is_collection=True)
    assert len(l_data) == FAKE_ICE_ARRAY.shape[0]
    for i, line_array in enumerate(l_data):
        line_data = np.stack(
            [FAKE_LINESPACE, FAKE_ICE_ARRAY[i, :, class_index]], axis=1)
        assert np.array_equal(line_array, line_data)
    assert np.isclose(l_colour,
                      np.array([[0.412, 0.412, 0.412, 0.5]]),
                      atol=1e-2).all()  # dimgray mapping apparently
    assert l_alpha == 0.5
    assert l_label == 'ICE'
    assert l_width == 1.75

    # Validate plot legend
    legend = [
        i for i in axis.get_children()
        if isinstance(i, matplotlib.legend.Legend)
    ]
    assert len(legend) == 1
    legend_texts = legend[0].get_texts()
    assert len(legend_texts) == 1
    assert legend_texts[0].get_text() == 'ICE'
コード例 #2
0
def test_plot_partial_dependence():
    """
    Tests :func:`fatf.vis.feature_influence.plot_partial_dependence` function.
    """
    feature_name = 'some feature'
    class_index = 1
    class_name = 'middle'

    figure, axis = fvfi.plot_partial_dependence(FAKE_PD_ARRAY, FAKE_LINESPACE,
                                                class_index, feature_name,
                                                class_name)

    assert isinstance(figure, plt.Figure)
    p_title, p_x_label, p_x_range, p_y_label, p_y_range = futv.get_plot_data(
        axis)
    # ...check title
    assert p_title == 'Partial Dependence'
    # ...check x range
    assert np.array_equal(p_x_range, [FAKE_LINESPACE[0], FAKE_LINESPACE[-1]])
    # ...check x label
    assert p_x_label == feature_name
    # ...check y range
    assert np.array_equal(p_y_range, [-0.05, 1.05])
    # ...check y label
    assert p_y_label == '{} class probability'.format(class_name)

    # Test the line
    assert len(axis.lines) == 1
    l_data, l_colour, l_alpha, l_label, l_width = futv.get_line_data(
        axis.lines[0])
    line_data = np.stack([FAKE_LINESPACE, FAKE_PD_ARRAY[:, class_index]],
                         axis=1)
    assert np.array_equal(l_data, line_data)
    assert l_colour == 'lightsalmon'
    assert l_alpha == 0.6
    assert l_label == 'PD'
    assert l_width == 7

    # Validate plot legend
    legend = [
        i for i in axis.get_children()
        if isinstance(i, matplotlib.legend.Legend)
    ]
    assert len(legend) == 1
    legend_texts = legend[0].get_texts()
    assert len(legend_texts) == 1
    assert legend_texts[0].get_text() == 'PD'
コード例 #3
0
def test_get_line_data():
    """
    Tests importing :mod:`fatf.utils.testing.vis.get_line_data` function.
    """
    # Test collection
    true_data = [[[0, 1], [0, 1]], [[4, 3], [0, 5]]]
    true_label = 'my label'
    true_colour = 'green'
    true_alpha = 0.5
    true_width = 7

    plot_figure, plot_axis = plt.subplots(1, 1)
    line_collection = matplotlib.collections.LineCollection(
        true_data,
        label=true_label,
        color=true_colour,
        alpha=true_alpha,
        linewidth=true_width)
    plot_axis.add_collection(line_collection)

    assert len(plot_axis.collections) == 1
    data, colour, alpha, label, width = futv.get_line_data(
        plot_axis.collections[0], is_collection=True)

    assert np.array_equal(true_data, data)
    assert np.allclose([[0.0, 0.5, 0.0, 0.5]], colour, atol=1e-2)
    assert true_alpha == alpha
    assert true_label == label
    assert true_width == width

    # Test a line
    true_data_x = [0, 1, 2, 3, 4]
    true_data_y = [5, 10, 15, 10, 5]

    plot_figure, plot_axis = plt.subplots(1, 1)
    plot_axis.plot(true_data_x,
                   true_data_y,
                   color=true_colour,
                   linewidth=true_width,
                   alpha=true_alpha,
                   label=true_label)

    assert len(plot_axis.lines) == 1
    data, colour, alpha, label, width = futv.get_line_data(plot_axis.lines[0],
                                                           is_collection=False)

    assert data.shape == (5, 2)
    assert np.array_equal(true_data_x, data[:, 0])
    assert np.array_equal(true_data_y, data[:, 1])
    assert true_colour == colour
    assert true_alpha == alpha
    assert true_label == label
    assert true_width == width

    data, colour, alpha, label, width = futv.get_line_data(plot_axis.lines[0])

    assert data.shape == (5, 2)
    assert np.array_equal(true_data_x, data[:, 0])
    assert np.array_equal(true_data_y, data[:, 1])
    assert true_colour == colour
    assert true_alpha == alpha
    assert true_label == label
    assert true_width == width
コード例 #4
0
def test_ice_pd_overlay():
    """
    Tests overlaying PD plot on top of an ICE plot.
    """
    f_name = 'some feature'
    c_index = 1
    c_name = 'middle'

    figure, axis = fvfi.plot_individual_conditional_expectation(
        FAKE_ICE_ARRAY, FAKE_LINESPACE, c_index, f_name, c_name)
    assert isinstance(figure, plt.Figure)
    assert isinstance(axis, plt.Axes)

    none, axis = fvfi.plot_partial_dependence(FAKE_PD_ARRAY, FAKE_LINESPACE,
                                              c_index, f_name, c_name, axis)
    assert none is None
    assert isinstance(axis, plt.Axes)

    # Inspect the canvas
    p_title, p_x_label, p_x_range, p_y_label, p_y_range = futv.get_plot_data(
        axis)
    # ...check title
    assert p_title == ('Individual Conditional Expectation &\nPartial '
                       'Dependence')
    # ...check x range
    assert np.array_equal(p_x_range, [FAKE_LINESPACE[0], FAKE_LINESPACE[-1]])
    # ...check x label
    assert p_x_label == f_name
    # ...check y range
    assert np.array_equal(p_y_range, [-0.05, 1.05])
    # ...check y label
    assert p_y_label == '{} class probability'.format(c_name)

    # Check ICE
    assert len(axis.collections) == 1
    l_data, l_colour, l_alpha, l_label, l_width = futv.get_line_data(
        axis.collections[0], is_collection=True)
    assert len(l_data) == FAKE_ICE_ARRAY.shape[0]
    for i, line_array in enumerate(l_data):
        line_data = np.stack([FAKE_LINESPACE, FAKE_ICE_ARRAY[i, :, c_index]],
                             axis=1)
        assert np.array_equal(line_array, line_data)
    assert np.isclose(l_colour,
                      np.array([[0.412, 0.412, 0.412, 0.5]]),
                      atol=1e-2).all()  # dimgray mapping apparently
    assert l_alpha == 0.5
    assert l_label == 'ICE'
    assert l_width == 1.75

    # Check PD
    assert len(axis.lines) == 1
    l_data, l_colour, l_alpha, l_label, l_width = futv.get_line_data(
        axis.lines[0])
    line_data = np.stack([FAKE_LINESPACE, FAKE_PD_ARRAY[:, c_index]], axis=1)
    assert np.array_equal(l_data, line_data)
    assert l_colour == 'lightsalmon'
    assert l_alpha == 0.6
    assert l_label == 'PD'
    assert l_width == 7

    # Validate plot legend
    legend = [
        i for i in axis.get_children()
        if isinstance(i, matplotlib.legend.Legend)
    ]
    assert len(legend) == 1
    legend_texts = legend[0].get_texts()
    assert len(legend_texts) == 2
    assert legend_texts[0].get_text() == 'PD'
    assert legend_texts[1].get_text() == 'ICE'