def load_checkpoint(checkpoint_dir, exe, main_program): ckpt_meta_path = os.path.join(checkpoint_dir, CKPT_FILE_NAME) ckpt = checkpoint_pb2.CheckPoint() logger.info("Try loading checkpoint from {}".format(ckpt_meta_path)) if os.path.exists(ckpt_meta_path): with open(ckpt_meta_path, "rb") as f: ckpt.ParseFromString(f.read()) current_epoch = 1 global_step = 0 def if_exist(var): return os.path.exists(os.path.join(ckpt.latest_model_dir, var.name)) if ckpt.latest_model_dir: fluid.io.load_vars(exe, ckpt.latest_model_dir, main_program, predicate=if_exist) logger.info("PaddleHub model checkpoint loaded. current_epoch={}, " "global_step={}".format(ckpt.current_epoch, ckpt.global_step)) return True, ckpt.current_epoch, ckpt.global_step logger.info( "PaddleHub model checkpoint not found, start training from scratch...") return False, current_epoch, global_step
def getCheckPointInfo(checkpoint_dir): ckpt_meta_path = os.path.join(checkpoint_dir, "ckpt.meta") ckpt = checkpoint_pb2.CheckPoint() if os.path.exists(ckpt_meta_path): with open(ckpt_meta_path, "rb") as f: ckpt.ParseFromString(f.read()) print(ckpt) return ckpt
def rewriteCheckPoint(checkpoint_dir, newScore): ckpt_meta_path = os.path.join(checkpoint_dir, "ckpt.meta") shutil.copy(ckpt_meta_path, ckpt_meta_path.replace(".meta", ".mete.temp")) ckpt = checkpoint_pb2.CheckPoint() if os.path.exists(ckpt_meta_path): with open(ckpt_meta_path, "rb") as f: ckpt.ParseFromString(f.read()) print(ckpt) ckpt.best_score = newScore with open(ckpt_meta_path, "wb") as f: f.write(ckpt.SerializeToString()) return ckpt
def save_checkpoint(checkpoint_dir, current_epoch, global_step, best_score, exe, main_program=fluid.default_main_program()): ckpt_meta_path = os.path.join(checkpoint_dir, CKPT_FILE_NAME) ckpt = checkpoint_pb2.CheckPoint() model_saved_dir = os.path.join(checkpoint_dir, "step_%d" % global_step) ckpt.current_epoch = current_epoch ckpt.global_step = global_step ckpt.latest_model_dir = model_saved_dir ckpt.best_score = best_score with open(ckpt_meta_path, "wb") as f: f.write(ckpt.SerializeToString())
def save_checkpoint(checkpoint_dir, current_epoch, global_step, exe, main_program=fluid.default_main_program()): ckpt_meta_path = os.path.join(checkpoint_dir, CKPT_FILE_NAME) ckpt = checkpoint_pb2.CheckPoint() model_saved_dir = os.path.join(checkpoint_dir, "step_%d" % global_step) logger.info("Saving model checkpoint to {}".format(model_saved_dir)) fluid.io.save_persistables(exe, dirname=model_saved_dir, main_program=main_program) ckpt.current_epoch = current_epoch ckpt.global_step = global_step ckpt.latest_model_dir = model_saved_dir with open(ckpt_meta_path, "wb") as f: f.write(ckpt.SerializeToString())
def load_checkpoint(checkpoint_dir, exe, main_program): ckpt_meta_path = os.path.join(checkpoint_dir, CKPT_FILE_NAME) ckpt = checkpoint_pb2.CheckPoint() logger.info("Try loading checkpoint from {}".format(ckpt_meta_path)) if os.path.exists(ckpt_meta_path): with open(ckpt_meta_path, "rb") as f: ckpt.ParseFromString(f.read()) current_epoch = 1 global_step = 0 best_score = -999 def if_exist(var): return os.path.exists(os.path.join(ckpt.latest_model_dir, var.name)) if ckpt.latest_model_dir: fluid.io.load_vars(exe, ckpt.latest_model_dir, main_program, predicate=if_exist) # Compatible with older versions without best_score in checkpoint_pb2 try: best_score = ckpt.best_score except: best_score = -999 logger.info("PaddleHub model checkpoint loaded. current_epoch={}, " "global_step={}, best_score={:.5f}".format( ckpt.current_epoch, ckpt.global_step, best_score)) return True, ckpt.current_epoch, ckpt.global_step, best_score logger.info("PaddleHub model checkpoint not found, start from scratch...") return False, current_epoch, global_step, best_score