コード例 #1
0
ファイル: scatter.py プロジェクト: danielrjiang/Ax
def tile_fitted(
    model: ModelBridge,
    generator_runs_dict: TNullableGeneratorRunsDict = None,
    rel: bool = True,
    show_arm_details_on_hover: bool = False,
    show_CI: bool = True,
    arm_noun: str = "arm",
    metrics: Optional[List[str]] = None,
    fixed_features: Optional[ObservationFeatures] = None,
) -> AxPlotConfig:
    """Tile version of fitted outcome plots.

    Args:
        model: model to use for predictions.
        generator_runs_dict: a mapping from
            generator run name to generator run.
        rel: if True, use relative effects. Default is True.
        show_arm_details_on_hover: if True, display
            parameterizations of arms on hover. Default is False.
        show_CI: if True, render confidence intervals.
        arm_noun: noun to use instead of "arm" (e.g. group)
        metrics: List of metric names to restrict to when plotting.
        fixed_features: Fixed features to use when making model predictions.

    """
    metrics = metrics or list(model.metric_names)
    nrows = int(np.ceil(len(metrics) / 2))
    ncols = min(len(metrics), 2)

    # make subplots (plot per row)
    fig = tools.make_subplots(
        rows=nrows,
        cols=ncols,
        print_grid=False,
        shared_xaxes=False,
        shared_yaxes=False,
        subplot_titles=tuple(metrics),
        horizontal_spacing=0.05,
        vertical_spacing=0.30 / nrows,
    )

    name_order_args: Dict[str, Any] = {}
    name_order_axes: Dict[str, Dict[str, Any]] = {}
    effect_order_args: Dict[str, Any] = {}

    for i, metric in enumerate(metrics):
        data = _single_metric_traces(
            model,
            metric,
            generator_runs_dict,
            rel,
            showlegend=i == 0,
            show_arm_details_on_hover=show_arm_details_on_hover,
            show_CI=show_CI,
            arm_noun=arm_noun,
            fixed_features=fixed_features,
        )

        # order arm name sorting arm numbers within batch
        names_by_arm = sorted(
            np.unique(np.concatenate([d["x"] for d in data])),
            key=lambda x: arm_name_to_tuple(x),
        )

        # get arm names sorted by effect size
        names_by_effect = list(
            OrderedDict.fromkeys(
                np.concatenate([d["x"] for d in data]).flatten().take(
                    np.argsort(
                        np.concatenate([d["y"] for d in data]).flatten()))))

        # options for ordering arms (x-axis)
        # Note that xaxes need to be references as xaxis, xaxis2, xaxis3, etc.
        # for the purposes of updatemenus argument (dropdown) in layout.
        # However, when setting the initial ordering layout, the keys should be
        # xaxis1, xaxis2, xaxis3, etc. Note the discrepancy for the initial
        # axis.
        label = "" if i == 0 else i + 1
        name_order_args["xaxis{}.categoryorder".format(label)] = "array"
        name_order_args["xaxis{}.categoryarray".format(label)] = names_by_arm
        effect_order_args["xaxis{}.categoryorder".format(label)] = "array"
        effect_order_args["xaxis{}.categoryarray".format(
            label)] = names_by_effect
        name_order_axes["xaxis{}".format(i + 1)] = {
            "categoryorder": "array",
            "categoryarray": names_by_arm,
            "type": "category",
        }
        name_order_axes["yaxis{}".format(i + 1)] = {
            "ticksuffix": "%" if rel else "",
            "zerolinecolor": "red",
        }
        for d in data:
            fig.append_trace(d, int(np.floor(i / ncols)) + 1, i % ncols + 1)

    order_options = [
        {
            "args": [name_order_args],
            "label": "Name",
            "method": "relayout"
        },
        {
            "args": [effect_order_args],
            "label": "Effect Size",
            "method": "relayout"
        },
    ]

    # if odd number of plots, need to manually remove the last blank subplot
    # generated by `tools.make_subplots`
    if len(metrics) % 2 == 1:
        del fig["layout"]["xaxis{}".format(nrows * ncols)]
        del fig["layout"]["yaxis{}".format(nrows * ncols)]

    # allocate 400 px per plot
    fig["layout"].update(
        margin={"t": 0},
        hovermode="closest",
        updatemenus=[{
            "x": 0.15,
            "y": 1 + 0.40 / nrows,
            "buttons": order_options,
            "xanchor": "left",
            "yanchor": "middle",
        }],
        font={"size": 10},
        width=650 if ncols == 1 else 950,
        height=300 * nrows,
        legend={
            "orientation": "h",
            "x": 0,
            "y": 1 + 0.20 / nrows,
            "xanchor": "left",
            "yanchor": "middle",
        },
        **name_order_axes,
    )

    # append dropdown annotations
    fig["layout"]["annotations"] += (
        {
            "x": 0.5,
            "y": 1 + 0.40 / nrows,
            "xref": "paper",
            "yref": "paper",
            "font": {
                "size": 14
            },
            "text": "Predicted Outcomes",
            "showarrow": False,
            "xanchor": "center",
            "yanchor": "middle",
        },
        {
            "x": 0.05,
            "y": 1 + 0.40 / nrows,
            "xref": "paper",
            "yref": "paper",
            "text": "Sort By",
            "showarrow": False,
            "xanchor": "left",
            "yanchor": "middle",
        },
    )

    fig = resize_subtitles(figure=fig, size=10)
    return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
