Exemplo n.º 1
0
def __draw_population_bar(population_bar_df, metric, color_scale):
    """ Draws a stacked bar of the sum of the percentage of population of the groups that obtained each result for the parity test."""
    population_bar_tooltips = [
        alt.Tooltip(field=f"{metric}_parity_result",
                    type="nominal",
                    title="Parity"),
        alt.Tooltip(
            field="tooltip_group_size",
            type="nominal",
            title="Size",
        ),
        alt.Tooltip(field="tooltip_groups_name_size",
                    type="nominal",
                    title="Groups"),
    ]

    population_bar = (alt.Chart(population_bar_df).transform_calculate(
        y_position="2.8").mark_bar(size=6, stroke="white").encode(
            alt.X("sum(group_size):Q", stack="normalize", axis=no_axis()),
            alt.Y("y_position:Q",
                  scale=alt.Scale(domain=[3, 1]),
                  axis=no_axis()),
            alt.Color(
                f"{metric}_parity_result:O",
                scale=color_scale,
                legend=alt.Legend(
                    title="Parity Test",
                    padding=20,
                ),
            ),
            tooltip=population_bar_tooltips,
        ))

    return population_bar
Exemplo n.º 2
0
def __draw_threshold_bands(threshold_df, scales, accessibility_mode=False):
    """Draws fairness threshold bands: regions painted red where the metric value is above the defined fairness threshold."""

    fill_color = (
        Threshold_Band.color_accessible if accessibility_mode else Threshold_Band.color
    )

    lower_threshold_band = (
        alt.Chart(threshold_df)
        .mark_rect(fill=fill_color, opacity=Threshold_Band.opacity, tooltip="")
        .encode(
            y=alt.Y(field="metric", type="nominal", scale=scales["y"], axis=no_axis()),
            x=alt.X("lower_threshold_value:Q", scale=scales["x"]),
            x2="lower_end:Q",
        )
    )

    upper_threshold_band = (
        alt.Chart(threshold_df)
        .mark_rect(fill=fill_color, opacity=Threshold_Band.opacity, tooltip="")
        .encode(
            y=alt.Y(field="metric", type="nominal", scale=scales["y"], axis=no_axis()),
            x=alt.X("upper_threshold_value:Q", scale=scales["x"]),
            x2="upper_end:Q"
            # scales["x"]["range"][1]
            # Replicating the approach of the lower_threshold_band doesn't work...
        )
    )

    return lower_threshold_band + upper_threshold_band
Exemplo n.º 3
0
def __draw_axis_rules(x_metric, y_metric, scales):
    """Draws horizontal and vertical rules for the axis."""

    # BASE CHART
    base = alt.Chart(pd.DataFrame({"start": 0, "end": 1}, index=[0]))

    # AXIS ENCODING
    axis_values = [0.0, 0.25, 0.5, 0.75, 1]
    bottom_axis = alt.Axis(
        values=axis_values,
        orient="bottom",
        domain=False,
        labels=False,
        ticks=False,
        title=x_metric.upper(),
    )
    left_axis = alt.Axis(
        values=axis_values,
        orient="left",
        domain=False,
        labels=False,
        ticks=False,
        title=y_metric.upper(),
    )

    # X AXIS
    x_rule = base.mark_rule(strokeWidth=Rule.stroke_width,
                            stroke=Rule.stroke,
                            tooltip="").encode(
                                x=alt.X("start:Q",
                                        scale=scales["x"],
                                        axis=bottom_axis),
                                x2="end:Q",
                                y=alt.Y("start:Q",
                                        scale=scales["y"],
                                        axis=no_axis()),
                            )

    # Y AXIS
    y_rule = base.mark_rule(strokeWidth=Rule.stroke_width,
                            stroke=Rule.stroke,
                            tooltip="").encode(
                                x=alt.X("start:Q",
                                        scale=scales["x"],
                                        axis=no_axis()),
                                y=alt.Y("start:Q",
                                        scale=scales["y"],
                                        axis=left_axis),
                                y2="end:Q",
                            )

    return x_rule + y_rule
