def print_tensors_in_checkpoint_file(file_name, tensor_name): """Prints tensors in a checkpoint file. If no `tensor_name` is provided, prints the tensor names and shapes in the checkpoint file. If `tensor_name` is provided, prints the content of the tensor. Args: file_name: Name of the checkpoint file. tensor_name: Name of the tensor in the checkpoint file to print. """ try: if not tensor_name: variables = checkpoint_utils.list_variables(file_name) for name, shape in variables: print("%s\t%s" % (name, str(shape))) else: print("tensor_name: ", tensor_name) print(checkpoint_utils.load_variable(file_name, tensor_name)) except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.")
def scan_checkpoint_for_vars(checkpoint_path, vars_to_check): check_var_list = checkpoint_utils.list_variables(checkpoint_path) check_var_list = [x[0] for x in check_var_list] check_var_set = set(check_var_list) vars_in_checkpoint = [x for x in vars_to_check if x.name[:x.name.index(":")] in check_var_set] vars_not_in_checkpoint = [x for x in vars_to_check if x.name[:x.name.index(":")] not in check_var_set] return vars_in_checkpoint, vars_not_in_checkpoint
def load_checkpoints(sess, var_scopes = ('encoder', 'decoder', 'dense')): checkpoint_path = configdl.lip_model_path if checkpoint_path: if os.path.isdir(checkpoint_path): checkpoint = tf.train.latest_checkpoint(checkpoint_path) else: checkpoint = checkpoint_path if configdl.featurizer: if checkpoint_path: from tensorflow.contrib.framework.python.framework import checkpoint_utils var_list = checkpoint_utils.list_variables(checkpoint) for var in var_list: if 'visual_frontend' in var[0]: var_scopes = var_scopes + ('visual_frontend',) break if not 'visual_frontend' in var_scopes: featurizer_vars = tf.global_variables(scope='visual_frontend') featurizer_ckpt = tf.train.get_checkpoint_state(configdl.featurizer_model_path) featurizer_vars = [var for var in featurizer_vars if not 'Adam' in var.name] tf.train.Saver(featurizer_vars).restore(sess, featurizer_ckpt.model_checkpoint_path) all_variables = [] for scope in var_scopes: all_variables += [var for var in tf.global_variables(scope=scope) if not 'Adam' in var.name ] if checkpoint_path: tf.train.Saver(all_variables).restore(sess, checkpoint) print("Restored saved model {}!".format(checkpoint))
def testGetAllVariables(self): checkpoint_dir = self.get_temp_dir() with self.cached_session() as session: _create_checkpoints(session, checkpoint_dir) self.assertEqual(checkpoint_utils.list_variables(checkpoint_dir), [("useful_scope/var4", [9, 9]), ("var1", [1, 10]), ("var2", [10, 10]), ("var3", [100, 100])])
def checkpoint_dtype_cast(in_checkpoint_file, out_checkpoint_file): var_list = checkpoint_utils.list_variables(tf.flags.FLAGS.init_checkpoint) def init_graph(): for name, shape in var_list: var = checkpoint_utils.load_variable( tf.flags.FLAGS.init_checkpoint, name) if "quant" in name or "amaxList" in name: recon_dtype = var.dtype else: recon_dtype = tf.float16 if var.dtype == np.float32 else var.dtype tf.get_variable(name, shape=shape, dtype=recon_dtype) init_graph() saver = tf.train.Saver(builder=CastFromFloat32SaverBuilder()) with tf.Session() as sess: saver.restore(sess, in_checkpoint_file) saver.save(sess, 'tmp.ckpt') tf.reset_default_graph() init_graph() saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, 'tmp.ckpt') saver.save(sess, out_checkpoint_file)
def load_model(sess, model_dir, model_name, scope='generator'): model_dir = os.path.join(model_dir, model_name) ckpt = tf.train.get_checkpoint_state(model_dir) try: if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) vars_model = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope) vars_ckpt = checkpoint_utils.list_variables(os.path.join(model_dir, ckpt_name)) vars_in_model = [var.name.split(':')[0] for var in vars_model] vars_in_ckpt = [var[0] for var in vars_ckpt] vars_to_remove = [] for var in vars_in_model: if var not in vars_in_ckpt: print(' [!] ' + var + ' not exists') for i in range(len(vars_model)): if vars_model[i].name.split(':')[0] == var: vars_to_remove.append(vars_model[i]) for var in vars_to_remove: vars_model.remove(var) saver = tf.train.Saver(vars_model) saver.restore(sess, os.path.join(model_dir, ckpt_name)) counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0)) print(" [*] Model load SUCCESS - " + os.path.abspath(os.path.join(model_dir, ckpt_name))) return True, counter except Exception as err: print(" [!] Model load FAILED - " + os.path.abspath(model_dir) + ', ' + str(err)) print(" [!] Model load FAILED - no checkpoint in " + os.path.abspath(model_dir)) return True, 0
def __init__(self, corpus): tf.reset_default_graph() self.corpus = corpus # Load embeddings pretrained_var_list = checkpoint_utils.list_variables(GlobalConstants.WORD_EMBEDDING_FILE_PATH) self.wordEmbeddings = checkpoint_utils.load_variable(checkpoint_dir=GlobalConstants.WORD_EMBEDDING_FILE_PATH, name="embeddings") # IMPORTANT !!! Add a zero row at the 0. position. This will be used as the padding feature. self.wordEmbeddings = np.concatenate([np.zeros(shape=(1, GlobalConstants.EMBEDDING_SIZE)), self.wordEmbeddings], axis=0) assert self.wordEmbeddings.shape[0] == self.corpus.get_vocabulary_size() + 1 assert self.wordEmbeddings.shape[1] == GlobalConstants.EMBEDDING_SIZE self.batch_size = tf.placeholder(dtype=tf.int32, shape=[], name='batch_size') self.input_word_codes = tf.placeholder(dtype=tf.int32, shape=[None, None], name='input_word_codes') self.input_y = tf.placeholder(dtype=tf.int64, shape=[None], name='input_y') self.keep_prob = tf.placeholder(dtype=tf.float32, shape=[], name='keep_prob') self.sequence_length = tf.placeholder(dtype=tf.int32, shape=[None], name='sequence_length') self.max_sequence_length = tf.placeholder(dtype=tf.int32, name='max_sequence_length') self.isTrainingFlag = tf.placeholder(name="is_training", dtype=tf.bool) self.embeddings = None self.inputs = None self.logits = None self.predictions = None self.numOfCorrectPredictions = None self.accuracy = None self.optimizer = None self.globalStep = None self.sess = None self.correctPredictions = None # L2 loss self.mainLoss = None self.l2_loss = tf.constant(0.0)
def load_model(sess, LOAD_MODEL_FILE, prefixs, strict=False): vars_in_pretrained_model = dict( checkpoint_utils.list_variables(LOAD_MODEL_FILE)) # print(vars_in_pretrained_model) vars_in_defined_model = [] for var in tf.trainable_variables(): if isinstance(prefixs, list): for prefix in prefixs: if (var.op.name.startswith(prefix)) and ( var.op.name in vars_in_pretrained_model.keys()) and ( 'logits' not in var.op.name): if (list(var.shape) == vars_in_pretrained_model[ var.op.name]): vars_in_defined_model.append(var) else: if (var.op.name.startswith(prefixs)) and ( var.op.name in vars_in_pretrained_model.keys()) and ( 'logits' not in var.op.name): if (list(var.shape) == vars_in_pretrained_model[var.op.name]): vars_in_defined_model.append(var) # print(vars_in_defined_model) saver = tf.train.Saver(vars_in_defined_model) try: saver.restore(sess, LOAD_MODEL_FILE) print("Model loaded in file: %s" % (LOAD_MODEL_FILE)) except: if strict: print("Fail to load modelfile: %s" % LOAD_MODEL_FILE) return False else: print("Fail loaded in file: %s" % (LOAD_MODEL_FILE)) return True return True
def restore(self): sess = self.sess saver = self.saver #print_tensors_in_checkpoint_file(self.weight_dir+'/u_net.ckpt', all_tensors=True, tensor_name = '') var_list = checkpoint_utils.list_variables(self.weight_dir+'/u_net.ckpt') for var in var_list: print(var) saver.restore(sess, self.weight_dir+'/u_net.ckpt')
def testGetAllVariables(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: _create_checkpoints(session, checkpoint_dir) self.assertEqual( checkpoint_utils.list_variables(checkpoint_dir), [("useful_scope/var4", [9, 9]), ("var1", [1, 10]), ("var2", [10, 10]), ("var3", [100, 100])])
def load_model(self, sess): saved_vars = checkpoint_utils.list_variables( checkpoint_dir="D:\\sensor_data_generation\\models\\vae_model.ckpt") all_vars = tf.global_variables() for var in all_vars: if "Adam" in var.name: continue source_array = checkpoint_utils.load_variable( checkpoint_dir="D:\\sensor_data_generation\\models\\vae_model.ckpt", name=var.name) tf.assign(var, source_array).eval(session=sess) print("X")
def showParametersInCkpt(ckpt_dir): ''' 显示出ckpt文件中的参数名称 :param ckpt_dir: :return: ''' from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file latest_ckp = tf.train.latest_checkpoint(ckpt_dir) # print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='', all_tensor_names=True) from tensorflow.contrib.framework.python.framework import checkpoint_utils var_list = checkpoint_utils.list_variables(latest_ckp) for v in var_list: print(v)
def restore_new_scope(self, dir, saved_scope, tf_scope): var_remap = dict() vars = [v for v in self.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf_scope) if "Adam" not in v.name] for var in vars: var_remap[saved_scope + var.name[len(tf_scope):-2]] = var path = os.path.join(dir, "model.ckpt") saver = tf.train.Saver(var_list=var_remap) try: saver.restore(self._session, path) except tf.errors.NotFoundError as e: from tensorflow.contrib.framework.python.framework import checkpoint_utils logging.info(checkpoint_utils.list_variables(dir)) raise e
def part_load_from_checkpoint(self, sess): var_list = tf.global_variables() # present model variables check_var_list = checkpoint_utils.list_variables( self.config.load_checkpoint_path) # chekpoint variables check_var_list = [x[0] for x in check_var_list] check_var_set = set(check_var_list) vars_in_checkpoint = [ x for x in var_list if x.name[:x.name.index(":")] in check_var_set ] #vars in teh present model that are also present in the checkpoint saverPart = tf.train.Saver(var_list=vars_in_checkpoint) saverPart.restore(sess, self.config.load_checkpoint_path) print("Restored variables from the parsed checkpoint")
def scan_checkpoint_for_vars(checkpoint_path, vars_to_check): check_var_list = checkpoint_utils.list_variables(checkpoint_path) # print('check_var_list:', check_var_list) # print('vars_to_check:', vars_to_check) check_var_set = set() for x in check_var_list: check_var_set.add(x[0]) vars_in_checkpoint = [] vars_not_in_checkpoint = [] for x in vars_to_check: var_name = x.name[:x.name.index(':')] if '/part_' in var_name: var_name = var_name[:var_name.index('/part_')] if var_name in check_var_set: vars_in_checkpoint.append(x) else: vars_not_in_checkpoint.append(x) return vars_in_checkpoint, vars_not_in_checkpoint
def build_codebook_multi(encoder, dataset, args, checkpoint_file_basename=None): embed_bb = args.getboolean('Embedding', 'EMBED_BB') existing_embs = [] if checkpoint_file_basename is not None: var_list = checkpoint_utils.list_variables(checkpoint_file_basename) for v in var_list: if 'embedding_normalized_' in v[0]: print(v) existing_embs.append( v[0].split('/embedding_normalized_')[-1].split('.')[0]) print(existing_embs) codebook = Codebook(encoder, dataset, embed_bb, existing_embs) return codebook
def restore(self, checkpoint_path, restrict_vars=None): if os.path.isdir(checkpoint_path): restore_checkpoint = tf.train.latest_checkpoint( checkpoint_path, latest_filename=None) else: restore_checkpoint = checkpoint_path # Retrieves the variables inside 'restore_checkpoint' ckpt_vars = [ name for name, shape in checkpoint_utils.list_variables( restore_checkpoint) ] if restrict_vars and len(restrict_vars) > 0: restore_vars = list(set(ckpt_vars).intersection(restrict_vars)) else: # If no list is provided, all the variables contained in the checkpoint will be restored uninit_vars = [ bs.decode("utf-8") for bs in self._sess.run( tf.report_uninitialized_variables(tf.global_variables())) ] restore_vars = list(set(ckpt_vars).intersection(uninit_vars)) restore_variables = [] # Retrieves the variables to be restored from their name with tf.variable_scope(tf.get_variable_scope(), reuse=True): for name in restore_vars: restore_variables.append( tf.get_variable( name, dtype=self._sess.graph.get_tensor_by_name(name + ":0").dtype)) restore_vars_saver = tf.train.Saver(var_list=restore_variables) restore_vars_names = map(lambda v: v.name, restore_variables) restore_vars_names = ''.join("%s, " % v for v in restore_vars_names)[:-2] print("Restoring variables ({}) from '{}'... ".format( restore_vars_names, restore_checkpoint), end="") restore_vars_saver.restore(self._sess, restore_checkpoint) print("Done.")
def load_trained_classifier(self, sess, run_id, target_category, iteration): tvars = tf.trainable_variables(scope=self.classifierName) file_path = pathlib.Path(__file__).parent.absolute() model_folder = os.path.join(file_path, "..", "models", target_category) checkpoint_folder = os.path.join( model_folder, "lstm{0}_iteration{1}".format(run_id, iteration)) model_path = os.path.join( checkpoint_folder, "lstm{0}_iteration{1}.ckpt".format(run_id, iteration)) saved_vars = checkpoint_utils.list_variables(checkpoint_dir=model_path) for var in tvars: # assert len([_var for _var in saved_vars if _var.name == var.name]) == 1 # if "Adam" in var.name: # continue var_name = var.name[len(self.classifierName) + 1:] source_array = checkpoint_utils.load_variable( checkpoint_dir=model_path, name=var_name) tf.assign(var, source_array).eval(session=sess)
def load_fp32_weights_into_fp16_vars(checkpoint_path: Path) -> List: """Load fp32 weights from checkpoint path into fp16 variables. Assumes that caller has executed `tf.run(tf.global_variables_initializer())` Args: checkpoint_path: Checkpoint path Returns: Collection of ops to use to restore the weights in the graph. """ checkpoint_variables = [var_name for var_name, _ in list_variables(checkpoint_path)] for graph_var in tf.global_variables(): if graph_var.op.name in checkpoint_variables: var = load_variable(checkpoint_path, graph_var.op.name) weights = tf.cast(var, tf.float16) if var.dtype == np.float32 else var tf.add_to_collection('restore_ops', graph_var.assign(weights)) return tf.get_collection('restore_ops')
def checkpoint_dtype_cast(in_checkpoint_file, out_checkpoint_file): var_list = checkpoint_utils.list_variables(in_checkpoint_file) def init_graph(): for name, shape in var_list: var = checkpoint_utils.load_variable(in_checkpoint_file, name) recon_dtype = tf.float16 if var.dtype == np.float32 else var.dtype tf.get_variable(name, shape=shape, dtype=recon_dtype) init_graph() saver = tf.train.Saver(builder=CastFromFloat32SaverBuilder()) with tf.Session() as sess: saver.restore(sess, in_checkpoint_file) saver.save(sess, "./init_ckpt/tmp.ckpt") tf.reset_default_graph() init_graph() saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, "./init_ckpt/tmp.ckpt") saver.save(sess, out_checkpoint_file)
def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000): dataset_fn = datasets.mnist.TinyMnist w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess meta_objectives = [] meta_objectives.append( meta_objective.linear_regression.LinearRegressionMetaObjective) meta_objectives.append(meta_objective.sklearn.LogisticRegression) checkpoint_vars, train_one_step_op, ( base_model, dataset) = evaluation.construct_evaluation_graph( theta_process_fn=theta_process_fn, w_learner_fn=w_learner_fn, dataset_fn=dataset_fn, meta_objectives=meta_objectives) batch = dataset() pre_logit, outputs = base_model(batch) global_step = tf.train.get_or_create_global_step() var_list = list( snt.get_variables_in_module(base_model, tf.GraphKeys.TRAINABLE_VARIABLES)) tf.logging.info("all vars") for v in tf.all_variables(): tf.logging.info(" %s" % str(v)) global_step = tf.train.get_global_step() accumulate_global_step = global_step.assign_add(1) reset_global_step = global_step.assign(0) train_op = tf.group( train_one_step_op, accumulate_global_step, name="train_op") summary_op = tf.summary.merge_all() file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"]) if checkpoint: str_var_list = checkpoint_utils.list_variables(checkpoint) name_to_v_map = {v.op.name: v for v in tf.all_variables()} var_list = [ name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map ] saver = tf.train.Saver(var_list) missed_variables = [ v.op.name for v in set( snt.get_variables_in_scope("LocalWeightUpdateProcess", tf.GraphKeys.GLOBAL_VARIABLES)) - set(var_list) ] assert len(missed_variables) == 0, "Missed a theta variable." hooks = [] with tf.train.SingularMonitoredSession(master="", hooks=hooks) as sess: # global step should be restored from the evals job checkpoint or zero for fresh. step = sess.run(global_step) if step == 0 and checkpoint: tf.logging.info("force restore") saver.restore(sess, checkpoint) tf.logging.info("force restore done") sess.run(reset_global_step) step = sess.run(global_step) while step < num_steps: if step % eval_every_n_steps == 0: s, _, step = sess.run([summary_op, train_op, global_step]) file_writer.add_summary(s, step) else: _, step = sess.run([train_op, global_step])
args.image_size, feature_name=args.feature_name, annotation_threshold=args.annotation_threshold) #initialize variables------------------ sess.run(tf.global_variables_initializer()) #set up savers------------ saver = tf.train.Saver(max_to_keep=20) summary_writer = tf.summary.FileWriter(args.logs_dir, sess.graph) #try to reload weights-------------------- if args.model_dir is not None: ckpt = tf.train.get_checkpoint_state(args.model_dir) if ckpt and ckpt.model_checkpoint_path: vars_stored = [var[0] for var in list_variables(args.model_dir)] vars_restore = [ v for v in tf.global_variables() if v.name[0:-2] in vars_stored ] restore_saver = tf.train.Saver(vars_restore) restore_saver.restore(sess, ckpt.model_checkpoint_path) print("Model restored...") else: print("Model restore failed...") #start training------------------------- for itr in range(args.max_iteration): if not args.weight_data: batch = train_dataset_reader.next_batch_in_seqs( args.batch_size, args.n_steps) else:
serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(input_dict) save_path = nn_estimator.export_savedmodel(models_dir, serving_input_receiver_fn) return save_path # GETTING VALIDATION SCORES from tensorflow.contrib.framework.python.framework import checkpoint_utils import re from collections import defaultdict files = [checkpoint_file] eval_data_dict = pickle.load(open('data/eval_data_dict_%d_files_%s_rows.pkl' % (NUM_DATAFILES, str(ROWS_TO_EXTRACT)), 'rb')) random_model_dir = files[0] ckpt_var_list = [var_name for (var_name, shape) in checkpoint_utils.list_variables(random_model_dir)] vars_to_restore = list() for graph_var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): if graph_var.name[:-2] in ckpt_var_list: vars_to_restore.append(graph_var) try: global_step = tf.train.create_global_step() except: print("global_step already exists") if global_step not in vars_to_restore: vars_to_restore.append(global_step) init_global = tf.global_variables_initializer()
def train(args): output_dir = args.output log_dir = args.log_dir if args.log_dir else os.path.join(output_dir, 'log') model_dir = os.path.join(output_dir, 'model') records_dir = args.records_dir if not args.data_dir else os.path.join( args.data_dir, args.records_dir) result_dir = os.path.join(output_dir, 'result') for dir in [output_dir, log_dir, model_dir, result_dir]: if not os.path.exists(dir): os.makedirs(dir) # save the args info to ouptut dir. with open(os.path.join(output_dir, 'args.json'), 'w') as p: json.dump(vars(args), p, indent=2) # load meta info with open(os.path.join(records_dir, 'train_meta.json'), 'r', encoding='utf8') as p: train_total = json.load(p)['size'] batch_num_per_epoch = int(np.ceil(train_total / args.batch_size)) print(f'{time.asctime()} - batch num per epoch: {batch_num_per_epoch}') with open(os.path.join(records_dir, 'dev_meta.json'), 'r', encoding='utf8') as p: dev_total = json.load(p)['size'] train_records_file = os.path.join(records_dir, 'train.tfrecords') dev_records_file = os.path.join(records_dir, 'dev.tfrecords') with tf.Graph().as_default() as graph, tf.device('/gpu:0'): parser = get_record_parser(args) train_dataset = get_batch_dataset(train_records_file, parser, args) dev_dataset = get_batch_dataset(dev_records_file, parser, args) handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, train_dataset.output_types, train_dataset.output_shapes) train_iterator = train_dataset.make_one_shot_iterator() dev_iterator = dev_dataset.make_one_shot_iterator() model = CommentModel(args, iterator) session_config = tf.ConfigProto(allow_soft_placement=True) session_config.gpu_options.allow_growth = True sess = tf.Session(config=session_config) # sess = tf_debug.LocalCLIDebugWrapperSession(sess) # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) writer = tf.summary.FileWriter(log_dir) best_ppl = tf.Variable(300, trainable=False, name='best_ppl', dtype=tf.float32) saver = tf.train.Saver(max_to_keep=10000) if args.restore: model_file = args.restore_model or tf.train.latest_checkpoint( model_dir) print(f'{time.asctime()} - Restore model from {model_file}..') var_list = [ _[0] for _ in checkpoint_utils.list_variables(model_file) ] saved_vars = [ _ for _ in tf.global_variables() if _.name.split(':')[0] in var_list ] res_saver = tf.train.Saver(saved_vars) res_saver.restore(sess, model_file) left_vars = [ _ for _ in tf.global_variables() if _.name.split(':')[0] not in var_list ] sess.run(tf.initialize_variables(left_vars)) print( f'{time.asctime()} - Restore {len(var_list)} vars and initialize {len(left_vars)} vars.' ) print(left_vars) else: print(f'{time.asctime()} - Initialize model..') sess.run(tf.global_variables_initializer()) # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess) train_handle = sess.run(train_iterator.string_handle()) dev_handle = sess.run(dev_iterator.string_handle()) sess.run(tf.assign(model.is_train, tf.constant(True, dtype=tf.bool))) #tmp patience = 0 lr = sess.run(model.lr) b_ppl = sess.run(best_ppl) print(f'{time.asctime()} - lr: {lr:.3f} best_ppl:{b_ppl:.3f}') t0 = datetime.now() while True: global_step = sess.run(model.global_step) + 1 epoch = int(np.ceil(global_step / batch_num_per_epoch)) loss, loss_gen, ppl, train_op, merge_sum, target, check_1 = sess.run( [ model.loss, model.loss_gen, model.ppl, model._train_op, model._summaries, model.target, model.check_dec_outputs ], feed_dict={handle: train_handle}) ela_time = str(datetime.now() - t0).split('.')[0] print( (f'{time.asctime()} - step/epoch:{global_step}/{epoch:<3d} ' f'gen_loss:{loss_gen:<3.3f} ' f'ppl:{ppl:<4.3f} ' f'elapsed:{ela_time}\r'), end='') if global_step % args.period == 0: writer.add_summary(merge_sum, global_step) writer.flush() if global_step % args.checkpoint == 0: model_file = os.path.join(model_dir, 'model') saver.save(sess, model_file, global_step=global_step) # if global_step % batch_num_per_epoch== 0: if global_step % args.checkpoint == 0 and not args.no_eval: sess.run( tf.assign(model.is_train, tf.constant(False, dtype=tf.bool))) metrics, summ = evaluate_batch(model, dev_total // args.batch_size, sess, handle, dev_handle, iterator) sess.run( tf.assign(model.is_train, tf.constant(True, dtype=tf.bool))) for s in summ: writer.add_summary(s, global_step) dev_ppl = metrics['ppl'] dev_gen_loss = metrics['gen_loss'] tqdm.write( f'{time.asctime()} - Evaluate after steps:{global_step}, ' f' gen_loss:{dev_gen_loss:.4f}, ppl:{dev_ppl:.3f}') if dev_ppl < b_ppl: sess.run(tf.assign(best_ppl, dev_ppl)) saver.save(sess, save_path=os.path.join(model_dir, 'best')) tqdm.write( f'{time.asctime()} - the ppl is lower than current best ppl so saved the model.' ) patience = 0 else: patience += 1 if patience >= args.patience: lr = lr / 2 sess.run( tf.assign(model.lr, tf.constant(lr, dtype=tf.float32))) patience = 0 tqdm.write( f'{time.asctime()} - The lr is decayed form {lr*2} to {lr}.' )
if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) ckpt = tf.train.get_checkpoint_state(ckpt_dir) start = 0 if ckpt and ckpt.model_checkpoint_path: start = int(ckpt.model_checkpoint_path.split("-")[1]) logger.info("start by iteration: %d" % (start)) saver = tf.train.Saver() saver.restore(sess, ckpt.model_checkpoint_path) logger.info("model is restored using " + str(ckpt)) elif pretraind_model: restore = {} from tensorflow.contrib.framework.python.framework.checkpoint_utils import list_variables slim = tf.contrib.slim for scope in list_variables(pretraind_model): if 'conv' in scope[0]: variables_to_restore = slim.get_variables(scope=scope[0]) if variables_to_restore: restore[scope[0]] = variables_to_restore[ 0] # variables_to_restore is list : [op] saver = tf.train.Saver(restore) saver.restore(sess, pretraind_model) logger.info("model is restored conv only using " + str(pretraind_model)) assign_op = global_step.assign(start) sess.run(assign_op) model_saver = tf.train.Saver(max_to_keep=30) # train
tf.constant(img, dtype=tf.float32), axis=0) #tf.zeros([batch_size, img_size, img_size, 3]) logits, endpoints = resnet_v1.resnet_v1_50(inputbatch, 1000, is_training=False) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True config.log_device_placement = False sess = tf.Session(config=config) variables_to_restore = [] a = [ name for name, _ in checkpoint_utils.list_variables( 'pretrained_model/resnet_v1_50.ckpt') ] # print a for var in slim.get_model_variables(): if (var.op.name.startswith('resnet_v1_50')) and ( var.op.name in a) and ('logits' not in var.op.name): variables_to_restore.append(var) # print variables_to_restore # slim.assign_from_checkpoint_fn('pretrained_model/resnet_v1_50.ckpt', variables_to_restore, ignore_missing_vars=False) init = tf.global_variables_initializer() sess.run(init) saver = tf.train.Saver(variables_to_restore) saver.restore(sess, 'pretrained_model/resnet_v1_50.ckpt') # print a.keys()
def checkpoint_quantization(in_checkpoint_file, out_checkpoint_file, per_channel_quantization): var_list = checkpoint_utils.list_variables(tf.flags.FLAGS.init_checkpoint) def init_graph(): restore_vars = [] layer_num = 0 regex = re.compile('layer_\d+') amaxTotalNum = 0 for name, shape in var_list: var = checkpoint_utils.load_variable( tf.flags.FLAGS.init_checkpoint, name) if "intermediate/dense/kernel" in name and amaxTotalNum == 0: amaxTotalNum = ACTIVATION_AMAX_NUM + 9 * shape[ 0] + INT8O_GEMM_NUM + TRT_FUSED_MHA_AMAX_NUM print(amaxTotalNum, shape[0]) recon_dtype = var.dtype restore_vars.append( tf.get_variable(name, shape=shape, dtype=var.dtype)) tmp = regex.findall(name) if len(tmp) < 1: continue num_tmp = int(tmp[0].replace("layer_", "")) if layer_num < num_tmp: layer_num = num_tmp layer_num = layer_num + 1 #add new var for amax for i in range(layer_num): tf.get_variable("bert/encoder/layer_{}/amaxList".format(i), shape=[amaxTotalNum], dtype=tf.float32) return layer_num, amaxTotalNum, restore_vars layer_num, amaxTotalNum, restore_vars = init_graph() restorer = tf.train.Saver(restore_vars) saver = tf.train.Saver() config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: restorer.restore(sess, in_checkpoint_file) kernel_name_list = [ "attention/self/query", "attention/self/key", "attention/self/value", "attention/output/dense", "intermediate/dense", "output/dense" ] #input_scale, 0 amax_name_list = [ "attention/self/query/input_quantizer", #Q_aftergemm_scale, 1 "attention/self/query/aftergemm_quantizer", #Qbias_scale, 2 "attention/self/matmul_q_input_quantizer", #K_aftergemm_scale, 3 "attention/self/key/aftergemm_quantizer", #Kbias_scale, 4 "attention/self/matmul_k_input_quantizer", #V_aftergemm_scale, 5 "attention/self/value/aftergemm_quantizer", #Vbias_scale, 6 "attention/self/matmul_v_input_quantizer", #bmm1_scale, 7 "attention/self/softmax_input_quantizer", #Softmax_scale, 8 "attention/self/matmul_a_input_quantizer", #bmm2_scale, 9 "attention/output/dense/input_quantizer", #Proj_aftergemm_scale, 10 "attention/output/dense/aftergemm_quantizer", #ProjBiasNorm_scale, 11 "intermediate/dense/input_quantizer", #FC1_aftergemm_scale, 12 "intermediate/dense/aftergemm_quantizer", #F1Bias_scale, 13 "output/dense/input_quantizer", #FC2_aftergemm_scale, 14 "output/dense/aftergemm_quantizer", #F2Bias_scale, 15 "special_F2Bias_scale", ] int8O_gemm_weight_amax_list = [0 for i in range(INT8O_GEMM_NUM)] #Q_aftergemm int8O_gemm_weight_list = [ "attention/self/query", #K_aftergemm "attention/self/key", #V_aftergemm "attention/self/value", #bmm1_aftergemm "attention/self/matmul_k_input_quantizer", #bmm2_aftergemm "attention/self/matmul_v_input_quantizer", #Proj_aftergemm "attention/output/dense", #FC1_aftergemm "intermediate/dense", #FC2_aftergemm "output/dense" ] int8O_gemm_input_amax_list = [0 for i in range(INT8O_GEMM_NUM)] #Q_aftergemm int8O_gemm_input_list = [ "attention/self/query/input_quantizer", #K_aftergemm "attention/self/key/input_quantizer", #V_aftergemm "attention/self/value/input_quantizer", #bmm1_aftergemm "attention/self/matmul_q_input_quantizer", #bmm2_aftergemm "attention/self/matmul_a_input_quantizer", #Proj_aftergemm "attention/output/dense/input_quantizer", #FC1_aftergemm "intermediate/dense/input_quantizer", #FC2_aftergemm "output/dense/input_quantizer" ] int8O_gemm_output_amax_list = [0 for i in range(INT8O_GEMM_NUM)] #Q_aftergemm int8O_gemm_output_list = [ "attention/self/query/aftergemm_quantizer", #K_aftergemm "attention/self/key/aftergemm_quantizer", #V_aftergemm "attention/self/value/aftergemm_quantizer", #bmm1_aftergemm "attention/self/softmax_input_quantizer", #bmm2_aftergemm "attention/output/dense/input_quantizer", #Proj_aftergemm "attention/output/dense/aftergemm_quantizer", #FC1_aftergemm "intermediate/dense/aftergemm_quantizer", #FC2_aftergemm "output/dense/aftergemm_quantizer" ] factor = 1000000.0 for i in range(layer_num): amaxList = np.zeros([amaxTotalNum]) amax_id = 0 for amax_name in amax_name_list: if amax_name == "special_F2Bias_scale": if i != layer_num - 1: name = "bert/encoder/layer_{}/{}/quant_max:0".format( i + 1, amax_name_list[0]) quant_max = checkpoint_utils.load_variable( tf.flags.FLAGS.init_checkpoint, name) name = "bert/encoder/layer_{}/{}/quant_min:0".format( i + 1, amax_name_list[0]) quant_min = checkpoint_utils.load_variable( tf.flags.FLAGS.init_checkpoint, name) if abs(quant_max) > abs(quant_min): amax = abs( quant_max) #int(abs(quant_max)*factor)/factor else: amax = abs( quant_min) #int(abs(quant_min)*factor)/factor else: #not used, placeholder amax = 1.0 amaxList[amax_id] = amax amax_id += 1 amaxList[amax_id] = amax / 127.0 amax_id += 1 amaxList[amax_id] = amax / 127.0 / 127.0 amax_id += 1 amaxList[amax_id] = 127.0 / amax amax_id += 1 continue name = "bert/encoder/layer_{}/{}/quant_max:0".format( i, amax_name) quant_max = checkpoint_utils.load_variable( tf.flags.FLAGS.init_checkpoint, name) name = "bert/encoder/layer_{}/{}/quant_min:0".format( i, amax_name) quant_min = checkpoint_utils.load_variable( tf.flags.FLAGS.init_checkpoint, name) if abs(quant_max) > abs(quant_min): amax = abs(quant_max) #int(abs(quant_max)*factor)/factor else: amax = abs(quant_min) #int(abs(quant_min)*factor)/factor if amax_name in int8O_gemm_input_list: int8O_gemm_input_amax_list[int8O_gemm_input_list.index( amax_name)] = amax if amax_name == "attention/self/query/input_quantizer": int8O_gemm_input_amax_list[int8O_gemm_input_list.index( "attention/self/key/input_quantizer")] = amax int8O_gemm_input_amax_list[int8O_gemm_input_list.index( "attention/self/value/input_quantizer")] = amax if amax_name in int8O_gemm_output_list: int8O_gemm_output_amax_list[int8O_gemm_output_list.index( amax_name)] = amax if amax_name in int8O_gemm_weight_list: int8O_gemm_weight_amax_list[int8O_gemm_weight_list.index( amax_name)] = amax amaxList[amax_id] = amax amax_id += 1 amaxList[amax_id] = amax / 127.0 amax_id += 1 amaxList[amax_id] = amax / 127.0 / 127.0 amax_id += 1 amaxList[amax_id] = 127.0 / amax amax_id += 1 print("done process layer_{} activation amax".format(i)) #kernel amax starts from ACTIVATION_AMAX_NUM amax_id = ACTIVATION_AMAX_NUM for kernel_id, kernel_name in enumerate(kernel_name_list): kernel = tf.get_default_graph().get_tensor_by_name( "bert/encoder/layer_{}/{}/kernel:0".format(i, kernel_name)) name = "bert/encoder/layer_{}/{}/kernel_quantizer/quant_max:0".format( i, kernel_name) quant_max2 = tf.convert_to_tensor( checkpoint_utils.load_variable( tf.flags.FLAGS.init_checkpoint, name)) name = "bert/encoder/layer_{}/{}/kernel_quantizer/quant_min:0".format( i, kernel_name) quant_min2 = tf.convert_to_tensor( checkpoint_utils.load_variable( tf.flags.FLAGS.init_checkpoint, name)) kernel_processed, quant_max_processed = transformer_op_module.weight_quantize( kernel, quant_max2, quant_min2, per_channel_quantization=per_channel_quantization) kernel_processed_, quant_max_processed_ = sess.run( [kernel_processed, quant_max_processed]) sess.run(tf.assign(kernel, kernel_processed_)) if kernel_name in int8O_gemm_weight_list: int8O_gemm_weight_amax_list[int8O_gemm_weight_list.index( kernel_name)] = quant_max_processed_[0] for e in quant_max_processed_: amaxList[amax_id] = e amax_id += 1 #for int8O gemm deQuant for j in range(INT8O_GEMM_NUM): amaxList[amax_id] = (int8O_gemm_input_amax_list[j] * int8O_gemm_weight_amax_list[j]) / ( 127.0 * int8O_gemm_output_amax_list[j]) amax_id += 1 #for trt fused MHA amax #### QKV_addBias_amax amaxList[amax_id] = np.maximum( np.maximum(amaxList[8], amaxList[16]), amaxList[24]) amax_id += 1 #### softmax amax amaxList[amax_id] = amaxList[32] amax_id += 1 #### bmm2 amax amaxList[amax_id] = amaxList[36] amax_id += 1 amaxL = tf.get_default_graph().get_tensor_by_name( "bert/encoder/layer_{}/amaxList:0".format(i)) sess.run(tf.assign(amaxL, amaxList)) print("done process layer_{} kernel weight".format(i)) saver.save(sess, out_checkpoint_file)
sess.run(init) sess.run(init_local) coord = tf.train.Coordinator() # start the threads tf.train.start_queue_runners(sess=sess, coord=coord) ### ckpt if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) restore = {} from tensorflow.contrib.framework.python.framework.checkpoint_utils import list_variables slim = tf.contrib.slim for scope in list_variables(ckpt): if 'conv' or 'fc1_image' or 'fc2_image' or 'fc3_image' or 'fc4_image' in scope[ 0]: variables_to_restore = slim.get_variables(scope=scope[0]) if variables_to_restore: restore[scope[0]] = variables_to_restore[ 0] # variables_to_restore is list : [op] for scope in list_variables(DET_ckpt): if 'fc1_adj' or 'fc2_adj' or 'fc3_adj' or 'fc4_adj' in scope[0]: variables_to_restore = slim.get_variables(scope=scope[0]) if variables_to_restore: restore[scope[0]] = variables_to_restore[ 0] # variables_to_restore is list : [op] saver = tf.train.Saver(restore)
def list_vars_in_checkpoint(self, dirname): ''' Just for tf debugging. ''' from tensorflow.contrib.framework.python.framework.checkpoint_utils import list_variables abspath = os.path.abspath(dirname) return list_variables(abspath)
def train(train_log_dir, checkpoint_dir, eval_every_n_steps=10, num_steps=3000): dataset_fn = datasets.mnist.TinyMnist w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess meta_objectives = [] meta_objectives.append( meta_objective.linear_regression.LinearRegressionMetaObjective) meta_objectives.append(meta_objective.sklearn.LogisticRegression) checkpoint_vars, train_one_step_op, ( base_model, dataset) = evaluation.construct_evaluation_graph( theta_process_fn=theta_process_fn, w_learner_fn=w_learner_fn, dataset_fn=dataset_fn, meta_objectives=meta_objectives) batch = dataset() pre_logit, outputs = base_model(batch) global_step = tf.train.get_or_create_global_step() var_list = list( snt.get_variables_in_module(base_model, tf.GraphKeys.TRAINABLE_VARIABLES)) tf.logging.info("all vars") for v in tf.all_variables(): tf.logging.info(" %s" % str(v)) global_step = tf.train.get_global_step() accumulate_global_step = global_step.assign_add(1) reset_global_step = global_step.assign(0) train_op = tf.group(train_one_step_op, accumulate_global_step, name="train_op") summary_op = tf.summary.merge_all() file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"]) if checkpoint_dir: str_var_list = checkpoint_utils.list_variables(checkpoint_dir) name_to_v_map = {v.op.name: v for v in tf.all_variables()} var_list = [ name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map ] saver = tf.train.Saver(var_list) missed_variables = [ v.op.name for v in set( snt.get_variables_in_scope("LocalWeightUpdateProcess", tf.GraphKeys.GLOBAL_VARIABLES)) - set(var_list) ] assert len(missed_variables) == 0, "Missed a theta variable." hooks = [] with tf.train.SingularMonitoredSession(master="", hooks=hooks) as sess: # global step should be restored from the evals job checkpoint or zero for fresh. step = sess.run(global_step) if step == 0 and checkpoint_dir: tf.logging.info("force restore") saver.restore(sess, checkpoint_dir) tf.logging.info("force restore done") sess.run(reset_global_step) step = sess.run(global_step) while step < num_steps: if step % eval_every_n_steps == 0: s, _, step = sess.run([summary_op, train_op, global_step]) file_writer.add_summary(s, step) else: _, step = sess.run([train_op, global_step])
def train(): with tf.Graph().as_default(): with tf.device('/gpu:0'): src_mesh = model.mesh_placeholder_inputs(BATCH_SIZE, MAX_NVERTS, MAX_NTRIS, img_size=(IMG_SIZE, IMG_SIZE), scope='src') ref_mesh = model.mesh_placeholder_inputs(BATCH_SIZE, MAX_NVERTS, MAX_NTRIS, img_size=(IMG_SIZE, IMG_SIZE), scope='ref') is_training_pl = tf.placeholder(tf.bool, shape=()) print(is_training_pl) # Note the global_step=batch parameter to minimize. # That tells the optimizer to helpfully increment the 'batch' parameter for you every time it trains. batch = tf.Variable(0, name='batch') bn_decay = get_bn_decay(batch) tf.summary.scalar('bn_decay', bn_decay) print("--- Get model and loss") # Get model and loss end_points = model.get_model(src_mesh, ref_mesh, NUM_POINTS, is_training_pl, bn=False) loss, end_points = model.get_loss(end_points) tf.summary.scalar('loss', loss) print("--- Get training operator") # Get training operator learning_rate = get_learning_rate(batch) tf.summary.scalar('learning_rate', learning_rate) if OPTIMIZER == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=MOMENTUM) elif OPTIMIZER == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate) # Create a session config = tf.ConfigProto() gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.99) config = tf.ConfigProto(gpu_options=gpu_options) config.gpu_options.allow_growth = True config.allow_soft_placement = True config.log_device_placement = False sess = tf.Session(config=config) # sess = tf.Session(config=tf.ConfigProto(log_device_placement=True)) # Add summary writers merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter( os.path.join(LOG_DIR, 'train'), sess.graph) test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test'), sess.graph) ##### all update_variables = [ x for x in tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES) ] train_op = optimizer.minimize(loss, global_step=batch, var_list=update_variables) # Init variables init = tf.global_variables_initializer() sess.run(init) ######### Loading Checkpoint ############### # CNN(Pretrained from ImageNet) if not load_model( sess, PRETRAINED_CNN_MODEL_FILE, 'vgg_16', strict=True): return # load weights from 3D deform net ckptstate = tf.train.get_checkpoint_state(PRETRAINED_DEFORM3D_PATH) if ckptstate is not None: PRETRAINED_DEFORM3D_MODEL_FILE = os.path.join( PRETRAINED_DEFORM3D_PATH, os.path.basename(ckptstate.model_checkpoint_path)) load_model(sess, PRETRAINED_DEFORM3D_MODEL_FILE, ['sharebiasnet', 'srcpc', 'refpc']) # Overall saver = tf.train.Saver([ v for v in tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES) if ('lr' not in v.name) and ('batch' not in v.name) ]) ckptstate = tf.train.get_checkpoint_state(PRETRAINED_MODEL_PATH) if ckptstate is not None: LOAD_MODEL_FILE = os.path.join( PRETRAINED_MODEL_PATH, os.path.basename(ckptstate.model_checkpoint_path)) vars_in_pretrained_model = dict( checkpoint_utils.list_variables(LOAD_MODEL_FILE)) checkpoint_keys = set(vars_in_pretrained_model.keys()) current_keys = set([ v.name.encode('ascii', 'ignore') for v in tf.global_variables() ]) try: with NoStdStreams(): saver.restore(sess, LOAD_MODEL_FILE) print("Model loaded in file: %s" % LOAD_MODEL_FILE) except: print("Fail to load overall modelfile: %s" % PRETRAINED_MODEL_PATH) ########################################### ops = { 'src_mesh': src_mesh, 'ref_mesh': ref_mesh, 'is_training_pl': is_training_pl, 'loss': loss, 'train_op': train_op, 'merged': merged, 'step': batch, 'end_points': end_points } best_loss = 1e20 for epoch in range(MAX_EPOCH): log_string('**** EPOCH %03d ****' % (epoch)) sys.stdout.flush() epoch_loss = train_one_epoch(sess, ops, train_writer, saver) if epoch_loss < best_loss: best_loss = epoch_loss save_path = saver.save( sess, os.path.join(LOG_DIR, "best_model_epoch_%03d.ckpt" % (epoch))) log_string("Model saved in file: %s" % save_path) # Save the variables to disk. if epoch % 10 == 0: save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt")) log_string("Model saved in file: %s" % save_path)
def main(_): '''does everything for testing''' decoder_cfg_file = None #read the database config file parsed_database_cfg = configparser.ConfigParser() parsed_database_cfg.read(os.path.join(FLAGS.asr_expdir, 'database.cfg')) database_cfg = dict(parsed_database_cfg.items('database')) #read the features config file parsed_feat_cfg = configparser.ConfigParser() parsed_feat_cfg.read( os.path.join(FLAGS.asr_expdir, 'model', 'features.cfg')) feat_cfg = dict(parsed_feat_cfg.items('features')) #read the asr config file parsed_asr_cfg = configparser.ConfigParser() parsed_asr_cfg.read(os.path.join(FLAGS.asr_expdir, 'model', 'asr.cfg')) asr_cfg = dict(parsed_asr_cfg.items('asr')) #read the lm config file parsed_lm_cfg = configparser.ConfigParser() parsed_lm_cfg.read(os.path.join(FLAGS.lm_expdir, 'model', 'lm.cfg')) lm_cfg = dict(parsed_lm_cfg.items('lm')) #read the asr-lm config file parsed_asr_lm_cfg = configparser.ConfigParser() parsed_asr_lm_cfg.read('config/asr_lm.cfg') asr_lm_cfg = dict(parsed_asr_lm_cfg.items('asr-lm')) #read the decoder config file if decoder_cfg_file is None: decoder_cfg_file = os.path.join(FLAGS.asr_expdir, 'model', 'decoder.cfg') parsed_decoder_cfg = configparser.ConfigParser() parsed_decoder_cfg.read(decoder_cfg_file) decoder_cfg = dict(parsed_decoder_cfg.items('decoder')) #create a feature reader featdir = os.path.join(database_cfg['test_dir'], feat_cfg['name']) with open(os.path.join(featdir, 'maxlength'), 'r') as fid: max_length = int(fid.read()) reader = feature_reader.FeatureReader( scpfile=os.path.join(featdir, 'feats.scp'), cmvnfile=os.path.join(featdir, 'cmvn.scp'), utt2spkfile=os.path.join(featdir, 'utt2spk'), max_length=max_length) #read the feature dimension with open( os.path.join(database_cfg['train_dir'], feat_cfg['name'], 'dim'), 'r') as fid: input_dim = int(fid.read()) #create the coder with open(os.path.join(database_cfg['train_dir'], 'alphabet')) as fid: alphabet = fid.read().split(' ') coder = target_coder.TargetCoder(alphabet) #create the classifier classifier = asr_lm_classifier.AsrLmClassifier( conf=asr_lm_cfg, asr_conf=asr_cfg, lm_conf=lm_cfg, output_dim=coder.num_labels) #create a decoder graph = tf.Graph() with graph.as_default(): decoder = decoder_factory.factory( conf=decoder_cfg, classifier=classifier, input_dim=input_dim, max_input_length=reader.max_length, coder=coder, expdir=FLAGS.asr_expdir) #create the lm saver varnames = zip(*checkpoint_utils.list_variables(os.path.join( FLAGS.lm_expdir, 'model', 'network.ckpt')))[0] variables = [v for v in tf.all_variables() if v.name.split(':')[0] in varnames] lm_saver = tf.train.Saver(variables) #create the asr saver varnames = zip(*checkpoint_utils.list_variables(os.path.join( FLAGS.asr_expdir, 'model', 'network.ckpt')))[0] variables = [v for v in tf.all_variables() if v.name.split(':')[0] in varnames] asr_saver = tf.train.Saver(variables) config = tf.ConfigProto() config.gpu_options.allow_growth = True #pylint: disable=E1101 config.allow_soft_placement = True with tf.Session(graph=graph, config=config) as sess: #load the lm model lm_saver.restore( sess, os.path.join(FLAGS.lm_expdir, 'model', 'network.ckpt')) #load the asr model asr_saver.restore( sess, os.path.join(FLAGS.asr_expdir, 'model', 'network.ckpt')) #decode with te neural net decoded = decoder.decode(reader, sess) #the path to the text file textfile = database_cfg['testtext'] #read all the reference transcriptions with open(textfile) as fid: lines = fid.readlines() references = dict() for line in lines: splitline = line.strip().split(' ') references[splitline[0]] = ' '.join(splitline[1:]) #compute the character error rate score = decoder.score(decoded, references) print 'score: %f' % score
def list_variables(checkpoint_dir): """See `tf.contrib.framework.list_variables`.""" return checkpoint_utils.list_variables(checkpoint_dir)