if args.config: MODEL = imp.load_source('config_script', args.config).Model M = MODEL() M.build_graph(M.get_input_vars()) else: M = ModelFromMetaGraph(args.meta) # loading... if args.model.endswith('.npy'): init = sessinit.ParamRestore(np.load(args.model).item()) else: init = sessinit.SaverRestore(args.model) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) init.init(sess) # dump ... with sess.as_default(): if args.output.endswith('npy'): varmanip.dump_session_params(args.output) else: var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY)) var_dict = {} for v in var: name = varmanip.get_savename_from_varname(v.name) var_dict[name] = v logger.info("Variables to dump:") logger.info(", ".join(var_dict.keys())) saver = tf.train.Saver(var_list=var_dict) saver.save(sess, args.output, write_meta_graph=False)
parser.add_argument(dest='output', help='output model file, can be npz or TF checkpoint') args = parser.parse_args() tf.train.import_meta_graph(args.meta, clear_devices=True) # loading... init = get_model_loader(args.input) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) init.init(sess) # dump ... with sess.as_default(): if args.output.endswith('npy') or args.output.endswith('npz'): varmanip.dump_session_params(args.output) else: var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)) gvars = set([k.name for k in tf.global_variables()]) var = [v for v in var if v.name in gvars] var_dict = {} for v in var: name = varmanip.get_savename_from_varname(v.name) var_dict[name] = v logger.info("Variables to dump:") logger.info(", ".join(var_dict.keys())) saver = tf.train.Saver( var_list=var_dict, write_version=tf.train.SaverDef.V2) saver.save(sess, args.output, write_meta_graph=False)