コード例 #1
0
def test_SelectionRulePerformancePlotter_plot_regrets():
    with patch(
        "triage.component.audition.selection_rule_performance.plot_cats"
    ) as plot_patch:
        with testing.postgresql.Postgresql() as postgresql:
            engine = create_engine(postgresql.url())
            distance_table, model_groups = create_sample_distance_table(engine)
            plotter = SelectionRulePerformancePlotter(
                SelectionRulePicker(distance_table)
            )
            plotter.plot(
                bound_selection_rules=[
                    BoundSelectionRule(
                        function_name="best_current_value",
                        args={"metric": "precision@", "parameter": "100_abs"},
                    ),
                    BoundSelectionRule(
                        function_name="best_average_value",
                        args={"metric": "precision@", "parameter": "100_abs"},
                    ),
                ],
                regret_metric="precision@",
                regret_parameter="100_abs",
                model_group_ids=[1, 2],
                train_end_times=["2014-01-01", "2015-01-01"],
            )
        assert plot_patch.called
        args, kwargs = plot_patch.call_args
        assert "regret" in kwargs["frame"]
        assert "train_end_time" in kwargs["frame"]
        assert kwargs["x_col"] == "train_end_time"
        assert kwargs["y_col"] == "regret"
コード例 #2
0
def test_SelectionRulePerformancePlotter_plot_metrics():
    with patch('triage.component.audition.selection_rule_performance.plot_cats'
               ) as plot_patch:
        with testing.postgresql.Postgresql() as postgresql:
            engine = create_engine(postgresql.url())
            distance_table, model_groups = create_sample_distance_table(engine)
            plotter = SelectionRulePerformancePlotter(
                SelectionRulePicker(distance_table))
            plotter.plot(bound_selection_rules=[
                BoundSelectionRule(
                    function_name='best_current_value',
                    args={
                        'metric': 'precision@',
                        'parameter': '100_abs'
                    },
                ),
                BoundSelectionRule(
                    function_name='best_average_value',
                    args={
                        'metric': 'precision@',
                        'parameter': '100_abs'
                    },
                )
            ],
                         regret_metric='precision@',
                         regret_parameter='100_abs',
                         model_group_ids=[1, 2],
                         train_end_times=['2014-01-01', '2015-01-01'],
                         plot_type='metric')
        assert plot_patch.called
        args, kwargs = plot_patch.call_args
        assert 'raw_value_next_time' in kwargs['frame']
        assert 'train_end_time' in kwargs['frame']
        assert kwargs['x_col'] == 'train_end_time'
        assert kwargs['y_col'] == 'raw_value_next_time'
コード例 #3
0
def test_SelectionRulePerformancePlotter_generate_plot_data():
    plotter = SelectionRulePerformancePlotter(MockSelectionRulePicker())
    df = plotter.generate_plot_data(
        bound_selection_rules=[
            BoundSelectionRule(
                function_name="best_current_value",
                args={"metric": "precision@", "parameter": "100_abs"},
            ),
            BoundSelectionRule(
                function_name="best_average_value",
                args={"metric": "precision@", "parameter": "100_abs"},
            ),
        ],
        regret_metric="precision@",
        regret_parameter="100_abs",
        model_group_ids=[1, 2],
        train_end_times=TRAIN_END_TIMES,
    )
    print(df.to_dict("list"))
    assert df.to_dict("list") == {
        "selection_rule": [
            "best_current_value_precision@_100_abs",
            "best_current_value_precision@_100_abs",
            "best_average_value_precision@_100_abs",
            "best_average_value_precision@_100_abs",
        ],
        "train_end_time": TRAIN_END_TIMES + TRAIN_END_TIMES,
        "regret": [0.15, 0.30, 0.15, 0.30],
        "raw_value_next_time": [0.5, 0.4, 0.5, 0.4],
        "model_group_id": [1, 2, 1, 2],
    }
