示例#1
0
文件: slice.py 项目: facebook/Ax
def plot_slice_plotly(
    model: ModelBridge,
    param_name: str,
    metric_name: str,
    generator_runs_dict: TNullableGeneratorRunsDict = None,
    relative: bool = False,
    density: int = 50,
    slice_values: Optional[Dict[str, Any]] = None,
    fixed_features: Optional[ObservationFeatures] = None,
    trial_index: Optional[int] = None,
) -> go.Figure:
    """Plot predictions for a 1-d slice of the parameter space.

    Args:
        model: ModelBridge that contains model for predictions
        param_name: Name of parameter that will be sliced
        metric_name: Name of metric to plot
        generator_runs_dict: A dictionary {name: generator run} of generator runs
            whose arms will be plotted, if they lie in the slice.
        relative: Predictions relative to status quo
        density: Number of points along slice to evaluate predictions.
        slice_values: A dictionary {name: val} for the fixed values of the
            other parameters. If not provided, then the status quo values will
            be used if there is a status quo, otherwise the mean of numeric
            parameters or the mode of choice parameters. Ignored if
            fixed_features is specified.
        fixed_features: An ObservationFeatures object containing the values of
            features (including non-parameter features like context) to be set
            in the slice.

    Returns:
        go.Figure: plot of objective vs. parameter value
    """
    pd, cntp, f_plt, rd, grid, _, _, _, fv, sd_plt, ls = _get_slice_predictions(
        model=model,
        param_name=param_name,
        metric_name=metric_name,
        generator_runs_dict=generator_runs_dict,
        relative=relative,
        density=density,
        slice_values=slice_values,
        fixed_features=fixed_features,
        trial_index=trial_index,
    )

    config = {
        "arm_data": pd,
        "arm_name_to_parameters": cntp,
        "f": f_plt,
        "fit_data": rd,
        "grid": grid,
        "metric": metric_name,
        "param": param_name,
        "rel": relative,
        "setx": fv,
        "sd": sd_plt,
        "is_log": ls,
    }
    config = AxPlotConfig(config, plot_type=AxPlotTypes.GENERIC).data

    arm_data = config["arm_data"]
    arm_name_to_parameters = config["arm_name_to_parameters"]
    f = config["f"]
    fit_data = config["fit_data"]
    grid = config["grid"]
    metric = config["metric"]
    param = config["param"]
    rel = config["rel"]
    setx = config["setx"]
    sd = config["sd"]
    is_log = config["is_log"]

    traces = slice_config_to_trace(
        arm_data,
        arm_name_to_parameters,
        f,
        fit_data,
        grid,
        metric,
        param,
        rel,
        setx,
        sd,
        is_log,
        True,
    )

    # layout
    xrange = axis_range(grid, is_log)
    xtype = "log" if is_log else "linear"

    layout = {
        "hovermode": "closest",
        "xaxis": {
            "anchor": "y",
            "autorange": False,
            "exponentformat": "e",
            "range": xrange,
            "tickfont": {
                "size": 11
            },
            "tickmode": "auto",
            "title": param,
            "type": xtype,
        },
        "yaxis": {
            "anchor": "x",
            "tickfont": {
                "size": 11
            },
            "tickmode": "auto",
            "title": metric,
        },
    }

    return go.Figure(data=traces, layout=layout)
