def _match_vars(self, func):
     reader, chkpt_vars = SaverRestoreNoGlobalStep._read_checkpoint_vars(self.path)
     graph_vars = tf.global_variables()
     chkpt_vars_used = set()
     for v in graph_vars:
         name = get_savename_from_varname(v.name, varname_prefix=self.prefix)
         # skip global step
         if name == "global_step:0":
             print("skip restoring global step!")
             continue
         
         if reader.has_tensor(name):
             func(reader, name, v)
             chkpt_vars_used.add(name)
         else:
             vname = v.op.name
             if not is_training_name(vname):
                 logger.warn("Variable {} in the graph not found in checkpoint!".format(vname))
     if len(chkpt_vars_used) < len(chkpt_vars):
         unused = chkpt_vars - chkpt_vars_used
         for name in sorted(unused):
             if not is_training_name(name):
                 logger.warn("Variable {} in checkpoint not found in the graph!".format(name))
Esempio n. 2
0
    if args.config:
        MODEL = imp.load_source('config_script', args.config).Model
        M = MODEL()
        M.build_graph(M.get_input_vars())
    else:
        M = ModelFromMetaGraph(args.meta)

    # loading...
    if args.model.endswith('.npy'):
        init = sessinit.ParamRestore(np.load(args.model).item())
    else:
        init = sessinit.SaverRestore(args.model)
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    init.init(sess)

    # dump ...
    with sess.as_default():
        if args.output.endswith('npy'):
            varmanip.dump_session_params(args.output)
        else:
            var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY))
            var_dict = {}
            for v in var:
                name = varmanip.get_savename_from_varname(v.name)
                var_dict[name] = v
            logger.info("Variables to dump:")
            logger.info(", ".join(var_dict.keys()))
            saver = tf.train.Saver(var_list=var_dict)
            saver.save(sess, args.output, write_meta_graph=False)
Esempio n. 3
0
    args = parser.parse_args()

    tf.train.import_meta_graph(args.meta, clear_devices=True)

    # loading...
    init = get_model_loader(args.input)
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    init.init(sess)

    # dump ...
    with sess.as_default():
        if args.output.endswith('npy') or args.output.endswith('npz'):
            varmanip.dump_session_params(args.output)
        else:
            var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
            gvars = set([k.name for k in tf.global_variables()])
            var = [v for v in var if v.name in gvars]
            var_dict = {}
            for v in var:
                name = varmanip.get_savename_from_varname(v.name)
                var_dict[name] = v
            logger.info("Variables to dump:")
            logger.info(", ".join(var_dict.keys()))
            saver = tf.train.Saver(
                var_list=var_dict,
                write_version=tf.train.SaverDef.V2)
            saver.save(sess, args.output, write_meta_graph=False)