예제 #1
0
def plot_metric_bubble_chart(
    disparity_df,
    metrics_list,
    attribute,
    fairness_threshold=1.25,
    chart_height=None,
    chart_width=Metric_Chart.full_width,
    accessibility_mode=False,
):
    """Draws bubble chart to visualize the values of the selected metrics for a given attribute.

    :param disparity_df: a dataframe generated by the Aequitas Bias class
    :type disparity_df: pandas.core.frame.DataFrame
    :param metrics_list: a list of the metrics of interest
    :type metrics_list: list
    :param attribute: an attribute to plot
    :type attribute: str
    :param fairness_threshold: a value for the maximum allowed disparity, defaults to 1.25
    :type fairness_threshold: float, optional
    :param chart_height: a value (in pixels) for the height of the chart
    :type chart_height: int, optional
    :param chart_width: a value (in pixels) for the width of the chart
    :type chart_width: int, optional
    :param accessibility_mode: a switch for the display of more accessible visual elements, defaults to False
    :type accessibility_mode: bool, optional

    :return: the full metrics chart
    :rtype: Altair chart object
    """
    (
        plot_table,
        metrics,
        ref_group,
        global_scales,
        chart_height,
        chart_width,
        selection,
    ) = Initializer.prepare_bubble_chart(
        disparity_df,
        metrics_list,
        attribute,
        fairness_threshold,
        chart_height,
        chart_width,
        Metric_Chart,
        accessibility_mode,
    )
    # GET MAIN CHART COMPONENTS
    main_chart = get_metric_bubble_chart_components(
        plot_table,
        metrics,
        ref_group,
        global_scales,
        selection,
        fairness_threshold,
        chart_height,
        chart_width,
        accessibility_mode,
    )

    # ADD LEGEND
    legend = draw_legend(global_scales, selection, chart_width)
    full_chart = main_chart + legend

    # FINALIZE CHART
    metric_chart = (
        full_chart.configure_view(strokeWidth=0)
        .configure_axisLeft(
            labelFontSize=Metric_Axis.label_font_size,
            labelColor=Metric_Axis.label_color,
            labelFont=FONT,
        )
        .properties(
            height=chart_height,
            width=chart_width,
            title=f"Absolute values by {attribute.title()}",
            padding=Metric_Chart.full_chart_padding,
        )
        .configure_title(
            align="center",
            baseline="middle",
            font=FONT,
            fontWeight=Chart_Title.font_weight,
            fontSize=Chart_Title.font_size,
            color=Chart_Title.font_color,
        )
        .resolve_scale(y="independent", size="independent")
    )

    return metric_chart
예제 #2
0
def plot_xy_metrics_chart(
    disparity_df,
    x_metric,
    y_metric,
    attribute,
    fairness_threshold=1.25,
    chart_height=None,
    chart_width=None,
    accessibility_mode=False,
):
    """Draws XY scatterplot, with the group_size encoded in the bubble size, based on the two metrics provided.
    Additionally draws threshold rules and bands if the fairness_threshold value is not None.

    :param disparity_df: a dataframe generated by the Aequitas Bias class
    :type disparity_df: pandas.core.frame.DataFrame
    :param x_metric: a metric to plot in the x axis
    :type x_metric: str
    :param y_metric: a metric to plot in the y axis
    :type y_metric: str
    :param attribute: an attribute to plot
    :type attribute: str
    :param fairness_threshold: a value for the maximum allowed disparity, defaults to 1.25
    :type fairness_threshold: float, optional
    :param chart_height: a value (in pixels) for the height of the chart
    :type chart_height: int, optional
    :param chart_width: a value (in pixels) for the width of the chart
    :type chart_width: int, optional
    :param accessibility_mode: a switch for the display of more accessible visual elements, defaults to False
    :type accessibility_mode: bool, optional

    :return: the full scatterplot chart
    :rtype: Altair chart object
    """

    (
        plot_table,
        x_metric,
        y_metric,
        ref_group,
        global_scales,
        chart_height,
        chart_width,
        interactive_selection_group,
    ) = Initializer.prepare_xy_chart(
        disparity_df,
        x_metric,
        y_metric,
        attribute,
        fairness_threshold,
        chart_height,
        chart_width,
        XY_Chart,
        accessibility_mode,
    )

    position_scales = __get_position_scales(chart_height, chart_width)

    scales = dict(global_scales, **position_scales)

    # AXIS RULES
    axis_rules = __draw_axis_rules(x_metric, y_metric, scales)

    # TICK LABELS
    tick_labels = __draw_tick_labels(scales, chart_height, chart_width)

    # INITIATE CHART
    chart = axis_rules + tick_labels

    # THRESHOLD AND BANDS
    if fairness_threshold is not None:
        # REF VALUES
        ref_group_index = plot_table.loc[plot_table["attribute_value"] ==
                                         ref_group].index

        x_ref_group_value = plot_table.loc[ref_group_index, x_metric].iloc[0]
        y_ref_group_value = plot_table.loc[ref_group_index, y_metric].iloc[0]

        # Y AXIS
        if y_metric not in Validator.NON_DISPARITY_METRICS_LIST:
            y_thresholds = __draw_threshold_bands(
                ref_group_value=y_ref_group_value,
                fairness_threshold=fairness_threshold,
                main_scale=scales["y"],
                aux_scale=scales["x"],
                accessibility_mode=accessibility_mode,
            )

            chart += y_thresholds

        # X AXIS
        if x_metric not in Validator.NON_DISPARITY_METRICS_LIST:
            x_thresholds = __draw_threshold_bands(
                ref_group_value=x_ref_group_value,
                fairness_threshold=fairness_threshold,
                main_scale=scales["x"],
                aux_scale=scales["y"],
                accessibility_mode=accessibility_mode,
                drawing_x=True,
            )

            chart += x_thresholds

    # LEGEND
    legend = draw_legend(scales, interactive_selection_group, chart_width)

    # BUBBLES
    bubbles = __draw_bubbles(
        plot_table,
        x_metric,
        y_metric,
        ref_group,
        scales,
        interactive_selection_group,
    )

    # FINISH CHART COMPOSITION
    chart += legend + bubbles

    # CONFIGURATION
    styled_chart = (chart.configure_view(strokeWidth=0).properties(
        height=chart_height,
        width=chart_width,
        title=
        f"{y_metric.upper()} by {x_metric.upper()} on {attribute.title()}",
        padding=XY_Chart.full_chart_padding,
    ).configure_title(
        align="center",
        baseline="middle",
        font=FONT,
        fontWeight=Chart_Title.font_weight,
        fontSize=Chart_Title.font_size,
        color=Chart_Title.font_color,
    ).configure_axis(
        titleFont=FONT,
        titleColor=Scatter_Axis.title_color,
        titleFontSize=Scatter_Axis.title_font_size,
        titleFontWeight=Scatter_Axis.title_font_weight,
        titleAngle=0,
    ).configure_axisLeft(
        titlePadding=Scatter_Axis.title_padding, ).resolve_scale(
            y="independent", x="independent",
            size="independent").resolve_axis(x="shared", y="shared"))

    return styled_chart
