예제 #1
0
def plot_marginal_effects(model: ModelBridge, metric: str) -> AxPlotConfig:
    """
    Calculates and plots the marginal effects -- the effect of changing one
    factor away from the randomized distribution of the experiment and fixing it
    at a particular level.

    Args:
        model: Model to use for estimating effects
        metric: The metric for which to plot marginal effects.

    Returns:
        AxPlotConfig of the marginal effects
    """
    plot_data, _, _ = get_plot_data(model, {}, {metric})

    arm_dfs = []
    for arm in plot_data.in_sample.values():
        arm_df = pd.DataFrame(arm.parameters, index=[arm.name])
        arm_df["mean"] = arm.y_hat[metric]
        arm_df["sem"] = arm.se_hat[metric]
        arm_dfs.append(arm_df)
    effect_table = marginal_effects(pd.concat(arm_dfs, 0))

    varnames = effect_table["Name"].unique()
    data: List[Any] = []
    for varname in varnames:
        var_df = effect_table[effect_table["Name"] == varname]
        data += [
            # pyre-ignore[16]
            go.Bar(
                x=var_df["Level"],
                y=var_df["Beta"],
                error_y={
                    "type": "data",
                    "array": var_df["SE"]
                },
                name=varname,
            )
        ]
    fig = tools.make_subplots(
        cols=len(varnames),
        rows=1,
        subplot_titles=list(varnames),
        print_grid=False,
        shared_yaxes=True,
    )
    for idx, item in enumerate(data):
        fig.append_trace(item, 1, idx + 1)
    fig.layout.showlegend = False
    # fig.layout.margin = go.Margin(l=2, r=2)
    fig.layout.title = "Marginal Effects by Factor"
    fig.layout.yaxis = {
        "title": "% better than experiment average",
        "hoverformat": ".{}f".format(DECIMALS),
    }
    return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
예제 #2
0
def _get_contour_predictions(
    model: ModelBridge,
    x_param_name: str,
    y_param_name: str,
    metric: str,
    generator_runs_dict: TNullableGeneratorRunsDict,
    density: int,
    slice_values: Optional[Dict[str, Any]] = None,
    fixed_features: Optional[ObservationFeatures] = None,
) -> ContourPredictions:
    """
    slice_values is a dictionary {param_name: value} for the parameters that
    are being sliced on.
    """
    x_param = get_range_parameter(model, x_param_name)
    y_param = get_range_parameter(model, y_param_name)

    plot_data, _, _ = get_plot_data(model,
                                    generator_runs_dict or {}, {metric},
                                    fixed_features=fixed_features)

    grid_x = get_grid_for_parameter(x_param, density)
    grid_y = get_grid_for_parameter(y_param, density)
    scales = {"x": x_param.log_scale, "y": y_param.log_scale}

    grid2_x, grid2_y = np.meshgrid(grid_x, grid_y)

    grid2_x = grid2_x.flatten()
    grid2_y = grid2_y.flatten()

    if fixed_features is not None:
        slice_values = fixed_features.parameters
    else:
        fixed_features = ObservationFeatures(parameters={})

    fixed_values = get_fixed_values(model, slice_values)

    param_grid_obsf = []
    for i in range(density**2):
        predf = deepcopy(fixed_features)
        predf.parameters = fixed_values.copy()
        predf.parameters[x_param_name] = grid2_x[i]
        predf.parameters[y_param_name] = grid2_y[i]
        param_grid_obsf.append(predf)

    mu, cov = model.predict(param_grid_obsf)

    f_plt = mu[metric]
    sd_plt = np.sqrt(cov[metric][metric])
    # pyre-fixme[7]: Expected `Tuple[PlotData, np.ndarray, np.ndarray, np.ndarray,
    #  np.ndarray, Dict[str, bool]]` but got `Tuple[PlotData, typing.List[float],
    #  typing.Any, np.ndarray, np.ndarray, Dict[str, bool]]`.
    return plot_data, f_plt, sd_plt, grid_x, grid_y, scales
예제 #3
0
def _get_contour_predictions(
    model: ModelBridge,
    x_param_name: str,
    y_param_name: str,
    metric: str,
    generator_runs_dict: TNullableGeneratorRunsDict,
    density: int,
    slice_values: Optional[Dict[str, Any]] = None,
) -> ContourPredictions:
    """
    slice_values is a dictionary {param_name: value} for the parameters that
    are being sliced on.
    """
    x_param = get_range_parameter(model, x_param_name)
    y_param = get_range_parameter(model, y_param_name)

    plot_data, _, _ = get_plot_data(model, generator_runs_dict or {}, {metric})

    grid_x = get_grid_for_parameter(x_param, density)
    grid_y = get_grid_for_parameter(y_param, density)
    scales = {"x": x_param.log_scale, "y": y_param.log_scale}

    grid2_x, grid2_y = np.meshgrid(grid_x, grid_y)

    grid2_x = grid2_x.flatten()
    grid2_y = grid2_y.flatten()

    fixed_values = get_fixed_values(model, slice_values)

    param_grid_obsf = []
    for i in range(density ** 2):
        parameters = fixed_values.copy()
        parameters[x_param_name] = grid2_x[i]
        parameters[y_param_name] = grid2_y[i]
        param_grid_obsf.append(ObservationFeatures(parameters))

    mu, cov = model.predict(param_grid_obsf)

    f_plt = mu[metric]
    sd_plt = np.sqrt(cov[metric][metric])
    return plot_data, f_plt, sd_plt, grid_x, grid_y, scales
예제 #4
0
파일: slice.py 프로젝트: facebook/Ax
def _get_slice_predictions(
    model: ModelBridge,
    param_name: str,
    metric_name: str,
    generator_runs_dict: TNullableGeneratorRunsDict = None,
    relative: bool = False,
    density: int = 50,
    slice_values: Optional[Dict[str, Any]] = None,
    fixed_features: Optional[ObservationFeatures] = None,
    trial_index: Optional[int] = None,
) -> SlicePredictions:
    """Computes slice prediction configuration values for a single metric name.

    Args:
        model: ModelBridge that contains model for predictions
        param_name: Name of parameter that will be sliced
        metric_name: Name of metric to plot
        generator_runs_dict: A dictionary {name: generator run} of generator runs
            whose arms will be plotted, if they lie in the slice.
        relative: Predictions relative to status quo
        density: Number of points along slice to evaluate predictions.
        slice_values: A dictionary {name: val} for the fixed values of the
            other parameters. If not provided, then the status quo values will
            be used if there is a status quo, otherwise the mean of numeric
            parameters or the mode of choice parameters. Ignored if
            fixed_features is specified.
        fixed_features: An ObservationFeatures object containing the values of
            features (including non-parameter features like context) to be set
            in the slice.

    Returns: Configruation values for AxPlotConfig.
    """
    if generator_runs_dict is None:
        generator_runs_dict = {}

    parameter = get_range_parameter(model, param_name)
    grid = get_grid_for_parameter(parameter, density)

    plot_data, raw_data, cond_name_to_parameters = get_plot_data(
        model=model,
        generator_runs_dict=generator_runs_dict,
        metric_names={metric_name},
        fixed_features=fixed_features,
    )

    if fixed_features is not None:
        slice_values = fixed_features.parameters
    else:
        fixed_features = ObservationFeatures(parameters={})
    fixed_values = get_fixed_values(model, slice_values, trial_index)

    prediction_features = []
    for x in grid:
        predf = deepcopy(fixed_features)
        predf.parameters = fixed_values.copy()
        predf.parameters[param_name] = x
        prediction_features.append(predf)

    f, cov = model.predict(prediction_features)
    f_plt = f[metric_name]
    sd_plt = np.sqrt(cov[metric_name][metric_name])
    # pyre-fixme[7]: Expected `Tuple[PlotData, List[Dict[str, Union[float, str]]],
    #  List[float], np.ndarray, np.ndarray, str, str, bool, Dict[str, Union[None, bool,
    #  float, int, str]], np.ndarray, bool]` but got `Tuple[PlotData, Dict[str,
    #  Dict[str, Union[None, bool, float, int, str]]], List[float], List[Dict[str,
    #  Union[float, str]]], np.ndarray, str, str, bool, Dict[str, Union[None, bool,
    #  float, int, str]], typing.Any, bool]`.
    return (
        plot_data,
        cond_name_to_parameters,
        f_plt,
        raw_data,
        grid,
        metric_name,
        param_name,
        relative,
        fixed_values,
        sd_plt,
        parameter.log_scale,
    )