コード例 #4
0
def test_SelectionPlotter_plot():
    with patch("triage.component.audition.regrets.plot_cats") as plot_patch:
        with testing.postgresql.Postgresql() as postgresql:
            engine = create_engine(postgresql.url())
            distance_table, model_groups = create_sample_distance_table(engine)
            plotter = SelectionRulePlotter(
                selection_rule_picker=SelectionRulePicker(distance_table)
            )
            plotter.plot_all_selection_rules(
                bound_selection_rules=[
                    BoundSelectionRule(
                        descriptive_name="best_current_precision",
                        function=best_current_value,
                        args={"metric": "precision@", "parameter": "100_abs"},
                    ),
                    BoundSelectionRule(
                        descriptive_name="best_avg_precision",
                        function=best_average_value,
                        args={"metric": "precision@", "parameter": "100_abs"},
                    ),
                ],
                model_group_ids=[mg.model_group_id for mg in model_groups.values()],
                train_end_times=["2014-01-01", "2015-01-01", "2016-01-01"],
                regret_metric="precision@",
                regret_parameter="100_abs",
            )
        assert plot_patch.called
        args, kwargs = plot_patch.call_args
        assert "regret" in kwargs["frame"]
        assert "pct_of_time" in kwargs["frame"]
        assert kwargs["x_col"] == "regret"
        assert kwargs["y_col"] == "pct_of_time"
コード例 #5
0
def test_selection_rule_picker_with_args():
    with testing.postgresql.Postgresql() as postgresql:
        engine = create_engine(postgresql.url())
        distance_table, model_groups = create_sample_distance_table(engine)

        def pick_highest_avg(df, train_end_time, metric, parameter):
            assert len(df["train_end_time"].unique()) == 2
            subsetted = df[(df["metric"] == metric)
                           & (df["parameter"] == parameter)]
            mean = subsetted.groupby(["model_group_id"])["raw_value"].mean()
            return [mean.nlargest(1).index[0]]

        selection_rule_picker = SelectionRulePicker(
            distance_from_best_table=distance_table)
        regrets = [
            result["dist_from_best_case_next_time"]
            for result in selection_rule_picker.results_for_rule(
                bound_selection_rule=BoundSelectionRule(
                    descriptive_name="pick_highest_avg",
                    function=pick_highest_avg,
                    args={
                        "metric": "recall@",
                        "parameter": "100_abs"
                    },
                ),
                model_group_ids=[
                    mg.model_group_id for mg in model_groups.values()
                ],
                train_end_times=["2015-01-01"],
                regret_metric="precision@",
                regret_parameter="100_abs",
            )
        ]
        # picking the highest avg recall will pick 'spiky' for this time
        assert regrets == [0.3]
コード例 #6
0
def test_selection_rule_picker():
    with testing.postgresql.Postgresql() as postgresql:
        engine = create_engine(postgresql.url())
        distance_table, model_groups = create_sample_distance_table(engine)

        def pick_spiky(df, train_end_time):
            return [model_groups["spiky"].model_group_id]

        selection_rule_picker = SelectionRulePicker(
            distance_from_best_table=distance_table)

        results = selection_rule_picker.results_for_rule(
            bound_selection_rule=BoundSelectionRule(descriptive_name="spiky",
                                                    function=pick_spiky,
                                                    args={}),
            model_group_ids=[
                mg.model_group_id for mg in model_groups.values()
            ],
            train_end_times=["2014-01-01", "2015-01-01", "2016-01-01"],
            regret_metric="precision@",
            regret_parameter="100_abs",
        )
        assert [result["dist_from_best_case_next_time"]
                for result in results] == [
                    0.19,
                    0.3,
                    0.12,
                ]
        assert [result["raw_value"]
                for result in results] == [0.45, 0.84, 0.45]
コード例 #7
0
def test_selection_rule_picker_with_args():
    with testing.postgresql.Postgresql() as postgresql:
        engine = create_engine(postgresql.url())
        distance_table, model_groups = create_sample_distance_table(engine)

        def pick_highest_avg(df, train_end_time, metric, parameter):
            assert len(df['train_end_time'].unique()) == 2
            subsetted = df[(df['metric'] == metric)
                           & (df['parameter'] == parameter)]
            mean = subsetted.groupby(['model_group_id'])['raw_value'].mean()
            return [mean.nlargest(1).index[0]]

        selection_rule_picker = SelectionRulePicker(
            distance_from_best_table=distance_table)
        regrets = [
            result['dist_from_best_case_next_time']
            for result in selection_rule_picker.results_for_rule(
                bound_selection_rule=BoundSelectionRule(
                    descriptive_name='pick_highest_avg',
                    function=pick_highest_avg,
                    args={
                        'metric': 'recall@',
                        'parameter': '100_abs'
                    },
                ),
                model_group_ids=[
                    mg.model_group_id for mg in model_groups.values()
                ],
                train_end_times=['2015-01-01'],
                regret_metric='precision@',
                regret_parameter='100_abs',
            )
        ]
        # picking the highest avg recall will pick 'spiky' for this time
        assert regrets == [0.3]