def plot_concatenated_bubble_charts(
    disparity_df,
    metrics_list,
    attribute,
    fairness_threshold=1.25,
    chart_height=None,
    chart_width=Sizes.Concat_Chart.full_width,
    accessibility_mode=False,
):
    """Draws a concatenation of the disparity bubble chart and the metric values bubble chart, 
    of the selected metrics for a given attribute.

    :param disparity_df: a dataframe generated by the Aequitas Bias class
    :type disparity_df: pandas.core.frame.DataFrame
    :param metrics_list: a list of the metrics of interest
    :type metrics_list: list
    :param attribute: an attribute to plot
    :type attribute: str
    :param fairness_threshold: a value for the maximum allowed disparity, defaults to 1.25
    :type fairness_threshold: float, optional
    :param chart_height: a value (in pixels) for the height of the chart
    :type chart_height: int, optional
    :param chart_width: a value (in pixels) for the width of the chart
    :type chart_width: int, optional
    :param accessibility_mode: a switch for the display of more accessible visual elements, defaults to False
    :type accessibility_mode: bool, optional

    :return: the full disparities chart
    :rtype: Altair chart object
    """

    (
        plot_table,
        metrics,
        ref_group,
        global_scales,
        chart_height,
        chart_width,
        selection,
    ) = Initializer.prepare_bubble_chart(
        disparity_df,
        metrics_list,
        attribute,
        fairness_threshold,
        chart_height,
        chart_width,
        Sizes.Disparity_Chart,
        accessibility_mode,
    )

    chart_sizes = __get_chart_sizes(chart_width)

    # TITLES
    disparity_title = draw_chart_title("DISPARITIES",
                                       chart_sizes["disparity_chart_width"])
    metric_title = draw_chart_title("METRICS",
                                    chart_sizes["metric_chart_width"])

    # DISPARITY CHART
    disparity_chart = ((get_disparity_bubble_chart_components(
        plot_table,
        metrics,
        ref_group,
        global_scales,
        selection,
        fairness_threshold,
        chart_height,
        chart_sizes["disparity_chart_width"],
        accessibility_mode,
        concat_chart=True,
    ) + disparity_title).resolve_scale(
        y="independent", size="independent").properties(
            height=chart_height, width=chart_sizes["disparity_chart_width"]))

    # METRIC CHART
    metric_chart = (
        (get_metric_bubble_chart_components(
            plot_table,
            metrics,
            ref_group,
            global_scales,
            selection,
            fairness_threshold,
            chart_height,
            chart_sizes["metric_chart_width"],
            accessibility_mode,
            concat_chart=True,
        ) + metric_title +
         draw_legend(global_scales, selection,
                     chart_sizes["metric_chart_width"])).resolve_scale(
                         y="independent").properties(
                             height=chart_height,
                             width=chart_sizes["metric_chart_width"]))

    full_chart = (alt.hconcat(
        disparity_chart, metric_chart, bounds="flush",
        spacing=20).configure_view(strokeWidth=0).configure_axisLeft(
            labelFontSize=Metric_Axis.label_font_size,
            labelColor=Metric_Axis.label_color,
            labelFont=FONT,
        ))

    return full_chart