Example #1
0
    def test_save(self):
        sbc = StatBookCollection()
        sbc.create_shared_stat_books(self.sb_names, self.ms_names)

        fname = "test_save"
        out_fname = sbc.save(fname)

        self.addCleanup(os.remove, out_fname)
Example #2
0
    def test_add_stat_book(self):
        sbc = StatBookCollection()
        sbc.create_shared_stat_books(self.sb_names, self.ms_names)

        sb = StatBook('test_sb')
        sbc.add_stat_book(sb)
        self.assertIn(sb.name, sbc.stat_book_names())
        self.assertIn(sb, sbc.stat_book_list())
Example #3
0
 def _build_from_input_dict(cls, input_dict: dict):
     stat_book_collection = StatBookCollection.from_dict(
         input_dict.pop("stat_book_collection"))
     base_kern_freq_names = input_dict.pop("base_kern_freq_names")
     tracker = super()._build_from_input_dict(input_dict)
     tracker.stat_book_collection = stat_book_collection
     tracker.base_kern_freq_names = base_kern_freq_names
     return tracker
Example #4
0
    def test_from_dict(self):
        sbc = StatBookCollection()
        sbc.create_shared_stat_books(self.sb_names, self.ms_names)

        actual = StatBookCollection.from_dict(sbc.to_dict())
        self.assertEqual(sbc.stat_book_names(), actual.stat_book_names())

        expected_ms_names = [sb.multi_stats_names() for sb in sbc.stat_book_list()]
        actual_ms_names = [sb.multi_stats_names() for sb in actual.stat_book_list()]
        self.assertListEqual(expected_ms_names, actual_ms_names)
Example #5
0
    def __init__(self):
        super().__init__()
        # statistics used for plotting
        self.n_hyperparams_name = 'n_hyperparameters'
        self.n_operands_name = 'n_operands'
        self.score_name = 'score'
        self.cov_dists_name = 'cov_dists'
        self.diversity_scores_name = 'diversity_scores'
        self.best_stat_name = 'best'

        # separate these!
        self.evaluations_name = 'evaluations'
        self.active_set_name = 'active_set'
        self.expansion_name = 'expansion'
        self.stat_book_names = [
            self.evaluations_name, self.expansion_name, self.active_set_name
        ]
        self.base_kern_freq_names = []

        self.stat_book_collection = StatBookCollection()
Example #6
0
    def test_load(self):
        sbc = StatBookCollection()
        sbc.create_shared_stat_books(self.sb_names, self.ms_names)

        fname = "test_save"
        out_fname = sbc.save(fname)

        new_sbc = StatBookCollection.load(out_fname)

        self.assertEqual(sbc.stat_book_names(), new_sbc.stat_book_names())

        expected_ms_names = [sb.multi_stats_names() for sb in sbc.stat_book_list()]
        actual_ms_names = [sb.multi_stats_names() for sb in new_sbc.stat_book_list()]
        self.assertListEqual(expected_ms_names, actual_ms_names)

        self.addCleanup(os.remove, out_fname)
Example #7
0
 def load(output_file_name: str):
     mst = ModelSearchLogger()
     sbc = StatBookCollection.load(output_file_name)
     mst.stat_book_collection = sbc
     return mst
