Beispiel #1
0
    # this script does not need GPU
    os.environ['CUDA_VISIBLE_DEVICES'] = ''

    try:
        tf.train.import_meta_graph(args.meta, clear_devices=True)
    except KeyError:
        print(
            "If your graph contains non-standard ops, you need to import the relevant library first."
        )
        raise

    # loading...
    if args.input.endswith('.npz'):
        dic = np.load(args.input)
    else:
        dic = varmanip.load_chkpt_vars(args.input)
    dic = {get_op_tensor_name(k)[1]: v for k, v in six.iteritems(dic)}

    # save variables that are GLOBAL, and either TRAINABLE or MODEL
    var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
    if len(set(var_to_dump)) != len(var_to_dump):
        print("TRAINABLE and MODEL variables have duplication!")
    var_to_dump = list(set(var_to_dump))
    globvarname = set([k.name for k in tf.global_variables()])
    var_to_dump = set([k.name for k in var_to_dump if k.name in globvarname])

    for name in var_to_dump:
        assert name in dic, "Variable {} not found in the model!".format(name)

    dic_to_dump = {k: v for k, v in six.iteritems(dic) if k in var_to_dump}
from tensorpack.tfutils.varmanip import load_chkpt_vars
from tensorpack.utils import logger

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('model')
    parser.add_argument('--dump', help='dump to an npz file')
    parser.add_argument('--shell',
                        action='store_true',
                        help='start a shell with the params')
    args = parser.parse_args()

    if args.model.endswith('.npy'):
        params = np.load(args.model, encoding='latin1').item()
    elif args.model.endswith('.npz'):
        params = dict(np.load(args.model))
    else:
        params = load_chkpt_vars(args.model)
    logger.info("Variables in the model:")
    logger.info(str(params.keys()))

    if args.dump:
        assert args.dump.endswith('.npz'), args.dump
        np.savez(args.dump, **params)

    if args.shell:
        # params is a dict. play with it
        import IPython as IP
        IP.embed(config=IP.terminal.ipapp.load_default_config())
        description='Keep only TRAINABLE and MODEL variables in a checkpoint.')
    parser.add_argument('--meta', help='metagraph file', required=True)
    parser.add_argument(dest='input', help='input model file, has to be a TF checkpoint')
    parser.add_argument(dest='output', help='output model file, can be npz or TF checkpoint')
    args = parser.parse_args()

    # this script does not need GPU
    os.environ['CUDA_VISIBLE_DEVICES'] = ''

    tf.train.import_meta_graph(args.meta, clear_devices=True)

    # loading...
    if args.input.endswith('.npz'):
        dic = np.load(args.input)
    else:
        dic = varmanip.load_chkpt_vars(args.input)
    dic = {get_op_tensor_name(k)[1]: v for k, v in six.iteritems(dic)}

    # save variables that are GLOBAL, and either TRAINABLE or MODEL
    var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
    if len(set(var_to_dump)) != len(var_to_dump):
        print("TRAINABLE and MODEL variables have duplication!")
    var_to_dump = list(set(var_to_dump))
    globvarname = set([k.name for k in tf.global_variables()])
    var_to_dump = set([k.name for k in var_to_dump if k.name in globvarname])

    for name in var_to_dump:
        assert name in dic, "Variable {} not found in the model!".format(name)

    dic_to_dump = {k: v for k, v in six.iteritems(dic) if k in var_to_dump}
Beispiel #4
0
            try:
                tf.reset_default_graph()
                tf.train.import_meta_graph(meta, clear_devices=True)
            except KeyError as e:
                op_name = e.args[0]
                _import_external_ops(op_name)
            except tf.errors.NotFoundError as e:
                _import_external_ops(e.message)
            else:
                break

    # loading...
    if input.endswith('.npz'):
        dic = np.load(input)
    else:
        dic = varmanip.load_chkpt_vars(input)
    dic = {get_op_tensor_name(k)[1]: v for k, v in six.iteritems(dic)}

    if args.meta is not None:
        # save variables that are GLOBAL, and either TRAINABLE or MODEL
        var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
        if len(set(var_to_dump)) != len(var_to_dump):
            logger.warn("TRAINABLE and MODEL variables have duplication!")
        var_to_dump = list(set(var_to_dump))
        globvarname = set([k.name for k in tf.global_variables()])
        var_to_dump = set(
            [k.name for k in var_to_dump if k.name in globvarname])

        for name in var_to_dump:
            assert name in dic, "Variable {} not found in the model!".format(