# this script does not need GPU os.environ['CUDA_VISIBLE_DEVICES'] = '' try: tf.train.import_meta_graph(args.meta, clear_devices=True) except KeyError: print( "If your graph contains non-standard ops, you need to import the relevant library first." ) raise # loading... if args.input.endswith('.npz'): dic = np.load(args.input) else: dic = varmanip.load_chkpt_vars(args.input) dic = {get_op_tensor_name(k)[1]: v for k, v in six.iteritems(dic)} # save variables that are GLOBAL, and either TRAINABLE or MODEL var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)) if len(set(var_to_dump)) != len(var_to_dump): print("TRAINABLE and MODEL variables have duplication!") var_to_dump = list(set(var_to_dump)) globvarname = set([k.name for k in tf.global_variables()]) var_to_dump = set([k.name for k in var_to_dump if k.name in globvarname]) for name in var_to_dump: assert name in dic, "Variable {} not found in the model!".format(name) dic_to_dump = {k: v for k, v in six.iteritems(dic) if k in var_to_dump}
from tensorpack.tfutils.varmanip import load_chkpt_vars from tensorpack.utils import logger if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('model') parser.add_argument('--dump', help='dump to an npz file') parser.add_argument('--shell', action='store_true', help='start a shell with the params') args = parser.parse_args() if args.model.endswith('.npy'): params = np.load(args.model, encoding='latin1').item() elif args.model.endswith('.npz'): params = dict(np.load(args.model)) else: params = load_chkpt_vars(args.model) logger.info("Variables in the model:") logger.info(str(params.keys())) if args.dump: assert args.dump.endswith('.npz'), args.dump np.savez(args.dump, **params) if args.shell: # params is a dict. play with it import IPython as IP IP.embed(config=IP.terminal.ipapp.load_default_config())
description='Keep only TRAINABLE and MODEL variables in a checkpoint.') parser.add_argument('--meta', help='metagraph file', required=True) parser.add_argument(dest='input', help='input model file, has to be a TF checkpoint') parser.add_argument(dest='output', help='output model file, can be npz or TF checkpoint') args = parser.parse_args() # this script does not need GPU os.environ['CUDA_VISIBLE_DEVICES'] = '' tf.train.import_meta_graph(args.meta, clear_devices=True) # loading... if args.input.endswith('.npz'): dic = np.load(args.input) else: dic = varmanip.load_chkpt_vars(args.input) dic = {get_op_tensor_name(k)[1]: v for k, v in six.iteritems(dic)} # save variables that are GLOBAL, and either TRAINABLE or MODEL var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)) if len(set(var_to_dump)) != len(var_to_dump): print("TRAINABLE and MODEL variables have duplication!") var_to_dump = list(set(var_to_dump)) globvarname = set([k.name for k in tf.global_variables()]) var_to_dump = set([k.name for k in var_to_dump if k.name in globvarname]) for name in var_to_dump: assert name in dic, "Variable {} not found in the model!".format(name) dic_to_dump = {k: v for k, v in six.iteritems(dic) if k in var_to_dump}
try: tf.reset_default_graph() tf.train.import_meta_graph(meta, clear_devices=True) except KeyError as e: op_name = e.args[0] _import_external_ops(op_name) except tf.errors.NotFoundError as e: _import_external_ops(e.message) else: break # loading... if input.endswith('.npz'): dic = np.load(input) else: dic = varmanip.load_chkpt_vars(input) dic = {get_op_tensor_name(k)[1]: v for k, v in six.iteritems(dic)} if args.meta is not None: # save variables that are GLOBAL, and either TRAINABLE or MODEL var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)) if len(set(var_to_dump)) != len(var_to_dump): logger.warn("TRAINABLE and MODEL variables have duplication!") var_to_dump = list(set(var_to_dump)) globvarname = set([k.name for k in tf.global_variables()]) var_to_dump = set( [k.name for k in var_to_dump if k.name in globvarname]) for name in var_to_dump: assert name in dic, "Variable {} not found in the model!".format(