Esempio n. 1
0
    if args.config:
        MODEL = imp.load_source('config_script', args.config).Model
        M = MODEL()
        M.build_graph(M.get_input_vars())
    else:
        M = ModelFromMetaGraph(args.meta)

    # loading...
    if args.model.endswith('.npy'):
        init = sessinit.ParamRestore(np.load(args.model).item())
    else:
        init = sessinit.SaverRestore(args.model)
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    init.init(sess)

    # dump ...
    with sess.as_default():
        if args.output.endswith('npy'):
            varmanip.dump_session_params(args.output)
        else:
            var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY))
            var_dict = {}
            for v in var:
                name = varmanip.get_savename_from_varname(v.name)
                var_dict[name] = v
            logger.info("Variables to dump:")
            logger.info(", ".join(var_dict.keys()))
            saver = tf.train.Saver(var_list=var_dict)
            saver.save(sess, args.output, write_meta_graph=False)
Esempio n. 2
0
    parser.add_argument(dest='output', help='output model file, can be npz or TF checkpoint')
    args = parser.parse_args()

    tf.train.import_meta_graph(args.meta, clear_devices=True)

    # loading...
    init = get_model_loader(args.input)
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    init.init(sess)

    # dump ...
    with sess.as_default():
        if args.output.endswith('npy') or args.output.endswith('npz'):
            varmanip.dump_session_params(args.output)
        else:
            var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
            gvars = set([k.name for k in tf.global_variables()])
            var = [v for v in var if v.name in gvars]
            var_dict = {}
            for v in var:
                name = varmanip.get_savename_from_varname(v.name)
                var_dict[name] = v
            logger.info("Variables to dump:")
            logger.info(", ".join(var_dict.keys()))
            saver = tf.train.Saver(
                var_list=var_dict,
                write_version=tf.train.SaverDef.V2)
            saver.save(sess, args.output, write_meta_graph=False)