Exemplo n.º 4
0
def __draw_metrics_rules(metrics, scales, concat_chart):
    """Draws an horizontal rule for each metric where the bubbles will be positioned."""

    metrics_labels = [metric.upper() for metric in metrics]

    if concat_chart:
        y_axis = no_axis()
    else:
        y_axis = alt.Axis(
            domain=False,
            ticks=False,
            orient="left",
            labelAngle=Metric_Axis.label_angle,
            labelPadding=Metric_Axis.label_padding,
            title="",
        )

    horizontal_rules = (
        alt.Chart(pd.DataFrame({"y_position": metrics_labels, "start": 0, "end": 1}))
        .mark_rule(
            strokeWidth=Metric_Axis.stroke_width,
            stroke=Metric_Axis.stroke,
            tooltip="",
        )
        .encode(
            y=alt.Y("y_position:N", scale=scales["y"], axis=y_axis),
            x=alt.X("start:Q", scale=scales["x"]),
            x2="end:Q",
        )
    )

    return horizontal_rules
Exemplo n.º 5
0
def __draw_threshold_rules(threshold_df, scales, position, accessibility_mode=False):
    """Draws fairness threshold rules: red lines that mark the defined fairness threshold in the chart."""
    stroke_color = (
        Threshold_Rule.stroke_accessible
        if accessibility_mode
        else Threshold_Rule.stroke
    )

    threshold_rules = (
        alt.Chart(threshold_df)
        .mark_rect(
            stroke=stroke_color,
            opacity=Threshold_Rule.opacity,
            strokeWidth=Threshold_Rule.stroke_width,
            tooltip="",
        )
        .encode(
            x=alt.X(
                field=f"{position}_threshold_value",
                type="quantitative",
                scale=scales["x"],
            ),
            x2=f"{position}_threshold_value:Q",
            y=alt.Y(field="metric", type="nominal", scale=scales["y"], axis=no_axis()),
        )
    )
    return threshold_rules
Exemplo n.º 6
0
def __draw_group_circles(plot_df, metric, scales, size_constants):
    """ Draws a circle for each group, color-coded by the result of the parity test.
    The groups are spread around the central reference group according to their disparity."""

    circle_tooltip_encoding = [
        alt.Tooltip(field="attribute_value", type="nominal", title="Group"),
        alt.Tooltip(field="tooltip_group_size",
                    type="nominal",
                    title="Group Size"),
        alt.Tooltip(
            field=f"tooltip_parity_test_explanation_{metric}",
            type="nominal",
            title="Parity Test",
        ),
        alt.Tooltip(
            field=f"tooltip_disparity_explanation_{metric}",
            type="nominal",
            title="Disparity",
        ),
        alt.Tooltip(
            field=f"{metric}",
            type="quantitative",
            format=".2f",
            title=f"{metric}".upper(),
        ),
    ]

    return (alt.Chart(plot_df).transform_calculate(y_position="2").mark_circle(
        opacity=1).encode(
            alt.X(f"{metric}_disparity_rank:Q",
                  scale=scales["circles_x"],
                  axis=no_axis()),
            alt.Y("y_position:Q",
                  scale=alt.Scale(domain=[3, 1]),
                  axis=no_axis()),
            alt.Color(
                f"{metric}_parity_result:O",
                scale=scales["color"],
                legend=alt.Legend(title=""),
            ),
            size=alt.value(size_constants["group_circle_size"]),
            tooltip=circle_tooltip_encoding,
        ))
