def __init__(self,
                 model_config: Point,
                 input_space: Hypergrid,
                 output_space: Hypergrid,
                 logger: logging.Logger = None):
        NaiveMultiObjectiveRegressionModel.__init__(
            self,
            model_type=RegressionEnhancedRandomForestRegressionModel,
            model_config=model_config,
            input_space=input_space,
            output_space=output_space,
            logger=logger)

        # We just need to assert that the model config belongs in regression_enhanced_random_forest_config_store.parameter_space.
        # A more elaborate solution might be needed down the road, but for now this simple solution should suffice.
        #
        assert model_config in regression_enhanced_random_forest_config_store.parameter_space

        for output_dimension in output_space.dimensions:
            # We copy the model_config (rather than share across objectives below because the perform_initial_random_forest_hyper_parameter_search
            #  is set to False after the initial fit() call so that subsequent .fit() calls don't pay the cost penalty for this embedded hyper parameter search
            rerf_model = RegressionEnhancedRandomForestRegressionModel(
                model_config=model_config.copy(),
                input_space=input_space,
                output_space=SimpleHypergrid(
                    name=f"{output_dimension.name}_objective",
                    dimensions=[output_dimension]),
                logger=self.logger)
            self._regressors_by_objective_name[
                output_dimension.name] = rerf_model
Exemplo n.º 2
0
    def test_composite_spaces(self):

        valid_config_no_emergency_buffer = Point(num_readers=1,
                                                 log2_buffer_size=10,
                                                 use_emergency_buffer=False)
        self.assertTrue(
            valid_config_no_emergency_buffer in self.hierarchical_settings)

        valid_emergency_buffer_config = Point(log2_emergency_buffer_size=2,
                                              use_colors=False)

        self.assertTrue(
            valid_emergency_buffer_config in self.emergency_buffer_settings)

        valid_config_with_emergency_buffer = Point(
            num_readers=1,
            log2_buffer_size=10,
            use_emergency_buffer=True,
            emergency_buffer_config=valid_emergency_buffer_config)
        self.assertTrue(
            valid_config_with_emergency_buffer in self.hierarchical_settings)

        valid_emergency_buffer_color_config = Point(color='Crimson')
        valid_emergency_buffer_color_config_with_pivot_dimension = valid_emergency_buffer_color_config.copy(
        )
        valid_emergency_buffer_color_config_with_pivot_dimension[
            'use_colors'] = True
        self.assertTrue(
            valid_emergency_buffer_color_config_with_pivot_dimension in
            self.emergency_buffer_color)

        valid_colorful_emergency_buffer_config = Point(
            log2_emergency_buffer_size=2,
            use_colors=True,
            emergency_buffer_color=valid_emergency_buffer_color_config)
        valid_colorful_emergency_buffer_config_with_pivot_dimension = valid_colorful_emergency_buffer_config.copy(
        )
        valid_colorful_emergency_buffer_config_with_pivot_dimension[
            'use_emergency_buffer'] = True
        self.assertTrue(
            valid_colorful_emergency_buffer_config_with_pivot_dimension in
            self.emergency_buffer_settings_with_color)

        valid_config_with_emergency_buffer_colors = Point(
            num_readers=1,
            log2_buffer_size=10,
            use_emergency_buffer=True,
            emergency_buffer_config=valid_colorful_emergency_buffer_config)

        valid_config_with_emergency_buffer_and_redundant_coordinates = Point(
            num_readers=1,
            log2_buffer_size=10,
            use_emergency_buffer=False,
            log2_emergency_buffer_size=2)
        self.assertTrue(
            valid_config_with_emergency_buffer_and_redundant_coordinates in
            self.hierarchical_settings)

        another_invalid_config_with_emergency_buffer = Point(
            num_readers=1, log2_buffer_size=10, use_emergency_buffer=True)

        yet_another_invalid_config_with_emergency_buffer = Point(
            num_readers=1,
            log2_buffer_size=10,
            use_emergency_buffer=True,
            log2_emergency_buffer_size=40)

        self.assertTrue(
            valid_config_no_emergency_buffer in self.hierarchical_settings)
        self.assertTrue(
            valid_config_no_emergency_buffer in self.hierarchical_settings)
        self.assertTrue(
            valid_config_with_emergency_buffer in self.hierarchical_settings)
        self.assertTrue(valid_config_with_emergency_buffer_colors in
                        self.hierarchical_settings)
        self.assertTrue(
            valid_config_with_emergency_buffer_and_redundant_coordinates in
            self.hierarchical_settings)
        self.assertTrue(another_invalid_config_with_emergency_buffer not in
                        self.hierarchical_settings)
        self.assertTrue(yet_another_invalid_config_with_emergency_buffer not in
                        self.hierarchical_settings)