def _load_persistable_vars(self, executor, dirname, program): def _is_checkpoint_var(var): """ the checkpoint will not save or load all the variables. var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded. : param var(Variable) """ if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ var.desc.type() == core.VarDesc.VarType.RAW: return False # @GRAD are named for gradient variables, checkpoint will not save it. if "@GRAD" in var.name: return False # .trainer_ are named for distribute train variables, checkpoint will not save it. if ".trainer_" in var.name: return False # .block is named for distribute train variables, checkpoint will not save it. if ".block" in var.name: return False if "tmp_" in var.name: return False return var.persistable io.load_vars(executor, dirname=dirname, main_program=program, predicate=_is_checkpoint_var, filename=None)
def load_model_params(exe, params_path, program): """ 加载模型参数路径下的Parameter类型的参数,可以用于模型初始化 """ # 判断参数目录是否存在 assert os.path.exists(params_path), "[%s] can't be found." % params_path # 过滤器,两层过滤,一看参数是不是Parameter类型, 二是只加载路径下已经有的参数到网络中 def existed_params(var): if not isinstance(var, fluid.framework.Parameter): return False if os.path.exists(os.path.join(params_path, var.name)): return True logger.info("missing layer: {}".format(var.name)) return False io.load_vars(exe, params_path, main_program=program, predicate=existed_params)