示例#2
0
文件: slice.py 项目: facebook/Ax
def interact_slice_plotly(
    model: ModelBridge,
    generator_runs_dict: TNullableGeneratorRunsDict = None,
    relative: bool = False,
    density: int = 50,
    slice_values: Optional[Dict[str, Any]] = None,
    fixed_features: Optional[ObservationFeatures] = None,
    trial_index: Optional[int] = None,
) -> go.Figure:
    """Create interactive plot with predictions for a 1-d slice of the parameter
    space.

    Args:
        model: ModelBridge that contains model for predictions
        generator_runs_dict: A dictionary {name: generator run} of generator runs
            whose arms will be plotted, if they lie in the slice.
        relative: Predictions relative to status quo
        density: Number of points along slice to evaluate predictions.
        slice_values: A dictionary {name: val} for the fixed values of the
            other parameters. If not provided, then the status quo values will
            be used if there is a status quo, otherwise the mean of numeric
            parameters or the mode of choice parameters. Ignored if
            fixed_features is specified.
        fixed_features: An ObservationFeatures object containing the values of
            features (including non-parameter features like context) to be set
            in the slice.

    Returns:
        go.Figure: interactive plot of objective vs. parameter
    """
    if generator_runs_dict is None:
        generator_runs_dict = {}

    metric_names = list(model.metric_names)

    # Populate `pbuttons`, which allows the user to select 1D slices of parameter
    # space with the chosen parameter on the x-axis.
    range_parameters = get_range_parameters(model)
    param_names = [parameter.name for parameter in range_parameters]
    pbuttons = []
    init_traces = []
    xaxis_init_format = {}
    first_param_bool = True
    should_replace_slice_values = fixed_features is not None
    for param_name in param_names:
        pbutton_data_args = {"x": [], "y": [], "error_y": []}
        parameter = get_range_parameter(model, param_name)
        grid = get_grid_for_parameter(parameter, density)

        plot_data_dict = {}
        raw_data_dict = {}
        sd_plt_dict: Dict[str, Dict[str, np.ndarray]] = {}

        cond_name_to_parameters_dict = {}
        is_log_dict: Dict[str, bool] = {}

        if should_replace_slice_values:
            slice_values = not_none(fixed_features).parameters
        else:
            fixed_features = ObservationFeatures(parameters={})
        fixed_values = get_fixed_values(model, slice_values, trial_index)
        prediction_features = []
        for x in grid:
            predf = deepcopy(not_none(fixed_features))
            predf.parameters = fixed_values.copy()
            predf.parameters[param_name] = x
            prediction_features.append(predf)

        f, cov = model.predict(prediction_features)

        for metric_name in metric_names:
            pd, cntp, f_plt, rd, _, _, _, _, _, sd_plt, ls = _get_slice_predictions(
                model=model,
                param_name=param_name,
                metric_name=metric_name,
                generator_runs_dict=generator_runs_dict,
                relative=relative,
                density=density,
                slice_values=slice_values,
                fixed_features=fixed_features,
            )

            plot_data_dict[metric_name] = pd
            raw_data_dict[metric_name] = rd
            cond_name_to_parameters_dict[metric_name] = cntp

            sd_plt_dict[metric_name] = np.sqrt(cov[metric_name][metric_name])
            is_log_dict[metric_name] = ls

        config = {
            "arm_data": plot_data_dict,
            "arm_name_to_parameters": cond_name_to_parameters_dict,
            "f": f,
            "fit_data": raw_data_dict,
            "grid": grid,
            "metrics": metric_names,
            "param": param_name,
            "rel": relative,
            "setx": fixed_values,
            "sd": sd_plt_dict,
            "is_log": is_log_dict,
        }
        config = AxPlotConfig(config, plot_type=AxPlotTypes.GENERIC).data

        arm_data = config["arm_data"]
        arm_name_to_parameters = config["arm_name_to_parameters"]
        f = config["f"]
        fit_data = config["fit_data"]
        grid = config["grid"]
        metrics = config["metrics"]
        param = config["param"]
        rel = config["rel"]
        setx = config["setx"]
        sd = config["sd"]
        is_log = config["is_log"]

        # layout
        xrange = axis_range(grid, is_log[metrics[0]])
        xtype = "log" if is_log_dict[metrics[0]] else "linear"

        for i, metric in enumerate(metrics):
            cur_visible = i == 0
            metric = metrics[i]
            traces = slice_config_to_trace(
                arm_data[metric],
                arm_name_to_parameters[metric],
                f[metric],
                fit_data[metric],
                grid,
                metric,
                param,
                rel,
                setx,
                sd[metric],
                is_log[metric],
                cur_visible,
            )
            pbutton_data_args["x"] += [trace["x"] for trace in traces]
            pbutton_data_args["y"] += [trace["y"] for trace in traces]
            pbutton_data_args["error_y"] += [{
                "type": "data",
                "array": trace["error_y"]["array"],
                "visible": True,
                "color": "black",
            } if "error_y" in trace and "array" in trace["error_y"] else []
                                             for trace in traces]
            if first_param_bool:
                init_traces.extend(traces)
        pbutton_args = [
            pbutton_data_args,
            {
                "xaxis.title": param_name,
                "xaxis.range": xrange,
                "xaxis.type": xtype,
            },
        ]

        pbuttons.append({
            "args": pbutton_args,
            "label": param_name,
            "method": "update"
        })
        if first_param_bool:
            xaxis_init_format = {
                "anchor": "y",
                "autorange": False,
                "exponentformat": "e",
                "range": xrange,
                "tickfont": {
                    "size": 11
                },
                "tickmode": "auto",
                "title": param_name,
                "type": xtype,
            }
            first_param_bool = False

    # Populate mbuttons, which allows the user to select which metric to plot
    mbuttons = []
    for i, metric in enumerate(metrics):
        trace_cnt = 3 + len(arm_data[metric]["out_of_sample"].keys())
        visible = [False] * (len(metrics) * trace_cnt)
        for j in range(i * trace_cnt, (i + 1) * trace_cnt):
            visible[j] = True
        mbuttons.append({
            "method": "update",
            "args": [{
                "visible": visible
            }, {
                "yaxis.title": metric
            }],
            "label": metric,
        })

    layout = {
        "title":
        "Predictions for a 1-d slice of the parameter space",
        "annotations": [
            {
                "showarrow": False,
                "text": "Choose metric:",
                "x": 0.225,
                "xanchor": "right",
                "xref": "paper",
                "y": -0.455,
                "yanchor": "bottom",
                "yref": "paper",
            },
            {
                "showarrow": False,
                "text": "Choose parameter:",
                "x": 0.225,
                "xanchor": "right",
                "xref": "paper",
                "y": -0.305,
                "yanchor": "bottom",
                "yref": "paper",
            },
        ],
        "updatemenus": [
            {
                "y": -0.35,
                "x": 0.25,
                "xanchor": "left",
                "yanchor": "top",
                "buttons": mbuttons,
                "direction": "up",
            },
            {
                "y": -0.2,
                "x": 0.25,
                "xanchor": "left",
                "yanchor": "top",
                "buttons": pbuttons,
                "direction": "up",
            },
        ],
        "hovermode":
        "closest",
        "xaxis":
        xaxis_init_format,
        "yaxis": {
            "anchor": "x",
            "autorange": True,
            "tickfont": {
                "size": 11
            },
            "tickmode": "auto",
            "title": metrics[0],
        },
    }

    return go.Figure(data=init_traces, layout=layout)