Exemplo n.º 7
0
def __draw_tick_labels(scales, chart_height, chart_width):
    """Draws the numbers in both axes."""

    axis_values = [0, 0.25, 0.5, 0.75, 1]

    axis_df = pd.DataFrame({
        "main_axis_values": axis_values,
        "aux_axis_position": 0
    })

    x_tick_labels = (alt.Chart(axis_df).mark_text(
        yOffset=Scatter_Axis.label_font_size * 1.5,
        tooltip="",
        align="center",
        fontSize=Scatter_Axis.label_font_size,
        color=Scatter_Axis.label_color,
        fontWeight=Scatter_Axis.label_font_weight,
        font=FONT,
    ).encode(
        text=alt.Text("main_axis_values:Q"),
        x=alt.X("main_axis_values:Q", scale=scales["x"], axis=no_axis()),
        y=alt.Y("aux_axis_position:Q", scale=scales["y"], axis=no_axis()),
    ))

    axis_df.drop(0, inplace=True)

    y_tick_labels = (alt.Chart(axis_df).mark_text(
        baseline="middle",
        xOffset=-Scatter_Axis.label_font_size * 1.5,
        tooltip="",
        align="center",
        fontSize=Scatter_Axis.label_font_size,
        fontWeight=Scatter_Axis.label_font_weight,
        color=Scatter_Axis.label_color,
        font=FONT,
    ).encode(
        text=alt.Text("main_axis_values:Q"),
        x=alt.X("aux_axis_position:Q", scale=scales["x"], axis=no_axis()),
        y=alt.Y("main_axis_values:Q", scale=scales["y"], axis=no_axis()),
    ))

    return x_tick_labels + y_tick_labels
Exemplo n.º 8
0
def __draw_parity_result_text(parity_result, color_scale):
    """ Draws the uppercased text result of the provided parity test (Pass, Fail or Reference), 
    color-coded according to the provided Altair scale."""

    return (alt.Chart(
        pd.DataFrame({"parity_result": parity_result},
                     index=[0])).transform_calculate(y_position="1").mark_text(
                         align="center",
                         baseline="middle",
                         font=FONT,
                         size=Parity_Result.font_size,
                         fontWeight=Parity_Result.font_weight,
                     ).encode(
                         alt.Y("y_position:Q",
                               scale=alt.Scale(domain=[3, 1]),
                               axis=no_axis()),
                         alt.Color("parity_result:O",
                                   scale=color_scale,
                                   legend=alt.Legend(title="")),
                         text=alt.value(parity_result.upper()),
                     ))
Exemplo n.º 9
0
def __draw_threshold_bands(
    ref_group_value,
    fairness_threshold,
    main_scale,
    aux_scale,
    accessibility_mode=False,
    drawing_x=False,
):
    """Draws threshold rules and bands for both axis: regions painted red on the chart
    where the metric is above the defined fairness threshold."""

    # DATASOURCE
    threshold_df = pd.DataFrame(
        {
            "start": 0,
            "end": 1,
            "lower_threshold": ref_group_value / fairness_threshold,
            "upper_threshold": min(ref_group_value * fairness_threshold, 1),
        },
        index=[0],
    )

    stroke_color = (Threshold_Rule.stroke_accessible
                    if accessibility_mode else Threshold_Rule.stroke)
    fill_color = (Threshold_Band.color_accessible
                  if accessibility_mode else Threshold_Band.color)

    # BASES
    rule_base = alt.Chart(threshold_df).mark_rule(
        stroke=stroke_color,
        strokeWidth=Threshold_Rule.stroke_width,
        opacity=Threshold_Rule.opacity,
        tooltip="",
    )

    band_base = alt.Chart(threshold_df).mark_rect(
        fill=fill_color, opacity=Threshold_Band.opacity, tooltip="")

    # PARAMS
    if drawing_x:
        common_params = dict(y=alt.Y("start:Q",
                                     scale=aux_scale,
                                     axis=no_axis()),
                             y2="end:Q")
        upper_param = dict(
            x=alt.X("upper_threshold:Q", scale=main_scale, axis=no_axis()))
        lower_param = dict(
            x=alt.X("lower_threshold:Q", scale=main_scale, axis=no_axis()))
        lower_end_param = dict(x2="start:Q")
        upper_end_param = dict(x2="end:Q")
    else:
        common_params = dict(x=alt.X("start:Q",
                                     scale=aux_scale,
                                     axis=no_axis()),
                             x2="end:Q")
        upper_param = dict(
            y=alt.Y("upper_threshold:Q", scale=main_scale, axis=no_axis()))
        lower_param = dict(
            y=alt.Y("lower_threshold:Q", scale=main_scale, axis=no_axis()))
        lower_end_param = dict(y2="start:Q")
        upper_end_param = dict(y2="end:Q")

    # RULES
    lower_threshold_rule = rule_base.encode(
        **common_params,
        **lower_param,
    )
    upper_threshold_rule = rule_base.encode(
        **common_params,
        **upper_param,
    )

    # BANDS
    lower_threshold_band = band_base.encode(**common_params, **lower_param,
                                            **lower_end_param)
    upper_threshold_band = band_base.encode(**common_params, **upper_param,
                                            **upper_end_param)

    return (lower_threshold_rule + upper_threshold_rule +
            lower_threshold_band + upper_threshold_band)