Example #8
0
class ModelSearchLogger(Callback, Serializable):
    def __init__(self):
        super().__init__()
        # statistics used for plotting
        self.n_hyperparams_name = 'n_hyperparameters'
        self.n_operands_name = 'n_operands'
        self.score_name = 'score'
        self.cov_dists_name = 'cov_dists'
        self.diversity_scores_name = 'diversity_scores'
        self.best_stat_name = 'best'

        # separate these!
        self.evaluations_name = 'evaluations'
        self.active_set_name = 'active_set'
        self.expansion_name = 'expansion'
        self.stat_book_names = [
            self.evaluations_name, self.expansion_name, self.active_set_name
        ]
        self.base_kern_freq_names = []

        self.stat_book_collection = StatBookCollection()

    def set_stat_book_collection(self, base_kernel_names: List[str]):
        self.base_kern_freq_names = [
            base_kern_name + '_frequency'
            for base_kern_name in base_kernel_names
        ]

        # All stat books track these variables
        shared_multi_stat_names = [
            self.n_hyperparams_name, self.n_operands_name
        ] + self.base_kern_freq_names

        # raw value statistics
        base_kern_stat_funcs = [
            base_kern_freq(base_kern_name)
            for base_kern_name in base_kernel_names
        ]
        shared_stats = [get_n_hyperparams, get_n_operands
                        ] + base_kern_stat_funcs

        self.stat_book_collection.create_shared_stat_books(
            self.stat_book_names, shared_multi_stat_names, shared_stats)

        sb_active_set = self.stat_book_collection.stat_books[
            self.active_set_name]
        sb_active_set.add_raw_value_stat(self.score_name, get_model_scores)
        sb_active_set.add_raw_value_stat(self.cov_dists_name, get_cov_dists)
        sb_active_set.add_raw_value_stat(self.diversity_scores_name,
                                         get_diversity_scores)
        sb_active_set.multi_stats[self.n_hyperparams_name].add_statistic(
            Statistic(self.best_stat_name, get_best_n_hyperparams))
        sb_active_set.multi_stats[self.n_operands_name].add_statistic(
            Statistic(self.best_stat_name, get_best_n_operands))

        sb_evals = self.stat_book_collection.stat_books[self.evaluations_name]
        sb_evals.add_raw_value_stat(self.score_name, get_model_scores)

    def on_evaluate_all_end(self, logs: Optional[dict] = None):
        logs = logs or {}
        models = logs.get('gp_models', [])
        x = logs.get('x', None)

        grammar = self.model.grammar
        stat_book = self.stat_book_collection.stat_books[self.active_set_name]
        if models:
            update_stat_book(stat_book, models, x, grammar.base_kernel_names,
                             grammar.n_dims)

    def on_evaluate_end(self, logs: Optional[dict] = None):
        logs = logs or {}
        model = logs.get('gp_model', [])
        model = [model] or []
        x = logs.get('x', None)

        grammar = self.model.grammar
        stat_book = self.stat_book_collection.stat_books[self.evaluations_name]
        if model:
            update_stat_book(stat_book, model, x, grammar.base_kernel_names,
                             grammar.n_dims)

    def on_propose_new_models_end(self, logs: Optional[dict] = None):
        logs = logs or {}
        models = logs.get('gp_models', [])
        x = logs.get('x', None)

        grammar = self.model.grammar
        stat_book = self.stat_book_collection.stat_books[self.expansion_name]
        if models:
            update_stat_book(stat_book, models, x, grammar.base_kernel_names,
                             grammar.n_dims)

    def to_dict(self) -> dict:
        output_dict = super().to_dict()
        output_dict[
            "stat_book_collection"] = self.stat_book_collection.to_dict()
        output_dict["base_kern_freq_names"] = self.base_kern_freq_names
        return output_dict

    @classmethod
    def _build_from_input_dict(cls, input_dict: dict):
        stat_book_collection = StatBookCollection.from_dict(
            input_dict.pop("stat_book_collection"))
        base_kern_freq_names = input_dict.pop("base_kern_freq_names")
        tracker = super()._build_from_input_dict(input_dict)
        tracker.stat_book_collection = stat_book_collection
        tracker.base_kern_freq_names = base_kern_freq_names
        return tracker

    def save(self, output_file_name: str):
        self.stat_book_collection.save(output_file_name)

    @staticmethod
    def load(output_file_name: str):
        mst = ModelSearchLogger()
        sbc = StatBookCollection.load(output_file_name)
        mst.stat_book_collection = sbc
        return mst
Example #9
0
    def test_to_dict(self):
        sbc = StatBookCollection()
        sbc.create_shared_stat_books(self.sb_names, self.ms_names)

        output_dict = sbc.to_dict()
        self.assertEqual([sb.to_dict() for sb in sbc.stat_book_list()], output_dict["stat_books"])