def before_train(self, logs=None):
        """Call before_train of the managed callbacks."""
        super().before_train(logs)
        """Be called before the training process."""
        hpo_result = FileOps.load_pickle(
            FileOps.join_path(self.trainer.local_output_path,
                              'best_config.pickle'))
        logging.info("loading stage1_hpo_result \n{}".format(hpo_result))

        feature_interaction_score = hpo_result['feature_interaction_score']
        print('feature_interaction_score:', feature_interaction_score)
        sorted_pairs = sorted(feature_interaction_score.items(),
                              key=lambda x: abs(x[1]),
                              reverse=True)

        model_cfg = ClassFactory.__configs__.get('model')
        if model_cfg:
            fis_ratio = model_cfg["model_desc"]["custom"]["fis_ratio"]
        else:
            fis_ratio = 1.0
        top_k = int(len(feature_interaction_score) * min(1.0, fis_ratio))
        self.selected_pairs = list(map(lambda x: x[0], sorted_pairs[:top_k]))

        # add selected_pairs
        setattr(model_cfg["model_desc"]["custom"], 'selected_pairs',
                self.selected_pairs)
    def before_train(self, logs=None):
        """Call before_train of the managed callbacks."""
        super().before_train(logs)
        """Be called before the training process."""
        hpo_result = FileOps.load_pickle(
            FileOps.join_path(self.trainer.local_output_path,
                              'best_config.pickle'))
        logging.info("loading stage1_hpo_result \n{}".format(hpo_result))

        self.selected_pairs = hpo_result['feature_interaction']
        print('feature_interaction:', self.selected_pairs)

        model_cfg = ClassFactory.__configs__.get('model')
        # add selected_pairs
        setattr(model_cfg["model_desc"]["custom"], 'selected_pairs',
                self.selected_pairs)
Esempio n. 3
0
 def update(self, record):
     """Update sampler."""
     step_name = record.get("step_name")
     worker_id = record.get("worker_id")
     worker_result_path = TaskOps().get_local_worker_path(
         step_name, worker_id)
     performance_file = self.performance_path(worker_result_path)
     logging.info(
         "SpNas.update(), performance file={}".format(performance_file))
     info = FileOps.load_pickle(performance_file)
     if info is not None:
         self._total_list.append(info)
     else:
         logging.info("SpNas.update(), file is not exited, "
                      "performance file={}".format(performance_file))
     self.save_output(self.output_path)
     if self.backup_base_path is not None:
         FileOps.copy_folder(self.output_path, self.backup_base_path)