Exemplo n.º 10
0
def __draw_bubbles(
    plot_table,
    x_metric,
    y_metric,
    ref_group,
    scales,
    interactive_selection_group,
):
    """Draws the bubbles for all metrics."""

    # FILTER DF
    fields_to_keep_in_metric_table = [
        "group_size",
        "attribute_value",
        "total_entities",
        x_metric,
        y_metric,
    ]
    metric_plot_table = plot_table[fields_to_keep_in_metric_table].copy(
        deep=True)

    metric_plot_table["tooltip_group_size"] = plot_table.apply(
        lambda row: get_tooltip_text_group_size(row["group_size"], row[
            "total_entities"]),
        axis=1,
    )

    # COLOR ENCODING
    bubble_color_encoding = alt.condition(
        interactive_selection_group,
        alt.Color("attribute_value:N", scale=scales["color"], legend=None),
        alt.value(Bubble.color_faded),
    )

    # TOOLTIP ENCODING
    bubble_tooltip_encoding = [
        alt.Tooltip(field="attribute_value", type="nominal", title="Group"),
        alt.Tooltip(field="tooltip_group_size",
                    type="nominal",
                    title="Group Size"),
        alt.Tooltip(field=x_metric,
                    type="quantitative",
                    format=".2f",
                    title=x_metric.upper()),
        alt.Tooltip(field=y_metric,
                    type="quantitative",
                    format=".2f",
                    title=y_metric.upper()),
    ]

    # BUBBLE CENTERS
    bubbles_centers = (alt.Chart(metric_plot_table).mark_point(
        filled=True, size=Bubble.center_size).encode(
            x=alt.X(f"{x_metric}:Q", scale=scales["x"], axis=no_axis()),
            y=alt.Y(f"{y_metric}:Q", scale=scales["y"], axis=no_axis()),
            tooltip=bubble_tooltip_encoding,
            color=bubble_color_encoding,
            shape=alt.Shape("attribute_value:N",
                            scale=scales["shape"],
                            legend=None),
        ))

    # BUBBLE AREAS
    bubbles_areas = (alt.Chart(metric_plot_table).mark_circle(
        opacity=Bubble.opacity).encode(
            size=alt.Size("group_size:Q",
                          legend=None,
                          scale=scales["bubble_size"]),
            x=alt.X(f"{x_metric}:Q", scale=scales["x"], axis=no_axis()),
            y=alt.Y(f"{y_metric}:Q", scale=scales["y"], axis=no_axis()),
            tooltip=bubble_tooltip_encoding,
            color=bubble_color_encoding,
        ))

    return bubbles_centers + bubbles_areas