예제 #5
0
def table_view_plot(
    experiment: Experiment,
    data: Data,
    use_empirical_bayes: bool = True,
    only_data_frame: bool = False,
    arm_noun: str = "arm",
):
    """Table of means and confidence intervals.

    Table is of the form:

    +-------+------------+-----------+
    |  arm  |  metric_1  |  metric_2 |
    +=======+============+===========+
    |  0_0  | mean +- CI |    ...    |
    +-------+------------+-----------+
    |  0_1  |    ...     |    ...    |
    +-------+------------+-----------+

    """
    model_func = get_empirical_bayes_thompson if use_empirical_bayes else get_thompson
    model = model_func(experiment=experiment, data=data)

    # We don't want to include metrics from a collection,
    # or the chart will be too big to read easily.
    # Example:
    # experiment.metrics = {
    #   'regular_metric': Metric(),
    #   'collection_metric: CollectionMetric()', # collection_metric =[metric1, metric2]
    # }
    # model.metric_names = [regular_metric, metric1, metric2] # "exploded" out
    # We want to filter model.metric_names and get rid of metric1, metric2
    metric_names = [
        metric_name for metric_name in model.metric_names
        if metric_name in experiment.metrics
    ]

    metric_name_to_lower_is_better = {
        metric_name: experiment.metrics[metric_name].lower_is_better
        for metric_name in metric_names
    }

    plot_data, _, _ = get_plot_data(
        model=model,
        generator_runs_dict={},
        # pyre-fixme[6]: Expected `Optional[typing.Set[str]]` for 3rd param but got
        #  `List[str]`.
        metric_names=metric_names,
    )

    if plot_data.status_quo_name:
        status_quo_arm = plot_data.in_sample.get(plot_data.status_quo_name)
        rel = True
    else:
        status_quo_arm = None
        rel = False

    records = []
    colors = []
    records_with_mean = []
    records_with_ci = []
    for metric_name in metric_names:
        arm_names, _, ys, ys_se = _error_scatter_data(
            # pyre-fixme[6]: Expected
            #  `List[typing.Union[ax.plot.base.PlotInSampleArm,
            #  ax.plot.base.PlotOutOfSampleArm]]` for 1st param but got
            #  `List[ax.plot.base.PlotInSampleArm]`.
            arms=list(plot_data.in_sample.values()),
            y_axis_var=PlotMetric(metric_name, pred=True, rel=rel),
            x_axis_var=None,
            status_quo_arm=status_quo_arm,
        )

        results_by_arm = list(zip(arm_names, ys, ys_se))
        colors.append([
            get_color(
                x=y,
                ci=Z * y_se,
                rel=rel,
                # pyre-fixme[6]: Expected `bool` for 4th param but got
                #  `Optional[bool]`.
                reverse=metric_name_to_lower_is_better[metric_name],
            ) for (_, y, y_se) in results_by_arm
        ])
        records.append([
            "{:.3f} ± {:.3f}".format(y, Z * y_se)
            for (_, y, y_se) in results_by_arm
        ])
        records_with_mean.append(
            {arm_name: y
             for (arm_name, y, _) in results_by_arm})
        records_with_ci.append(
            {arm_name: Z * y_se
             for (arm_name, _, y_se) in results_by_arm})

    if only_data_frame:
        return tuple(
            pd.DataFrame.from_records(records, index=metric_names)
            for records in [records_with_mean, records_with_ci])

    def transpose(m):
        return [[m[j][i] for j in range(len(m))] for i in range(len(m[0]))]

    records = [[name.replace(":", " : ")
                for name in metric_names]] + transpose(records)
    colors = [["#ffffff"] * len(metric_names)] + transpose(colors)
    # pyre-fixme[6]: Expected `List[str]` for 1st param but got `List[float]`.
    header = [f"<b>{x}</b>" for x in [f"{arm_noun}s"] + arm_names]
    column_widths = [300] + [150] * len(arm_names)

    trace = go.Table(
        header={
            "values": header,
            "align": ["left"]
        },
        cells={
            "values": records,
            "align": ["left"],
            "fill": {
                "color": colors
            }
        },
        columnwidth=column_widths,
    )
    layout = go.Layout(
        width=sum(column_widths),
        margin=go.layout.Margin(l=0, r=20, b=20, t=20, pad=4),  # noqa E741
    )
    fig = go.Figure(data=[trace], layout=layout)
    return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
예제 #6
0
파일: slice.py 프로젝트: zorrock/Ax
def plot_slice(
    model: ModelBridge,
    param_name: str,
    metric_name: str,
    generator_runs_dict: TNullableGeneratorRunsDict = None,
    relative: bool = False,
    density: int = 50,
    slice_values: Optional[Dict[str, Any]] = None,
) -> AxPlotConfig:
    """Plot predictions for a 1-d slice of the parameter space.

    Args:
        model: ModelBridge that contains model for predictions
        param_name: Name of parameter that will be sliced
        metric_name: Name of metric to plot
        generator_runs_dict: A dictionary {name: generator run} of generator runs
            whose arms will be plotted, if they lie in the slice.
        relative: Predictions relative to status quo
        density: Number of points along slice to evaluate predictions.
        slice_values: A dictionary {name: val} for the fixed values of the
            other parameters. If not provided, then the status quo values will
            be used if there is a status quo, otherwise the mean of numeric
            parameters or the mode of choice parameters.
    """
    if generator_runs_dict is None:
        generator_runs_dict = {}

    parameter = get_range_parameter(model, param_name)
    grid = get_grid_for_parameter(parameter, density)

    plot_data, raw_data, cond_name_to_parameters = get_plot_data(
        model=model, generator_runs_dict=generator_runs_dict, metric_names={metric_name}
    )

    fixed_values = get_fixed_values(model, slice_values)

    prediction_features = []
    for x in grid:
        parameters = fixed_values.copy()
        parameters[param_name] = x
        # Here we assume context is None
        prediction_features.append(ObservationFeatures(parameters=parameters))

    f, cov = model.predict(prediction_features)

    f_plt = f[metric_name]
    sd_plt = np.sqrt(cov[metric_name][metric_name])

    config = {
        "arm_data": plot_data,
        "arm_name_to_parameters": cond_name_to_parameters,
        "f": f_plt,
        "fit_data": raw_data,
        "grid": grid,
        "metric": metric_name,
        "param": param_name,
        "rel": relative,
        "setx": fixed_values,
        "sd": sd_plt,
        "is_log": parameter.log_scale,
    }
    return AxPlotConfig(config, plot_type=AxPlotTypes.SLICE)
예제 #7
0
파일: scatter.py 프로젝트: danielrjiang/Ax
def _single_metric_traces(
    model: ModelBridge,
    metric: str,
    generator_runs_dict: TNullableGeneratorRunsDict,
    rel: bool,
    show_arm_details_on_hover: bool = True,
    showlegend: bool = True,
    show_CI: bool = True,
    arm_noun: str = "arm",
    fixed_features: Optional[ObservationFeatures] = None,
) -> Traces:
    """Plot scatterplots with errors for a single metric (y-axis).

    Arms are plotted on the x-axis.

    Args:
        model: model to draw predictions from.
        metric: name of metric to plot.
        generator_runs_dict: a mapping from
            generator run name to generator run.
        rel: if True, plot relative predictions.
        show_arm_details_on_hover: if True, display
            parameterizations of arms on hover. Default is True.
        show_legend: if True, show legend for trace.
        show_CI: if True, render confidence intervals.
        arm_noun: noun to use instead of "arm" (e.g. group)
        fixed_features: Fixed features to use when making model predictions.

    """
    plot_data, _, _ = get_plot_data(model,
                                    generator_runs_dict or {}, {metric},
                                    fixed_features=fixed_features)

    status_quo_arm = (
        None if plot_data.status_quo_name is None
        # pyre-fixme[6]: Expected `str` for 1st param but got `Optional[str]`.
        else plot_data.in_sample.get(plot_data.status_quo_name))

    traces = [
        _error_scatter_trace(
            # Expected `List[Union[PlotInSampleArm, PlotOutOfSampleArm]]`
            # for 1st anonymous parameter to call
            # `ax.plot.scatter._error_scatter_trace` but got
            # `List[PlotInSampleArm]`.
            # pyre-fixme[6]:
            list(plot_data.in_sample.values()),
            x_axis_var=None,
            y_axis_var=PlotMetric(metric, pred=True, rel=rel),
            status_quo_arm=status_quo_arm,
            legendgroup="In-sample",
            showlegend=showlegend,
            show_arm_details_on_hover=show_arm_details_on_hover,
            show_CI=show_CI,
            arm_noun=arm_noun,
        )
    ]

    # Candidates
    for i, (generator_run_name, cand_arms) in enumerate(
        (plot_data.out_of_sample or {}).items(), start=1):
        traces.append(
            _error_scatter_trace(
                list(cand_arms.values()),
                x_axis_var=None,
                y_axis_var=PlotMetric(metric, pred=True, rel=rel),
                status_quo_arm=status_quo_arm,
                name=generator_run_name,
                color=DISCRETE_COLOR_SCALE[i],
                legendgroup=generator_run_name,
                showlegend=showlegend,
                show_arm_details_on_hover=show_arm_details_on_hover,
                show_CI=show_CI,
                arm_noun=arm_noun,
            ))
    return traces
