예제 #1
0
    def test_get_search_algorithm_manager(self):
        # Grid search
        experiment_group = ExperimentGroupFactory()
        assert get_iteration_config(experiment_group.search_algorithm) is None

        # Random search
        experiment_group = ExperimentGroupFactory(
            content=experiment_group_spec_content_early_stopping)
        assert get_iteration_config(experiment_group.search_algorithm) is None

        # Hyperband
        experiment_group = ExperimentGroupFactory(
            content=experiment_group_spec_content_hyperband)
        iteration = {
            'iteration': 1,
            'bracket_iteration': 0,
            'experiment_ids': [1, 2, 3],
            'experiments_metrics': None
        }
        assert isinstance(get_iteration_config(experiment_group.search_algorithm,
                                               iteration=iteration),
                          HyperbandIterationConfig)

        # BO
        experiment_group = ExperimentGroupFactory(
            content=experiment_group_spec_content_bo)
        iteration = {
            'iteration': 1,
            'experiment_ids': [1, 2, 3],
            'experiment_configs': [[1, {1: 1}], [2, {2: 2}], [3, {3: 3}]],
            'experiments_metrics': None
        }
        assert isinstance(get_iteration_config(experiment_group.search_algorithm,
                                               iteration=iteration),
                          BOIterationConfig)
예제 #2
0
    def test_get_search_algorithm_manager(self):
        iteration = {
            'iteration': 1,
            'num_suggestions': 12,
        }
        # Grid search
        experiment_group = ExperimentGroupFactory()
        assert isinstance(
            get_iteration_config(experiment_group.search_algorithm,
                                 iteration=iteration), BaseIterationConfig)

        # Random search
        experiment_group = ExperimentGroupFactory(
            content=experiment_group_spec_content_early_stopping)
        assert isinstance(
            get_iteration_config(experiment_group.search_algorithm,
                                 iteration=iteration), BaseIterationConfig)

        # Hyperband
        experiment_group = ExperimentGroupFactory(
            content=experiment_group_spec_content_hyperband)
        iteration = {
            'iteration': 1,
            'num_suggestions': 12,
            'bracket_iteration': 0,
            'experiment_ids': [1, 2, 3],
            'experiments_metrics': None
        }
        assert isinstance(
            get_iteration_config(experiment_group.search_algorithm,
                                 iteration=iteration),
            HyperbandIterationConfig)

        # BO
        experiment_group = ExperimentGroupFactory(
            content=experiment_group_spec_content_bo)
        iteration = {
            'iteration': 1,
            'num_suggestions': 12,
            'experiment_ids': [1, 2, 3],
            'experiments_configs': [[1, {
                1: 1
            }], [2, {
                2: 2
            }], [3, {
                3: 3
            }]],
            'experiments_metrics': None
        }
        assert isinstance(
            get_iteration_config(experiment_group.search_algorithm,
                                 iteration=iteration), BOIterationConfig)
예제 #3
0
    def iteration_config(self):
        from hpsearch.schemas import get_iteration_config

        if self.iteration_data and self.search_algorithm:
            return get_iteration_config(search_algorithm=self.search_algorithm,
                                        iteration=self.iteration_data)
        return None
예제 #4
0
    def iteration_config(self) -> 'BaseIterationConfig':
        from hpsearch.schemas import get_iteration_config

        iteration_data = self.iteration_data

        if iteration_data and self.search_algorithm:
            return get_iteration_config(search_algorithm=self.search_algorithm,
                                        iteration=iteration_data)
        return None