コード例 #1
0
def load_checkpoint(checkpoint_dir, exe, main_program):

    ckpt_meta_path = os.path.join(checkpoint_dir, CKPT_FILE_NAME)
    ckpt = checkpoint_pb2.CheckPoint()
    logger.info("Try loading checkpoint from {}".format(ckpt_meta_path))
    if os.path.exists(ckpt_meta_path):
        with open(ckpt_meta_path, "rb") as f:
            ckpt.ParseFromString(f.read())
    current_epoch = 1
    global_step = 0

    def if_exist(var):
        return os.path.exists(os.path.join(ckpt.latest_model_dir, var.name))

    if ckpt.latest_model_dir:
        fluid.io.load_vars(exe,
                           ckpt.latest_model_dir,
                           main_program,
                           predicate=if_exist)

        logger.info("PaddleHub model checkpoint loaded. current_epoch={}, "
                    "global_step={}".format(ckpt.current_epoch,
                                            ckpt.global_step))
        return True, ckpt.current_epoch, ckpt.global_step

    logger.info(
        "PaddleHub model checkpoint not found, start training from scratch...")

    return False, current_epoch, global_step
コード例 #2
0
ファイル: hubutils.py プロジェクト: onewaymyway/aiutils
def getCheckPointInfo(checkpoint_dir):
    ckpt_meta_path = os.path.join(checkpoint_dir, "ckpt.meta")
    ckpt = checkpoint_pb2.CheckPoint()
    if os.path.exists(ckpt_meta_path):
        with open(ckpt_meta_path, "rb") as f:
            ckpt.ParseFromString(f.read())
    print(ckpt)
    return ckpt
コード例 #3
0
ファイル: hubutils.py プロジェクト: onewaymyway/aiutils
def rewriteCheckPoint(checkpoint_dir, newScore):
    ckpt_meta_path = os.path.join(checkpoint_dir, "ckpt.meta")

    shutil.copy(ckpt_meta_path, ckpt_meta_path.replace(".meta", ".mete.temp"))
    ckpt = checkpoint_pb2.CheckPoint()
    if os.path.exists(ckpt_meta_path):
        with open(ckpt_meta_path, "rb") as f:
            ckpt.ParseFromString(f.read())
    print(ckpt)
    ckpt.best_score = newScore
    with open(ckpt_meta_path, "wb") as f:
        f.write(ckpt.SerializeToString())
    return ckpt
コード例 #4
0
def save_checkpoint(checkpoint_dir,
                    current_epoch,
                    global_step,
                    best_score,
                    exe,
                    main_program=fluid.default_main_program()):

    ckpt_meta_path = os.path.join(checkpoint_dir, CKPT_FILE_NAME)
    ckpt = checkpoint_pb2.CheckPoint()

    model_saved_dir = os.path.join(checkpoint_dir, "step_%d" % global_step)
    ckpt.current_epoch = current_epoch
    ckpt.global_step = global_step
    ckpt.latest_model_dir = model_saved_dir
    ckpt.best_score = best_score
    with open(ckpt_meta_path, "wb") as f:
        f.write(ckpt.SerializeToString())
コード例 #5
0
def save_checkpoint(checkpoint_dir,
                    current_epoch,
                    global_step,
                    exe,
                    main_program=fluid.default_main_program()):

    ckpt_meta_path = os.path.join(checkpoint_dir, CKPT_FILE_NAME)
    ckpt = checkpoint_pb2.CheckPoint()

    model_saved_dir = os.path.join(checkpoint_dir, "step_%d" % global_step)
    logger.info("Saving model checkpoint to {}".format(model_saved_dir))
    fluid.io.save_persistables(exe,
                               dirname=model_saved_dir,
                               main_program=main_program)

    ckpt.current_epoch = current_epoch
    ckpt.global_step = global_step
    ckpt.latest_model_dir = model_saved_dir
    with open(ckpt_meta_path, "wb") as f:
        f.write(ckpt.SerializeToString())
コード例 #6
0
def load_checkpoint(checkpoint_dir, exe, main_program):

    ckpt_meta_path = os.path.join(checkpoint_dir, CKPT_FILE_NAME)
    ckpt = checkpoint_pb2.CheckPoint()
    logger.info("Try loading checkpoint from {}".format(ckpt_meta_path))
    if os.path.exists(ckpt_meta_path):
        with open(ckpt_meta_path, "rb") as f:
            ckpt.ParseFromString(f.read())
    current_epoch = 1
    global_step = 0
    best_score = -999

    def if_exist(var):
        return os.path.exists(os.path.join(ckpt.latest_model_dir, var.name))

    if ckpt.latest_model_dir:
        fluid.io.load_vars(exe,
                           ckpt.latest_model_dir,
                           main_program,
                           predicate=if_exist)

        # Compatible with older versions without best_score in checkpoint_pb2
        try:
            best_score = ckpt.best_score
        except:
            best_score = -999

        logger.info("PaddleHub model checkpoint loaded. current_epoch={}, "
                    "global_step={}, best_score={:.5f}".format(
                        ckpt.current_epoch, ckpt.global_step, best_score))

        return True, ckpt.current_epoch, ckpt.global_step, best_score

    logger.info("PaddleHub model checkpoint not found, start from scratch...")

    return False, current_epoch, global_step, best_score