예제 #8
0
파일: scatter.py 프로젝트: danielrjiang/Ax
def lattice_multiple_metrics(
    model: ModelBridge,
    generator_runs_dict: TNullableGeneratorRunsDict = None,
    rel: bool = True,
    show_arm_details_on_hover: bool = False,
) -> AxPlotConfig:
    """Plot raw values or predictions of combinations of two metrics for arms.

    Args:
        model: model to draw predictions from.
        generator_runs_dict: a mapping from
            generator run name to generator run.
        rel: if True, use relative effects. Default is True.
        show_arm_details_on_hover: if True, display
            parameterizations of arms on hover. Default is False.

    """
    metrics = model.metric_names
    fig = tools.make_subplots(
        rows=len(metrics),
        cols=len(metrics),
        print_grid=False,
        shared_xaxes=False,
        shared_yaxes=False,
    )

    plot_data, _, _ = get_plot_data(
        model, generator_runs_dict if generator_runs_dict is not None else {},
        metrics)
    status_quo_arm = (
        None if plot_data.status_quo_name is None
        # pyre-fixme[6]: Expected `str` for 1st param but got `Optional[str]`.
        else plot_data.in_sample.get(plot_data.status_quo_name))

    # iterate over all combinations of metrics and generate scatter traces
    for i, o1 in enumerate(metrics, start=1):
        for j, o2 in enumerate(metrics, start=1):
            if o1 != o2:
                # in-sample observed and predicted
                obs_insample_trace = _error_scatter_trace(
                    # Expected `List[Union[PlotInSampleArm,
                    # PlotOutOfSampleArm]]` for 1st anonymous parameter to call
                    # `ax.plot.scatter._error_scatter_trace` but got
                    # `List[PlotInSampleArm]`.
                    # pyre-fixme[6]:
                    list(plot_data.in_sample.values()),
                    x_axis_var=PlotMetric(o1, pred=False, rel=rel),
                    y_axis_var=PlotMetric(o2, pred=False, rel=rel),
                    status_quo_arm=status_quo_arm,
                    showlegend=(i == 1 and j == 2),
                    legendgroup="In-sample",
                    visible=False,
                    show_arm_details_on_hover=show_arm_details_on_hover,
                )
                predicted_insample_trace = _error_scatter_trace(
                    # Expected `List[Union[PlotInSampleArm,
                    # PlotOutOfSampleArm]]` for 1st anonymous parameter to call
                    # `ax.plot.scatter._error_scatter_trace` but got
                    # `List[PlotInSampleArm]`.
                    # pyre-fixme[6]:
                    list(plot_data.in_sample.values()),
                    x_axis_var=PlotMetric(o1, pred=True, rel=rel),
                    y_axis_var=PlotMetric(o2, pred=True, rel=rel),
                    status_quo_arm=status_quo_arm,
                    legendgroup="In-sample",
                    showlegend=(i == 1 and j == 2),
                    visible=True,
                    show_arm_details_on_hover=show_arm_details_on_hover,
                )
                fig.append_trace(obs_insample_trace, j, i)
                fig.append_trace(predicted_insample_trace, j, i)

                # iterate over models here
                for k, (generator_run_name, cand_arms) in enumerate(
                    (plot_data.out_of_sample or {}).items(), start=1):
                    fig.append_trace(
                        _error_scatter_trace(
                            list(cand_arms.values()),
                            x_axis_var=PlotMetric(o1, pred=True, rel=rel),
                            y_axis_var=PlotMetric(o2, pred=True, rel=rel),
                            status_quo_arm=status_quo_arm,
                            name=generator_run_name,
                            color=DISCRETE_COLOR_SCALE[k],
                            showlegend=(i == 1 and j == 2),
                            legendgroup=generator_run_name,
                            show_arm_details_on_hover=show_arm_details_on_hover,
                        ),
                        j,
                        i,
                    )
            else:
                # if diagonal is set to True, add box plots
                fig.append_trace(
                    go.Box(
                        y=[arm.y[o1] for arm in plot_data.in_sample.values()],
                        name=None,
                        marker={"color": rgba(COLORS.STEELBLUE.value)},
                        showlegend=False,
                        legendgroup="In-sample",
                        visible=False,
                        hoverinfo="none",
                    ),
                    j,
                    i,
                )
                fig.append_trace(
                    go.Box(
                        y=[
                            arm.y_hat[o1]
                            for arm in plot_data.in_sample.values()
                        ],
                        name=None,
                        marker={"color": rgba(COLORS.STEELBLUE.value)},
                        showlegend=False,
                        legendgroup="In-sample",
                        hoverinfo="none",
                    ),
                    j,
                    i,
                )

                for k, (generator_run_name, cand_arms) in enumerate(
                    (plot_data.out_of_sample or {}).items(), start=1):
                    fig.append_trace(
                        go.Box(
                            y=[arm.y_hat[o1] for arm in cand_arms.values()],
                            name=None,
                            marker={"color": rgba(DISCRETE_COLOR_SCALE[k])},
                            showlegend=False,
                            legendgroup=generator_run_name,
                            hoverinfo="none",
                        ),
                        j,
                        i,
                    )

    fig["layout"].update(
        height=800,
        width=960,
        font={"size": 10},
        hovermode="closest",
        legend={
            "orientation": "h",
            "x": 0,
            "y": 1.05,
            "xanchor": "left",
            "yanchor": "middle",
        },
        updatemenus=[
            {
                "x":
                0.35,
                "y":
                1.08,
                "xanchor":
                "left",
                "yanchor":
                "middle",
                "buttons": [
                    {
                        "args": [{
                            "error_x.width": 0,
                            "error_x.thickness": 0,
                            "error_y.width": 0,
                            "error_y.thickness": 0,
                        }],
                        "label":
                        "No",
                        "method":
                        "restyle",
                    },
                    {
                        "args": [{
                            "error_x.width": 4,
                            "error_x.thickness": 2,
                            "error_y.width": 4,
                            "error_y.thickness": 2,
                        }],
                        "label":
                        "Yes",
                        "method":
                        "restyle",
                    },
                ],
            },
            {
                "x":
                0.1,
                "y":
                1.08,
                "xanchor":
                "left",
                "yanchor":
                "middle",
                "buttons": [
                    {
                        "args": [{
                            "visible":
                            (([False, True] +
                              [True] * len(plot_data.out_of_sample or {})) *
                             (len(metrics)**2))
                        }],
                        "label":
                        "Modeled",
                        "method":
                        "restyle",
                    },
                    {
                        "args": [{
                            "visible":
                            (([True, False] +
                              [False] * len(plot_data.out_of_sample or {})) *
                             (len(metrics)**2))
                        }],
                        "label":
                        "In-sample",
                        "method":
                        "restyle",
                    },
                ],
            },
        ],
        annotations=[
            {
                "x": 0.02,
                "y": 1.1,
                "xref": "paper",
                "yref": "paper",
                "text": "Type",
                "showarrow": False,
                "yanchor": "middle",
                "xanchor": "left",
            },
            {
                "x": 0.30,
                "y": 1.1,
                "xref": "paper",
                "yref": "paper",
                "text": "Show CI",
                "showarrow": False,
                "yanchor": "middle",
                "xanchor": "left",
            },
        ],
    )

    # add metric names to axes - add to each subplot if boxplots on the
    # diagonal and axes are not shared; else, add to the leftmost y-axes
    # and bottom x-axes.
    for i, o in enumerate(metrics):
        pos_x = len(metrics) * len(metrics) - len(metrics) + i + 1
        pos_y = 1 + (len(metrics) * i)
        fig["layout"]["xaxis{}".format(pos_x)].update(title=_wrap_metric(o),
                                                      titlefont={"size": 10})
        fig["layout"]["yaxis{}".format(pos_y)].update(title=_wrap_metric(o),
                                                      titlefont={"size": 10})

    # do not put x-axis ticks for boxplots
    boxplot_xaxes = []
    for trace in fig["data"]:
        if trace["type"] == "box":
            # stores the xaxes which correspond to boxplot subplots
            # since we use xaxis1, xaxis2, etc, in plotly.py
            boxplot_xaxes.append("xaxis{}".format(trace["xaxis"][1:]))
        else:
            # clear all error bars since default is no CI
            trace["error_x"].update(width=0, thickness=0)
            trace["error_y"].update(width=0, thickness=0)
    for xaxis in boxplot_xaxes:
        fig["layout"][xaxis]["showticklabels"] = False

    return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