コード例 #8
0
def test_SelectionPlotter_create_plot_dataframe():
    with testing.postgresql.Postgresql() as postgresql:
        engine = create_engine(postgresql.url())
        distance_table, model_groups = create_sample_distance_table(engine)
        plotter = SelectionRulePlotter(
            selection_rule_picker=SelectionRulePicker(distance_table))
        plot_df = plotter.create_plot_dataframe(
            bound_selection_rules=[
                BoundSelectionRule(
                    descriptive_name="best_current_precision",
                    function=best_current_value,
                    args={
                        "metric": "precision@",
                        "parameter": "100_abs"
                    },
                ),
                BoundSelectionRule(
                    descriptive_name="best_avg_precision",
                    function=best_average_value,
                    args={
                        "metric": "precision@",
                        "parameter": "100_abs"
                    },
                ),
            ],
            model_group_ids=[
                mg.model_group_id for mg in model_groups.values()
            ],
            train_end_times=["2014-01-01", "2015-01-01", "2016-01-01"],
            regret_metric="precision@",
            regret_parameter="100_abs",
        )
        # assert that we have the right # of columns and a row for each % diff value
        assert plot_df.shape == (100 * 2, 3)

        # both selection rules have a regret lower than 70
        for value in plot_df[plot_df["regret"] == 0.70]["pct_of_time"].values:
            assert np.isclose(value, 1.0)

        # best avg precision rule should be within 0.14 1/3 of the time
        for value in plot_df[(plot_df["regret"] == 0.14)
                             & (plot_df["selection_rule"] ==
                                "best_avg_precision")]["pct_of_time"].values:
            assert np.isclose(value, 1.0 / 3)
コード例 #9
0
def test_SelectionPlotter_create_plot_dataframe():
    with testing.postgresql.Postgresql() as postgresql:
        engine = create_engine(postgresql.url())
        distance_table, model_groups = create_sample_distance_table(engine)
        plotter = SelectionRulePlotter(
            selection_rule_picker=SelectionRulePicker(distance_table))
        plot_df = plotter.create_plot_dataframe(
            bound_selection_rules=[
                BoundSelectionRule(descriptive_name='best_current_precision',
                                   function=best_current_value,
                                   args={
                                       'metric': 'precision@',
                                       'parameter': '100_abs'
                                   }),
                BoundSelectionRule(descriptive_name='best_avg_precision',
                                   function=best_average_value,
                                   args={
                                       'metric': 'precision@',
                                       'parameter': '100_abs'
                                   }),
            ],
            model_group_ids=[
                mg.model_group_id for mg in model_groups.values()
            ],
            train_end_times=['2014-01-01', '2015-01-01', '2016-01-01'],
            regret_metric='precision@',
            regret_parameter='100_abs',
        )
        # assert that we have the right # of columns and a row for each % diff value
        assert plot_df.shape == (100 * 2, 3)

        # both selection rules have a regret lower than 70
        for value in plot_df[plot_df['regret'] == 0.70]['pct_of_time'].values:
            assert numpy.isclose(value, 1.0)

        # best avg precision rule should be within 0.14 1/3 of the time
        for value in plot_df[(plot_df['regret'] == 0.14)
                             & (plot_df['selection_rule'] ==
                                'best_avg_precision')]['pct_of_time'].values:
            assert numpy.isclose(value, 1.0 / 3)
