def load_checkpoint(self, worker_id=None, step_name=None, saved_folder=None): """Load checkpoint.""" if saved_folder is None: if worker_id is None: worker_id = self.worker_id if step_name is None: step_name = self.step_name saved_folder = self.get_local_worker_path(step_name, worker_id) checkpoint_file = FileOps.join_path(saved_folder, self.checkpoint_file_name) model_pickle_file = FileOps.join_path(saved_folder, self.model_pickle_file_name) try: with open(model_pickle_file, 'rb') as f: model = pickle.load(f) if vega.is_torch_backend(): ckpt = torch.load(checkpoint_file, map_location=torch.device('cpu')) model.load_state_dict(ckpt['weight']) if self.config.cuda: model = model.cuda() elif vega.is_tf_backend(): FileOps.copy_folder(saved_folder, self.get_local_worker_path()) self.model = model except Exception: logging.info( 'Checkpoint file is not existed, use default model now.') return
def _backup(self): """Backup result worker folder.""" if self.need_backup is True and self.backup_base_path is not None: backup_worker_path = FileOps.join_path(self.backup_base_path, self.get_worker_subpath()) FileOps.copy_folder( self.get_local_worker_path(self.step_name, self.worker_id), backup_worker_path)
def _load_pretrained_model(self): if self.model is None: return if self.config.pretrained_model_file is not None: model_file = self.config.pretrained_model_file model_file = os.path.abspath(model_file) if vega.is_torch_backend(): ckpt = torch.load(model_file) self.model.load_state_dict(ckpt) elif vega.is_tf_backend(): model_folder = os.path.dirname(model_file) FileOps.copy_folder(model_folder, self.get_local_worker_path()) return
def update(self, record): """Update sampler.""" step_name = record.get("step_name") worker_id = record.get("worker_id") worker_result_path = TaskOps().get_local_worker_path( step_name, worker_id) performance_file = self.performance_path(worker_result_path) logging.info( "SpNas.update(), performance file={}".format(performance_file)) info = FileOps.load_pickle(performance_file) if info is not None: self._total_list.append(info) else: logging.info("SpNas.update(), file is not exited, " "performance file={}".format(performance_file)) self.save_output(self.output_path) if self.backup_base_path is not None: FileOps.copy_folder(self.output_path, self.backup_base_path)
def backup_output_path(self): """Back up output to local path.""" backup_path = TaskOps().backup_base_path if backup_path is None: return FileOps.copy_folder(TaskOps().local_output_path, backup_path)
def _backup_output_path(self): # TODO: only backup step output path backup_path = self.task.backup_base_path if backup_path is None: return FileOps.copy_folder(self.task.local_output_path, backup_path)