コード例 #1
0
    def _load_persistable_vars(self, executor, dirname, program):
        def _is_checkpoint_var(var):
            """
            the checkpoint will not save or load all the variables.
            var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.

            : param var(Variable)
            """
            if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
                    var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
                    var.desc.type() == core.VarDesc.VarType.RAW:
                return False
            # @GRAD are named for gradient variables, checkpoint will not save it.
            if "@GRAD" in var.name:
                return False
            # .trainer_ are named for distribute train variables, checkpoint will not save it.
            if ".trainer_" in var.name:
                return False

            # .block is named for distribute train variables, checkpoint will not save it.
            if ".block" in var.name:
                return False

            if "tmp_" in var.name:
                return False

            return var.persistable

        io.load_vars(executor,
                     dirname=dirname,
                     main_program=program,
                     predicate=_is_checkpoint_var,
                     filename=None)
コード例 #2
0
def load_model_params(exe, params_path, program):
    """
    加载模型参数路径下的Parameter类型的参数,可以用于模型初始化
    """
    # 判断参数目录是否存在
    assert os.path.exists(params_path), "[%s] can't be found." % params_path

    # 过滤器,两层过滤,一看参数是不是Parameter类型, 二是只加载路径下已经有的参数到网络中
    def existed_params(var):
        if not isinstance(var, fluid.framework.Parameter):
            return False
        if os.path.exists(os.path.join(params_path, var.name)):
            return True
        logger.info("missing layer: {}".format(var.name))
        return False

    io.load_vars(exe, params_path, main_program=program, predicate=existed_params)