コード例 #10
0
def test_SelectionRulePerformancePlotter_generate_plot_data():
    plotter = SelectionRulePerformancePlotter(MockSelectionRulePicker())
    df = plotter.generate_plot_data(
        bound_selection_rules=[
            BoundSelectionRule(
                function_name='best_current_value',
                args={
                    'metric': 'precision@',
                    'parameter': '100_abs'
                },
            ),
            BoundSelectionRule(
                function_name='best_average_value',
                args={
                    'metric': 'precision@',
                    'parameter': '100_abs'
                },
            )
        ],
        regret_metric='precision@',
        regret_parameter='100_abs',
        model_group_ids=[1, 2],
        train_end_times=TRAIN_END_TIMES,
    )
    print(df.to_dict('list'))
    assert df.to_dict('list') == {
        'selection_rule': [
            'best_current_value_precision@_100_abs',
            'best_current_value_precision@_100_abs',
            'best_average_value_precision@_100_abs',
            'best_average_value_precision@_100_abs'
        ],
        'train_end_time':
        TRAIN_END_TIMES + TRAIN_END_TIMES,
        'regret': [0.15, 0.30, 0.15, 0.30],
        'raw_value_next_time': [0.5, 0.4, 0.5, 0.4],
        'model_group_id': [1, 2, 1, 2]
    }
コード例 #11
0
def test_SelectionPlotter_plot():
    with patch('triage.component.audition.regrets.plot_cats') as plot_patch:
        with testing.postgresql.Postgresql() as postgresql:
            engine = create_engine(postgresql.url())
            distance_table, model_groups = create_sample_distance_table(engine)
            plotter = SelectionRulePlotter(
                selection_rule_picker=SelectionRulePicker(distance_table))
            plotter.plot_all_selection_rules(
                bound_selection_rules=[
                    BoundSelectionRule(
                        descriptive_name='best_current_precision',
                        function=best_current_value,
                        args={
                            'metric': 'precision@',
                            'parameter': '100_abs'
                        }),
                    BoundSelectionRule(descriptive_name='best_avg_precision',
                                       function=best_average_value,
                                       args={
                                           'metric': 'precision@',
                                           'parameter': '100_abs'
                                       }),
                ],
                model_group_ids=[
                    mg.model_group_id for mg in model_groups.values()
                ],
                train_end_times=['2014-01-01', '2015-01-01', '2016-01-01'],
                regret_metric='precision@',
                regret_parameter='100_abs',
            )
        assert plot_patch.called
        args, kwargs = plot_patch.call_args
        assert 'regret' in kwargs['frame']
        assert 'pct_of_time' in kwargs['frame']
        assert kwargs['x_col'] == 'regret'
        assert kwargs['y_col'] == 'pct_of_time'