예제 #9
0
파일: scatter.py 프로젝트: danielrjiang/Ax
def _multiple_metric_traces(
    model: ModelBridge,
    metric_x: str,
    metric_y: str,
    generator_runs_dict: TNullableGeneratorRunsDict,
    rel_x: bool,
    rel_y: bool,
    fixed_features: Optional[ObservationFeatures] = None,
) -> Traces:
    """Plot traces for multiple metrics given a model and metrics.

    Args:
        model: model to draw predictions from.
        metric_x: metric to plot on the x-axis.
        metric_y: metric to plot on the y-axis.
        generator_runs_dict: a mapping from
            generator run name to generator run.
        rel_x: if True, use relative effects on metric_x.
        rel_y: if True, use relative effects on metric_y.
        fixed_features: Fixed features to use when making model predictions.

    """
    plot_data, _, _ = get_plot_data(
        model,
        generator_runs_dict if generator_runs_dict is not None else {},
        {metric_x, metric_y},
        fixed_features=fixed_features,
    )

    status_quo_arm = (
        None if plot_data.status_quo_name is None
        # pyre-fixme[6]: Expected `str` for 1st param but got `Optional[str]`.
        else plot_data.in_sample.get(plot_data.status_quo_name))

    traces = [
        _error_scatter_trace(
            # Expected `List[Union[PlotInSampleArm, PlotOutOfSampleArm]]`
            # for 1st anonymous parameter to call
            # `ax.plot.scatter._error_scatter_trace` but got
            # `List[PlotInSampleArm]`.
            # pyre-fixme[6]:
            list(plot_data.in_sample.values()),
            x_axis_var=PlotMetric(metric_x, pred=False, rel=rel_x),
            y_axis_var=PlotMetric(metric_y, pred=False, rel=rel_y),
            status_quo_arm=status_quo_arm,
            visible=False,
        ),
        _error_scatter_trace(
            # Expected `List[Union[PlotInSampleArm, PlotOutOfSampleArm]]`
            # for 1st anonymous parameter to call
            # `ax.plot.scatter._error_scatter_trace` but got
            # `List[PlotInSampleArm]`.
            # pyre-fixme[6]:
            list(plot_data.in_sample.values()),
            x_axis_var=PlotMetric(metric_x, pred=True, rel=rel_x),
            y_axis_var=PlotMetric(metric_y, pred=True, rel=rel_y),
            status_quo_arm=status_quo_arm,
            visible=True,
        ),
    ]

    for i, (generator_run_name, cand_arms) in enumerate(
        (plot_data.out_of_sample or {}).items(), start=1):
        traces.append(
            _error_scatter_trace(
                list(cand_arms.values()),
                x_axis_var=PlotMetric(metric_x, pred=True, rel=rel_x),
                y_axis_var=PlotMetric(metric_y, pred=True, rel=rel_y),
                status_quo_arm=status_quo_arm,
                name=generator_run_name,
                color=DISCRETE_COLOR_SCALE[i],
            ))
    return traces
예제 #10
0
def interact_contour(
    model: ModelBridge,
    metric_name: str,
    generator_runs_dict: TNullableGeneratorRunsDict = None,
    relative: bool = False,
    density: int = 50,
    slice_values: Optional[Dict[str, Any]] = None,
    lower_is_better: bool = False,
    fixed_features: Optional[ObservationFeatures] = None,
) -> AxPlotConfig:
    """Create interactive plot with predictions for a 2-d slice of the parameter
    space.

    Args:
        model: ModelBridge that contains model for predictions
        metric_name: Name of metric to plot
        generator_runs_dict: A dictionary {name: generator run} of generator runs
            whose arms will be plotted, if they lie in the slice.
        relative: Predictions relative to status quo
        density: Number of points along slice to evaluate predictions.
        slice_values: A dictionary {name: val} for the fixed values of the
            other parameters. If not provided, then the status quo values will
            be used if there is a status quo, otherwise the mean of numeric
            parameters or the mode of choice parameters.
        lower_is_better: Lower values for metric are better.
        fixed_features: An ObservationFeatures object containing the values of
            features (including non-parameter features like context) to be set
            in the slice.
    """
    range_parameters = get_range_parameters(model)
    plot_data, _, _ = get_plot_data(model,
                                    generator_runs_dict or {}, {metric_name},
                                    fixed_features=fixed_features)

    # TODO T38563759: Sort parameters by feature importances
    param_names = [parameter.name for parameter in range_parameters]

    is_log_dict: Dict[str, bool] = {}
    grid_dict: Dict[str, np.ndarray] = {}
    for parameter in range_parameters:
        is_log_dict[parameter.name] = parameter.log_scale
        grid_dict[parameter.name] = get_grid_for_parameter(parameter, density)

    # pyre: f_dict is declared to have type `Dict[str, Dict[str, np.ndarray]]`
    # pyre-fixme[9]: but is used as type `Dict[str, Dict[str, typing.List[]]]`.
    f_dict: Dict[str, Dict[str, np.ndarray]] = {
        param1: {param2: []
                 for param2 in param_names}
        for param1 in param_names
    }
    # pyre: sd_dict is declared to have type `Dict[str, Dict[str, np.
    # pyre: ndarray]]` but is used as type `Dict[str, Dict[str, typing.
    # pyre-fixme[9]: List[]]]`.
    sd_dict: Dict[str, Dict[str, np.ndarray]] = {
        param1: {param2: []
                 for param2 in param_names}
        for param1 in param_names
    }
    for param1 in param_names:
        for param2 in param_names:
            _, f_plt, sd_plt, _, _, _ = _get_contour_predictions(
                model=model,
                x_param_name=param1,
                y_param_name=param2,
                metric=metric_name,
                generator_runs_dict=generator_runs_dict,
                density=density,
                slice_values=slice_values,
                fixed_features=fixed_features,
            )
            f_dict[param1][param2] = f_plt
            sd_dict[param1][param2] = sd_plt

    config = {
        "arm_data": plot_data,
        "blue_scale": BLUE_SCALE,
        "density": density,
        "f_dict": f_dict,
        "green_scale": GREEN_SCALE,
        "green_pink_scale": GREEN_PINK_SCALE,
        "grid_dict": grid_dict,
        "lower_is_better": lower_is_better,
        "metric": metric_name,
        "rel": relative,
        "sd_dict": sd_dict,
        "is_log_dict": is_log_dict,
        "param_names": param_names,
    }
    return AxPlotConfig(config, plot_type=AxPlotTypes.INTERACT_CONTOUR)
