def restore(sess, global_vars): print('from vgg_16 pretrained model') reader = NewCheckpointReader(os.path.join(os.getcwd(), 'model/vgg_16.ckpt')) # no batchnorm from vgg_16 pretrained model restored_var_names = [ name + ':0' for name in reader.get_variable_to_dtype_map().keys() if re.match('^.*weights$', name) ] # skip conv's biases restored_vars = [ var for var in global_vars if var.name in restored_var_names ] restored_var_names = [var.name[:-2] for var in restored_vars] value_ph = tf.placeholder(dtype=tf.float32) for i in range(len(restored_var_names)): sess.run( tf.assign(restored_vars[i], value_ph), feed_dict={value_ph: reader.get_tensor(restored_var_names[i])}) initialized_vars = [var for var in global_vars if not var in restored_vars] sess.run(tf.variables_initializer(initialized_vars))
def restore(sess, global_vars): print('from resnet_v2_50 pretrained model') reader = NewCheckpointReader(os.path.join( os.getcwd(), 'model/resnet_v2_50.ckpt')) # restore both weights and biases from conv and shortcut layers restored_var_names = [name + ':0' for name in reader.get_variable_to_dtype_map().keys() if re.match('^.*weights$', name) or re.match('^.*biases$', name)] restored_vars = [var for var in global_vars if var.name in restored_var_names] restored_var_names = [var.name[:-2] for var in restored_vars] value_ph = tf.placeholder(dtype=tf.float32) for i in range(len(restored_var_names)): sess.run(tf.assign(restored_vars[i], value_ph), feed_dict={value_ph: reader.get_tensor(restored_var_names[i])}) initialized_vars = [var for var in global_vars if not var in restored_vars] sess.run(tf.variables_initializer(initialized_vars))
def load_ckpt(self, pretrained): # restore model with ckpt/pretrain or init try: print('trying to restore last checkpoint') last_ckpt_path = tf.train.latest_checkpoint( checkpoint_dir=cfg.ckpt_dir) self.saver.restore(self.sess, save_path=last_ckpt_path) print('restored checkpoint from:', last_ckpt_path) except: if self.is_training: print('init variables') restored_vars = [] global_vars = tf.global_variables() if pretrained: # restore from tf-slim model if os.path.exists( os.path.join(cfg.workspace, premodel['path'])): print('from ' + premodel['endp']) import re from tensorflow.python.pywrap_tensorflow import NewCheckpointReader reader = NewCheckpointReader( os.path.join(cfg.workspace, premodel['ckpt'])) # only restoring conv's weights restored_var_names = [ name + ':0' for name in reader.get_variable_to_dtype_map().keys() if re.match(premodel['rptn'], name) ] # update restored variables from pretrained model restored_vars = [ var for var in global_vars if var.name in restored_var_names ] # update restored variables' name restored_var_names = [ var.name[:-2] for var in restored_vars ] # assignment variables value_ph = tf.placeholder(tf.float32, shape=None) for i in range(len(restored_var_names)): self.sess.run( tf.assign(restored_vars[i], value_ph), feed_dict={ value_ph: reader.get_tensor(restored_var_names[i]) }) initialized_vars = list(set(global_vars) - set(restored_vars)) self.sess.run(tf.variables_initializer(initialized_vars))
def _get_reader_for_run(self, run): if run in self.readers: return self.readers[run] config = self.configs[run] reader = NewCheckpointReader(config.model_checkpoint_path) self.readers[run] = reader return reader
def _load_softmax(self, vocab: C2I, full_vocab_path: str, ckpt_dir: str) -> None: with open(full_vocab_path) as f: full_vocab: List[str] = f.read().strip().split('\n') bias_reader = NewCheckpointReader( os.path.join(ckpt_dir, 'ckpt-softmax8')) full_bias = bias_reader.get_tensor('softmax/b') # SoftMax is chunked into 8 arrays of size 100000x1024 for i in range(8): sm_reader = NewCheckpointReader( os.path.join(ckpt_dir, f'ckpt-softmax{i}')) sm_chunk = sm_reader.get_tensor(f'softmax/W_{i}').astype( np.float32) bias_chunk = full_bias[i:len(full_bias):8] vocab_chunk = full_vocab[i:len(full_bias):8] for j, w in enumerate(vocab_chunk): sm = sm_chunk[j] bias = bias_chunk[j] if w in vocab: self.decoder_w[vocab[w]] = sm self.decoder_b[vocab[w]] = bias if w == '</S>': self.decoder_w[vocab[vocab.eos_token]] = sm self.decoder_b[vocab[vocab.eos_token]] = bias elif w == '<UNK>': self.decoder_w[vocab[vocab.unk_token]] = sm self.decoder_b[vocab[vocab.unk_token]] = bias
def _load_softmax(self, vocab: C2I, full_vocab_path: str, ckpt_dir: str) -> None: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader with open(full_vocab_path) as f: full_vocab: List[str] = f.read().strip().split("\n") bias_reader = NewCheckpointReader( os.path.join(ckpt_dir, "ckpt-softmax8")) full_bias = torch.from_numpy(bias_reader.get_tensor("softmax/b")) # SoftMax is chunked into 8 arrays of size 100000x1024 for i in range(8): sm_reader = NewCheckpointReader( os.path.join(ckpt_dir, f"ckpt-softmax{i}")) sm_chunk = torch.from_numpy(sm_reader.get_tensor(f"softmax/W_{i}")) bias_chunk = full_bias[i:full_bias.size(0):8] vocab_chunk = full_vocab[i:full_bias.size(0):8] for j, w in enumerate(vocab_chunk): sm = sm_chunk[j] bias = bias_chunk[j] if w in vocab: self.decoder_w[vocab[w]] = sm self.decoder_b[vocab[w]] = bias if w == "</S>": self.decoder_w[vocab[vocab.eos_token]] = sm self.decoder_b[vocab[vocab.eos_token]] = bias elif w == "<UNK>": self.decoder_w[vocab[vocab.unk_token]] = sm self.decoder_b[vocab[vocab.unk_token]] = bias
def _get_reader_for_run(self, run): if run in self.readers: return self.readers[run] config = self._configs[run] reader = None if config.model_checkpoint_path: try: reader = NewCheckpointReader(config.model_checkpoint_path) except Exception: # pylint: disable=broad-except logging.warning('Failed reading %s', config.model_checkpoint_path) self.readers[run] = reader return reader
def _load_lstm(self, ckpt_dir: str, device: str) -> None: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader lstm_reader = NewCheckpointReader(os.path.join(ckpt_dir, "ckpt-lstm")) for l in range(self.num_layers): # Model weights are divided into 8 chunks # Shape: (2048, 32768) self.weight[l] = torch.cat( [ torch.from_numpy( lstm_reader.get_tensor(f"lstm/lstm_{l}/W_{i}")) for i in range(8) ], dim=0, ) # Shape: (32768,) self.bias[l] = torch.from_numpy( lstm_reader.get_tensor(f"lstm/lstm_{l}/B")) # Shape: (8192, 1024) self.weight_P[l] = torch.cat( [ torch.from_numpy( lstm_reader.get_tensor(f"lstm/lstm_{l}/W_P_{i}")) for i in range(8) ], dim=0, ) for p in ["f", "i", "o"]: # Shape: (8192, 8192) self.peepholes[l, p] = torch.from_numpy( lstm_reader.get_tensor( f"lstm/lstm_{l}/W_{p.upper()}_diag")) # Cast to float32 tensors self.weight[l] = self.weight[l].to(config.DTYPE).to(device) self.weight_P[l] = self.weight_P[l].to(config.DTYPE).to(device) self.bias[l] = self.bias[l].to(config.DTYPE).to(device) for p in ["f", "i", "o"]: self.peepholes[l, p] = self.peepholes[l, p].to( config.DTYPE).to(device)
def _load_lstm(self, ckpt_dir: str) -> None: lstm_reader = NewCheckpointReader(os.path.join(ckpt_dir, 'ckpt-lstm')) for l in range(2): # Model weights are divided into 8 chunks # (32768, 2048) self.weight[l] = np.concatenate([ lstm_reader.get_tensor(f'lstm/lstm_{l}/W_{i}') for i in range(8) ]).astype(np.float32).T # (32768,) self.bias[l] = lstm_reader.get_tensor(f'lstm/lstm_{l}/B').astype( np.float32) # (8192, 1024) self.weight_P[l] = np.concatenate([ lstm_reader.get_tensor(f'lstm/lstm_{l}/W_P_{i}') for i in range(8) ]).astype(np.float32) for p in ['F', 'I', 'O']: self.peepholes[l, p.lower()] = \ lstm_reader.get_tensor(f'lstm/lstm_{l}/W_{p}_diag').astype(np.float32)
def get_list_of_variables_from_ckpt(ckpt_file): reader = NewCheckpointReader(ckpt_file) names = reader.get_variable_to_dtype_map() return list(names.keys())
def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir=None, saved_model_tags=None, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants. Args: input_graph_def: A `GraphDef`. input_saver_def: A `SaverDef` (optional). input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking priority. Typically the result of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or V1/V2. output_node_names: The name(s) of the output nodes, comma separated. restore_op_name: Unused. filename_tensor_name: Unused. output_graph: String where to write the frozen `GraphDef`. clear_devices: A Bool whether to remove device specifications. initializer_nodes: Comma separated string of initializer nodes to run before freezing. variable_names_whitelist: The set of variable names to convert (optional, by default, all variables are converted). variable_names_blacklist: The set of variable names to omit converting to constants (optional). input_meta_graph_def: A `MetaGraphDef` (optional), input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and variables (optional). saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to load, in string format (optional). checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1 or saver_pb2.SaverDef.V2) Returns: Location of the output_graph_def. """ del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not checkpoint_management.checkpoint_exists(input_checkpoint)): raise ValueError("Input checkpoint '" + input_checkpoint + "' doesn't exist!") if not output_node_names: raise ValueError( "You need to supply the name of a node to --output_node_names.") # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = "" if input_graph_def: _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver( saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph( input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} #reader = py_checkpoint_reader.NewCheckpointReader(input_checkpoint) reader = NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() # List of all partition variables. Because the condition is heuristic # based, the list could include false positives. all_partition_variable_names = [ tensor.name.split(":")[0] for op in sess.graph.get_operations() for tensor in op.values() if re.search(r"/part_\d+/", tensor.name) ] has_partition_var = False for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") if any(key in name for name in all_partition_variable_names): has_partition_var = True except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor try: saver = saver_lib.Saver( var_list=var_list, write_version=checkpoint_version) except TypeError as e: # `var_list` is required to be a map of variable names to Variable # tensors. Partition variables are Identity tensors that cannot be # handled by Saver. if has_partition_var: raise ValueError( "Models containing partition variables cannot be converted " "from checkpoint files. Please pass in a SavedModel using " "the flag --input_saved_model_dir.") # Models that have been frozen previously do not contain Variables. elif _has_no_variables(sess): raise ValueError( "No variables were found in this model. It is likely the model " "was frozen previously. You cannot freeze a graph twice.") return 0 else: raise e saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) variable_names_whitelist = ( variable_names_whitelist.replace(" ", "").split(",") if variable_names_whitelist else None) variable_names_blacklist = ( variable_names_blacklist.replace(" ", "").split(",") if variable_names_blacklist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) # Write GraphDef to file if output path has been given. if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) return output_graph_def
tarfile.open(args.checkpoint).extractall(checkpoint_dir) files = [ os.path.join(checkpoint_dir, d) for d in os.listdir(checkpoint_dir) ] checkpoint_dir = files[0] if os.path.isdir(files[0]) else checkpoint_dir elif args.checkpoint.endswith('.zip'): checkpoint_dir = args.tmp zipfile.ZipFile(args.checkpoint).extractall(checkpoint_dir) files = [ os.path.join(checkpoint_dir, d) for d in os.listdir(checkpoint_dir) ] checkpoint_dir = files[0] if os.path.isdir(files[0]) else checkpoint_dir else: checkpoint_dir = args.checkpoint reader = NewCheckpointReader( tensorflow.train.latest_checkpoint(checkpoint_dir)) blobs = {k: reader.get_tensor(k) for k in reader.get_variable_to_shape_map()} if args.output_path.endswith('.json'): with open(args.output_path, 'w') as f: json.dump({k: blob.tolist() for k, blob in blobs.items()}, f, sort_keys=True, indent=2) elif args.output_path.endswith('.h5'): import h5py with h5py.File(args.output_path, 'w') as h: h.update(**blobs) elif args.output_path.endswith('.npy') or args.output_path.endswith('.npz'): (np.savez if args.output_path[-1] == 'z' else numpy.save)(args.output_path,