Example #1
0
def interact_batch_comparison(
    observations: List[Observation],
    experiment: Experiment,
    batch_x: int,
    batch_y: int,
    rel: bool = False,
    status_quo_name: Optional[str] = None,
) -> AxPlotConfig:
    """Compare repeated arms from two trials; select metric via dropdown.

    Args:
        observations: List of observations to compute comparison.
        batch_x: Index of batch for x-axis.
        batch_y: Index of bach for y-axis.
        rel: Whether to relativize data against status_quo arm.
        status_quo_name: Name of the status_quo arm.
    """
    if isinstance(experiment, MultiTypeExperiment):
        observations = convert_mt_observations(observations, experiment)
    if not status_quo_name and experiment.status_quo:
        status_quo_name = not_none(experiment.status_quo).name
    plot_data = _get_batch_comparison_plot_data(
        observations,
        batch_x,
        batch_y,
        rel=rel,
        status_quo_name=status_quo_name)
    fig = _obs_vs_pred_dropdown_plot(
        data=plot_data,
        rel=rel,
        xlabel="Batch {}".format(batch_x),
        ylabel="Batch {}".format(batch_y),
    )
    fig["layout"]["title"] = "Repeated arms across trials"
    return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
Example #2
0
    def testConvertMetricNames(self):
        transform = ConvertMetricNames(
            None, self.observation_features, self.observation_data, config=self.tconfig
        )

        transformed_observations = convert_mt_observations(
            self.observations, self.experiment
        )
        transformed_observation_data = [o.data for o in transformed_observations]
        transformed_observation_features = [
            o.features for o in transformed_observations
        ]

        # All trials should have canonical name "m1"
        for obsd in transformed_observation_data:
            self.assertEqual(obsd.metric_names[0], "m1")

        untransformed_observation_data = transform.untransform_observation_data(
            transformed_observation_data, transformed_observation_features
        )

        # Should have original metric_name
        for i in range(len(self.observations)):
            metric_name = (
                "m1" if self.observation_features[i].trial_index == 0 else "m2"
            )
            self.assertEqual(
                untransformed_observation_data[i].metric_names[0], metric_name
            )