def __draw_population_bar(population_bar_df, metric, color_scale): """ Draws a stacked bar of the sum of the percentage of population of the groups that obtained each result for the parity test.""" population_bar_tooltips = [ alt.Tooltip(field=f"{metric}_parity_result", type="nominal", title="Parity"), alt.Tooltip( field="tooltip_group_size", type="nominal", title="Size", ), alt.Tooltip(field="tooltip_groups_name_size", type="nominal", title="Groups"), ] population_bar = (alt.Chart(population_bar_df).transform_calculate( y_position="2.8").mark_bar(size=6, stroke="white").encode( alt.X("sum(group_size):Q", stack="normalize", axis=no_axis()), alt.Y("y_position:Q", scale=alt.Scale(domain=[3, 1]), axis=no_axis()), alt.Color( f"{metric}_parity_result:O", scale=color_scale, legend=alt.Legend( title="Parity Test", padding=20, ), ), tooltip=population_bar_tooltips, )) return population_bar
def __draw_threshold_bands(threshold_df, scales, accessibility_mode=False): """Draws fairness threshold bands: regions painted red where the metric value is above the defined fairness threshold.""" fill_color = ( Threshold_Band.color_accessible if accessibility_mode else Threshold_Band.color ) lower_threshold_band = ( alt.Chart(threshold_df) .mark_rect(fill=fill_color, opacity=Threshold_Band.opacity, tooltip="") .encode( y=alt.Y(field="metric", type="nominal", scale=scales["y"], axis=no_axis()), x=alt.X("lower_threshold_value:Q", scale=scales["x"]), x2="lower_end:Q", ) ) upper_threshold_band = ( alt.Chart(threshold_df) .mark_rect(fill=fill_color, opacity=Threshold_Band.opacity, tooltip="") .encode( y=alt.Y(field="metric", type="nominal", scale=scales["y"], axis=no_axis()), x=alt.X("upper_threshold_value:Q", scale=scales["x"]), x2="upper_end:Q" # scales["x"]["range"][1] # Replicating the approach of the lower_threshold_band doesn't work... ) ) return lower_threshold_band + upper_threshold_band
def __draw_axis_rules(x_metric, y_metric, scales): """Draws horizontal and vertical rules for the axis.""" # BASE CHART base = alt.Chart(pd.DataFrame({"start": 0, "end": 1}, index=[0])) # AXIS ENCODING axis_values = [0.0, 0.25, 0.5, 0.75, 1] bottom_axis = alt.Axis( values=axis_values, orient="bottom", domain=False, labels=False, ticks=False, title=x_metric.upper(), ) left_axis = alt.Axis( values=axis_values, orient="left", domain=False, labels=False, ticks=False, title=y_metric.upper(), ) # X AXIS x_rule = base.mark_rule(strokeWidth=Rule.stroke_width, stroke=Rule.stroke, tooltip="").encode( x=alt.X("start:Q", scale=scales["x"], axis=bottom_axis), x2="end:Q", y=alt.Y("start:Q", scale=scales["y"], axis=no_axis()), ) # Y AXIS y_rule = base.mark_rule(strokeWidth=Rule.stroke_width, stroke=Rule.stroke, tooltip="").encode( x=alt.X("start:Q", scale=scales["x"], axis=no_axis()), y=alt.Y("start:Q", scale=scales["y"], axis=left_axis), y2="end:Q", ) return x_rule + y_rule
def __draw_metrics_rules(metrics, scales, concat_chart): """Draws an horizontal rule for each metric where the bubbles will be positioned.""" metrics_labels = [metric.upper() for metric in metrics] if concat_chart: y_axis = no_axis() else: y_axis = alt.Axis( domain=False, ticks=False, orient="left", labelAngle=Metric_Axis.label_angle, labelPadding=Metric_Axis.label_padding, title="", ) horizontal_rules = ( alt.Chart(pd.DataFrame({"y_position": metrics_labels, "start": 0, "end": 1})) .mark_rule( strokeWidth=Metric_Axis.stroke_width, stroke=Metric_Axis.stroke, tooltip="", ) .encode( y=alt.Y("y_position:N", scale=scales["y"], axis=y_axis), x=alt.X("start:Q", scale=scales["x"]), x2="end:Q", ) ) return horizontal_rules
def __draw_threshold_rules(threshold_df, scales, position, accessibility_mode=False): """Draws fairness threshold rules: red lines that mark the defined fairness threshold in the chart.""" stroke_color = ( Threshold_Rule.stroke_accessible if accessibility_mode else Threshold_Rule.stroke ) threshold_rules = ( alt.Chart(threshold_df) .mark_rect( stroke=stroke_color, opacity=Threshold_Rule.opacity, strokeWidth=Threshold_Rule.stroke_width, tooltip="", ) .encode( x=alt.X( field=f"{position}_threshold_value", type="quantitative", scale=scales["x"], ), x2=f"{position}_threshold_value:Q", y=alt.Y(field="metric", type="nominal", scale=scales["y"], axis=no_axis()), ) ) return threshold_rules
def __draw_group_circles(plot_df, metric, scales, size_constants): """ Draws a circle for each group, color-coded by the result of the parity test. The groups are spread around the central reference group according to their disparity.""" circle_tooltip_encoding = [ alt.Tooltip(field="attribute_value", type="nominal", title="Group"), alt.Tooltip(field="tooltip_group_size", type="nominal", title="Group Size"), alt.Tooltip( field=f"tooltip_parity_test_explanation_{metric}", type="nominal", title="Parity Test", ), alt.Tooltip( field=f"tooltip_disparity_explanation_{metric}", type="nominal", title="Disparity", ), alt.Tooltip( field=f"{metric}", type="quantitative", format=".2f", title=f"{metric}".upper(), ), ] return (alt.Chart(plot_df).transform_calculate(y_position="2").mark_circle( opacity=1).encode( alt.X(f"{metric}_disparity_rank:Q", scale=scales["circles_x"], axis=no_axis()), alt.Y("y_position:Q", scale=alt.Scale(domain=[3, 1]), axis=no_axis()), alt.Color( f"{metric}_parity_result:O", scale=scales["color"], legend=alt.Legend(title=""), ), size=alt.value(size_constants["group_circle_size"]), tooltip=circle_tooltip_encoding, ))
def __draw_tick_labels(scales, chart_height, chart_width): """Draws the numbers in both axes.""" axis_values = [0, 0.25, 0.5, 0.75, 1] axis_df = pd.DataFrame({ "main_axis_values": axis_values, "aux_axis_position": 0 }) x_tick_labels = (alt.Chart(axis_df).mark_text( yOffset=Scatter_Axis.label_font_size * 1.5, tooltip="", align="center", fontSize=Scatter_Axis.label_font_size, color=Scatter_Axis.label_color, fontWeight=Scatter_Axis.label_font_weight, font=FONT, ).encode( text=alt.Text("main_axis_values:Q"), x=alt.X("main_axis_values:Q", scale=scales["x"], axis=no_axis()), y=alt.Y("aux_axis_position:Q", scale=scales["y"], axis=no_axis()), )) axis_df.drop(0, inplace=True) y_tick_labels = (alt.Chart(axis_df).mark_text( baseline="middle", xOffset=-Scatter_Axis.label_font_size * 1.5, tooltip="", align="center", fontSize=Scatter_Axis.label_font_size, fontWeight=Scatter_Axis.label_font_weight, color=Scatter_Axis.label_color, font=FONT, ).encode( text=alt.Text("main_axis_values:Q"), x=alt.X("aux_axis_position:Q", scale=scales["x"], axis=no_axis()), y=alt.Y("main_axis_values:Q", scale=scales["y"], axis=no_axis()), )) return x_tick_labels + y_tick_labels
def __draw_parity_result_text(parity_result, color_scale): """ Draws the uppercased text result of the provided parity test (Pass, Fail or Reference), color-coded according to the provided Altair scale.""" return (alt.Chart( pd.DataFrame({"parity_result": parity_result}, index=[0])).transform_calculate(y_position="1").mark_text( align="center", baseline="middle", font=FONT, size=Parity_Result.font_size, fontWeight=Parity_Result.font_weight, ).encode( alt.Y("y_position:Q", scale=alt.Scale(domain=[3, 1]), axis=no_axis()), alt.Color("parity_result:O", scale=color_scale, legend=alt.Legend(title="")), text=alt.value(parity_result.upper()), ))
def __draw_threshold_bands( ref_group_value, fairness_threshold, main_scale, aux_scale, accessibility_mode=False, drawing_x=False, ): """Draws threshold rules and bands for both axis: regions painted red on the chart where the metric is above the defined fairness threshold.""" # DATASOURCE threshold_df = pd.DataFrame( { "start": 0, "end": 1, "lower_threshold": ref_group_value / fairness_threshold, "upper_threshold": min(ref_group_value * fairness_threshold, 1), }, index=[0], ) stroke_color = (Threshold_Rule.stroke_accessible if accessibility_mode else Threshold_Rule.stroke) fill_color = (Threshold_Band.color_accessible if accessibility_mode else Threshold_Band.color) # BASES rule_base = alt.Chart(threshold_df).mark_rule( stroke=stroke_color, strokeWidth=Threshold_Rule.stroke_width, opacity=Threshold_Rule.opacity, tooltip="", ) band_base = alt.Chart(threshold_df).mark_rect( fill=fill_color, opacity=Threshold_Band.opacity, tooltip="") # PARAMS if drawing_x: common_params = dict(y=alt.Y("start:Q", scale=aux_scale, axis=no_axis()), y2="end:Q") upper_param = dict( x=alt.X("upper_threshold:Q", scale=main_scale, axis=no_axis())) lower_param = dict( x=alt.X("lower_threshold:Q", scale=main_scale, axis=no_axis())) lower_end_param = dict(x2="start:Q") upper_end_param = dict(x2="end:Q") else: common_params = dict(x=alt.X("start:Q", scale=aux_scale, axis=no_axis()), x2="end:Q") upper_param = dict( y=alt.Y("upper_threshold:Q", scale=main_scale, axis=no_axis())) lower_param = dict( y=alt.Y("lower_threshold:Q", scale=main_scale, axis=no_axis())) lower_end_param = dict(y2="start:Q") upper_end_param = dict(y2="end:Q") # RULES lower_threshold_rule = rule_base.encode( **common_params, **lower_param, ) upper_threshold_rule = rule_base.encode( **common_params, **upper_param, ) # BANDS lower_threshold_band = band_base.encode(**common_params, **lower_param, **lower_end_param) upper_threshold_band = band_base.encode(**common_params, **upper_param, **upper_end_param) return (lower_threshold_rule + upper_threshold_rule + lower_threshold_band + upper_threshold_band)
def __draw_bubbles( plot_table, x_metric, y_metric, ref_group, scales, interactive_selection_group, ): """Draws the bubbles for all metrics.""" # FILTER DF fields_to_keep_in_metric_table = [ "group_size", "attribute_value", "total_entities", x_metric, y_metric, ] metric_plot_table = plot_table[fields_to_keep_in_metric_table].copy( deep=True) metric_plot_table["tooltip_group_size"] = plot_table.apply( lambda row: get_tooltip_text_group_size(row["group_size"], row[ "total_entities"]), axis=1, ) # COLOR ENCODING bubble_color_encoding = alt.condition( interactive_selection_group, alt.Color("attribute_value:N", scale=scales["color"], legend=None), alt.value(Bubble.color_faded), ) # TOOLTIP ENCODING bubble_tooltip_encoding = [ alt.Tooltip(field="attribute_value", type="nominal", title="Group"), alt.Tooltip(field="tooltip_group_size", type="nominal", title="Group Size"), alt.Tooltip(field=x_metric, type="quantitative", format=".2f", title=x_metric.upper()), alt.Tooltip(field=y_metric, type="quantitative", format=".2f", title=y_metric.upper()), ] # BUBBLE CENTERS bubbles_centers = (alt.Chart(metric_plot_table).mark_point( filled=True, size=Bubble.center_size).encode( x=alt.X(f"{x_metric}:Q", scale=scales["x"], axis=no_axis()), y=alt.Y(f"{y_metric}:Q", scale=scales["y"], axis=no_axis()), tooltip=bubble_tooltip_encoding, color=bubble_color_encoding, shape=alt.Shape("attribute_value:N", scale=scales["shape"], legend=None), )) # BUBBLE AREAS bubbles_areas = (alt.Chart(metric_plot_table).mark_circle( opacity=Bubble.opacity).encode( size=alt.Size("group_size:Q", legend=None, scale=scales["bubble_size"]), x=alt.X(f"{x_metric}:Q", scale=scales["x"], axis=no_axis()), y=alt.Y(f"{y_metric}:Q", scale=scales["y"], axis=no_axis()), tooltip=bubble_tooltip_encoding, color=bubble_color_encoding, )) return bubbles_centers + bubbles_areas
def __draw_bubbles( plot_table, metrics, ref_group, scales, selection, ): """Draws the bubbles for all metrics.""" # X AXIS GRIDLINES axis_values = [0.25, 0.5, 0.75] x_axis = alt.Axis( values=axis_values, ticks=False, domain=False, labels=False, title=None ) # COLOR bubble_color_encoding = alt.condition( selection, alt.Color("attribute_value:N", scale=scales["color"], legend=None), alt.value(Bubble.color_faded), ) # CHART INITIALIZATION bubble_centers = alt.Chart().mark_point() bubble_areas = alt.Chart().mark_circle() plot_table["tooltip_group_size"] = plot_table.apply( lambda row: get_tooltip_text_group_size( row["group_size"], row["total_entities"] ), axis=1, ) # LAYERING THE METRICS for metric in metrics: # TOOLTIP plot_table[f"tooltip_disparity_explanation_{metric}"] = plot_table.apply( lambda row: get_tooltip_text_disparity_explanation( row[f"{metric}_disparity_scaled"], row["attribute_value"], metric, ref_group, ), axis=1, ) bubble_tooltip_encoding = [ alt.Tooltip(field="attribute_value", type="nominal", title="Group"), alt.Tooltip(field="tooltip_group_size", type="nominal", title="Group Size"), alt.Tooltip( field=f"tooltip_disparity_explanation_{metric}", type="nominal", title="Disparity", ), alt.Tooltip( field=f"{metric}", type="quantitative", format=".2f", title=f"{metric}".upper(), ), ] # BUBBLE CENTERS trigger_centers = alt.selection_multi(empty="all", fields=["attribute_value"]) bubble_centers += ( alt.Chart(plot_table) .transform_calculate(metric_variable=f"'{metric.upper()}'") .mark_point(filled=True, size=Bubble.center_size) .encode( x=alt.X(f"{metric}:Q", scale=scales["x"], axis=x_axis), y=alt.Y("metric_variable:N", scale=scales["y"], axis=no_axis()), tooltip=bubble_tooltip_encoding, color=bubble_color_encoding, shape=alt.Shape( "attribute_value:N", scale=scales["shape"], legend=None ), ) .add_selection(trigger_centers) ) # BUBBLE AREAS trigger_areas = alt.selection_multi(empty="all", fields=["attribute_value"]) bubble_areas += ( alt.Chart(plot_table) .mark_circle(opacity=Bubble.opacity) .transform_calculate(metric_variable=f"'{metric.upper()}'") .encode( x=alt.X(f"{metric}:Q", scale=scales["x"], axis=x_axis), y=alt.Y("metric_variable:N", scale=scales["y"], axis=no_axis()), tooltip=bubble_tooltip_encoding, color=bubble_color_encoding, size=alt.Size("group_size:Q", legend=None, scale=scales["bubble_size"]), ) .add_selection(trigger_areas) ) return bubble_areas + bubble_centers
def draw_legend(global_scales, selection, chart_width): """Draws the interactive group's colors legend for the chart.""" groups = global_scales["color"].domain labels = groups.copy() labels[0] = labels[0] + " [REF]" legend_df = pd.DataFrame({"attribute_value": groups, "label": labels}) # Position the legend to the right of the chart title_text_x_position = chart_width title_text_height = Legend.title_font_size + Legend.title_margin_bottom subtitle_text_height = Legend.font_size + Legend.vertical_spacing entries_circles_x_position = title_text_x_position + Legend.horizontal_spacing entries_text_x_position = (title_text_x_position + 2 * Legend.circle_radius + Legend.horizontal_spacing) # Title of the legend. title_text = (alt.Chart(DUMMY_DF).mark_text( align="left", baseline="middle", color=Legend.font_color, fontSize=Legend.title_font_size, font=FONT, fontWeight=Legend.title_font_weight, ).encode( x=alt.value(title_text_x_position), y=alt.value(Legend.margin_top), text=alt.value("Groups"), )) # Subtitle text that explains how to interact with the legend. subtitle_text = (alt.Chart(DUMMY_DF).mark_text( align="left", baseline="middle", color=Legend.font_color, fontSize=Legend.font_size, font=FONT, fontWeight=Legend.font_weight, ).encode( x=alt.value(title_text_x_position), y=alt.value(Legend.margin_top + title_text_height), text=alt.value("Click to highlight a group."), )) # Conditionally color each legend item # If the group is selected, it is colored according to the group color scale, otherwise it is faded color_encoding = alt.condition( selection, alt.Color("attribute_value:N", scale=global_scales["color"], legend=None), alt.value(Legend.color_faded), ) # Offset the positioning of the legend items after the subitlr text legend_start_y_position = (Legend.margin_top + title_text_height + subtitle_text_height) y_scale = alt.Scale( domain=groups, range=[ legend_start_y_position, legend_start_y_position # number of legend elements x text size + (len(groups) * Legend.font_size) # (number of "spacings" + start and end "spacings") x spacing + ((len(groups) + 1) * Legend.vertical_spacing), ], ) # Calculate circle size from radius entries_circle_size = Legend.circle_radius * math.pi**2 # Draw color squares for each group entries_circles = (alt.Chart(legend_df).mark_point( filled=True, opacity=1, size=entries_circle_size).encode( x=alt.value(entries_circles_x_position), y=alt.Y("attribute_value:N", scale=y_scale, axis=no_axis()), color=color_encoding, shape=alt.Shape("attribute_value:N", scale=global_scales["shape"], legend=None), ).add_selection(selection)) trigger_text = alt.selection_multi(empty="all", fields=["attribute_value"]) # Draw colored label for each group entries_text = (alt.Chart(legend_df).mark_text( align="left", baseline="middle", font=FONT, fontSize=Legend.font_size, fontWeight=Legend.font_weight, ).encode( x=alt.value(entries_text_x_position), y=alt.Y("attribute_value:N", scale=y_scale, axis=no_axis()), text=alt.Text("label:N"), color=color_encoding, ).add_selection(trigger_text)) return entries_circles + entries_text + subtitle_text + title_text
def __draw_metric_line_titles(metrics, size_constants): """Draws left hand side titles for metrics.""" metric_line_titles = [] for metric in metrics: # METRIC TITLE metric_title = (alt.Chart(DUMMY_DF).transform_calculate( y_position="1.2").mark_text( align="center", baseline="middle", font=FONT, fontWeight=Title.font_weight, size=Title.font_size, color=Title.font_color, ).encode( alt.Y("y_position:Q", scale=alt.Scale(domain=[3, 1]), axis=no_axis()), text=alt.value(metric.upper()), )) # GROUPS TEXT group_circles_title = (alt.Chart(DUMMY_DF).transform_calculate( y_position="2").mark_text( align="center", baseline="middle", font=FONT, size=Subtitle.font_size, color=Subtitle.font_color, ).encode( alt.Y("y_position:Q", scale=alt.Scale(domain=[3, 1]), axis=no_axis()), text=alt.value("Groups"), )) # PERCENT. POP TEXT population_percentage_title = (alt.Chart(DUMMY_DF).transform_calculate( y_position="2.7").mark_text( align="center", baseline="middle", font=FONT, size=Subtitle.font_size, color=Subtitle.font_color, ).encode( alt.Y("y_position:Q", scale=alt.Scale(domain=[3, 1]), axis=no_axis()), text=alt.value("% Pop."), )) metric_line_titles.append( (metric_title + group_circles_title + population_percentage_title).properties( height=size_constants["line_height"], width=size_constants["metric_titles_width"], )) # EMPTY CORNER SPACE # To make sure that the attribute columns align properly with the title column, we need to create a blank # space of the same size of the attribute titles. For this purpose, we use the same function (__draw_attribute_title) # and pass in an empty string so that nothing is actually drawn. top_left_corner_space = __draw_attribute_title( "", size_constants["metric_titles_width"], size_constants) # CONCATENATE SUBPLOTS metric_titles = alt.vconcat( top_left_corner_space, *metric_line_titles, spacing=size_constants["line_spacing"], bounds="flush", ) return metric_titles