예제 #11
0
def interact_contour(
    model: ModelBridge,
    metric_name: str,
    generator_runs_dict: TNullableGeneratorRunsDict = None,
    relative: bool = False,
    density: int = 50,
    slice_values: Optional[Dict[str, Any]] = None,
    lower_is_better: bool = False,
    fixed_features: Optional[ObservationFeatures] = None,
) -> AxPlotConfig:
    """Create interactive plot with predictions for a 2-d slice of the parameter
    space.

    Args:
        model: ModelBridge that contains model for predictions
        metric_name: Name of metric to plot
        generator_runs_dict: A dictionary {name: generator run} of generator runs
            whose arms will be plotted, if they lie in the slice.
        relative: Predictions relative to status quo
        density: Number of points along slice to evaluate predictions.
        slice_values: A dictionary {name: val} for the fixed values of the
            other parameters. If not provided, then the status quo values will
            be used if there is a status quo, otherwise the mean of numeric
            parameters or the mode of choice parameters.
        lower_is_better: Lower values for metric are better.
        fixed_features: An ObservationFeatures object containing the values of
            features (including non-parameter features like context) to be set
            in the slice.
    """
    range_parameters = get_range_parameters(model)
    plot_data, _, _ = get_plot_data(
        model, generator_runs_dict or {}, {metric_name}, fixed_features=fixed_features
    )

    # TODO T38563759: Sort parameters by feature importances
    param_names = [parameter.name for parameter in range_parameters]

    is_log_dict: Dict[str, bool] = {}
    grid_dict: Dict[str, np.ndarray] = {}
    for parameter in range_parameters:
        is_log_dict[parameter.name] = parameter.log_scale
        grid_dict[parameter.name] = get_grid_for_parameter(parameter, density)

    f_dict: Dict[str, Dict[str, np.ndarray]] = {
        param1: {param2: [] for param2 in param_names} for param1 in param_names
    }
    sd_dict: Dict[str, Dict[str, np.ndarray]] = {
        param1: {param2: [] for param2 in param_names} for param1 in param_names
    }
    for param1 in param_names:
        for param2 in param_names:
            _, f_plt, sd_plt, _, _, _ = _get_contour_predictions(
                model=model,
                x_param_name=param1,
                y_param_name=param2,
                metric=metric_name,
                generator_runs_dict=generator_runs_dict,
                density=density,
                slice_values=slice_values,
                fixed_features=fixed_features,
            )
            f_dict[param1][param2] = f_plt
            sd_dict[param1][param2] = sd_plt

    config = {
        "arm_data": plot_data,
        "blue_scale": BLUE_SCALE,
        "density": density,
        "f_dict": f_dict,
        "green_scale": GREEN_SCALE,
        "green_pink_scale": GREEN_PINK_SCALE,
        "grid_dict": grid_dict,
        "lower_is_better": lower_is_better,
        "metric": metric_name,
        "rel": relative,
        "sd_dict": sd_dict,
        "is_log_dict": is_log_dict,
        "param_names": param_names,
    }

    config = AxPlotConfig(config, plot_type=AxPlotTypes.GENERIC).data

    arm_data = config["arm_data"]
    density = config["density"]
    grid_dict = config["grid_dict"]
    f_dict = config["f_dict"]
    lower_is_better = config["lower_is_better"]
    metric = config["metric"]
    rel = config["rel"]
    sd_dict = config["sd_dict"]
    is_log_dict = config["is_log_dict"]
    param_names = config["param_names"]

    green_scale = config["green_scale"]
    green_pink_scale = config["green_pink_scale"]
    blue_scale = config["blue_scale"]

    CONTOUR_CONFIG = {
        "autocolorscale": False,
        "autocontour": True,
        "contours": {"coloring": "heatmap"},
        "hoverinfo": "x+y+z",
        "ncontours": int(density / 2),
        "type": "contour",
    }

    if rel:
        f_scale = reversed(green_pink_scale) if lower_is_better else green_pink_scale
    else:
        f_scale = green_scale

    f_contour_trace_base = {
        "colorbar": {
            "len": 0.875,
            "x": 0.45,
            "y": 0.5,
            "ticksuffix": "%" if rel else "",
            "tickfont": {"size": 8},
        },
        "colorscale": [(i / (len(f_scale) - 1), rgb(v)) for i, v in enumerate(f_scale)],
        "xaxis": "x",
        "yaxis": "y",
        # zmax and zmin are ignored if zauto is true
        "zauto": not rel,
    }

    sd_contour_trace_base = {
        "colorbar": {
            "len": 0.875,
            "x": 1,
            "y": 0.5,
            "ticksuffix": "%" if rel else "",
            "tickfont": {"size": 8},
        },
        "colorscale": [
            (i / (len(blue_scale) - 1), rgb(v)) for i, v in enumerate(blue_scale)
        ],
        "xaxis": "x2",
        "yaxis": "y2",
    }

    f_contour_trace_base.update(CONTOUR_CONFIG)
    sd_contour_trace_base.update(CONTOUR_CONFIG)

    insample_param_values = {}
    for param_name in param_names:
        insample_param_values[param_name] = []
        for arm_name in arm_data["in_sample"].keys():
            insample_param_values[param_name].append(
                arm_data["in_sample"][arm_name]["parameters"][param_name]
            )

    insample_arm_text = list(arm_data["in_sample"].keys())

    out_of_sample_param_values = {}
    for param_name in param_names:
        out_of_sample_param_values[param_name] = {}
        for generator_run_name in arm_data["out_of_sample"].keys():
            out_of_sample_param_values[param_name][generator_run_name] = []
            for arm_name in arm_data["out_of_sample"][generator_run_name].keys():
                out_of_sample_param_values[param_name][generator_run_name].append(
                    arm_data["out_of_sample"][generator_run_name][arm_name][
                        "parameters"
                    ][param_name]
                )

    out_of_sample_arm_text = {}
    for generator_run_name in arm_data["out_of_sample"].keys():
        out_of_sample_arm_text[generator_run_name] = [
            "<em>Candidate " + arm_name + "</em>"
            for arm_name in arm_data["out_of_sample"][generator_run_name].keys()
        ]

    # Number of traces for each pair of parameters
    trace_cnt = 4 + (len(arm_data["out_of_sample"]) * 2)

    xbuttons = []
    ybuttons = []

    for xvar in param_names:
        xbutton_data_args = {"x": [], "y": [], "z": []}
        for yvar in param_names:
            res = relativize_data(
                f_dict[xvar][yvar], sd_dict[xvar][yvar], rel, arm_data, metric
            )
            f_final = res[0]
            sd_final = res[1]
            # transform to nested array
            f_plt = []
            for ind in range(0, len(f_final), density):
                f_plt.append(f_final[ind : ind + density])
            sd_plt = []
            for ind in range(0, len(sd_final), density):
                sd_plt.append(sd_final[ind : ind + density])

            # grid + in-sample
            xbutton_data_args["x"] += [
                grid_dict[xvar],
                grid_dict[xvar],
                insample_param_values[xvar],
                insample_param_values[xvar],
            ]
            xbutton_data_args["y"] += [
                grid_dict[yvar],
                grid_dict[yvar],
                insample_param_values[yvar],
                insample_param_values[yvar],
            ]
            xbutton_data_args["z"] = xbutton_data_args["z"] + [f_plt, sd_plt, [], []]

            for generator_run_name in out_of_sample_param_values[xvar]:
                generator_run_x_vals = out_of_sample_param_values[xvar][
                    generator_run_name
                ]
                xbutton_data_args["x"] += [generator_run_x_vals] * 2
            for generator_run_name in out_of_sample_param_values[yvar]:
                generator_run_y_vals = out_of_sample_param_values[yvar][
                    generator_run_name
                ]
                xbutton_data_args["y"] += [generator_run_y_vals] * 2
                xbutton_data_args["z"] += [[]] * 2

        xbutton_args = [
            xbutton_data_args,
            {
                "xaxis.title": short_name(xvar),
                "xaxis2.title": short_name(xvar),
                "xaxis.range": axis_range(grid_dict[xvar], is_log_dict[xvar]),
                "xaxis2.range": axis_range(grid_dict[xvar], is_log_dict[xvar]),
            },
        ]
        xbuttons.append({"args": xbutton_args, "label": xvar, "method": "update"})

    # No y button for first param so initial value is sane
    for y_idx in range(1, len(param_names)):
        visible = [False] * (len(param_names) * trace_cnt)
        for i in range(y_idx * trace_cnt, (y_idx + 1) * trace_cnt):
            visible[i] = True
        y_param = param_names[y_idx]
        ybuttons.append(
            {
                "args": [
                    {"visible": visible},
                    {
                        "yaxis.title": short_name(y_param),
                        "yaxis.range": axis_range(
                            grid_dict[y_param], is_log_dict[y_param]
                        ),
                        "yaxis2.range": axis_range(
                            grid_dict[y_param], is_log_dict[y_param]
                        ),
                    },
                ],
                "label": param_names[y_idx],
                "method": "update",
            }
        )

    # calculate max of abs(outcome), used for colorscale
    # TODO(T37079623) Make this work for relative outcomes
    # let f_absmax = Math.max(Math.abs(Math.min(...f_final)), Math.max(...f_final))

    traces = []
    xvar = param_names[0]
    base_in_sample_arm_config = None

    # start symbol at 2 for out-of-sample candidate markers
    i = 2

    for yvar_idx, yvar in enumerate(param_names):
        cur_visible = yvar_idx == 1
        f_start = xbuttons[0]["args"][0]["z"][trace_cnt * yvar_idx]
        sd_start = xbuttons[0]["args"][0]["z"][trace_cnt * yvar_idx + 1]

        # create traces
        f_trace = {
            "x": grid_dict[xvar],
            "y": grid_dict[yvar],
            "z": f_start,
            "visible": cur_visible,
        }

        for key in f_contour_trace_base.keys():
            f_trace[key] = f_contour_trace_base[key]

        sd_trace = {
            "x": grid_dict[xvar],
            "y": grid_dict[yvar],
            "z": sd_start,
            "visible": cur_visible,
        }

        for key in sd_contour_trace_base.keys():
            sd_trace[key] = sd_contour_trace_base[key]

        f_in_sample_arm_trace = {"xaxis": "x", "yaxis": "y"}

        sd_in_sample_arm_trace = {"showlegend": False, "xaxis": "x2", "yaxis": "y2"}
        base_in_sample_arm_config = {
            "hoverinfo": "text",
            "legendgroup": "In-sample",
            "marker": {"color": "black", "symbol": 1, "opacity": 0.5},
            "mode": "markers",
            "name": "In-sample",
            "text": insample_arm_text,
            "type": "scatter",
            "visible": cur_visible,
            "x": insample_param_values[xvar],
            "y": insample_param_values[yvar],
        }

        for key in base_in_sample_arm_config.keys():
            f_in_sample_arm_trace[key] = base_in_sample_arm_config[key]
            sd_in_sample_arm_trace[key] = base_in_sample_arm_config[key]

        traces += [f_trace, sd_trace, f_in_sample_arm_trace, sd_in_sample_arm_trace]

        # iterate over out-of-sample arms
        for generator_run_name in arm_data["out_of_sample"].keys():
            traces.append(
                {
                    "hoverinfo": "text",
                    "legendgroup": generator_run_name,
                    "marker": {"color": "black", "symbol": i, "opacity": 0.5},
                    "mode": "markers",
                    "name": generator_run_name,
                    "text": out_of_sample_arm_text[generator_run_name],
                    "type": "scatter",
                    "xaxis": "x",
                    "x": out_of_sample_param_values[xvar][generator_run_name],
                    "yaxis": "y",
                    "y": out_of_sample_param_values[yvar][generator_run_name],
                    "visible": cur_visible,
                }
            )
            traces.append(
                {
                    "hoverinfo": "text",
                    "legendgroup": generator_run_name,
                    "marker": {"color": "black", "symbol": i, "opacity": 0.5},
                    "mode": "markers",
                    "name": "In-sample",
                    "showlegend": False,
                    "text": out_of_sample_arm_text[generator_run_name],
                    "type": "scatter",
                    "x": out_of_sample_param_values[xvar][generator_run_name],
                    "xaxis": "x2",
                    "y": out_of_sample_param_values[yvar][generator_run_name],
                    "yaxis": "y2",
                    "visible": cur_visible,
                }
            )
            i += 1

    xrange = axis_range(grid_dict[xvar], is_log_dict[xvar])
    yrange = axis_range(grid_dict[yvar], is_log_dict[yvar])

    xtype = "log" if is_log_dict[xvar] else "linear"
    ytype = "log" if is_log_dict[yvar] else "linear"

    layout = {
        "annotations": [
            {
                "font": {"size": 14},
                "showarrow": False,
                "text": "Mean",
                "x": 0.25,
                "xanchor": "center",
                "xref": "paper",
                "y": 1,
                "yanchor": "bottom",
                "yref": "paper",
            },
            {
                "font": {"size": 14},
                "showarrow": False,
                "text": "Standard Error",
                "x": 0.8,
                "xanchor": "center",
                "xref": "paper",
                "y": 1,
                "yanchor": "bottom",
                "yref": "paper",
            },
            {
                "x": 0.26,
                "y": -0.26,
                "xref": "paper",
                "yref": "paper",
                "text": "x-param:",
                "showarrow": False,
                "yanchor": "top",
                "xanchor": "left",
            },
            {
                "x": 0.26,
                "y": -0.4,
                "xref": "paper",
                "yref": "paper",
                "text": "y-param:",
                "showarrow": False,
                "yanchor": "top",
                "xanchor": "left",
            },
        ],
        "updatemenus": [
            {
                "x": 0.35,
                "y": -0.29,
                "buttons": xbuttons,
                "xanchor": "left",
                "yanchor": "middle",
                "direction": "up",
            },
            {
                "x": 0.35,
                "y": -0.43,
                "buttons": ybuttons,
                "xanchor": "left",
                "yanchor": "middle",
                "direction": "up",
            },
        ],
        "autosize": False,
        "height": 450,
        "hovermode": "closest",
        "legend": {"orientation": "v", "x": 0, "y": -0.2, "yanchor": "top"},
        "margin": {"b": 100, "l": 35, "pad": 0, "r": 35, "t": 35},
        "width": 950,
        "xaxis": {
            "anchor": "y",
            "autorange": False,
            "domain": [0.05, 0.45],
            "exponentformat": "e",
            "range": xrange,
            "tickfont": {"size": 11},
            "tickmode": "auto",
            "title": short_name(xvar),
            "type": xtype,
        },
        "xaxis2": {
            "anchor": "y2",
            "autorange": False,
            "domain": [0.6, 1],
            "exponentformat": "e",
            "range": xrange,
            "tickfont": {"size": 11},
            "tickmode": "auto",
            "title": short_name(xvar),
            "type": xtype,
        },
        "yaxis": {
            "anchor": "x",
            "autorange": False,
            "domain": [0, 1],
            "exponentformat": "e",
            "range": yrange,
            "tickfont": {"size": 11},
            "tickmode": "auto",
            "title": short_name(yvar),
            "type": ytype,
        },
        "yaxis2": {
            "anchor": "x2",
            "autorange": False,
            "domain": [0, 1],
            "exponentformat": "e",
            "range": yrange,
            "tickfont": {"size": 11},
            "tickmode": "auto",
            "type": ytype,
        },
    }

    fig = go.Figure(data=traces, layout=layout)
    return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