Exemplo n.º 11
0
def __draw_bubbles(
    plot_table,
    metrics,
    ref_group,
    scales,
    selection,
):
    """Draws the bubbles for all metrics."""

    # X AXIS GRIDLINES
    axis_values = [0.25, 0.5, 0.75]
    x_axis = alt.Axis(
        values=axis_values, ticks=False, domain=False, labels=False, title=None
    )

    # COLOR
    bubble_color_encoding = alt.condition(
        selection,
        alt.Color("attribute_value:N", scale=scales["color"], legend=None),
        alt.value(Bubble.color_faded),
    )

    # CHART INITIALIZATION
    bubble_centers = alt.Chart().mark_point()
    bubble_areas = alt.Chart().mark_circle()

    plot_table["tooltip_group_size"] = plot_table.apply(
        lambda row: get_tooltip_text_group_size(
            row["group_size"], row["total_entities"]
        ),
        axis=1,
    )
    # LAYERING THE METRICS
    for metric in metrics:
        # TOOLTIP
        plot_table[f"tooltip_disparity_explanation_{metric}"] = plot_table.apply(
            lambda row: get_tooltip_text_disparity_explanation(
                row[f"{metric}_disparity_scaled"],
                row["attribute_value"],
                metric,
                ref_group,
            ),
            axis=1,
        )

        bubble_tooltip_encoding = [
            alt.Tooltip(field="attribute_value", type="nominal", title="Group"),
            alt.Tooltip(field="tooltip_group_size", type="nominal", title="Group Size"),
            alt.Tooltip(
                field=f"tooltip_disparity_explanation_{metric}",
                type="nominal",
                title="Disparity",
            ),
            alt.Tooltip(
                field=f"{metric}",
                type="quantitative",
                format=".2f",
                title=f"{metric}".upper(),
            ),
        ]

        # BUBBLE CENTERS
        trigger_centers = alt.selection_multi(empty="all", fields=["attribute_value"])

        bubble_centers += (
            alt.Chart(plot_table)
            .transform_calculate(metric_variable=f"'{metric.upper()}'")
            .mark_point(filled=True, size=Bubble.center_size)
            .encode(
                x=alt.X(f"{metric}:Q", scale=scales["x"], axis=x_axis),
                y=alt.Y("metric_variable:N", scale=scales["y"], axis=no_axis()),
                tooltip=bubble_tooltip_encoding,
                color=bubble_color_encoding,
                shape=alt.Shape(
                    "attribute_value:N", scale=scales["shape"], legend=None
                ),
            )
            .add_selection(trigger_centers)
        )

        # BUBBLE AREAS
        trigger_areas = alt.selection_multi(empty="all", fields=["attribute_value"])

        bubble_areas += (
            alt.Chart(plot_table)
            .mark_circle(opacity=Bubble.opacity)
            .transform_calculate(metric_variable=f"'{metric.upper()}'")
            .encode(
                x=alt.X(f"{metric}:Q", scale=scales["x"], axis=x_axis),
                y=alt.Y("metric_variable:N", scale=scales["y"], axis=no_axis()),
                tooltip=bubble_tooltip_encoding,
                color=bubble_color_encoding,
                size=alt.Size("group_size:Q", legend=None, scale=scales["bubble_size"]),
            )
            .add_selection(trigger_areas)
        )

    return bubble_areas + bubble_centers
Exemplo n.º 12
0
def draw_legend(global_scales, selection, chart_width):
    """Draws the interactive group's colors legend for the chart."""

    groups = global_scales["color"].domain
    labels = groups.copy()
    labels[0] = labels[0] + " [REF]"
    legend_df = pd.DataFrame({"attribute_value": groups, "label": labels})

    # Position the legend to the right of the chart

    title_text_x_position = chart_width
    title_text_height = Legend.title_font_size + Legend.title_margin_bottom
    subtitle_text_height = Legend.font_size + Legend.vertical_spacing

    entries_circles_x_position = title_text_x_position + Legend.horizontal_spacing
    entries_text_x_position = (title_text_x_position +
                               2 * Legend.circle_radius +
                               Legend.horizontal_spacing)

    # Title of the legend.
    title_text = (alt.Chart(DUMMY_DF).mark_text(
        align="left",
        baseline="middle",
        color=Legend.font_color,
        fontSize=Legend.title_font_size,
        font=FONT,
        fontWeight=Legend.title_font_weight,
    ).encode(
        x=alt.value(title_text_x_position),
        y=alt.value(Legend.margin_top),
        text=alt.value("Groups"),
    ))

    # Subtitle text that explains how to interact with the legend.
    subtitle_text = (alt.Chart(DUMMY_DF).mark_text(
        align="left",
        baseline="middle",
        color=Legend.font_color,
        fontSize=Legend.font_size,
        font=FONT,
        fontWeight=Legend.font_weight,
    ).encode(
        x=alt.value(title_text_x_position),
        y=alt.value(Legend.margin_top + title_text_height),
        text=alt.value("Click to highlight a group."),
    ))

    # Conditionally color each legend item
    # If the group is selected, it is colored according to the group color scale, otherwise it is faded
    color_encoding = alt.condition(
        selection,
        alt.Color("attribute_value:N",
                  scale=global_scales["color"],
                  legend=None),
        alt.value(Legend.color_faded),
    )

    # Offset the positioning of the legend items after the subitlr text
    legend_start_y_position = (Legend.margin_top + title_text_height +
                               subtitle_text_height)

    y_scale = alt.Scale(
        domain=groups,
        range=[
            legend_start_y_position,
            legend_start_y_position
            # number of legend elements x text size
            + (len(groups) * Legend.font_size)
            # (number of "spacings" + start and end "spacings") x spacing
            + ((len(groups) + 1) * Legend.vertical_spacing),
        ],
    )

    # Calculate circle size from radius
    entries_circle_size = Legend.circle_radius * math.pi**2

    # Draw color squares for each group
    entries_circles = (alt.Chart(legend_df).mark_point(
        filled=True, opacity=1, size=entries_circle_size).encode(
            x=alt.value(entries_circles_x_position),
            y=alt.Y("attribute_value:N", scale=y_scale, axis=no_axis()),
            color=color_encoding,
            shape=alt.Shape("attribute_value:N",
                            scale=global_scales["shape"],
                            legend=None),
        ).add_selection(selection))
    trigger_text = alt.selection_multi(empty="all", fields=["attribute_value"])

    # Draw colored label for each group
    entries_text = (alt.Chart(legend_df).mark_text(
        align="left",
        baseline="middle",
        font=FONT,
        fontSize=Legend.font_size,
        fontWeight=Legend.font_weight,
    ).encode(
        x=alt.value(entries_text_x_position),
        y=alt.Y("attribute_value:N", scale=y_scale, axis=no_axis()),
        text=alt.Text("label:N"),
        color=color_encoding,
    ).add_selection(trigger_text))

    return entries_circles + entries_text + subtitle_text + title_text