示例#3
0
def interact_slice(
    model: ModelBridge,
    param_name: str,
    metric_name: str = "",
    generator_runs_dict: TNullableGeneratorRunsDict = None,
    relative: bool = False,
    density: int = 50,
    slice_values: Optional[Dict[str, Any]] = None,
    fixed_features: Optional[ObservationFeatures] = None,
) -> AxPlotConfig:
    """Create interactive plot with predictions for a 1-d slice of the parameter
    space.

    Args:
        model: ModelBridge that contains model for predictions
        param_name: Name of parameter that will be sliced
        metric_name: Name of metric to plot
        generator_runs_dict: A dictionary {name: generator run} of generator runs
            whose arms will be plotted, if they lie in the slice.
        relative: Predictions relative to status quo
        density: Number of points along slice to evaluate predictions.
        slice_values: A dictionary {name: val} for the fixed values of the
            other parameters. If not provided, then the status quo values will
            be used if there is a status quo, otherwise the mean of numeric
            parameters or the mode of choice parameters. Ignored if
            fixed_features is specified.
        fixed_features: An ObservationFeatures object containing the values of
            features (including non-parameter features like context) to be set
            in the slice.
    """
    if generator_runs_dict is None:
        generator_runs_dict = {}

    metric_names = list(model.metric_names)

    parameter = get_range_parameter(model, param_name)
    grid = get_grid_for_parameter(parameter, density)

    plot_data_dict = {}
    raw_data_dict = {}
    sd_plt_dict: Dict[str, Dict[str, np.ndarray]] = {}

    cond_name_to_parameters_dict = {}
    is_log_dict: Dict[str, bool] = {}

    if fixed_features is not None:
        slice_values = fixed_features.parameters
    else:
        fixed_features = ObservationFeatures(parameters={})
    fixed_values = get_fixed_values(model, slice_values)

    prediction_features = []
    for x in grid:
        predf = deepcopy(fixed_features)
        predf.parameters = fixed_values.copy()
        predf.parameters[param_name] = x
        prediction_features.append(predf)

    f, cov = model.predict(prediction_features)

    for metric_name in metric_names:
        pd, cntp, f_plt, rd, _, _, _, _, _, sd_plt, ls = _get_slice_predictions(
            model=model,
            param_name=param_name,
            metric_name=metric_name,
            generator_runs_dict=generator_runs_dict,
            relative=relative,
            density=density,
            slice_values=slice_values,
            fixed_features=fixed_features,
        )

        plot_data_dict[metric_name] = pd
        raw_data_dict[metric_name] = rd
        cond_name_to_parameters_dict[metric_name] = cntp

        sd_plt_dict[metric_name] = np.sqrt(cov[metric_name][metric_name])
        is_log_dict[metric_name] = ls

    config = {
        "arm_data": plot_data_dict,
        "arm_name_to_parameters": cond_name_to_parameters_dict,
        "f": f,
        "fit_data": raw_data_dict,
        "grid": grid,
        "metrics": metric_names,
        "param": param_name,
        "rel": relative,
        "setx": fixed_values,
        "sd": sd_plt_dict,
        "is_log": is_log_dict,
    }
    config = AxPlotConfig(config, plot_type=AxPlotTypes.GENERIC).data

    arm_data = config["arm_data"]
    arm_name_to_parameters = config["arm_name_to_parameters"]
    f = config["f"]
    fit_data = config["fit_data"]
    grid = config["grid"]
    metrics = config["metrics"]
    param = config["param"]
    rel = config["rel"]
    setx = config["setx"]
    sd = config["sd"]
    is_log = config["is_log"]

    traces = []

    for i, metric in enumerate(metrics):
        cur_visible = i == 0
        metric = metrics[i]
        traces.extend(
            slice_config_to_trace(
                arm_data[metric],
                arm_name_to_parameters[metric],
                f[metric],
                fit_data[metric],
                grid,
                metric,
                param,
                rel,
                setx,
                sd[metric],
                is_log[metric],
                cur_visible,
            )
        )

    # layout
    xrange = axis_range(grid, is_log[metrics[0]])
    xtype = "log" if is_log[metrics[0]] else "linear"

    buttons = []
    for i, metric in enumerate(metrics):
        trace_cnt = 3 + len(arm_data[metric]["out_of_sample"].keys()) * 2
        visible = [False] * (len(metrics) * trace_cnt)
        for j in range(i * trace_cnt, (i + 1) * trace_cnt):
            visible[j] = True
        buttons.append(
            {
                "method": "update",
                "args": [{"visible": visible}, {"yaxis.title": metric}],
                "label": metric,
            }
        )

    layout = {
        "title": "Predictions for a 1-d slice of the parameter space",
        "annotations": [
            {
                "showarrow": False,
                "text": "Choose metric:",
                "x": 0.225,
                "xanchor": "center",
                "xref": "paper",
                "y": 1.005,
                "yanchor": "bottom",
                "yref": "paper",
            }
        ],
        "updatemenus": [{"y": 1.1, "x": 0.5, "yanchor": "top", "buttons": buttons}],
        "hovermode": "closest",
        "xaxis": {
            "anchor": "y",
            "autorange": False,
            "exponentformat": "e",
            "range": xrange,
            "tickfont": {"size": 11},
            "tickmode": "auto",
            "title": param,
            "type": xtype,
        },
        "yaxis": {
            "anchor": "x",
            "autorange": True,
            "tickfont": {"size": 11},
            "tickmode": "auto",
            "title": metrics[0],
        },
    }

    fig = go.Figure(data=traces, layout=layout)
    return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)