def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir=None, saved_model_tags=None, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants. Args: input_graph_def: A `GraphDef`. input_saver_def: A `SaverDef` (optional). input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking priority. Typically the result of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or V1/V2. output_node_names: The name(s) of the output nodes, comma separated. restore_op_name: Unused. filename_tensor_name: Unused. output_graph: String where to write the frozen `GraphDef`. clear_devices: A Bool whether to remove device specifications. initializer_nodes: Comma separated string of initializer nodes to run before freezing. variable_names_whitelist: The set of variable names to convert (optional, by default, all variables are converted). variable_names_blacklist: The set of variable names to omit converting to constants (optional). input_meta_graph_def: A `MetaGraphDef` (optional), input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and variables (optional). saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to load, in string format (optional). checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1 or saver_pb2.SaverDef.V2) Returns: Location of the output_graph_def. """ del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not checkpoint_management.checkpoint_exists(input_checkpoint)): raise ValueError("Input checkpoint '" + input_checkpoint + "' doesn't exist!") if not output_node_names: raise ValueError( "You need to supply the name of a node to --output_node_names.") # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = "" if input_graph_def: _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver(saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph(input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() # List of all partition variables. Because the condition is heuristic # based, the list could include false positives. all_parition_variable_names = [ tensor.name.split(":")[0] for op in sess.graph.get_operations() for tensor in op.values() if re.search(r"/part_\d+/", tensor.name) ] has_partition_var = False for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") if any(key in name for name in all_parition_variable_names): has_partition_var = True except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor try: saver = saver_lib.Saver(var_list=var_list, write_version=checkpoint_version) except TypeError as e: # `var_list` is required to be a map of variable names to Variable # tensors. Partition variables are Identity tensors that cannot be # handled by Saver. if has_partition_var: raise ValueError( "Models containing partition variables cannot be converted " "from checkpoint files. Please pass in a SavedModel using " "the flag --input_saved_model_dir.") # Models that have been frozen previously do not contain Variables. elif _has_no_variables(sess): raise ValueError( "No variables were found in this model. It is likely the model " "was frozen previously. You cannot freeze a graph twice." ) return 0 else: raise e saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) variable_names_whitelist = (variable_names_whitelist.replace( " ", "").split(",") if variable_names_whitelist else None) variable_names_blacklist = (variable_names_blacklist.replace( " ", "").split(",") if variable_names_blacklist else None) if input_meta_graph_def: output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) else: output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) # Write GraphDef to file if output path has been given. if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) return output_graph_def
def restore(self, save_path): """Restore a training checkpoint. Restores `root_checkpointable` and any objects that it tracks (transitive). Either assigns values immediately if variables to restore have been created already, or defers restoration until the variables are created. Dependencies added to the `root_checkpointable` passed to the constructor after this call will be matched if they have a corresponding object in the checkpoint. When building a graph, restorations are added to the graph but not run. To disallow deferred loading, assert immediately that all checkpointed variables have been matched to variable objects: ```python saver = Saver(root) saver.restore(path).assert_consumed() ``` An exception will be raised unless every object was matched and its variables already exist. When graph building, `assert_consumed()` indicates that all of the restore ops which will be created for this checkpoint have been created. They can be run via the `run_restore_ops()` function of the status object: ```python saver.restore(path).assert_consumed().run_restore_ops() ``` If the checkpoint has not been consumed completely, then the list of restore ops will grow as more objects are added to the dependency graph. Name-based `tf.train.Saver` checkpoints can be loaded using this method. There is no deferred loading, and names are used to match variables. No restore ops are created/run until `run_restore_ops()` or `initialize_or_restore()` are called on the returned status object, even when executing eagerly. Re-encode name-based checkpoints using this object-based `Saver.save` as soon as possible. Args: save_path: The path to the checkpoint, as returned by `save` or `tf.train.latest_checkpoint`. If None (as when there is no latest checkpoint for `tf.train.latest_checkpoint` to return), returns an object which may run initializers for objects in the dependency graph. If the checkpoint was written by the name-based `tf.train.Saver`, names are used to match variables. Returns: A load status object, which can be used to make assertions about the status of checkpoint restoration and run initialization/restore ops (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if `save_path` is `None`). If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus` object is returned which runs restore ops from a name-based saver. """ if save_path is None: return InitializationOnlyStatus(self._root_checkpointable) in_graph_mode = not context.executing_eagerly() if in_graph_mode: file_prefix_tensor = self._file_prefix_placeholder file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} else: with ops.device("/cpu:0"): file_prefix_tensor = constant_op.constant(save_path) file_prefix_feed_dict = None reader = pywrap_tensorflow.NewCheckpointReader(save_path) try: object_graph_string = reader.get_tensor( checkpointable_lib.OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError: # The object graph proto does not exist in this checkpoint. Try again with # name-based saving. return NameBasedSaverStatus(self, save_path) object_graph_proto = ( checkpointable_object_graph_pb2.CheckpointableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) if in_graph_mode and object_graph_proto == self._last_restore_object_graph: checkpoint = self._last_restore_checkpoint else: if in_graph_mode: dtype_map = None else: dtype_map = reader.get_variable_to_dtype_map() checkpoint = _CheckpointRestoreCoordinator( object_graph_proto=object_graph_proto, save_path=file_prefix_tensor, dtype_map=dtype_map) if in_graph_mode: if self._last_restore_object_graph is not None: raise NotImplementedError( "Using a single Saver to restore different object graphs is not " "currently supported when graph building. Use a different Saver " "for each object graph (restore ops will be duplicated), or " "file a feature request if this limitation bothers you." ) self._last_restore_checkpoint = checkpoint self._last_restore_object_graph = object_graph_proto checkpointable_lib._CheckpointPosition( # pylint: disable=protected-access checkpoint=checkpoint, proto_id=0).restore(self._root_checkpointable) load_status = CheckpointLoadStatus(checkpoint, feed_dict=file_prefix_feed_dict) return load_status
Note that Inception Layer 5b, Branch 2, Convolution 0a and 0b are both incorrectly labeled 0a. This causes 0b to overwrite 0a. To get around this issue, consider removing the '\dx\d' regex to obtain both sets of weights """ # tf.reset_default_graph() # saver = tf.train.import_meta_graph('TFtransformer_files/imagenet_inception_v1.ckpt.meta') # with tf.Session() as sess: # new_saver = tf.train.import_meta_graph('TFtransformer_files/imagenet_inception_v1.ckpt.meta') # new_saver.restore(sess, tf.train.latest_checkpoint('C:/Users/Eugene/Documents/UCLA/Research/TFtransformer_files/imagenet_inception_v1.ckpt')) # new_saver.restore(sess,'C:/Users/Eugene/Documents/UCLA/Research/TFtransformer_files/imagenet_inception_v1.ckpt') # print("Model restored.") # chkp.print_tensors_in_checkpoint_file('C:/Users/Eugene/Documents/UCLA/Research/TFtransformer_files/imagenet_inception_v1.ckpt', tensor_name='', all_tensors=False, all_tensor_names=True) # dir = 'C:/Users/Eugene/Documents/UCLA/Research/TFtransformer_files/imagenet_inception_v1.ckpt' dir = 'C:/Users/Eugene/Documents/UCLA/Research/inception_v1/inception_v1.ckpt' reader = pywrap_tensorflow.NewCheckpointReader(dir) var_to_shape_map = reader.get_variable_to_shape_map() for key in sorted(var_to_shape_map): name = re.sub('InceptionV1/', '', key) name = re.sub('/', '_', name) name = re.sub('Conv2d', 'conv', name) name = re.sub('\dx\d_', '', name) name = re.sub('BatchNorm_moving_mean', 'batchmean', name) name = re.sub('BatchNorm_moving_variance', 'batchvar', name) name = re.sub('BatchNorm_beta', 'batchbeta', name) scipy.io.savemat(name, dict(weights=reader.get_tensor(key)))
def main(): parser = argparse.ArgumentParser() parser.add_argument("--layers_num", type=int, default=12, help=".") parser.add_argument("--input_model_path", default="models/bert_base_chinese/bert_model.ckpt", type=str, help=".") parser.add_argument("--output_model_path", default=None, type=str, required=True, help="Path to the output PyTorch model.") args = parser.parse_args() reader = pywrap_tensorflow.NewCheckpointReader(args.input_model_path) var_to_shape_map = reader.get_variable_to_shape_map() input_model = collections.OrderedDict() for key in var_to_shape_map: torch_tensor = reader.get_tensor(key) if any([x in key for x in tensors_to_transopse]): torch_tensor = torch_tensor.T if key == "bert/embeddings/token_type_embeddings": col_dim = torch_tensor.shape[1] sess = tf.Session() zeros_var = tf.Variable(tf.zeros([1, col_dim], dtype=tf.float32), name="zeros_var") sess.run(zeros_var.initializer) torch_tensor = sess.run(tf.concat([sess.run(zeros_var), torch_tensor], 0)) input_model[key] = torch.Tensor(torch_tensor) output_model = collections.OrderedDict() output_model["embedding.word_embedding.weight"] = input_model["bert/embeddings/word_embeddings"] output_model["embedding.position_embedding.weight"] = input_model["bert/embeddings/position_embeddings"][:512] output_model["embedding.segment_embedding.weight"] = input_model["bert/embeddings/token_type_embeddings"] output_model["embedding.layer_norm.gamma"] = input_model["bert/embeddings/LayerNorm/gamma"] output_model["embedding.layer_norm.beta"] = input_model["bert/embeddings/LayerNorm/beta"] for i in range(args.layers_num): output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.0.weight"] = input_model["bert/encoder/layer_" + str(i) + "/attention/self/query/kernel"] output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.0.bias"] = input_model["bert/encoder/layer_" + str(i) + "/attention/self/query/bias"] output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.1.weight"] = input_model["bert/encoder/layer_" + str(i) + "/attention/self/key/kernel"] output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.1.bias"] = input_model["bert/encoder/layer_" + str(i) + "/attention/self/key/bias"] output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.2.weight"] = input_model["bert/encoder/layer_" + str(i) + "/attention/self/value/kernel"] output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.2.bias"] = input_model["bert/encoder/layer_" + str(i) + "/attention/self/value/bias"] output_model["encoder.transformer." + str(i) + ".self_attn.final_linear.weight"] = input_model["bert/encoder/layer_" + str(i) + "/attention/output/dense/kernel"] output_model["encoder.transformer." + str(i) + ".self_attn.final_linear.bias"] = input_model["bert/encoder/layer_" + str(i) + "/attention/output/dense/bias"] output_model["encoder.transformer." + str(i) + ".layer_norm_1.gamma"] = input_model["bert/encoder/layer_" + str(i) + "/attention/output/LayerNorm/gamma"] output_model["encoder.transformer." + str(i) + ".layer_norm_1.beta"] = input_model["bert/encoder/layer_" + str(i) + "/attention/output/LayerNorm/beta"] output_model["encoder.transformer." + str(i) + ".feed_forward.linear_1.weight"] = input_model["bert/encoder/layer_" + str(i) + "/intermediate/dense/kernel"] output_model["encoder.transformer." + str(i) + ".feed_forward.linear_1.bias"] = input_model["bert/encoder/layer_" + str(i) + "/intermediate/dense/bias"] output_model["encoder.transformer." + str(i) + ".feed_forward.linear_2.weight"] = input_model["bert/encoder/layer_" + str(i) + "/output/dense/kernel"] output_model["encoder.transformer." + str(i) + ".feed_forward.linear_2.bias"] = input_model["bert/encoder/layer_" + str(i) + "/output/dense/bias"] output_model["encoder.transformer." + str(i) + ".layer_norm_2.gamma"] = input_model["bert/encoder/layer_" + str(i) + "/output/LayerNorm/gamma"] output_model["encoder.transformer." + str(i) + ".layer_norm_2.beta"] = input_model["bert/encoder/layer_" + str(i) + "/output/LayerNorm/beta"] output_model["target.nsp_linear_1.weight"] = input_model["bert/pooler/dense/kernel"] output_model["target.nsp_linear_1.bias"] = input_model["bert/pooler/dense/bias"] output_model["target.nsp_linear_2.weight"] = input_model["cls/seq_relationship/output_weights"] output_model["target.nsp_linear_2.bias"] = input_model["cls/seq_relationship/output_bias"] output_model["target.mlm_linear_1.weight"] = input_model["cls/predictions/transform/dense/kernel"] output_model["target.mlm_linear_1.bias"] = input_model["cls/predictions/transform/dense/bias"] output_model["target.layer_norm.gamma"] = input_model["cls/predictions/transform/LayerNorm/gamma"] output_model["target.layer_norm.beta"] = input_model["cls/predictions/transform/LayerNorm/beta"] output_model["target.mlm_linear_2.weight"] = input_model["bert/embeddings/word_embeddings"] output_model["target.mlm_linear_2.bias"] = input_model["cls/predictions/output_bias"] torch.save(output_model, args.output_model_path)
import sys import math import numpy as np import tensorflow as tf from utils import configs from utils.ops import load_image from PIL import Image import os import scipy from scipy import io from tensorflow.python import pywrap_tensorflow file_name = '/media/best/Coding Disk/sae_ws/inception_v1.ckpt' #.ckpt的路径 name_variable_to_restore = 'InceptionV1/Conv2d_1a_7x7/weights' #要读取权重的变量名 name_variable2_to_restore = 'InceptionV1/Conv2d_1a_7x7/BatchNorm/beta' reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() ''' for key in var_to_shape_map: print("tensor_name: ", key) ''' conv1 = tf.get_variable("Conv2d_1a_7x7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定义接收权重的变量名 bias1 = tf.get_variable("Conv2d_1a_7x7/BatchNorm", var_to_shape_map[name_variable2_to_restore], trainable=False) print(conv1) print(bias1)
def train(cont): sess = tf.Session() game = Game(SCREEN_WIDTH, SCREEN_HEIGHT, OBS_NUM, BUN_NUM, show_game=False) brain = DQN(sess, SCREEN_WIDTH, SCREEN_HEIGHT, CHANNEL, NUM_ACTION) rewards = tf.placeholder(tf.float32, [None]) tf.summary.scalar('avg.reward/ep.', tf.reduce_mean(rewards)) saver = tf.train.Saver() if cont: sess.run(tf.global_variables_initializer()) ckpt = str(tf.train.get_checkpoint_state('model')) i = ckpt.find("\"") + 1 j = ckpt.find("\"", i) reader = pywrap_tensorflow.NewCheckpointReader(ckpt[i:j]) var_to_shape_map = reader.get_variable_to_shape_map() target_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) for key in var_to_shape_map: if "conv2d" in key and "Adam" not in key: for key_f in target_vars: if key in key_f.name: sess.run(key_f.assign(reader.get_tensor(key))) break # saver.restore(sess, ckpt.model_checkpoint_path) else: sess.run(tf.global_variables_initializer()) writer = tf.summary.FileWriter('logs', sess.graph) summary_merged = tf.summary.merge_all() # 타겟 네트웍을 초기화합니다. brain.update_target_network() # 다음에 취할 액션을 DQN 을 이용해 결정할 시기를 결정합니다. epsilon = 1.0 # 프레임 횟수 time_step = 0 total_reward_list = [] for episode in range(MAX_EPISODE): terminal = False total_reward = 0 state = game.reset() brain.init_state(state) if episode > OBSERVE: epsilon = 0.01 while not terminal: if np.random.rand() < epsilon: action = random.randrange(NUM_ACTION) else: action = brain.get_action() epsilon += 0.00001 state, reward, terminal = game.step(action) total_reward += reward brain.remember(state, action, reward, terminal) if time_step > OBSERVE and time_step % TRAIN_INTERVAL == 0: brain.train() if time_step % TARGET_UPDATE_INTERVAL == 0: brain.update_target_network() time_step += 1 if episode % 10 == 0: print('Games: %d Score: %d' % (episode + 1, total_reward)) total_reward_list.append(total_reward) if episode % 10 == 0: summary = sess.run(summary_merged, feed_dict={rewards: total_reward_list}) writer.add_summary(summary, time_step) total_reward_list = [] if episode % 10000 == 0: saver.save(sess, 'model/dqn.ckpt', global_step=episode)
#!/usr/bin/env python # -*- coding: UTF-8 -*- # coding=utf-8 """ @author: Li Tian @contact: [email protected] @software: pycharm @file: find_error.py @time: 2019/4/29 12:17 @desc: 报错参数名字不一样,检查问题 """ import os from tensorflow.python import pywrap_tensorflow model_dir = './' checkpoint_path = os.path.join(model_dir, 'attention_ckpt-2800') reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print("tensor_name: ", key) print(reader.get_tensor(key))
import tensorflow as tf import numpy as np from tensorflow.python import pywrap_tensorflow te_ckpt_dir = '/data/logs-wavenet/nsynth_wavenet_mol' st_ckpt_dir = '/data/logs-wavenet/nsynth_parallel_wavenet' te_ckpt = tf.train.latest_checkpoint(te_ckpt_dir) assert tf.train.checkpoint_exists(te_ckpt) st_ckpt = tf.train.latest_checkpoint(st_ckpt_dir) assert tf.train.checkpoint_exists(st_ckpt) te_reader = pywrap_tensorflow.NewCheckpointReader(te_ckpt) te_var_to_shape_map = te_reader.get_variable_to_shape_map() te_var_names = list(te_var_to_shape_map.keys()) st_reader = pywrap_tensorflow.NewCheckpointReader(st_ckpt) st_var_to_shape_map = st_reader.get_variable_to_shape_map() st_var_names = list(st_var_to_shape_map.keys()) inspect_var_names = [vn for vn in te_var_names if (vn.endswith('W') or vn.endswith('biases') or vn.endswith('bias') or vn.endswith('kernel'))] for ivn in inspect_var_names: assert np.allclose( st_reader.get_tensor(ivn), te_reader.get_tensor('{}/ExponentialMovingAverage'.format(ivn)))
def main(args): evaluate = None print_flags_info() try: #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0) # avoid using all gpu memory config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) #gpu_options=gpu_options) config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) evaluate = Evaluate() saver = tf.train.Saver() #print(tf.get_default_graph().as_graph_def()) # # from tensorflow.python import pywrap_tensorflow # import os # checkpoint_path = os.path.join(model_dir, "model.ckpt") # reader = pywrap_tensorflow.NewCheckpointReader(checkpoint.model_checkpoint_path) # var_to_shape_map = reader.get_variable_to_shape_map() # for key in var_to_shape_map: # print("tensor_name: ", key) # #print(reader.get_tensor(key)) checkpoint = tf.train.get_checkpoint_state(flags.checkpoint_dir) if flags.checkpoint != "": #checkpoint = tf.train.get_checkpoint_state(flags.checkpoint_dir, latest_filename=flags.checkpoint) checkpoint.model_checkpoint_path = os.path.join(flags.checkpoint_dir, flags.checkpoint) if not os.path.exists(checkpoint.model_checkpoint_path): checkpoint = tf.train.get_checkpoint_state(flags.checkpoint_dir) if checkpoint and checkpoint.model_checkpoint_path: # List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80] print_tensors_in_checkpoint_file(file_name=checkpoint.model_checkpoint_path, tensor_name='', all_tensors=False, all_tensor_names=True) print("Checkpoint file path:", checkpoint.model_checkpoint_path) if flags.segnet == 0: from tensorflow.python import pywrap_tensorflow reader = pywrap_tensorflow.NewCheckpointReader(checkpoint.model_checkpoint_path) big_var_to_shape_map = reader.get_variable_to_shape_map() s = [] for key in big_var_to_shape_map: s += [key] #print("tensor_name: ", key) glob_var_names = [v.name for v in tf.global_variables()] endings = [r.split('/')[-1][:-2] for r in glob_var_names] old_ckpt_to_new_ckpt = {[k for k in s if endings[i] in k][0]: v for i, v in enumerate(tf.global_variables())} saver1 = tf.train.Saver(var_list=old_ckpt_to_new_ckpt) saver1.restore(sess, checkpoint.model_checkpoint_path) else: saver.restore(sess, checkpoint.model_checkpoint_path) print("checkpoint loaded:", checkpoint.model_checkpoint_path) tokens = checkpoint.model_checkpoint_path.split("-") # set global step if 'best' in checkpoint.model_checkpoint_path: i = 3 if len(tokens) > 3 else 2 else: i = 2 if len(tokens) > 3 else 1 global_t = int(tokens[i]) # for i in range(flags.parallel_size): # self.trainers[i].local_t = self.global_t print(">>> global step set: ", global_t) else: print("Could not find old checkpoint") if flags.segnet >= 2: sess.run([evaluate.global_network.reset_evaluation_vars]) while not evaluate.is_done(): evaluate.update(sess) evaluate.episode_roomtype = np.array(evaluate.episode_roomtype) evaluate.episode_reward = np.array(evaluate.episode_reward[:-1]) # last is unnecessary n_episode = len(evaluate.episode_reward) evaluate.success_rate = np.array(evaluate.success_rate) if flags.segnet >= 2: score_miou = sess.run(evaluate.global_network.evaluation) # print(type(score), # np.isnan(score), # score is None) print("Global mIoU: {}".format(score_miou)) print("Success Rate:{} ".format(np.sum(evaluate.success_rate) / n_episode)) print("RoomType distribution") for k, v in evaluate.roomType_dict.items(): fraq = np.mean(evaluate.episode_roomtype == k) print("RoomType {0}: {1:.3%}".format(v, fraq), end="\n") for k, v in evaluate.roomType_dict.items(): roomtype_ind = evaluate.episode_roomtype == k fraq_succ = np.sum(evaluate.success_rate[roomtype_ind])/np.sum(roomtype_ind) av_reward = np.sum(evaluate.episode_reward[roomtype_ind])/np.sum(roomtype_ind) print("RoomType {0} success rate: {1:.6%}, average episode reward: {2:.4}".format(v, fraq_succ, av_reward)) for k, v in evaluate.segnet_class_dict.items(): sim_all = np.array(v) print("For class id {} accuracy is {:.5}".format(k, np.sum(sim_all[:, 0])/np.sum(sim_all[:, 1]))) except Exception as e: print(traceback.format_exc()) finally: if evaluate is not None: evaluate.environment.stop()
from tensorflow.python import pywrap_tensorflow import os model_dir = "/home/sid/rddnn/models/" checkpoint_path = os.path.join(model_dir, "test.ckpt") reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print("tensor_name: ", key) # print(reader.get_tensor(key)) #print(reader.get_tensor("fc4_weights")) print("===========fc3_wieghts(test)================") print(reader.get_tensor("fc3_weights")) print("===========fc6_wieghts(test)================") #print(reader.get_tensor("fc6_weights")) train_checkpoint_path = os.path.join(model_dir, "trained.ckpt") reader_train = pywrap_tensorflow.NewCheckpointReader(train_checkpoint_path) print("===========fc3_wieghts(train)================") print(reader_train.get_tensor("fc3_weights")) print("===========fc6_wieghts(train)================") print(reader_train.get_tensor("fc6_weights")) print("Train tensor names ==========================") var_to_shape_map_train = reader_train.get_variable_to_shape_map() for key in var_to_shape_map_train: print("tensor_name: ", key)
def assign_from_checkpoint(model_path, var_list, ignore_missing_vars=False): """Creates an operation to assign specific variables from a checkpoint. Args: model_path: The full path to the model checkpoint. To get latest checkpoint use `model_path = tf.train.latest_checkpoint(checkpoint_dir)` var_list: A list of (possibly partitioned) `Variable` objects or a dictionary mapping names in the checkpoint to the corresponding variables or list of variables to initialize from that checkpoint value. For partitioned Variables, the name in the checkpoint must be the full variable, not the name of the partitioned variable, eg. "my_var" rather than "my_var/part_4". If empty, returns no_op(), {}. ignore_missing_vars: Boolean, if True ignore variables missing in the checkpoint with a warning instead of failing. Returns: the restore_op and the feed_dict that need to be run to restore var_list. Raises: ValueError: If `ignore_missing_vars` is False and the checkpoint specified at `model_path` is missing one of the variables in `var_list`. """ # Normalize var_list into a dictionary mapping names in the # checkpoint to the list of variables to initialize from that # checkpoint variable. Sliced (including partitioned) variables will # end up under the same key. grouped_vars = {} if isinstance(var_list, (tuple, list)): for var in var_list: ckpt_name = get_variable_full_name(var) if ckpt_name not in grouped_vars: grouped_vars[ckpt_name] = [] grouped_vars[ckpt_name].append(var) else: for ckpt_name, value in var_list.items(): if isinstance(value, (tuple, list)): grouped_vars[ckpt_name] = value else: grouped_vars[ckpt_name] = [value] # Read each checkpoint entry. Create a placeholder variable and # add the (possibly sliced) data from the checkpoint to the feed_dict. reader = pywrap_tensorflow.NewCheckpointReader(model_path) feed_dict = {} assign_ops = [] for ckpt_name in grouped_vars: if not reader.has_tensor(ckpt_name): log_str = 'Checkpoint is missing variable [%s]' % ckpt_name if ignore_missing_vars: logging.warning(log_str) continue else: raise ValueError(log_str) ckpt_value = reader.get_tensor(ckpt_name) for var in grouped_vars[ckpt_name]: placeholder_tensor = array_ops.placeholder( dtype=var.dtype.base_dtype, shape=var.get_shape(), name='placeholder/' + var.op.name) assign_ops.append(var.assign(placeholder_tensor)) if not var._save_slice_info: if var.get_shape() != ckpt_value.shape: raise ValueError( 'Total size of new array must be unchanged for %s ' 'lh_shape: [%s], rh_shape: [%s]' % (ckpt_name, str(ckpt_value.shape), str(var.get_shape()))) feed_dict[placeholder_tensor] = ckpt_value.reshape(ckpt_value.shape) else: slice_dims = zip(var._save_slice_info.var_offset, var._save_slice_info.var_shape) slice_dims = [(start, start + size) for (start, size) in slice_dims] slice_dims = [slice(*x) for x in slice_dims] slice_value = ckpt_value[slice_dims] slice_value = slice_value.reshape(var._save_slice_info.var_shape) feed_dict[placeholder_tensor] = slice_value assign_op = control_flow_ops.group(*assign_ops) return assign_op, feed_dict
def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir="/Users/dapicella/Desktop/Tensorboard_demo/modeltest", saved_model_tags=None, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not saver_lib.checkpoint_exists(input_checkpoint)): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = "" if input_graph_def: _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver( saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph( input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor saver = saver_lib.Saver( var_list=var_list, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) variable_names_whitelist = ( variable_names_whitelist.replace(" ", "").split(",") if variable_names_whitelist else None) variable_names_blacklist = ( variable_names_blacklist.replace(" ", "").split(",") if variable_names_blacklist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) # Write GraphDef to file if output path has been given. if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) return output_graph_def
def freeze_graph_with_def_protos(graph_dir_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, frozen_graph, clear_devices, initializer_nodes, variable_names_blacklist=""): """ Converts all variables in a graph and checkpoint into constants. """ del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if not saver_lib.checkpoint_exists(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in graph_dir_def.node: node.device = "" _ = importer.import_graph_def(graph_dir_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver(saver_def=input_saver_def) saver.restore(sess, input_checkpoint) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. print( "[WARNING] skip %s, cause this tensor doesn't exist in the graph." % key) continue var_list[key] = tensor saver = saver_lib.Saver(var_list=var_list) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes) variable_names_blacklist = (variable_names_blacklist.split(",") if variable_names_blacklist else None) frozen_graph_def = graph_util.convert_variables_to_constants( sess, graph_dir_def, output_node_names.split(","), variable_names_blacklist=variable_names_blacklist) # Write GraphDef to file if output path has been given. if frozen_graph: with gfile.GFile(frozen_graph, "wb") as f: f.write(frozen_graph_def.SerializeToString()) print("%d ops in the final graph." % len(frozen_graph_def.node)) return frozen_graph_def
def read_and_decode_tfrecords(filename_queue): reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) wide_columns, deep_columns, label_columns = build_model_columns2() # embedding_initializer=tf.contrib.framework.load_embedding_initializer( # ckpt_path='C:/work/tensorflow_template/log/model.ckpt') from tensorflow.python import pywrap_tensorflow model_dir = 'C:/work/tensorflow_template/log/model.ckpt' checkpoint_path = model_dir reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) aa = reader.get_tensor('embeddings/Variable') examples = tf.parse_single_example(serialized_example, features={ "education_num": tf.VarLenFeature(tf.int64), 'workclass': tf.FixedLenFeature([], tf.string), 'fnlwgt': tf.FixedLenFeature([], tf.int64), 'education': tf.FixedLenFeature([], tf.string), 'marital_status': tf.FixedLenFeature([], tf.string), 'occupation': tf.FixedLenFeature([], tf.string), 'relationship': tf.FixedLenFeature([], tf.string), 'race': tf.FixedLenFeature([], tf.string), 'gender': tf.FixedLenFeature([], tf.string), 'capital_gain': tf.FixedLenFeature([], tf.int64), 'capital_loss': tf.FixedLenFeature([], tf.int64), 'hours_per_week': tf.FixedLenFeature([], tf.int64), 'native_country': tf.FixedLenFeature([], tf.string), 'age': tf.FixedLenFeature([], tf.int64), 'income_bracket': tf.FixedLenFeature([], tf.string) }) # batch_features = tf.train.shuffle_batch( # examples, # batch_size=FLAGS.batch_size, # num_threads=FLAGS.batch_thread_number, # capacity=16, # min_after_dequeue=FLAGS.min_after_dequeue) batch_features = tf.train.batch(examples, batch_size=FLAGS.batch_size, dynamic_pad=True) item2vec = tf.nn.embedding_lookup_sparse(aa, batch_features['education_num'], None, combiner="sum") wide_features = tf.feature_column.input_layer(batch_features, wide_columns) label = tf.feature_column.input_layer(batch_features, label_columns) deep_features = tf.concat([ tf.feature_column.input_layer(batch_features, deep_columns), item2vec ], 1) return label, deep_features
import argparse from tensorflow.python import pywrap_tensorflow if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--restore-model-name', type=str, default=None) args = parser.parse_args() saved_model_dir = 'saved-models/' if not args.restore_model_name: ValueError('Please provide restore model name') else: restore_model_path = saved_model_dir + args.restore_model_name reader = pywrap_tensorflow.NewCheckpointReader(restore_model_path) var_to_shape_map = reader.get_variable_to_shape_map() for key in sorted(var_to_shape_map): print("tensor_name: ", key)
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_blacklist=""): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. if not gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if input_saver and not gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 # 'input_checkpoint' may be a prefix if we're using Saver V2 format if not saver_lib.checkpoint_exists(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = graph_pb2.GraphDef() mode = "rb" if input_binary else "r" with gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read().decode("utf-8"), input_graph_def) # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" _ = importer.import_graph_def(input_graph_def, name="tensorName") with session.Session() as sess: if input_saver: with gfile.FastGFile(input_saver, mode) as f: saver_def = saver_pb2.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) saver = saver_lib.Saver(saver_def=saver_def) saver.restore(sess, input_checkpoint) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor saver = saver_lib.Saver(var_list=var_list) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes) variable_names_blacklist = (variable_names_blacklist.split(",") if variable_names_blacklist else None) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(","), variable_names_blacklist=variable_names_blacklist) with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph_with_def_protos( input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, clear_devices, initializer_nodes, variable_names_blacklist=''): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if not saver_lib.checkpoint_exists(input_checkpoint): raise ValueError( "Input checkpoint ' + input_checkpoint + ' does not exist!") if not output_node_names: raise ValueError( 'You must supply the name of a node to --output_node_names.') # Remove all the explicit device specifications for this node. This helps # to make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = '' with tf.Graph().as_default(): tf.import_graph_def(input_graph_def, name='') config = tf.ConfigProto(graph_options=tf.GraphOptions()) with session.Session(config=config) as sess: if input_saver_def: saver = saver_lib.Saver(saver_def=input_saver_def) saver.restore(sess, input_checkpoint) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader( input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ':0') except KeyError: # This tensor doesn't exist in the graph (for example # it's 'global_step' or a similar housekeeping element) # so skip it. continue var_list[key] = tensor saver = saver_lib.Saver(var_list=var_list) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes) variable_names_blacklist = (variable_names_blacklist.split(',') if variable_names_blacklist else None) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(','), variable_names_blacklist=variable_names_blacklist) return output_graph_def
import tensorflow as tf from tensorflow.python import pywrap_tensorflow model_reader = pywrap_tensorflow.NewCheckpointReader( './pretrained_models/20190416_021402_l2_softmax.ckpt-43000') var_dict = model_reader.get_variable_to_shape_map() for key in var_dict: print("variable name: ", key) #print(model_reader.get_tensor(key))
# -*- coding: utf-8 -*- from tensorflow.python import pywrap_tensorflow reader = pywrap_tensorflow.NewCheckpointReader((r"./cc3/train_model.ckpt1")) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print("tensor_name: ", key) print(reader.get_tensor( key)) # Remove this is you want to print only variable names
import os from tensorflow.python import pywrap_tensorflow import tensorflow as tf #checkpoint_path = os.path.join(model_dir, "model.ckpt") # Read data from checkpoint file reader = pywrap_tensorflow.NewCheckpointReader('./YOLOV2COCO/yolo2_coco.ckpt') var_to_shape_map = reader.get_variable_to_shape_map() # Print tensor name and values for key in var_to_shape_map: print("tensor_name: ", key) a = tf.constant(reader.get_tensor(key)) print(a) #print(reader.get_tensor(key))
def load_tf_weights(inputbase, config): """ Load the weights from the tensorflow checkpoint """ weights_dict = dict() try: reader = pyTF.NewCheckpointReader(inputbase) tensor_dict = reader.get_variable_to_shape_map() # There might be training-related variables in the checkpoint that can be discarded param_names = [ key for key in sorted(tensor_dict) if "adam" not in key and "global_step" not in key and "pooler" not in key ] count = len(param_names) TRT_LOGGER.log(TRT_LOGGER.INFO, "Found {:} entries in weight map".format(count)) for pn in param_names: toks = pn.lower().split("/") if "encoder" in pn: assert ("layer" in pn) l = (re.findall("\d+", pn))[0] outname = "l{}_".format(l) + "_".join(toks[3:]) else: outname = "_".join(toks) tensor = reader.get_tensor(pn) shape = tensor.shape if pn.find("kernel") != -1: weights_dict[outname + "_notrans"] = trt.Weights( np.ascontiguousarray(tensor).flatten()) TRT_LOGGER.log(TRT_LOGGER.VERBOSE, "Transposing {}\n".format(np)) tensor = np.transpose(tensor) shape = tensor.shape flat_tensor = tensor.flatten() shape_str = "{} ".format(len(shape)) + " ".join( [str(d) for d in shape]) weights_dict[outname] = trt.Weights(flat_tensor) TRT_LOGGER.log( TRT_LOGGER.VERBOSE, "Original name: {:}, TensorRT name: {:}, shape: {:}".format( pn, outname, shape_str)) N = config.num_attention_heads H = config.head_size additional_dict = dict() for key, value in weights_dict.items(): pos = key.find(BQ) if pos != -1: hidden_size = value.size prefix = key[:pos] Bq_ = value Bk_ = weights_dict[prefix + BK] Bv_ = weights_dict[prefix + BV] Wq_ = weights_dict[prefix + WQ] Wk_ = weights_dict[prefix + WK] Wv_ = weights_dict[prefix + WV] mat_size = hidden_size * hidden_size wcount = 3 * mat_size Wall = np.zeros(wcount, np.float32) bcount = 3 * hidden_size Ball = np.zeros(bcount, np.float32) Wall[0:mat_size] = Wq_.numpy()[0:mat_size] Wall[mat_size:2 * mat_size] = Wk_.numpy()[0:mat_size] Wall[2 * mat_size:3 * mat_size] = Wv_.numpy()[0:mat_size] Ball[0:hidden_size] = Bq_.numpy()[0:hidden_size] Ball[hidden_size:2 * hidden_size] = Bk_.numpy()[0:hidden_size] Ball[2 * hidden_size:3 * hidden_size] = Bv_.numpy()[0:hidden_size] if config.use_int8 and config.interleaved: Wall = np.ascontiguousarray(Wall.reshape((3, N, H, N, H)), dtype=np.float32) Ball = np.ascontiguousarray(Ball.reshape((3, N, H)), dtype=np.float32) else: Wall = np.ascontiguousarray(Wall.reshape( (3, N, H, N, H)).transpose((1, 0, 2, 3, 4)), dtype=np.float32) Ball = np.ascontiguousarray(Ball.reshape( (3, N, H)).transpose((1, 0, 2)), dtype=np.float32) additional_dict[prefix + WQKV] = trt.Weights(Wall) additional_dict[prefix + BQKV] = trt.Weights(Ball) additional_dict[prefix + WQKV + "_notrans"] = trt.Weights( Wall.T) except Exception as error: TRT_LOGGER.log(TRT_LOGGER.ERROR, str(error)) weights_dict.update(additional_dict) return weights_dict
def __init__(self, net_name, snapshot_path, feature_norm_method=None, should_restore_classifier=False, gpu_memory_fraction=None, vgg_16_heads=None): """ Args: snapshot_path: path or dir with checkpoints feature_norm_method: should_restore_classifier: if None - do not restore last layer from the snapshot, otherwise must be equal to the number of classes of the snapshot. if vgg_16_heads is not None then the classifiers will be restored anyway. """ self.net_name = net_name if net_name != 'vgg_16_multihead' and vgg_16_heads is not None: raise ValueError( 'vgg_16_heads must be not None only for vgg_16_multihead') if net_name == 'vgg_16_multihead' and vgg_16_heads is None: raise ValueError( 'vgg_16_heads must be not None for vgg_16_multihead') if tf.gfile.IsDirectory(snapshot_path): snapshot_path = tf.train.latest_checkpoint(snapshot_path) if not isinstance(feature_norm_method, list): feature_norm_method = [feature_norm_method] accepable_methods = [None, 'signed_sqrt', 'unit_norm'] for method in feature_norm_method: if method not in accepable_methods: raise ValueError( 'unknown norm method: {}. Use one of {}'.format( method, accepable_methods)) self.feature_norm_method = feature_norm_method if vgg_16_heads is not None: should_restore_classifier = True if should_restore_classifier: if vgg_16_heads is None: reader = pywrap_tensorflow.NewCheckpointReader(snapshot_path) if net_name == 'inception_v1': var_value = reader.get_tensor( 'InceptionV1/Logits/Conv2d_0c_1x1/weights') else: var_value = reader.get_tensor('vgg_16/fc8/weights') num_classes = var_value.shape[3] else: num_classes = vgg_16_heads else: num_classes = 2 if vgg_16_heads is None else vgg_16_heads network_fn = nets_factory.get_network_fn(net_name, num_classes=num_classes, is_training=False) image_preprocessing_fn = preprocessing_factory.get_preprocessing( net_name, is_training=False) eval_image_size = network_fn.default_image_size self.img_resize_shape = (eval_image_size, eval_image_size ) # (224, 224) for VGG with tf.Graph().as_default() as graph: self.graph = graph with tf.variable_scope('input'): input_pl = tf.placeholder( tf.float32, shape=[None, eval_image_size, eval_image_size, 3], name='x') # not used is_phase_train_pl = tf.placeholder(tf.bool, shape=tuple(), name='is_phase_train') function_to_map = lambda x: image_preprocessing_fn( x, eval_image_size, eval_image_size) images = tf.map_fn(function_to_map, input_pl) logits, self.end_points = network_fn(images) self.__dict__.update(self.end_points) if net_name == 'inception_v1': for tensor_name in [ 'Branch_0/Conv2d_0a_1x1', 'Branch_1/Conv2d_0a_1x1', 'Branch_1/Conv2d_0b_3x3', 'Branch_2/Conv2d_0a_1x1', 'Branch_2/Conv2d_0b_3x3', 'Branch_3/MaxPool_0a_3x3', 'Branch_3/Conv2d_0b_1x1' ]: full_tensor_name = 'InceptionV1/InceptionV1/Mixed_4d/' + tensor_name if 'MaxPool' in tensor_name: full_tensor_name += '/MaxPool:0' else: full_tensor_name += '/Relu:0' short_name = 'Mixed_4d/' + tensor_name self.__dict__[short_name] = tf.get_default_graph( ).get_tensor_by_name(full_tensor_name) self.MaxPool_0a_7x7 = tf.get_default_graph( ).get_tensor_by_name( "InceptionV1/Logits/MaxPool_0a_7x7/AvgPool:0") elif net_name in ['vgg_16', 'vgg_16_multihead']: for layer_name in ['fc6', 'fc7'] + \ ['conv{0}/conv{0}_{1}'.format(i, j) for i in xrange(3, 6) for j in xrange(1, 4)]: self.__dict__['vgg_16/{}_prerelu'.format(layer_name)] = \ tf.get_default_graph().get_tensor_by_name("vgg_16/{}/BiasAdd:0".format(layer_name)) config = tf.ConfigProto(gpu_options=tf.GPUOptions( per_process_gpu_memory_fraction=gpu_memory_fraction)) self.sess = tf.Session(config=config) if should_restore_classifier: variables_to_restore = slim.get_model_variables() else: variables_to_restore = [ var for var in slim.get_model_variables() if not var.op.name.startswith(classifier_scope[net_name]) ] init_fn = slim.assign_from_checkpoint_fn(snapshot_path, variables_to_restore) init_fn(self.sess)