コード例 #12
0
def test_selection_rule_grid():
    input_data = [{
        'shared_parameters': [
            {
                'metric': 'precision@',
                'parameter': '100_abs'
            },
            {
                'metric': 'recall@',
                'parameter': '100_abs'
            },
        ],
        'selection_rules': [{
            'name': 'most_frequent_best_dist',
            'dist_from_best_case': [0.1, 0.2, 0.3]
        }, {
            'name': 'best_current_value'
        }]
    }, {
        'shared_parameters': [
            {
                'metric1': 'precision@',
                'parameter1': '100_abs'
            },
        ],
        'selection_rules': [
            {
                'name': 'best_average_two_metrics',
                'metric2': ['recall@'],
                'parameter2': ['100_abs'],
                'metric1_weight': [0.4, 0.5, 0.6]
            },
        ]
    }]

    expected_output = [
        BoundSelectionRule(
            descriptive_name='most_frequent_best_dist_precision@_100_abs_0.1',
            function_name='most_frequent_best_dist',
            args={
                'metric': 'precision@',
                'parameter': '100_abs',
                'dist_from_best_case': 0.1
            }),
        BoundSelectionRule(
            descriptive_name='most_frequent_best_dist_precision@_100_abs_0.2',
            function_name='most_frequent_best_dist',
            args={
                'metric': 'precision@',
                'parameter': '100_abs',
                'dist_from_best_case': 0.2
            }),
        BoundSelectionRule(
            descriptive_name='most_frequent_best_dist_precision@_100_abs_0.3',
            function_name='most_frequent_best_dist',
            args={
                'metric': 'precision@',
                'parameter': '100_abs',
                'dist_from_best_case': 0.3
            }),
        BoundSelectionRule(
            descriptive_name='most_frequent_best_dist_recall@_100_abs_0.1',
            function_name='most_frequent_best_dist',
            args={
                'metric': 'recall@',
                'parameter': '100_abs',
                'dist_from_best_case': 0.1
            }),
        BoundSelectionRule(
            descriptive_name='most_frequent_best_dist_recall@_100_abs_0.2',
            function_name='most_frequent_best_dist',
            args={
                'metric': 'recall@',
                'parameter': '100_abs',
                'dist_from_best_case': 0.2
            }),
        BoundSelectionRule(
            descriptive_name='most_frequent_best_dist_recall@_100_abs_0.3',
            function_name='most_frequent_best_dist',
            args={
                'metric': 'recall@',
                'parameter': '100_abs',
                'dist_from_best_case': 0.3
            }),
        BoundSelectionRule(
            descriptive_name='best_current_value_precision@_100_abs',
            function_name='best_current_value',
            args={
                'metric': 'precision@',
                'parameter': '100_abs'
            }),
        BoundSelectionRule(
            descriptive_name='best_current_value_recall@_100_abs',
            function_name='best_current_value',
            args={
                'metric': 'recall@',
                'parameter': '100_abs'
            }),
        BoundSelectionRule(
            descriptive_name=
            'best_average_two_metrics_precision@_100_abs_recall@_100_abs_0.4',
            function_name='best_average_two_metrics',
            args={
                'metric1': 'precision@',
                'parameter1': '100_abs',
                'metric2': 'recall@',
                'parameter2': '100_abs',
                'metric1_weight': 0.4
            }),
        BoundSelectionRule(
            descriptive_name=
            'best_average_two_metrics_precision@_100_abs_recall@_100_abs_0.5',
            function_name='best_average_two_metrics',
            args={
                'metric1': 'precision@',
                'parameter1': '100_abs',
                'metric2': 'recall@',
                'parameter2': '100_abs',
                'metric1_weight': 0.5
            }),
        BoundSelectionRule(
            descriptive_name=
            'best_average_two_metrics_precision@_100_abs_recall@_100_abs_0.6',
            function_name='best_average_two_metrics',
            args={
                'metric1': 'precision@',
                'parameter1': '100_abs',
                'metric2': 'recall@',
                'parameter2': '100_abs',
                'metric1_weight': 0.6
            })
    ]

    # sort both lists so we can compare them without resorting to a hash
    expected_output.sort(key=lambda x: x.descriptive_name)
    grid = sorted(make_selection_rule_grid(input_data),
                  key=lambda x: x.descriptive_name)
    assert len(grid) == len(expected_output)
    for expected_rule, actual_rule in zip(expected_output, grid):
        assert expected_rule.descriptive_name == actual_rule.descriptive_name