예제 #12
0
파일: table_view.py 프로젝트: stjordanis/Ax
def table_view_plot(
    experiment: Experiment,
    data: Data,
    use_empirical_bayes: bool = True,
    only_data_frame: bool = False,
):
    """Table of means and confidence intervals.

    Table is of the form:

    +-------+------------+-----------+
    |  arm  |  metric_1  |  metric_2 |
    +=======+============+===========+
    |  0_0  | mean +- CI |    ...    |
    +-------+------------+-----------+
    |  0_1  |    ...     |    ...    |
    +-------+------------+-----------+

    """
    model_func = get_empirical_bayes_thompson if use_empirical_bayes else get_thompson
    model = model_func(experiment=experiment, data=data)
    metric_name_to_lower_is_better = {
        metric.name: metric.lower_is_better
        for metric in experiment.metrics.values()
    }

    plot_data, _, _ = get_plot_data(model=model,
                                    generator_runs_dict={},
                                    metric_names=model.metric_names)

    if plot_data.status_quo_name:
        status_quo_arm = plot_data.in_sample.get(plot_data.status_quo_name)
        rel = True
    else:
        status_quo_arm = None
        rel = False

    results = {}
    records_with_mean = []
    records_with_ci = []
    for metric_name in model.metric_names:
        arms, _, ys, ys_se = _error_scatter_data(
            arms=list(plot_data.in_sample.values()),
            y_axis_var=PlotMetric(metric_name, True),
            x_axis_var=None,
            rel=rel,
            status_quo_arm=status_quo_arm,
        )
        # results[metric] will hold a list of tuples, one tuple per arm
        tuples = list(zip(arms, ys, ys_se))
        results[metric_name] = tuples
        # used if only_data_frame == True
        records_with_mean.append({arm: y for (arm, y, _) in tuples})
        records_with_ci.append({arm: y_se for (arm, _, y_se) in tuples})

    if only_data_frame:
        return tuple(
            pd.DataFrame.from_records(records,
                                      index=model.metric_names).transpose()
            for records in [records_with_mean, records_with_ci])

    # cells and colors are both lists of lists
    # each top-level list corresponds to a column,
    # so the first is a list of arms
    cells = [[f"<b>{x}</b>" for x in arms]]
    colors = [["#ffffff"] * len(arms)]
    metric_names = []
    for metric_name, list_of_tuples in sorted(results.items()):
        cells.append([
            "{:.3f} &plusmn; {:.3f}".format(y, Z * y_se)
            for (_, y, y_se) in list_of_tuples
        ])
        metric_names.append(metric_name.replace(":", " : "))

        color_vec = []
        for (_, y, y_se) in list_of_tuples:
            color_vec.append(
                get_color(
                    x=y,
                    ci=Z * y_se,
                    rel=rel,
                    reverse=metric_name_to_lower_is_better[metric_name],
                ))
        colors.append(color_vec)

    header = ["arms"] + metric_names
    header = [f"<b>{x}</b>" for x in header]
    trace = go.Table(
        header={
            "values": header,
            "align": ["left"]
        },
        cells={
            "values": cells,
            "align": ["left"],
            "fill": {
                "color": colors
            }
        },
    )
    layout = go.Layout(
        height=min([400, len(arms) * 20 + 200]),
        width=175 * len(header),
        margin=go.Margin(l=0, r=20, b=20, t=20, pad=4),  # noqa E741
    )
    fig = go.Figure(data=[trace], layout=layout)
    return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
