def test_plot_individual_conditional_expectation(): """ Tests ICE plotting. Tests :func:`fatf.vis.feature_influence.plot_individual_conditional_expectation` function. """ feature_name = 'some feature' class_index = 1 class_name = 'middle' figure, axis = fvfi.plot_individual_conditional_expectation( FAKE_ICE_ARRAY, FAKE_LINESPACE, class_index, feature_name, class_name) assert isinstance(figure, plt.Figure) p_title, p_x_label, p_x_range, p_y_label, p_y_range = futv.get_plot_data( axis) # ...check title assert p_title == 'Individual Conditional Expectation' # ...check x range assert np.array_equal(p_x_range, [FAKE_LINESPACE[0], FAKE_LINESPACE[-1]]) # ...check x label assert p_x_label == feature_name # ...check y range assert np.array_equal(p_y_range, [-0.05, 1.05]) # ...check y label assert p_y_label == '{} class probability'.format(class_name) # Test the line assert len(axis.collections) == 1 l_data, l_colour, l_alpha, l_label, l_width = futv.get_line_data( axis.collections[0], is_collection=True) assert len(l_data) == FAKE_ICE_ARRAY.shape[0] for i, line_array in enumerate(l_data): line_data = np.stack( [FAKE_LINESPACE, FAKE_ICE_ARRAY[i, :, class_index]], axis=1) assert np.array_equal(line_array, line_data) assert np.isclose(l_colour, np.array([[0.412, 0.412, 0.412, 0.5]]), atol=1e-2).all() # dimgray mapping apparently assert l_alpha == 0.5 assert l_label == 'ICE' assert l_width == 1.75 # Validate plot legend legend = [ i for i in axis.get_children() if isinstance(i, matplotlib.legend.Legend) ] assert len(legend) == 1 legend_texts = legend[0].get_texts() assert len(legend_texts) == 1 assert legend_texts[0].get_text() == 'ICE'
def test_plot_partial_dependence(): """ Tests :func:`fatf.vis.feature_influence.plot_partial_dependence` function. """ feature_name = 'some feature' class_index = 1 class_name = 'middle' figure, axis = fvfi.plot_partial_dependence(FAKE_PD_ARRAY, FAKE_LINESPACE, class_index, feature_name, class_name) assert isinstance(figure, plt.Figure) p_title, p_x_label, p_x_range, p_y_label, p_y_range = futv.get_plot_data( axis) # ...check title assert p_title == 'Partial Dependence' # ...check x range assert np.array_equal(p_x_range, [FAKE_LINESPACE[0], FAKE_LINESPACE[-1]]) # ...check x label assert p_x_label == feature_name # ...check y range assert np.array_equal(p_y_range, [-0.05, 1.05]) # ...check y label assert p_y_label == '{} class probability'.format(class_name) # Test the line assert len(axis.lines) == 1 l_data, l_colour, l_alpha, l_label, l_width = futv.get_line_data( axis.lines[0]) line_data = np.stack([FAKE_LINESPACE, FAKE_PD_ARRAY[:, class_index]], axis=1) assert np.array_equal(l_data, line_data) assert l_colour == 'lightsalmon' assert l_alpha == 0.6 assert l_label == 'PD' assert l_width == 7 # Validate plot legend legend = [ i for i in axis.get_children() if isinstance(i, matplotlib.legend.Legend) ] assert len(legend) == 1 legend_texts = legend[0].get_texts() assert len(legend_texts) == 1 assert legend_texts[0].get_text() == 'PD'
def test_get_line_data(): """ Tests importing :mod:`fatf.utils.testing.vis.get_line_data` function. """ # Test collection true_data = [[[0, 1], [0, 1]], [[4, 3], [0, 5]]] true_label = 'my label' true_colour = 'green' true_alpha = 0.5 true_width = 7 plot_figure, plot_axis = plt.subplots(1, 1) line_collection = matplotlib.collections.LineCollection( true_data, label=true_label, color=true_colour, alpha=true_alpha, linewidth=true_width) plot_axis.add_collection(line_collection) assert len(plot_axis.collections) == 1 data, colour, alpha, label, width = futv.get_line_data( plot_axis.collections[0], is_collection=True) assert np.array_equal(true_data, data) assert np.allclose([[0.0, 0.5, 0.0, 0.5]], colour, atol=1e-2) assert true_alpha == alpha assert true_label == label assert true_width == width # Test a line true_data_x = [0, 1, 2, 3, 4] true_data_y = [5, 10, 15, 10, 5] plot_figure, plot_axis = plt.subplots(1, 1) plot_axis.plot(true_data_x, true_data_y, color=true_colour, linewidth=true_width, alpha=true_alpha, label=true_label) assert len(plot_axis.lines) == 1 data, colour, alpha, label, width = futv.get_line_data(plot_axis.lines[0], is_collection=False) assert data.shape == (5, 2) assert np.array_equal(true_data_x, data[:, 0]) assert np.array_equal(true_data_y, data[:, 1]) assert true_colour == colour assert true_alpha == alpha assert true_label == label assert true_width == width data, colour, alpha, label, width = futv.get_line_data(plot_axis.lines[0]) assert data.shape == (5, 2) assert np.array_equal(true_data_x, data[:, 0]) assert np.array_equal(true_data_y, data[:, 1]) assert true_colour == colour assert true_alpha == alpha assert true_label == label assert true_width == width
def test_ice_pd_overlay(): """ Tests overlaying PD plot on top of an ICE plot. """ f_name = 'some feature' c_index = 1 c_name = 'middle' figure, axis = fvfi.plot_individual_conditional_expectation( FAKE_ICE_ARRAY, FAKE_LINESPACE, c_index, f_name, c_name) assert isinstance(figure, plt.Figure) assert isinstance(axis, plt.Axes) none, axis = fvfi.plot_partial_dependence(FAKE_PD_ARRAY, FAKE_LINESPACE, c_index, f_name, c_name, axis) assert none is None assert isinstance(axis, plt.Axes) # Inspect the canvas p_title, p_x_label, p_x_range, p_y_label, p_y_range = futv.get_plot_data( axis) # ...check title assert p_title == ('Individual Conditional Expectation &\nPartial ' 'Dependence') # ...check x range assert np.array_equal(p_x_range, [FAKE_LINESPACE[0], FAKE_LINESPACE[-1]]) # ...check x label assert p_x_label == f_name # ...check y range assert np.array_equal(p_y_range, [-0.05, 1.05]) # ...check y label assert p_y_label == '{} class probability'.format(c_name) # Check ICE assert len(axis.collections) == 1 l_data, l_colour, l_alpha, l_label, l_width = futv.get_line_data( axis.collections[0], is_collection=True) assert len(l_data) == FAKE_ICE_ARRAY.shape[0] for i, line_array in enumerate(l_data): line_data = np.stack([FAKE_LINESPACE, FAKE_ICE_ARRAY[i, :, c_index]], axis=1) assert np.array_equal(line_array, line_data) assert np.isclose(l_colour, np.array([[0.412, 0.412, 0.412, 0.5]]), atol=1e-2).all() # dimgray mapping apparently assert l_alpha == 0.5 assert l_label == 'ICE' assert l_width == 1.75 # Check PD assert len(axis.lines) == 1 l_data, l_colour, l_alpha, l_label, l_width = futv.get_line_data( axis.lines[0]) line_data = np.stack([FAKE_LINESPACE, FAKE_PD_ARRAY[:, c_index]], axis=1) assert np.array_equal(l_data, line_data) assert l_colour == 'lightsalmon' assert l_alpha == 0.6 assert l_label == 'PD' assert l_width == 7 # Validate plot legend legend = [ i for i in axis.get_children() if isinstance(i, matplotlib.legend.Legend) ] assert len(legend) == 1 legend_texts = legend[0].get_texts() assert len(legend_texts) == 2 assert legend_texts[0].get_text() == 'PD' assert legend_texts[1].get_text() == 'ICE'