コード例 #13
0
    def test_create_grid(self):
        """
        input_data = [{
            'shared_parameters': [
                {'metric': 'precision@', 'parameter': '100_abs'},
                {'metric': 'recall@', 'parameter': '100_abs'},
            ],
            'selection_rules': [
                {'name': 'most_frequent_best_dist', 'dist_from_best_case': [0.1, 0.2, 0.3]},
                {'name': 'best_current_value'}
            ]
        }, {
            'shared_parameters': [
                {'metric1': 'precision@', 'parameter1': '100_abs'},
            ],
            'selection_rules': [
                {
                    'name': 'best_average_two_metrics',
                    'metric2': ['recall@'],
                    'parameter2': ['100_abs'],
                    'metric1_weight': [0.4, 0.5, 0.6]
                },
            ]
        }]
        """
        Rule1 = SimpleRuleMaker()
        Rule1.add_rule_best_current_value(metric="precision@",
                                          parameter="100_abs")
        Rule1.add_rule_most_frequent_best_dist(
            metric="recall@",
            parameter="100_abs",
            dist_from_best_case=[0.1, 0.2, 0.3])

        Rule2 = TwoMetricsRuleMaker()
        Rule2.add_rule_best_average_two_metrics(metric1="precision@",
                                                parameter1="100_abs",
                                                metric2="recall@",
                                                parameter2="100_abs",
                                                metric1_weight=[0.4, 0.5, 0.6])

        expected_output = [
            BoundSelectionRule(
                descriptive_name=
                'most_frequent_best_dist_precision@_100_abs_0.1',
                function_name='most_frequent_best_dist',
                args={
                    'metric': 'precision@',
                    'parameter': '100_abs',
                    'dist_from_best_case': 0.1
                }),
            BoundSelectionRule(
                descriptive_name=
                'most_frequent_best_dist_precision@_100_abs_0.2',
                function_name='most_frequent_best_dist',
                args={
                    'metric': 'precision@',
                    'parameter': '100_abs',
                    'dist_from_best_case': 0.2
                }),
            BoundSelectionRule(
                descriptive_name=
                'most_frequent_best_dist_precision@_100_abs_0.3',
                function_name='most_frequent_best_dist',
                args={
                    'metric': 'precision@',
                    'parameter': '100_abs',
                    'dist_from_best_case': 0.3
                }),
            BoundSelectionRule(
                descriptive_name='most_frequent_best_dist_recall@_100_abs_0.1',
                function_name='most_frequent_best_dist',
                args={
                    'metric': 'recall@',
                    'parameter': '100_abs',
                    'dist_from_best_case': 0.1
                }),
            BoundSelectionRule(
                descriptive_name='most_frequent_best_dist_recall@_100_abs_0.2',
                function_name='most_frequent_best_dist',
                args={
                    'metric': 'recall@',
                    'parameter': '100_abs',
                    'dist_from_best_case': 0.2
                }),
            BoundSelectionRule(
                descriptive_name='most_frequent_best_dist_recall@_100_abs_0.3',
                function_name='most_frequent_best_dist',
                args={
                    'metric': 'recall@',
                    'parameter': '100_abs',
                    'dist_from_best_case': 0.3
                }),
            BoundSelectionRule(
                descriptive_name='best_current_value_precision@_100_abs',
                function_name='best_current_value',
                args={
                    'metric': 'precision@',
                    'parameter': '100_abs'
                }),
            BoundSelectionRule(
                descriptive_name='best_current_value_recall@_100_abs',
                function_name='best_current_value',
                args={
                    'metric': 'recall@',
                    'parameter': '100_abs'
                }),
            BoundSelectionRule(
                descriptive_name=
                'best_average_two_metrics_precision@_100_abs_recall@_100_abs_0.4',
                function_name='best_average_two_metrics',
                args={
                    'metric1': 'precision@',
                    'parameter1': '100_abs',
                    'metric2': 'recall@',
                    'parameter2': '100_abs',
                    'metric1_weight': 0.4
                }),
            BoundSelectionRule(
                descriptive_name=
                'best_average_two_metrics_precision@_100_abs_recall@_100_abs_0.5',
                function_name='best_average_two_metrics',
                args={
                    'metric1': 'precision@',
                    'parameter1': '100_abs',
                    'metric2': 'recall@',
                    'parameter2': '100_abs',
                    'metric1_weight': 0.5
                }),
            BoundSelectionRule(
                descriptive_name=
                'best_average_two_metrics_precision@_100_abs_recall@_100_abs_0.6',
                function_name='best_average_two_metrics',
                args={
                    'metric1': 'precision@',
                    'parameter1': '100_abs',
                    'metric2': 'recall@',
                    'parameter2': '100_abs',
                    'metric1_weight': 0.6
                })
        ]
        expected_output.sort(key=lambda x: x.descriptive_name)
        grid = sorted(make_selection_rule_grid(
            create_selection_grid(Rule1, Rule2)),
                      key=lambda x: x.descriptive_name)
        assert len(grid) == len(expected_output)
        for expected_rule, actual_rule in zip(expected_output, grid):
            assert expected_rule.descriptive_name == actual_rule.descriptive_name