예제 #13
0
파일: contour.py 프로젝트: kjanoudi/Ax
def interact_contour_plotly(
    model: ModelBridge,
    metric_name: str,
    generator_runs_dict: TNullableGeneratorRunsDict = None,
    relative: bool = False,
    density: int = 50,
    slice_values: Optional[Dict[str, Any]] = None,
    lower_is_better: bool = False,
    fixed_features: Optional[ObservationFeatures] = None,
    trial_index: Optional[int] = None,
) -> go.Figure:
    """Create interactive plot with predictions for a 2-d slice of the parameter
    space.

    Args:
        model: ModelBridge that contains model for predictions
        metric_name: Name of metric to plot
        generator_runs_dict: A dictionary {name: generator run} of generator runs
            whose arms will be plotted, if they lie in the slice.
        relative: Predictions relative to status quo
        density: Number of points along slice to evaluate predictions.
        slice_values: A dictionary {name: val} for the fixed values of the
            other parameters. If not provided, then the status quo values will
            be used if there is a status quo, otherwise the mean of numeric
            parameters or the mode of choice parameters.
        lower_is_better: Lower values for metric are better.
        fixed_features: An ObservationFeatures object containing the values of
            features (including non-parameter features like context) to be set
            in the slice.

    Returns:
        go.Figure: interactive plot of objective vs. parameters
    """

    # NOTE: This implements a hack to allow Plotly to specify two parameters
    # simultaneously. It is not possible within Plotly to specify a third,
    # so `metric_name` must be specified and cannot be selected via dropdown
    # by the user.

    if trial_index is not None:
        if slice_values is None:
            slice_values = {}
        slice_values["TRIAL_PARAM"] = str(trial_index)

    range_parameters = get_range_parameters(model)
    plot_data, _, _ = get_plot_data(
        model, generator_runs_dict or {}, {metric_name}, fixed_features=fixed_features
    )

    # TODO T38563759: Sort parameters by feature importances
    param_names = [parameter.name for parameter in range_parameters]

    is_log_dict: Dict[str, bool] = {}
    grid_dict: Dict[str, np.ndarray] = {}
    for parameter in range_parameters:
        is_log_dict[parameter.name] = parameter.log_scale
        grid_dict[parameter.name] = get_grid_for_parameter(parameter, density)

    # Populate `f_dict` (the predicted expectation value of `metric_name`) and
    # `sd_dict` (the predicted SEM), each of which represents a 2D array of plots
    # where each parameter can be assigned to each of the x or y axes.

    # pyre-fixme[9]: f_dict has type `Dict[str, Dict[str, np.ndarray]]`; used as
    #  `Dict[str, Dict[str, typing.List[Variable[_T]]]]`.
    f_dict: Dict[str, Dict[str, np.ndarray]] = {
        param1: {param2: [] for param2 in param_names} for param1 in param_names
    }
    # pyre-fixme[9]: sd_dict has type `Dict[str, Dict[str, np.ndarray]]`; used as
    #  `Dict[str, Dict[str, typing.List[Variable[_T]]]]`.
    sd_dict: Dict[str, Dict[str, np.ndarray]] = {
        param1: {param2: [] for param2 in param_names} for param1 in param_names
    }

    for param1 in param_names:
        for param2 in param_names:
            _, f_plt, sd_plt, _, _, _ = _get_contour_predictions(
                model=model,
                x_param_name=param1,
                y_param_name=param2,
                metric=metric_name,
                generator_runs_dict=generator_runs_dict,
                density=density,
                slice_values=slice_values,
                fixed_features=fixed_features,
            )
            f_dict[param1][param2] = f_plt
            sd_dict[param1][param2] = sd_plt

    # Set plotting defaults for all subplots

    config = {
        "arm_data": plot_data,
        "blue_scale": BLUE_SCALE,
        "density": density,
        "f_dict": f_dict,
        "green_scale": GREEN_SCALE,
        "green_pink_scale": GREEN_PINK_SCALE,
        "grid_dict": grid_dict,
        "lower_is_better": lower_is_better,
        "metric": metric_name,
        "rel": relative,
        "sd_dict": sd_dict,
        "is_log_dict": is_log_dict,
        "param_names": param_names,
    }

    config = AxPlotConfig(config, plot_type=AxPlotTypes.GENERIC).data

    arm_data = config["arm_data"]
    density = config["density"]
    grid_dict = config["grid_dict"]
    f_dict = config["f_dict"]
    lower_is_better = config["lower_is_better"]
    metric = config["metric"]
    rel = config["rel"]
    sd_dict = config["sd_dict"]
    is_log_dict = config["is_log_dict"]
    param_names = config["param_names"]

    green_scale = config["green_scale"]
    green_pink_scale = config["green_pink_scale"]
    blue_scale = config["blue_scale"]

    CONTOUR_CONFIG = {
        "autocolorscale": False,
        "autocontour": True,
        "contours": {"coloring": "heatmap"},
        "hoverinfo": "x+y+z",
        "ncontours": int(density / 2),
        "type": "contour",
    }

    if rel:
        f_scale = reversed(green_pink_scale) if lower_is_better else green_pink_scale
    else:
        f_scale = green_scale

    f_contour_trace_base = {
        "colorbar": {
            "len": 0.875,
            "x": 0.45,
            "y": 0.5,
            "ticksuffix": "%" if rel else "",
            "tickfont": {"size": 8},
        },
        "colorscale": [(i / (len(f_scale) - 1), rgb(v)) for i, v in enumerate(f_scale)],
        "xaxis": "x",
        "yaxis": "y",
        # zmax and zmin are ignored if zauto is true
        "zauto": not rel,
    }

    sd_contour_trace_base = {
        "colorbar": {
            "len": 0.875,
            "x": 1,
            "y": 0.5,
            "ticksuffix": "%" if rel else "",
            "tickfont": {"size": 8},
        },
        "colorscale": [
            (i / (len(blue_scale) - 1), rgb(v)) for i, v in enumerate(blue_scale)
        ],
        "xaxis": "x2",
        "yaxis": "y2",
    }

    # pyre-fixme[6]: Expected `Mapping[str, typing.Union[Dict[str,
    #  typing.Union[Dict[str, int], float, str]], typing.List[Tuple[float, str]], bool,
    #  str]]` for 1st param but got `Dict[str, typing.Union[Dict[str, str], int,
    #  str]]`.
    f_contour_trace_base.update(CONTOUR_CONFIG)
    # pyre-fixme[6]: Expected `Mapping[str, typing.Union[Dict[str,
    #  typing.Union[Dict[str, int], float, str]], typing.List[Tuple[float, str]],
    #  str]]` for 1st param but got `Dict[str, typing.Union[Dict[str, str], int,
    #  str]]`.
    sd_contour_trace_base.update(CONTOUR_CONFIG)

    # Format and add hovertext to contour plots.

    insample_param_values = {}
    for param_name in param_names:
        insample_param_values[param_name] = []
        for arm_name in arm_data["in_sample"].keys():
            insample_param_values[param_name].append(
                arm_data["in_sample"][arm_name]["parameters"][param_name]
            )

    insample_arm_text = []
    for arm_name in arm_data["in_sample"].keys():
        atext = f"Arm {arm_name}"
        params = arm_data["in_sample"][arm_name]["parameters"]
        ys = arm_data["in_sample"][arm_name]["y"]
        ses = arm_data["in_sample"][arm_name]["se"]
        for yname in ys.keys():
            sem_str = f"{ses[yname]}" if ses[yname] is None else f"{ses[yname]:.6g}"
            y_str = f"{ys[yname]}" if ys[yname] is None else f"{ys[yname]:.6g}"
            atext += f"<br>{yname}: {y_str} (SEM: {sem_str})"
        for pname in params.keys():
            pval = params[pname]
            pstr = f"{pval:.6g}" if isinstance(pval, float) else f"{pval}"
            atext += f"<br>{pname}: {pstr}"
        insample_arm_text.append(atext)

    out_of_sample_param_values = {}
    for param_name in param_names:
        out_of_sample_param_values[param_name] = {}
        for generator_run_name in arm_data["out_of_sample"].keys():
            out_of_sample_param_values[param_name][generator_run_name] = []
            for arm_name in arm_data["out_of_sample"][generator_run_name].keys():
                out_of_sample_param_values[param_name][generator_run_name].append(
                    arm_data["out_of_sample"][generator_run_name][arm_name][
                        "parameters"
                    ][param_name]
                )

    out_of_sample_arm_text = {}
    for generator_run_name in arm_data["out_of_sample"].keys():
        out_of_sample_arm_text[generator_run_name] = [
            "<em>Candidate " + arm_name + "</em>"
            for arm_name in arm_data["out_of_sample"][generator_run_name].keys()
        ]

    # Populate `xbuttons`, which allows the user to select 1D slices of `f_dict` and
    # `sd_dict`, corresponding to all plots that have a certain parameter on the x-axis.

    # Number of traces for each pair of parameters
    trace_cnt = 4 + (len(arm_data["out_of_sample"]) * 2)

    xbuttons = []
    ybuttons = []

    for xvar in param_names:
        xbutton_data_args = {"x": [], "y": [], "z": []}
        for yvar in param_names:
            res = relativize_data(
                f_dict[xvar][yvar], sd_dict[xvar][yvar], rel, arm_data, metric
            )
            f_final = res[0]
            sd_final = res[1]
            # transform to nested array
            f_plt = []
            for ind in range(0, len(f_final), density):
                f_plt.append(f_final[ind : ind + density])
            sd_plt = []
            for ind in range(0, len(sd_final), density):
                sd_plt.append(sd_final[ind : ind + density])

            # grid + in-sample
            xbutton_data_args["x"] += [
                grid_dict[xvar],
                grid_dict[xvar],
                insample_param_values[xvar],
                insample_param_values[xvar],
            ]
            xbutton_data_args["y"] += [
                grid_dict[yvar],
                grid_dict[yvar],
                insample_param_values[yvar],
                insample_param_values[yvar],
            ]
            xbutton_data_args["z"] += [f_plt, sd_plt, [], []]

            for generator_run_name in out_of_sample_param_values[xvar]:
                generator_run_x_vals = out_of_sample_param_values[xvar][
                    generator_run_name
                ]
                xbutton_data_args["x"] += [generator_run_x_vals] * 2
            for generator_run_name in out_of_sample_param_values[yvar]:
                generator_run_y_vals = out_of_sample_param_values[yvar][
                    generator_run_name
                ]
                xbutton_data_args["y"] += [generator_run_y_vals] * 2
                xbutton_data_args["z"] += [[]] * 2

        xbutton_args = [
            xbutton_data_args,
            {
                "xaxis.title": short_name(xvar),
                "xaxis2.title": short_name(xvar),
                "xaxis.range": axis_range(grid_dict[xvar], is_log_dict[xvar]),
                "xaxis2.range": axis_range(grid_dict[xvar], is_log_dict[xvar]),
                "xaxis.type": "log" if is_log_dict[xvar] else "linear",
                "xaxis2.type": "log" if is_log_dict[xvar] else "linear",
            },
        ]
        xbuttons.append({"args": xbutton_args, "label": xvar, "method": "update"})

    # Populate `ybuttons`, which uses the `visible` arg to mask the 1D slice of plots
    # produced by `xbuttons`, down to a single plot, so that only one element `f_dict`
    # and `sd_dict` remain.

    # No y button for first param so initial value is sane
    for y_idx in range(1, len(param_names)):
        visible = [False] * (len(param_names) * trace_cnt)
        for i in range(y_idx * trace_cnt, (y_idx + 1) * trace_cnt):
            visible[i] = True
        y_param = param_names[y_idx]
        ybuttons.append(
            {
                "args": [
                    {"visible": visible},
                    {
                        "yaxis.title": short_name(y_param),
                        "yaxis.range": axis_range(
                            grid_dict[y_param], is_log_dict[y_param]
                        ),
                        "yaxis2.range": axis_range(
                            grid_dict[y_param], is_log_dict[y_param]
                        ),
                        "yaxis.type": "log" if is_log_dict[y_param] else "linear",
                        "yaxis2.type": "log" if is_log_dict[y_param] else "linear",
                    },
                ],
                "label": param_names[y_idx],
                "method": "update",
            }
        )

    # calculate max of abs(outcome), used for colorscale
    # TODO(T37079623) Make this work for relative outcomes
    # let f_absmax = Math.max(Math.abs(Math.min(...f_final)), Math.max(...f_final))

    traces = []
    xvar = param_names[0]
    base_in_sample_arm_config = None

    # start symbol at 2 for out-of-sample candidate markers
    i = 2

    for yvar_idx, yvar in enumerate(param_names):
        cur_visible = yvar_idx == 1
        f_start = xbuttons[0]["args"][0]["z"][trace_cnt * yvar_idx]
        sd_start = xbuttons[0]["args"][0]["z"][trace_cnt * yvar_idx + 1]

        # create traces
        f_trace = {
            "x": grid_dict[xvar],
            "y": grid_dict[yvar],
            "z": f_start,
            "visible": cur_visible,
        }

        for key in f_contour_trace_base.keys():
            f_trace[key] = f_contour_trace_base[key]

        sd_trace = {
            "x": grid_dict[xvar],
            "y": grid_dict[yvar],
            "z": sd_start,
            "visible": cur_visible,
        }

        for key in sd_contour_trace_base.keys():
            sd_trace[key] = sd_contour_trace_base[key]

        f_in_sample_arm_trace = {"xaxis": "x", "yaxis": "y"}

        sd_in_sample_arm_trace = {"showlegend": False, "xaxis": "x2", "yaxis": "y2"}
        base_in_sample_arm_config = {
            "hoverinfo": "text",
            "legendgroup": "In-sample",
            "marker": {"color": "black", "symbol": 1, "opacity": 0.5},
            "mode": "markers",
            "name": "In-sample",
            "text": insample_arm_text,
            "type": "scatter",
            "visible": cur_visible,
            "x": insample_param_values[xvar],
            "y": insample_param_values[yvar],
        }

        for key in base_in_sample_arm_config.keys():
            f_in_sample_arm_trace[key] = base_in_sample_arm_config[key]
            sd_in_sample_arm_trace[key] = base_in_sample_arm_config[key]

        traces += [f_trace, sd_trace, f_in_sample_arm_trace, sd_in_sample_arm_trace]

        # iterate over out-of-sample arms
        for generator_run_name in arm_data["out_of_sample"].keys():
            traces.append(
                {
                    "hoverinfo": "text",
                    "legendgroup": generator_run_name,
                    "marker": {"color": "black", "symbol": i, "opacity": 0.5},
                    "mode": "markers",
                    "name": generator_run_name,
                    "text": out_of_sample_arm_text[generator_run_name],
                    "type": "scatter",
                    "xaxis": "x",
                    "x": out_of_sample_param_values[xvar][generator_run_name],
                    "yaxis": "y",
                    "y": out_of_sample_param_values[yvar][generator_run_name],
                    "visible": cur_visible,
                }
            )
            traces.append(
                {
                    "hoverinfo": "text",
                    "legendgroup": generator_run_name,
                    "marker": {"color": "black", "symbol": i, "opacity": 0.5},
                    "mode": "markers",
                    "name": "In-sample",
                    "showlegend": False,
                    "text": out_of_sample_arm_text[generator_run_name],
                    "type": "scatter",
                    "x": out_of_sample_param_values[xvar][generator_run_name],
                    "xaxis": "x2",
                    "y": out_of_sample_param_values[yvar][generator_run_name],
                    "yaxis": "y2",
                    "visible": cur_visible,
                }
            )
            i += 1

    # Initially visible yvar
    yvar = param_names[1]

    xrange = axis_range(grid_dict[xvar], is_log_dict[xvar])
    yrange = axis_range(grid_dict[yvar], is_log_dict[yvar])

    xtype = "log" if is_log_dict[xvar] else "linear"
    ytype = "log" if is_log_dict[yvar] else "linear"

    layout = {
        "annotations": [
            {
                "font": {"size": 14},
                "showarrow": False,
                "text": "Mean",
                "x": 0.25,
                "xanchor": "center",
                "xref": "paper",
                "y": 1,
                "yanchor": "bottom",
                "yref": "paper",
            },
            {
                "font": {"size": 14},
                "showarrow": False,
                "text": "Standard Error",
                "x": 0.8,
                "xanchor": "center",
                "xref": "paper",
                "y": 1,
                "yanchor": "bottom",
                "yref": "paper",
            },
            {
                "x": 0.26,
                "y": -0.26,
                "xref": "paper",
                "yref": "paper",
                "text": "x-param:",
                "showarrow": False,
                "yanchor": "top",
                "xanchor": "left",
            },
            {
                "x": 0.26,
                "y": -0.4,
                "xref": "paper",
                "yref": "paper",
                "text": "y-param:",
                "showarrow": False,
                "yanchor": "top",
                "xanchor": "left",
            },
        ],
        "updatemenus": [
            {
                "x": 0.35,
                "y": -0.29,
                "buttons": xbuttons,
                "xanchor": "left",
                "yanchor": "middle",
                "direction": "up",
            },
            {
                "x": 0.35,
                "y": -0.43,
                "buttons": ybuttons,
                "xanchor": "left",
                "yanchor": "middle",
                "direction": "up",
            },
        ],
        "autosize": False,
        "height": 450,
        "hovermode": "closest",
        "legend": {"orientation": "v", "x": 0, "y": -0.2, "yanchor": "top"},
        "margin": {"b": 100, "l": 35, "pad": 0, "r": 35, "t": 35},
        "width": 950,
        "xaxis": {
            "anchor": "y",
            "autorange": False,
            "domain": [0.05, 0.45],
            "exponentformat": "e",
            "range": xrange,
            "tickfont": {"size": 11},
            "tickmode": "auto",
            "title": short_name(xvar),
            "type": xtype,
        },
        "xaxis2": {
            "anchor": "y2",
            "autorange": False,
            "domain": [0.6, 1],
            "exponentformat": "e",
            "range": xrange,
            "tickfont": {"size": 11},
            "tickmode": "auto",
            "title": short_name(xvar),
            "type": xtype,
        },
        "yaxis": {
            "anchor": "x",
            "autorange": False,
            "domain": [0, 1],
            "exponentformat": "e",
            "range": yrange,
            "tickfont": {"size": 11},
            "tickmode": "auto",
            "title": short_name(yvar),
            "type": ytype,
        },
        "yaxis2": {
            "anchor": "x2",
            "autorange": False,
            "domain": [0, 1],
            "exponentformat": "e",
            "range": yrange,
            "tickfont": {"size": 11},
            "tickmode": "auto",
            "type": ytype,
        },
    }

    return go.Figure(data=traces, layout=layout)