예제 #1
0
def _get_cv_plot_data(cv_results: List[CVResult]) -> PlotData:
    if len(cv_results) == 0:
        return PlotData(
            metrics=[], in_sample={}, out_of_sample=None, status_quo_name=None
        )

    # arm_name -> Arm data
    insample_data: Dict[str, PlotInSampleArm] = {}

    # Assume input is well formed and this is consistent
    metric_names = cv_results[0].predicted.metric_names

    for rid, cv_result in enumerate(cv_results):
        arm_name = cv_result.observed.arm_name
        arm_data = {
            "name": cv_result.observed.arm_name,
            "y": {},
            "se": {},
            "parameters": cv_result.observed.features.parameters,
            "y_hat": {},
            "se_hat": {},
            "context_stratum": None,
        }
        for i, mname in enumerate(cv_result.observed.data.metric_names):
            # pyre-fixme[16]: Optional type has no attribute `__setitem__`.
            arm_data["y"][mname] = cv_result.observed.data.means[i]
            # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
            #  typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]`
            #  has no attribute `__setitem__`.
            arm_data["se"][mname] = np.sqrt(cv_result.observed.data.covariance[i][i])
        for i, mname in enumerate(cv_result.predicted.metric_names):
            # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
            #  typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]`
            #  has no attribute `__setitem__`.
            arm_data["y_hat"][mname] = cv_result.predicted.means[i]
            # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
            #  typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]`
            #  has no attribute `__setitem__`.
            arm_data["se_hat"][mname] = np.sqrt(cv_result.predicted.covariance[i][i])

        # Expected `str` for 2nd anonymous parameter to call `dict.__setitem__` but got
        # `Optional[str]`.
        # pyre-fixme[6]:
        insample_data[f"{arm_name}_{rid}"] = PlotInSampleArm(**arm_data)
    return PlotData(
        metrics=metric_names,
        in_sample=insample_data,
        out_of_sample=None,
        status_quo_name=None,
    )
예제 #2
0
def interact_empirical_model_validation(batch: BatchTrial, data: Data) -> AxPlotConfig:
    """Compare the model predictions for the batch arms against observed data.

    Relies on the model predictions stored on the generator_runs of batch.

    Args:
        batch: Batch on which to perform analysis.
        data: Observed data for the batch.
    Returns:
        AxPlotConfig for the plot.
    """
    insample_data: Dict[str, PlotInSampleArm] = {}
    metric_names = list(data.df["metric_name"].unique())
    for struct in batch.generator_run_structs:
        generator_run = struct.generator_run
        if generator_run.model_predictions is None:
            continue
        for i, arm in enumerate(generator_run.arms):
            arm_data = {
                "name": arm.name_or_short_signature,
                "y": {},
                "se": {},
                "parameters": arm.parameters,
                "y_hat": {},
                "se_hat": {},
                "context_stratum": None,
            }
            predictions = generator_run.model_predictions
            for _, row in data.df[
                data.df["arm_name"] == arm.name_or_short_signature
            ].iterrows():
                metric_name = row["metric_name"]
                # pyre-fixme[16]: Optional type has no attribute `__setitem__`.
                arm_data["y"][metric_name] = row["mean"]
                # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
                #  typing.Any], Dict[str, typing.Union[None, bool, float, int, str]],
                #  str]` has no attribute `__setitem__`.
                arm_data["se"][metric_name] = row["sem"]
                # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
                arm_data["y_hat"][metric_name] = predictions[0][metric_name][i]
                # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
                #  typing.Any], Dict[str, typing.Union[None, bool, float, int, str]],
                #  str]` has no attribute `__setitem__`.
                arm_data["se_hat"][metric_name] = predictions[1][metric_name][
                    metric_name
                ][i]
            # pyre-fixme[6]: Expected `Optional[Dict[str, Union[float, str]]]` for 1s...
            insample_data[arm.name_or_short_signature] = PlotInSampleArm(**arm_data)
    if not insample_data:
        raise ValueError("No model predictions present on the batch.")
    plot_data = PlotData(
        metrics=metric_names,
        in_sample=insample_data,
        out_of_sample=None,
        status_quo_name=None,
    )

    fig = _obs_vs_pred_dropdown_plot(data=plot_data, rel=False)
    fig["layout"]["title"] = "Cross-validation"
    return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