コード例 #14
0
def test_selection_rule_grid():
    input_data = [
        {
            "shared_parameters": [
                {
                    "metric": "precision@",
                    "parameter": "100_abs"
                },
                {
                    "metric": "recall@",
                    "parameter": "100_abs"
                },
            ],
            "selection_rules": [
                {
                    "name": "most_frequent_best_dist",
                    "dist_from_best_case": [0.1, 0.2, 0.3],
                },
                {
                    "name": "best_current_value"
                },
            ],
        },
        {
            "shared_parameters": [{
                "metric1": "precision@",
                "parameter1": "100_abs"
            }],
            "selection_rules": [{
                "name": "best_average_two_metrics",
                "metric2": ["recall@"],
                "parameter2": ["100_abs"],
                "metric1_weight": [0.4, 0.5, 0.6],
            }],
        },
    ]

    expected_output = [
        BoundSelectionRule(
            descriptive_name="most_frequent_best_dist_precision@_100_abs_0.1",
            function_name="most_frequent_best_dist",
            args={
                "metric": "precision@",
                "parameter": "100_abs",
                "dist_from_best_case": 0.1,
            },
        ),
        BoundSelectionRule(
            descriptive_name="most_frequent_best_dist_precision@_100_abs_0.2",
            function_name="most_frequent_best_dist",
            args={
                "metric": "precision@",
                "parameter": "100_abs",
                "dist_from_best_case": 0.2,
            },
        ),
        BoundSelectionRule(
            descriptive_name="most_frequent_best_dist_precision@_100_abs_0.3",
            function_name="most_frequent_best_dist",
            args={
                "metric": "precision@",
                "parameter": "100_abs",
                "dist_from_best_case": 0.3,
            },
        ),
        BoundSelectionRule(
            descriptive_name="most_frequent_best_dist_recall@_100_abs_0.1",
            function_name="most_frequent_best_dist",
            args={
                "metric": "recall@",
                "parameter": "100_abs",
                "dist_from_best_case": 0.1,
            },
        ),
        BoundSelectionRule(
            descriptive_name="most_frequent_best_dist_recall@_100_abs_0.2",
            function_name="most_frequent_best_dist",
            args={
                "metric": "recall@",
                "parameter": "100_abs",
                "dist_from_best_case": 0.2,
            },
        ),
        BoundSelectionRule(
            descriptive_name="most_frequent_best_dist_recall@_100_abs_0.3",
            function_name="most_frequent_best_dist",
            args={
                "metric": "recall@",
                "parameter": "100_abs",
                "dist_from_best_case": 0.3,
            },
        ),
        BoundSelectionRule(
            descriptive_name="best_current_value_precision@_100_abs",
            function_name="best_current_value",
            args={
                "metric": "precision@",
                "parameter": "100_abs"
            },
        ),
        BoundSelectionRule(
            descriptive_name="best_current_value_recall@_100_abs",
            function_name="best_current_value",
            args={
                "metric": "recall@",
                "parameter": "100_abs"
            },
        ),
        BoundSelectionRule(
            descriptive_name=
            "best_average_two_metrics_precision@_100_abs_recall@_100_abs_0.4",
            function_name="best_average_two_metrics",
            args={
                "metric1": "precision@",
                "parameter1": "100_abs",
                "metric2": "recall@",
                "parameter2": "100_abs",
                "metric1_weight": 0.4,
            },
        ),
        BoundSelectionRule(
            descriptive_name=
            "best_average_two_metrics_precision@_100_abs_recall@_100_abs_0.5",
            function_name="best_average_two_metrics",
            args={
                "metric1": "precision@",
                "parameter1": "100_abs",
                "metric2": "recall@",
                "parameter2": "100_abs",
                "metric1_weight": 0.5,
            },
        ),
        BoundSelectionRule(
            descriptive_name=
            "best_average_two_metrics_precision@_100_abs_recall@_100_abs_0.6",
            function_name="best_average_two_metrics",
            args={
                "metric1": "precision@",
                "parameter1": "100_abs",
                "metric2": "recall@",
                "parameter2": "100_abs",
                "metric1_weight": 0.6,
            },
        ),
    ]

    # sort both lists so we can compare them without resorting to a hash
    expected_output.sort(key=lambda x: x.descriptive_name)
    grid = sorted(make_selection_rule_grid(input_data),
                  key=lambda x: x.descriptive_name)
    assert len(grid) == len(expected_output)
    for expected_rule, actual_rule in zip(expected_output, grid):
        assert expected_rule.descriptive_name == actual_rule.descriptive_name