Exemplo n.º 13
0
def __draw_metric_line_titles(metrics, size_constants):
    """Draws left hand side titles for metrics."""

    metric_line_titles = []

    for metric in metrics:
        # METRIC TITLE
        metric_title = (alt.Chart(DUMMY_DF).transform_calculate(
            y_position="1.2").mark_text(
                align="center",
                baseline="middle",
                font=FONT,
                fontWeight=Title.font_weight,
                size=Title.font_size,
                color=Title.font_color,
            ).encode(
                alt.Y("y_position:Q",
                      scale=alt.Scale(domain=[3, 1]),
                      axis=no_axis()),
                text=alt.value(metric.upper()),
            ))

        # GROUPS TEXT
        group_circles_title = (alt.Chart(DUMMY_DF).transform_calculate(
            y_position="2").mark_text(
                align="center",
                baseline="middle",
                font=FONT,
                size=Subtitle.font_size,
                color=Subtitle.font_color,
            ).encode(
                alt.Y("y_position:Q",
                      scale=alt.Scale(domain=[3, 1]),
                      axis=no_axis()),
                text=alt.value("Groups"),
            ))

        # PERCENT. POP TEXT
        population_percentage_title = (alt.Chart(DUMMY_DF).transform_calculate(
            y_position="2.7").mark_text(
                align="center",
                baseline="middle",
                font=FONT,
                size=Subtitle.font_size,
                color=Subtitle.font_color,
            ).encode(
                alt.Y("y_position:Q",
                      scale=alt.Scale(domain=[3, 1]),
                      axis=no_axis()),
                text=alt.value("% Pop."),
            ))

        metric_line_titles.append(
            (metric_title + group_circles_title +
             population_percentage_title).properties(
                 height=size_constants["line_height"],
                 width=size_constants["metric_titles_width"],
             ))

    # EMPTY CORNER SPACE
    # To make sure that the attribute columns align properly with the title column, we need to create a blank
    # space of the same size of the attribute titles. For this purpose, we use the same function (__draw_attribute_title)
    # and pass in an empty string so that nothing is actually drawn.
    top_left_corner_space = __draw_attribute_title(
        "", size_constants["metric_titles_width"], size_constants)

    # CONCATENATE SUBPLOTS
    metric_titles = alt.vconcat(
        top_left_corner_space,
        *metric_line_titles,
        spacing=size_constants["line_spacing"],
        bounds="flush",
    )

    return metric_titles