예제 #3
0
파일: helper.py 프로젝트: jeffersonp317/Ax
def get_plot_data(
    model: ModelBridge,
    generator_runs_dict: Dict[str, GeneratorRun],
    metric_names: Optional[Set[str]] = None,
    fixed_features: Optional[ObservationFeatures] = None,
    data_selector: Optional[Callable[[Observation], bool]] = None,
) -> Tuple[PlotData, RawData, Dict[str, TParameterization]]:
    """Format data object with metrics for in-sample and out-of-sample
    arms.

    Calculate both observed and predicted metrics for in-sample arms.
    Calculate predicted metrics for out-of-sample arms passed via the
    `generator_runs_dict` argument.

    In PlotData, in-sample observations are merged with IVW. In RawData, they
    are left un-merged and given as a list of dictionaries, one for each
    observation and having keys 'arm_name', 'mean', and 'sem'.

    Args:
        model: The model.
        generator_runs_dict: a mapping from generator run name to generator run.
        metric_names: Restrict predictions to this set. If None, all metrics
            in the model will be returned.
        fixed_features: Fixed features to use when making model predictions.
        data_selector: Function for selecting observations for plotting.

    Returns:
        A tuple containing

        - PlotData object with in-sample and out-of-sample predictions.
        - List of observations like::

            {'metric_name': 'likes', 'arm_name': '0_1', 'mean': 1., 'sem': 0.1}.

        - Mapping from arm name to parameters.
    """
    metrics_plot = model.metric_names if metric_names is None else metric_names
    in_sample_plot, raw_data, cond_name_to_parameters = _get_in_sample_arms(
        model=model,
        metric_names=metrics_plot,
        fixed_features=fixed_features,
        data_selector=data_selector,
    )
    out_of_sample_plot = _get_out_of_sample_arms(
        model=model,
        generator_runs_dict=generator_runs_dict,
        metric_names=metrics_plot,
        fixed_features=fixed_features,
    )
    # pyre-fixme[16]: `Optional` has no attribute `arm_name`.
    status_quo_name = None if model.status_quo is None else model.status_quo.arm_name
    plot_data = PlotData(
        metrics=list(metrics_plot),
        in_sample=in_sample_plot,
        out_of_sample=out_of_sample_plot,
        status_quo_name=status_quo_name,
    )
    return plot_data, raw_data, cond_name_to_parameters
예제 #4
0
def _get_batch_comparison_plot_data(
    observations: List[Observation],
    batch_x: int,
    batch_y: int,
    rel: bool = False,
    status_quo_name: Optional[str] = None,
) -> PlotData:
    """Compute PlotData for comparing repeated arms across trials.

    Args:
        observations: List of observations.
        batch_x: Batch for x-axis.
        batch_y: Batch for y-axis.
        rel: Whether to relativize data against status_quo arm.
        status_quo_name: Name of the status_quo arm.

    Returns:
        PlotData: a plot data object.
    """
    if rel and status_quo_name is None:
        raise ValueError("Experiment status quo must be set for rel=True")
    x_observations = {
        observation.arm_name: observation
        for observation in observations
        if observation.features.trial_index == batch_x
    }
    y_observations = {
        observation.arm_name: observation
        for observation in observations
        if observation.features.trial_index == batch_y
    }

    # Assume input is well formed and metric_names are consistent across observations
    metric_names = observations[0].data.metric_names
    insample_data: Dict[str, PlotInSampleArm] = {}
    for arm_name, x_observation in x_observations.items():
        # Restrict to arms present in both trials
        if arm_name not in y_observations:
            continue

        y_observation = y_observations[arm_name]
        arm_data = {
            "name": arm_name,
            "y": {},
            "se": {},
            "parameters": x_observation.features.parameters,
            "y_hat": {},
            "se_hat": {},
            "context_stratum": None,
        }
        for i, mname in enumerate(x_observation.data.metric_names):
            # pyre-fixme[16]: Optional type has no attribute `__setitem__`.
            arm_data["y"][mname] = x_observation.data.means[i]
            arm_data["se"][mname] = np.sqrt(
                x_observation.data.covariance[i][i])
        for i, mname in enumerate(y_observation.data.metric_names):
            arm_data["y_hat"][mname] = y_observation.data.means[i]
            arm_data["se_hat"][mname] = np.sqrt(
                y_observation.data.covariance[i][i])
        # Expected `str` for 2nd anonymous parameter to call `dict.__setitem__` but got
        # `Optional[str]`.
        # pyre-fixme[6]:
        insample_data[arm_name] = PlotInSampleArm(**arm_data)

    return PlotData(
        metrics=metric_names,
        in_sample=insample_data,
        out_of_sample=None,
        status_quo_name=status_quo_name,
    )