コード例 #2
0
ファイル: scatter.py プロジェクト: danielrjiang/Ax
def plot_fitted(
    model: ModelBridge,
    metric: str,
    generator_runs_dict: TNullableGeneratorRunsDict = None,
    rel: bool = True,
    custom_arm_order: Optional[List[str]] = None,
    custom_arm_order_name: str = "Custom",
    show_CI: bool = True,
) -> AxPlotConfig:
    """Plot fitted metrics.

    Args:
        model: model to use for predictions.
        metric: metric to plot predictions for.
        generator_runs_dict: a mapping from
            generator run name to generator run.
        rel: if True, use relative effects. Default is True.
        custom_arm_order: a list of arm names in the
            order corresponding to how they should be plotted on the x-axis.
            If not None, this is the default ordering.
        custom_arm_order_name: name for custom ordering to
            show in the ordering dropdown. Default is 'Custom'.
        show_CI: if True, render confidence intervals.

    """
    traces = _single_metric_traces(model,
                                   metric,
                                   generator_runs_dict,
                                   rel,
                                   show_CI=show_CI)

    # order arm name sorting arm numbers within batch
    names_by_arm = sorted(
        np.unique(np.concatenate([d["x"] for d in traces])),
        key=lambda x: arm_name_to_tuple(x),
    )

    # get arm names sorted by effect size
    names_by_effect = list(
        OrderedDict.fromkeys(
            np.concatenate([d["x"] for d in traces]).flatten().take(
                np.argsort(np.concatenate([d["y"]
                                           for d in traces]).flatten()))))

    # options for ordering arms (x-axis)
    xaxis_categoryorder = "array"
    xaxis_categoryarray = names_by_arm

    order_options = [
        {
            "args": [{
                "xaxis.categoryorder": "array",
                "xaxis.categoryarray": names_by_arm
            }],
            "label":
            "Name",
            "method":
            "relayout",
        },
        {
            "args": [{
                "xaxis.categoryorder": "array",
                "xaxis.categoryarray": names_by_effect
            }],
            "label":
            "Effect Size",
            "method":
            "relayout",
        },
    ]

    # if a custom order has been passed, default to that
    if custom_arm_order is not None:
        xaxis_categoryorder = "array"
        xaxis_categoryarray = custom_arm_order
        order_options = [{
            "args": [{
                "xaxis.categoryorder": "array",
                "xaxis.categoryarray": custom_arm_order,
            }],
            "label":
            custom_arm_order_name,
            "method":
            "relayout",
        }
                         # Union[List[str...
                         ] + order_options

    layout = go.Layout(
        title="Predicted Outcomes",
        hovermode="closest",
        updatemenus=[{
            "x": 1.25,
            "y": 0.67,
            "buttons": list(order_options),
            "yanchor": "middle",
            "xanchor": "left",
        }],
        yaxis={
            "zerolinecolor": "red",
            "title": "{}{}".format(metric, " (%)" if rel else ""),
        },
        xaxis={
            "tickangle": 45,
            "categoryorder": xaxis_categoryorder,
            "categoryarray": xaxis_categoryarray,
        },
        annotations=[{
            "x": 1.18,
            "y": 0.72,
            "xref": "paper",
            "yref": "paper",
            "text": "Sort By",
            "showarrow": False,
            "yanchor": "middle",
        }],
        font={"size": 10},
    )

    fig = go.Figure(data=traces, layout=layout)
    return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)