def tile_fitted( model: ModelBridge, generator_runs_dict: TNullableGeneratorRunsDict = None, rel: bool = True, show_arm_details_on_hover: bool = False, show_CI: bool = True, arm_noun: str = "arm", metrics: Optional[List[str]] = None, fixed_features: Optional[ObservationFeatures] = None, ) -> AxPlotConfig: """Tile version of fitted outcome plots. Args: model: model to use for predictions. generator_runs_dict: a mapping from generator run name to generator run. rel: if True, use relative effects. Default is True. show_arm_details_on_hover: if True, display parameterizations of arms on hover. Default is False. show_CI: if True, render confidence intervals. arm_noun: noun to use instead of "arm" (e.g. group) metrics: List of metric names to restrict to when plotting. fixed_features: Fixed features to use when making model predictions. """ metrics = metrics or list(model.metric_names) nrows = int(np.ceil(len(metrics) / 2)) ncols = min(len(metrics), 2) # make subplots (plot per row) fig = tools.make_subplots( rows=nrows, cols=ncols, print_grid=False, shared_xaxes=False, shared_yaxes=False, subplot_titles=tuple(metrics), horizontal_spacing=0.05, vertical_spacing=0.30 / nrows, ) name_order_args: Dict[str, Any] = {} name_order_axes: Dict[str, Dict[str, Any]] = {} effect_order_args: Dict[str, Any] = {} for i, metric in enumerate(metrics): data = _single_metric_traces( model, metric, generator_runs_dict, rel, showlegend=i == 0, show_arm_details_on_hover=show_arm_details_on_hover, show_CI=show_CI, arm_noun=arm_noun, fixed_features=fixed_features, ) # order arm name sorting arm numbers within batch names_by_arm = sorted( np.unique(np.concatenate([d["x"] for d in data])), key=lambda x: arm_name_to_tuple(x), ) # get arm names sorted by effect size names_by_effect = list( OrderedDict.fromkeys( np.concatenate([d["x"] for d in data]).flatten().take( np.argsort( np.concatenate([d["y"] for d in data]).flatten())))) # options for ordering arms (x-axis) # Note that xaxes need to be references as xaxis, xaxis2, xaxis3, etc. # for the purposes of updatemenus argument (dropdown) in layout. # However, when setting the initial ordering layout, the keys should be # xaxis1, xaxis2, xaxis3, etc. Note the discrepancy for the initial # axis. label = "" if i == 0 else i + 1 name_order_args["xaxis{}.categoryorder".format(label)] = "array" name_order_args["xaxis{}.categoryarray".format(label)] = names_by_arm effect_order_args["xaxis{}.categoryorder".format(label)] = "array" effect_order_args["xaxis{}.categoryarray".format( label)] = names_by_effect name_order_axes["xaxis{}".format(i + 1)] = { "categoryorder": "array", "categoryarray": names_by_arm, "type": "category", } name_order_axes["yaxis{}".format(i + 1)] = { "ticksuffix": "%" if rel else "", "zerolinecolor": "red", } for d in data: fig.append_trace(d, int(np.floor(i / ncols)) + 1, i % ncols + 1) order_options = [ { "args": [name_order_args], "label": "Name", "method": "relayout" }, { "args": [effect_order_args], "label": "Effect Size", "method": "relayout" }, ] # if odd number of plots, need to manually remove the last blank subplot # generated by `tools.make_subplots` if len(metrics) % 2 == 1: del fig["layout"]["xaxis{}".format(nrows * ncols)] del fig["layout"]["yaxis{}".format(nrows * ncols)] # allocate 400 px per plot fig["layout"].update( margin={"t": 0}, hovermode="closest", updatemenus=[{ "x": 0.15, "y": 1 + 0.40 / nrows, "buttons": order_options, "xanchor": "left", "yanchor": "middle", }], font={"size": 10}, width=650 if ncols == 1 else 950, height=300 * nrows, legend={ "orientation": "h", "x": 0, "y": 1 + 0.20 / nrows, "xanchor": "left", "yanchor": "middle", }, **name_order_axes, ) # append dropdown annotations fig["layout"]["annotations"] += ( { "x": 0.5, "y": 1 + 0.40 / nrows, "xref": "paper", "yref": "paper", "font": { "size": 14 }, "text": "Predicted Outcomes", "showarrow": False, "xanchor": "center", "yanchor": "middle", }, { "x": 0.05, "y": 1 + 0.40 / nrows, "xref": "paper", "yref": "paper", "text": "Sort By", "showarrow": False, "xanchor": "left", "yanchor": "middle", }, ) fig = resize_subtitles(figure=fig, size=10) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def plot_fitted( model: ModelBridge, metric: str, generator_runs_dict: TNullableGeneratorRunsDict = None, rel: bool = True, custom_arm_order: Optional[List[str]] = None, custom_arm_order_name: str = "Custom", show_CI: bool = True, ) -> AxPlotConfig: """Plot fitted metrics. Args: model: model to use for predictions. metric: metric to plot predictions for. generator_runs_dict: a mapping from generator run name to generator run. rel: if True, use relative effects. Default is True. custom_arm_order: a list of arm names in the order corresponding to how they should be plotted on the x-axis. If not None, this is the default ordering. custom_arm_order_name: name for custom ordering to show in the ordering dropdown. Default is 'Custom'. show_CI: if True, render confidence intervals. """ traces = _single_metric_traces(model, metric, generator_runs_dict, rel, show_CI=show_CI) # order arm name sorting arm numbers within batch names_by_arm = sorted( np.unique(np.concatenate([d["x"] for d in traces])), key=lambda x: arm_name_to_tuple(x), ) # get arm names sorted by effect size names_by_effect = list( OrderedDict.fromkeys( np.concatenate([d["x"] for d in traces]).flatten().take( np.argsort(np.concatenate([d["y"] for d in traces]).flatten())))) # options for ordering arms (x-axis) xaxis_categoryorder = "array" xaxis_categoryarray = names_by_arm order_options = [ { "args": [{ "xaxis.categoryorder": "array", "xaxis.categoryarray": names_by_arm }], "label": "Name", "method": "relayout", }, { "args": [{ "xaxis.categoryorder": "array", "xaxis.categoryarray": names_by_effect }], "label": "Effect Size", "method": "relayout", }, ] # if a custom order has been passed, default to that if custom_arm_order is not None: xaxis_categoryorder = "array" xaxis_categoryarray = custom_arm_order order_options = [{ "args": [{ "xaxis.categoryorder": "array", "xaxis.categoryarray": custom_arm_order, }], "label": custom_arm_order_name, "method": "relayout", } # Union[List[str... ] + order_options layout = go.Layout( title="Predicted Outcomes", hovermode="closest", updatemenus=[{ "x": 1.25, "y": 0.67, "buttons": list(order_options), "yanchor": "middle", "xanchor": "left", }], yaxis={ "zerolinecolor": "red", "title": "{}{}".format(metric, " (%)" if rel else ""), }, xaxis={ "tickangle": 45, "categoryorder": xaxis_categoryorder, "categoryarray": xaxis_categoryarray, }, annotations=[{ "x": 1.18, "y": 0.72, "xref": "paper", "yref": "paper", "text": "Sort By", "showarrow": False, "yanchor": "middle", }], font={"size": 10}, ) fig = go.Figure(data=traces, layout=layout) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)