def fit(self, pipeline_config, run_result_dir):

        instance_file_config_parser = ConfigFileParser([
            ConfigOption(name='path', type='directory', required=True),
            ConfigOption(name='is_classification', type=to_bool,
                         required=True),
            ConfigOption(name='is_multilabel', type=to_bool, required=True),
            ConfigOption(name='num_features', type=int, required=True),
            ConfigOption(name='categorical_features',
                         type=bool,
                         required=True,
                         list=True),
            ConfigOption(
                name='instance_shape',
                type=[ast.literal_eval, lambda x: isinstance(x, tuple)],
                required=True)
        ])
        instance_info = instance_file_config_parser.read(
            os.path.join(run_result_dir, 'instance.info'))
        instance_info = instance_file_config_parser.set_defaults(instance_info)

        dm = DataManager()
        if instance_info["is_multilabel"]:
            dm.problem_type = ProblemType.FeatureMultilabel
        elif instance_info["is_classification"]:
            dm.problem_type = ProblemType.FeatureClassification
        else:
            dm.problem_type = ProblemType.FeatureClassification

        return {'instance_info': instance_info, 'data_manager': dm}
예제 #2
0
    def get_pipeline_config(self,
                            throw_error_if_invalid=True,
                            **pipeline_config):
        """Get the full pipeline config given a partial pipeline config
        
        Keyword Arguments:
            throw_error_if_invalid {bool} -- Throw an error if invalid config option is defined (default: {True})
        
        Returns:
            dict -- the full config for the pipeline, containing values for all options
        """
        options = self.get_pipeline_config_options()
        conditions = self.get_pipeline_config_conditions()

        parser = ConfigFileParser(options)
        pipeline_config = parser.set_defaults(
            pipeline_config, throw_error_if_invalid=throw_error_if_invalid)

        # check the conditions e.g. max_budget > min_budget
        for c in conditions:
            try:
                c(pipeline_config)
            except Exception as e:
                if throw_error_if_invalid:
                    raise
                print(e)
                traceback.print_exc()

        return pipeline_config
예제 #3
0
    def get_pipeline_config(self,
                            throw_error_if_invalid=True,
                            **pipeline_config):
        options = self.get_pipeline_config_options()
        conditions = self.get_pipeline_config_conditions()

        parser = ConfigFileParser(options)
        pipeline_config = parser.set_defaults(
            pipeline_config, throw_error_if_invalid=throw_error_if_invalid)

        for c in conditions:
            try:
                c(pipeline_config)
            except Exception as e:
                if throw_error_if_invalid:
                    raise
                print(e)
                traceback.print_exc()

        return pipeline_config