def restore(self, checkpoint=None): """Reload the values of all variables from a checkpoint file. Parameters ---------- checkpoint: str the path to the checkpoint file to load. If this is None, the most recent checkpoint will be chosen automatically. Call get_checkpoints() to get a list of all available checkpoints. """ if not self.built: self.build() if checkpoint is None: checkpoint = tf.train.latest_checkpoint(self.model_dir) if checkpoint is None: raise ValueError('No checkpoint found') with self._get_tf("Graph").as_default(): reader = NewCheckpointReader(checkpoint) var_names = set([x for x in reader.get_variable_to_shape_map()]) var_map = { x.op.name: x for x in tf.global_variables() if x.op.name in var_names } saver = tf.train.Saver(var_list=var_map) saver.restore(self.session, checkpoint)
def restore(self, checkpoint=None): """Reload the values of all variables from a checkpoint file. Parameters ---------- checkpoint: str the path to the checkpoint file to load. If this is None, the most recent checkpoint will be chosen automatically. Call get_checkpoints() to get a list of all available checkpoints. """ if not self.built: self.build() if checkpoint is None: checkpoint = tf.train.latest_checkpoint(self.model_dir) if checkpoint is None: raise ValueError('No checkpoint found') with self._get_tf("Graph").as_default(): reader = NewCheckpointReader(checkpoint) var_names = set([x for x in reader.get_variable_to_shape_map()]) var_list = [] for var in self.get_variables(): name = var.name if ':' in name: name = name[:name.rfind(':')] if name in var_names: var_list.append(var) saver = tf.train.Saver(var_list=var_list) saver.restore(self.session, checkpoint)
def load_from_dir(self, model_dir, model_name): if not self.built: self.build() with self._get_tf("Graph").as_default(): reader = NewCheckpointReader(os.path.join(model_dir, model_name)) var_names = set([x for x in reader.get_variable_to_shape_map()]) var_map = { x.op.name: x for x in tf.global_variables() if x.op.name in var_names } saver = tf.train.Saver(var_list=var_map) saver.restore(self.session, os.path.join(model_dir, model_name))
def create_multihead_initializers(checkpoint_paths): comm = get_mpi_comm_world() rank = comm.Get_rank() if rank != SERVER_RANK: return [] readers = [ NewCheckpointReader(checkpoint_path) for checkpoint_path in checkpoint_paths ] variables_name = [ var_name for var_name in list(readers[0].get_variable_to_shape_map().keys()) if "RMSProp" not in var_name ] truncated_variables_name = [var_name[14:] for var_name in variables_name] initializers_ = [] with tf.variable_scope("", reuse=tf.AUTO_REUSE): for idx, reader in enumerate(readers): for var_name in truncated_variables_name: check_point_var_name = "model/tower_0/" + var_name learnable_ensemble_var_name = f"model/tower_{idx}/" + var_name values_np = reader.get_tensor(check_point_var_name) # print("var_name", var_name) # print("values_np", values_np) initializers_.append( tf.get_variable(learnable_ensemble_var_name).assign( values_np)) return initializers_
def _load_checkpoint(filepath): from tensorflow.python.pywrap_tensorflow_internal import NewCheckpointReader filename = _get_checkpoint_filename(filepath) return NewCheckpointReader(filename)
import os import tensorflow as tf from tensorflow.python.pywrap_tensorflow_internal import NewCheckpointReader tf.app.flags.DEFINE_string('model_path', '../checkpoints/couplet_seq2seq', 'Model path') tf.app.flags.DEFINE_string('model_name', 'couplet.ckpt-70000', 'Model name') tf.app.flags.DEFINE_bool('print_value', False, 'Print value of tensors') FLAGS = tf.app.flags.FLAGS # checkpoint path checkpoint_path = os.path.join(FLAGS.model_path, FLAGS.model_name) reader = NewCheckpointReader(checkpoint_path) var_to_shape_map = reader.get_variable_to_shape_map() tensors = [] # collect all keys for key in var_to_shape_map: tensors.append(key) # print keys and values for key in sorted(tensors): print('tensor_name:', key) if FLAGS.print_value: print(reader.get_tensor(key))