def _match_vars(self, func): reader, chkpt_vars = SaverRestoreNoGlobalStep._read_checkpoint_vars(self.path) graph_vars = tf.global_variables() chkpt_vars_used = set() for v in graph_vars: name = get_savename_from_varname(v.name, varname_prefix=self.prefix) # skip global step if name == "global_step:0": print("skip restoring global step!") continue if reader.has_tensor(name): func(reader, name, v) chkpt_vars_used.add(name) else: vname = v.op.name if not is_training_name(vname): logger.warn("Variable {} in the graph not found in checkpoint!".format(vname)) if len(chkpt_vars_used) < len(chkpt_vars): unused = chkpt_vars - chkpt_vars_used for name in sorted(unused): if not is_training_name(name): logger.warn("Variable {} in checkpoint not found in the graph!".format(name))
if args.config: MODEL = imp.load_source('config_script', args.config).Model M = MODEL() M.build_graph(M.get_input_vars()) else: M = ModelFromMetaGraph(args.meta) # loading... if args.model.endswith('.npy'): init = sessinit.ParamRestore(np.load(args.model).item()) else: init = sessinit.SaverRestore(args.model) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) init.init(sess) # dump ... with sess.as_default(): if args.output.endswith('npy'): varmanip.dump_session_params(args.output) else: var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY)) var_dict = {} for v in var: name = varmanip.get_savename_from_varname(v.name) var_dict[name] = v logger.info("Variables to dump:") logger.info(", ".join(var_dict.keys())) saver = tf.train.Saver(var_list=var_dict) saver.save(sess, args.output, write_meta_graph=False)
args = parser.parse_args() tf.train.import_meta_graph(args.meta, clear_devices=True) # loading... init = get_model_loader(args.input) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) init.init(sess) # dump ... with sess.as_default(): if args.output.endswith('npy') or args.output.endswith('npz'): varmanip.dump_session_params(args.output) else: var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)) gvars = set([k.name for k in tf.global_variables()]) var = [v for v in var if v.name in gvars] var_dict = {} for v in var: name = varmanip.get_savename_from_varname(v.name) var_dict[name] = v logger.info("Variables to dump:") logger.info(", ".join(var_dict.keys())) saver = tf.train.Saver( var_list=var_dict, write_version=tf.train.SaverDef.V2) saver.save(sess, args.output, write_meta_graph=False)