示例#1
0
    def save_metrics_value(self):
        """Save the metric value of the trained model.

        :return: save_path (local) and s3_path (remote). If s3_path not specified, then s3_path is None
        :rtype: a tuple of two str
        """
        pd_path = FileOps.join_path(self.trainer.local_output_path,
                                    self.trainer.step_name, "performace.csv")
        FileOps.make_base_dir(pd_path)
        encoding = self.model.nbit_w_list + self.model.nbit_a_list
        df = pd.DataFrame(
            [[encoding, self.flops_count, self.params_count, self.metric]],
            columns=[
                "encoding", "flops", "parameters",
                self.cfg.get("valid_metric", "acc")
            ])
        if not os.path.exists(pd_path):
            with open(pd_path, "w") as file:
                df.to_csv(file, index=False)
        else:
            with open(pd_path, "a") as file:
                df.to_csv(file, index=False, header=False)
        if self.trainer.backup_base_path is not None:
            FileOps.copy_folder(self.trainer.local_output_path,
                                self.trainer.backup_base_path)
示例#2
0
 def _save_best_model(self):
     save_path = FileOps.join_path(self.trainer.get_local_worker_path(),
                                   self.trainer.step_name, "best_model.pth")
     FileOps.make_base_dir(save_path)
     torch.save(self.model.state_dict(), save_path)
     if self.trainer.backup_base_path is not None:
         _dst = FileOps.join_path(self.trainer.backup_base_path, "workers",
                                  str(self.trainer.worker_id))
         FileOps.copy_folder(self.trainer.get_local_worker_path(), _dst)
示例#3
0
 def _copy_needed_file(self):
     if "pareto_front_file" not in self.cfg or self.cfg.pareto_front_file is None:
         raise FileNotFoundError("Config item paretor_front_file not found in config file.")
     init_pareto_front_file = self.cfg.pareto_front_file.replace("{local_base_path}", self.local_base_path)
     self.pareto_front_file = FileOps.join_path(self.local_output_path, self.cfg.step_name, "pareto_front.csv")
     FileOps.make_base_dir(self.pareto_front_file)
     FileOps.copy_file(init_pareto_front_file, self.pareto_front_file)
     if "random_file" not in self.cfg or self.cfg.random_file is None:
         raise FileNotFoundError("Config item random_file not found in config file.")
     init_random_file = self.cfg.random_file.replace("{local_base_path}", self.local_base_path)
     self.random_file = FileOps.join_path(self.local_output_path, self.cfg.step_name, "random.csv")
     FileOps.copy_file(init_random_file, self.random_file)
示例#4
0
    def _save_hpo_cache(self):
        """Save all hpo info."""
        csv_columns = ['id', 'hps', 'performance']

        try:
            FileOps.make_base_dir(self._cache_file)
            with open(self._cache_file, 'w') as csv_file:
                writer = csv.DictWriter(csv_file, fieldnames=csv_columns)
                writer.writeheader()
                for hpo_id, value in self._hps_cache.items():
                    data = {'id': hpo_id, 'hps': value[0], 'performance': value[1]}
                    writer.writerow(data)
        except Exception:
            logger.error("Failed to save hpo cache, file={}".format(self._cache_file))
            logging.error(traceback.format_exc())