def save_metrics_value(self): """Save the metric value of the trained model. :return: save_path (local) and s3_path (remote). If s3_path not specified, then s3_path is None :rtype: a tuple of two str """ pd_path = FileOps.join_path(self.trainer.local_output_path, self.trainer.step_name, "performace.csv") FileOps.make_base_dir(pd_path) encoding = self.model.nbit_w_list + self.model.nbit_a_list df = pd.DataFrame( [[encoding, self.flops_count, self.params_count, self.metric]], columns=[ "encoding", "flops", "parameters", self.cfg.get("valid_metric", "acc") ]) if not os.path.exists(pd_path): with open(pd_path, "w") as file: df.to_csv(file, index=False) else: with open(pd_path, "a") as file: df.to_csv(file, index=False, header=False) if self.trainer.backup_base_path is not None: FileOps.copy_folder(self.trainer.local_output_path, self.trainer.backup_base_path)
def _save_best_model(self): save_path = FileOps.join_path(self.trainer.get_local_worker_path(), self.trainer.step_name, "best_model.pth") FileOps.make_base_dir(save_path) torch.save(self.model.state_dict(), save_path) if self.trainer.backup_base_path is not None: _dst = FileOps.join_path(self.trainer.backup_base_path, "workers", str(self.trainer.worker_id)) FileOps.copy_folder(self.trainer.get_local_worker_path(), _dst)
def _copy_needed_file(self): if "pareto_front_file" not in self.cfg or self.cfg.pareto_front_file is None: raise FileNotFoundError("Config item paretor_front_file not found in config file.") init_pareto_front_file = self.cfg.pareto_front_file.replace("{local_base_path}", self.local_base_path) self.pareto_front_file = FileOps.join_path(self.local_output_path, self.cfg.step_name, "pareto_front.csv") FileOps.make_base_dir(self.pareto_front_file) FileOps.copy_file(init_pareto_front_file, self.pareto_front_file) if "random_file" not in self.cfg or self.cfg.random_file is None: raise FileNotFoundError("Config item random_file not found in config file.") init_random_file = self.cfg.random_file.replace("{local_base_path}", self.local_base_path) self.random_file = FileOps.join_path(self.local_output_path, self.cfg.step_name, "random.csv") FileOps.copy_file(init_random_file, self.random_file)
def _save_hpo_cache(self): """Save all hpo info.""" csv_columns = ['id', 'hps', 'performance'] try: FileOps.make_base_dir(self._cache_file) with open(self._cache_file, 'w') as csv_file: writer = csv.DictWriter(csv_file, fieldnames=csv_columns) writer.writeheader() for hpo_id, value in self._hps_cache.items(): data = {'id': hpo_id, 'hps': value[0], 'performance': value[1]} writer.writerow(data) except Exception: logger.error("Failed to save hpo cache, file={}".format(self._cache_file)) logging.error(traceback.format_exc())