Exemplo n.º 1
0
def print_tensors_in_checkpoint_file(file_name, tensor_name):
  """Prints tensors in a checkpoint file.

  If no `tensor_name` is provided, prints the tensor names and shapes
  in the checkpoint file.

  If `tensor_name` is provided, prints the content of the tensor.

  Args:
    file_name: Name of the checkpoint file.
    tensor_name: Name of the tensor in the checkpoint file to print.
  """
  try:
    if not tensor_name:
      variables = checkpoint_utils.list_variables(file_name)
      for name, shape in variables:
        print("%s\t%s" % (name, str(shape)))
    else:
      print("tensor_name: ", tensor_name)
      print(checkpoint_utils.load_variable(file_name, tensor_name))
  except Exception as e:  # pylint: disable=broad-except
    print(str(e))
    if "corrupted compressed block contents" in str(e):
      print("It's likely that your checkpoint file has been compressed "
            "with SNAPPY.")
Exemplo n.º 2
0
def scan_checkpoint_for_vars(checkpoint_path, vars_to_check):
    check_var_list = checkpoint_utils.list_variables(checkpoint_path)
    check_var_list = [x[0] for x in check_var_list]
    check_var_set = set(check_var_list)
    vars_in_checkpoint = [x for x in vars_to_check if x.name[:x.name.index(":")] in check_var_set]
    vars_not_in_checkpoint = [x for x in vars_to_check if x.name[:x.name.index(":")] not in check_var_set]
    return vars_in_checkpoint, vars_not_in_checkpoint
Exemplo n.º 3
0
def load_checkpoints(sess, var_scopes = ('encoder', 'decoder', 'dense')):
    checkpoint_path =  configdl.lip_model_path
    if checkpoint_path:
        if os.path.isdir(checkpoint_path):
            checkpoint = tf.train.latest_checkpoint(checkpoint_path)
        else:
            checkpoint = checkpoint_path

    if configdl.featurizer:

        if checkpoint_path:
            from tensorflow.contrib.framework.python.framework import checkpoint_utils
            var_list = checkpoint_utils.list_variables(checkpoint)
            for var in var_list:
                if 'visual_frontend' in var[0]:
                    var_scopes = var_scopes + ('visual_frontend',)
                    break

        if not 'visual_frontend' in var_scopes:
            featurizer_vars = tf.global_variables(scope='visual_frontend')
            featurizer_ckpt = tf.train.get_checkpoint_state(configdl.featurizer_model_path)
            featurizer_vars = [var for var in featurizer_vars if not 'Adam' in var.name]
            tf.train.Saver(featurizer_vars).restore(sess, featurizer_ckpt.model_checkpoint_path)

    all_variables = []
    for scope in var_scopes:
        all_variables += [var for var in tf.global_variables(scope=scope)
                          if not 'Adam' in var.name ]
    if checkpoint_path:
        tf.train.Saver(all_variables).restore(sess, checkpoint)

    print("Restored saved model {}!".format(checkpoint))
 def testGetAllVariables(self):
     checkpoint_dir = self.get_temp_dir()
     with self.cached_session() as session:
         _create_checkpoints(session, checkpoint_dir)
     self.assertEqual(checkpoint_utils.list_variables(checkpoint_dir),
                      [("useful_scope/var4", [9, 9]), ("var1", [1, 10]),
                       ("var2", [10, 10]), ("var3", [100, 100])])
Exemplo n.º 5
0
def checkpoint_dtype_cast(in_checkpoint_file, out_checkpoint_file):
    var_list = checkpoint_utils.list_variables(tf.flags.FLAGS.init_checkpoint)

    def init_graph():
        for name, shape in var_list:
            var = checkpoint_utils.load_variable(
                tf.flags.FLAGS.init_checkpoint, name)
            if "quant" in name or "amaxList" in name:
                recon_dtype = var.dtype
            else:
                recon_dtype = tf.float16 if var.dtype == np.float32 else var.dtype
            tf.get_variable(name, shape=shape, dtype=recon_dtype)

    init_graph()
    saver = tf.train.Saver(builder=CastFromFloat32SaverBuilder())
    with tf.Session() as sess:
        saver.restore(sess, in_checkpoint_file)
        saver.save(sess, 'tmp.ckpt')

    tf.reset_default_graph()

    init_graph()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, 'tmp.ckpt')
        saver.save(sess, out_checkpoint_file)
Exemplo n.º 6
0
def load_model(sess, model_dir, model_name, scope='generator'):
    model_dir = os.path.join(model_dir, model_name)

    ckpt = tf.train.get_checkpoint_state(model_dir)
    try:
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            vars_model = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)

            vars_ckpt = checkpoint_utils.list_variables(os.path.join(model_dir, ckpt_name))

            vars_in_model = [var.name.split(':')[0] for var in vars_model]
            vars_in_ckpt = [var[0] for var in vars_ckpt]
            vars_to_remove = []
            for var in vars_in_model:
                if var not in vars_in_ckpt:
                    print(' [!] ' + var + ' not exists')
                    for i in range(len(vars_model)):
                        if vars_model[i].name.split(':')[0] == var:
                            vars_to_remove.append(vars_model[i])
            for var in vars_to_remove:
                vars_model.remove(var)

            saver = tf.train.Saver(vars_model)
            saver.restore(sess, os.path.join(model_dir, ckpt_name))
            counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
            print(" [*] Model load SUCCESS - " + os.path.abspath(os.path.join(model_dir, ckpt_name)))
            return True, counter
    except Exception as err:
        print(" [!] Model load FAILED - " + os.path.abspath(model_dir) + ', ' + str(err))

    print(" [!] Model load FAILED - no checkpoint in " + os.path.abspath(model_dir))
    return True, 0
Exemplo n.º 7
0
def print_tensors_in_checkpoint_file(file_name, tensor_name):
  """Prints tensors in a checkpoint file.

  If no `tensor_name` is provided, prints the tensor names and shapes
  in the checkpoint file.

  If `tensor_name` is provided, prints the content of the tensor.

  Args:
    file_name: Name of the checkpoint file.
    tensor_name: Name of the tensor in the checkpoint file to print.
  """
  try:
    if not tensor_name:
      variables = checkpoint_utils.list_variables(file_name)
      for name, shape in variables:
        print("%s\t%s" % (name, str(shape)))
    else:
      print("tensor_name: ", tensor_name)
      print(checkpoint_utils.load_variable(file_name, tensor_name))
  except Exception as e:  # pylint: disable=broad-except
    print(str(e))
    if "corrupted compressed block contents" in str(e):
      print("It's likely that your checkpoint file has been compressed "
            "with SNAPPY.")
Exemplo n.º 8
0
 def __init__(self, corpus):
     tf.reset_default_graph()
     self.corpus = corpus
     # Load embeddings
     pretrained_var_list = checkpoint_utils.list_variables(GlobalConstants.WORD_EMBEDDING_FILE_PATH)
     self.wordEmbeddings = checkpoint_utils.load_variable(checkpoint_dir=GlobalConstants.WORD_EMBEDDING_FILE_PATH,
                                                          name="embeddings")
     # IMPORTANT !!! Add a zero row at the 0. position. This will be used as the padding feature.
     self.wordEmbeddings = np.concatenate([np.zeros(shape=(1, GlobalConstants.EMBEDDING_SIZE)), self.wordEmbeddings],
                                          axis=0)
     assert self.wordEmbeddings.shape[0] == self.corpus.get_vocabulary_size() + 1
     assert self.wordEmbeddings.shape[1] == GlobalConstants.EMBEDDING_SIZE
     self.batch_size = tf.placeholder(dtype=tf.int32, shape=[], name='batch_size')
     self.input_word_codes = tf.placeholder(dtype=tf.int32, shape=[None, None], name='input_word_codes')
     self.input_y = tf.placeholder(dtype=tf.int64, shape=[None], name='input_y')
     self.keep_prob = tf.placeholder(dtype=tf.float32, shape=[], name='keep_prob')
     self.sequence_length = tf.placeholder(dtype=tf.int32, shape=[None], name='sequence_length')
     self.max_sequence_length = tf.placeholder(dtype=tf.int32, name='max_sequence_length')
     self.isTrainingFlag = tf.placeholder(name="is_training", dtype=tf.bool)
     self.embeddings = None
     self.inputs = None
     self.logits = None
     self.predictions = None
     self.numOfCorrectPredictions = None
     self.accuracy = None
     self.optimizer = None
     self.globalStep = None
     self.sess = None
     self.correctPredictions = None
     # L2 loss
     self.mainLoss = None
     self.l2_loss = tf.constant(0.0)
Exemplo n.º 9
0
def scan_checkpoint_for_vars(checkpoint_path, vars_to_check):
    check_var_list = checkpoint_utils.list_variables(checkpoint_path)
    check_var_list = [x[0] for x in check_var_list]
    check_var_set = set(check_var_list)
    vars_in_checkpoint = [x for x in vars_to_check if x.name[:x.name.index(":")] in check_var_set]
    vars_not_in_checkpoint = [x for x in vars_to_check if x.name[:x.name.index(":")] not in check_var_set]
    return vars_in_checkpoint, vars_not_in_checkpoint
Exemplo n.º 10
0
def load_model(sess, LOAD_MODEL_FILE, prefixs, strict=False):
    vars_in_pretrained_model = dict(
        checkpoint_utils.list_variables(LOAD_MODEL_FILE))
    # print(vars_in_pretrained_model)
    vars_in_defined_model = []

    for var in tf.trainable_variables():
        if isinstance(prefixs, list):
            for prefix in prefixs:
                if (var.op.name.startswith(prefix)) and (
                        var.op.name in vars_in_pretrained_model.keys()) and (
                            'logits' not in var.op.name):
                    if (list(var.shape) == vars_in_pretrained_model[
                            var.op.name]):
                        vars_in_defined_model.append(var)
        else:
            if (var.op.name.startswith(prefixs)) and (
                    var.op.name in vars_in_pretrained_model.keys()) and (
                        'logits' not in var.op.name):
                if (list(var.shape) == vars_in_pretrained_model[var.op.name]):
                    vars_in_defined_model.append(var)
    # print(vars_in_defined_model)
    saver = tf.train.Saver(vars_in_defined_model)
    try:
        saver.restore(sess, LOAD_MODEL_FILE)
        print("Model loaded in file: %s" % (LOAD_MODEL_FILE))
    except:
        if strict:
            print("Fail to load modelfile: %s" % LOAD_MODEL_FILE)
            return False
        else:
            print("Fail loaded in file: %s" % (LOAD_MODEL_FILE))
            return True

    return True
Exemplo n.º 11
0
 def restore(self):
     sess = self.sess
     saver = self.saver
     #print_tensors_in_checkpoint_file(self.weight_dir+'/u_net.ckpt', all_tensors=True, tensor_name = '')
     var_list = checkpoint_utils.list_variables(self.weight_dir+'/u_net.ckpt')
     for var in var_list:
         print(var)
     saver.restore(sess, self.weight_dir+'/u_net.ckpt')
Exemplo n.º 12
0
 def testGetAllVariables(self):
   checkpoint_dir = self.get_temp_dir()
   with self.test_session() as session:
     _create_checkpoints(session, checkpoint_dir)
   self.assertEqual(
       checkpoint_utils.list_variables(checkpoint_dir),
       [("useful_scope/var4", [9, 9]), ("var1", [1, 10]), ("var2", [10, 10]),
        ("var3", [100, 100])])
 def load_model(self, sess):
     saved_vars = checkpoint_utils.list_variables(
         checkpoint_dir="D:\\sensor_data_generation\\models\\vae_model.ckpt")
     all_vars = tf.global_variables()
     for var in all_vars:
         if "Adam" in var.name:
             continue
         source_array = checkpoint_utils.load_variable(
             checkpoint_dir="D:\\sensor_data_generation\\models\\vae_model.ckpt", name=var.name)
         tf.assign(var, source_array).eval(session=sess)
     print("X")
Exemplo n.º 14
0
def showParametersInCkpt(ckpt_dir):
    '''
    显示出ckpt文件中的参数名称
    :param ckpt_dir:
    :return:
    '''
    from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
    latest_ckp = tf.train.latest_checkpoint(ckpt_dir)
    # print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='', all_tensor_names=True)
    from tensorflow.contrib.framework.python.framework import checkpoint_utils
    var_list = checkpoint_utils.list_variables(latest_ckp)
    for v in var_list:
        print(v)
Exemplo n.º 15
0
 def restore_new_scope(self, dir, saved_scope, tf_scope):
     var_remap = dict()
     vars = [v for v in self.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf_scope) if "Adam" not in v.name]
     for var in vars:
         var_remap[saved_scope + var.name[len(tf_scope):-2]] = var
     path = os.path.join(dir, "model.ckpt")
     saver = tf.train.Saver(var_list=var_remap)
     try:
         saver.restore(self._session, path)
     except tf.errors.NotFoundError as e:
         from tensorflow.contrib.framework.python.framework import checkpoint_utils
         logging.info(checkpoint_utils.list_variables(dir))
         raise e
Exemplo n.º 16
0
 def part_load_from_checkpoint(self, sess):
     var_list = tf.global_variables()
     # present model variables
     check_var_list = checkpoint_utils.list_variables(
         self.config.load_checkpoint_path)
     # chekpoint variables
     check_var_list = [x[0] for x in check_var_list]
     check_var_set = set(check_var_list)
     vars_in_checkpoint = [
         x for x in var_list if x.name[:x.name.index(":")] in check_var_set
     ]
     #vars in teh present model that are also present in the checkpoint
     saverPart = tf.train.Saver(var_list=vars_in_checkpoint)
     saverPart.restore(sess, self.config.load_checkpoint_path)
     print("Restored variables from the parsed checkpoint")
Exemplo n.º 17
0
def scan_checkpoint_for_vars(checkpoint_path, vars_to_check):
    check_var_list = checkpoint_utils.list_variables(checkpoint_path)
    # print('check_var_list:', check_var_list)
    # print('vars_to_check:', vars_to_check)
    check_var_set = set()
    for x in check_var_list:
        check_var_set.add(x[0])
    vars_in_checkpoint = []
    vars_not_in_checkpoint = []
    for x in vars_to_check:
        var_name = x.name[:x.name.index(':')]
        if '/part_' in var_name:
            var_name = var_name[:var_name.index('/part_')]
        if var_name in check_var_set:
            vars_in_checkpoint.append(x)
        else:
            vars_not_in_checkpoint.append(x)
    return vars_in_checkpoint, vars_not_in_checkpoint
def build_codebook_multi(encoder,
                         dataset,
                         args,
                         checkpoint_file_basename=None):
    embed_bb = args.getboolean('Embedding', 'EMBED_BB')

    existing_embs = []
    if checkpoint_file_basename is not None:
        var_list = checkpoint_utils.list_variables(checkpoint_file_basename)
        for v in var_list:
            if 'embedding_normalized_' in v[0]:
                print(v)
                existing_embs.append(
                    v[0].split('/embedding_normalized_')[-1].split('.')[0])

    print(existing_embs)
    codebook = Codebook(encoder, dataset, embed_bb, existing_embs)
    return codebook
Exemplo n.º 19
0
    def restore(self, checkpoint_path, restrict_vars=None):

        if os.path.isdir(checkpoint_path):
            restore_checkpoint = tf.train.latest_checkpoint(
                checkpoint_path, latest_filename=None)
        else:
            restore_checkpoint = checkpoint_path

        # Retrieves the variables inside 'restore_checkpoint'
        ckpt_vars = [
            name for name, shape in checkpoint_utils.list_variables(
                restore_checkpoint)
        ]

        if restrict_vars and len(restrict_vars) > 0:
            restore_vars = list(set(ckpt_vars).intersection(restrict_vars))
        else:
            # If no list is provided, all the variables contained in the checkpoint will be restored
            uninit_vars = [
                bs.decode("utf-8") for bs in self._sess.run(
                    tf.report_uninitialized_variables(tf.global_variables()))
            ]
            restore_vars = list(set(ckpt_vars).intersection(uninit_vars))

        restore_variables = []
        # Retrieves the variables to be restored from their name
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            for name in restore_vars:
                restore_variables.append(
                    tf.get_variable(
                        name,
                        dtype=self._sess.graph.get_tensor_by_name(name +
                                                                  ":0").dtype))

        restore_vars_saver = tf.train.Saver(var_list=restore_variables)
        restore_vars_names = map(lambda v: v.name, restore_variables)
        restore_vars_names = ''.join("%s, " % v
                                     for v in restore_vars_names)[:-2]

        print("Restoring variables ({}) from '{}'... ".format(
            restore_vars_names, restore_checkpoint),
              end="")
        restore_vars_saver.restore(self._sess, restore_checkpoint)
        print("Done.")
 def load_trained_classifier(self, sess, run_id, target_category,
                             iteration):
     tvars = tf.trainable_variables(scope=self.classifierName)
     file_path = pathlib.Path(__file__).parent.absolute()
     model_folder = os.path.join(file_path, "..", "models", target_category)
     checkpoint_folder = os.path.join(
         model_folder, "lstm{0}_iteration{1}".format(run_id, iteration))
     model_path = os.path.join(
         checkpoint_folder,
         "lstm{0}_iteration{1}.ckpt".format(run_id, iteration))
     saved_vars = checkpoint_utils.list_variables(checkpoint_dir=model_path)
     for var in tvars:
         # assert len([_var for _var in saved_vars if _var.name == var.name]) == 1
         # if "Adam" in var.name:
         #     continue
         var_name = var.name[len(self.classifierName) + 1:]
         source_array = checkpoint_utils.load_variable(
             checkpoint_dir=model_path, name=var_name)
         tf.assign(var, source_array).eval(session=sess)
Exemplo n.º 21
0
def load_fp32_weights_into_fp16_vars(checkpoint_path: Path) -> List:
    """Load fp32 weights from checkpoint path into fp16 variables.

    Assumes that caller has executed `tf.run(tf.global_variables_initializer())`

    Args:
        checkpoint_path: Checkpoint path

    Returns:
        Collection of ops to use to restore the weights in the graph.
    """
    checkpoint_variables = [var_name for var_name, _ in list_variables(checkpoint_path)]

    for graph_var in tf.global_variables():
        if graph_var.op.name in checkpoint_variables:
            var = load_variable(checkpoint_path, graph_var.op.name)
            weights = tf.cast(var, tf.float16) if var.dtype == np.float32 else var
            tf.add_to_collection('restore_ops', graph_var.assign(weights))
    return tf.get_collection('restore_ops')
Exemplo n.º 22
0
def checkpoint_dtype_cast(in_checkpoint_file, out_checkpoint_file):
    var_list = checkpoint_utils.list_variables(in_checkpoint_file)

    def init_graph():
        for name, shape in var_list:
            var = checkpoint_utils.load_variable(in_checkpoint_file, name)
            recon_dtype = tf.float16 if var.dtype == np.float32 else var.dtype
            tf.get_variable(name, shape=shape, dtype=recon_dtype)

    init_graph()
    saver = tf.train.Saver(builder=CastFromFloat32SaverBuilder())
    with tf.Session() as sess:
        saver.restore(sess, in_checkpoint_file)
        saver.save(sess, "./init_ckpt/tmp.ckpt")

    tf.reset_default_graph()

    init_graph()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, "./init_ckpt/tmp.ckpt")
        saver.save(sess, out_checkpoint_file)
Exemplo n.º 23
0
def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000):
  dataset_fn = datasets.mnist.TinyMnist
  w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner
  theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess

  meta_objectives = []
  meta_objectives.append(
      meta_objective.linear_regression.LinearRegressionMetaObjective)
  meta_objectives.append(meta_objective.sklearn.LogisticRegression)

  checkpoint_vars, train_one_step_op, (
      base_model, dataset) = evaluation.construct_evaluation_graph(
          theta_process_fn=theta_process_fn,
          w_learner_fn=w_learner_fn,
          dataset_fn=dataset_fn,
          meta_objectives=meta_objectives)
  batch = dataset()
  pre_logit, outputs = base_model(batch)

  global_step = tf.train.get_or_create_global_step()
  var_list = list(
      snt.get_variables_in_module(base_model, tf.GraphKeys.TRAINABLE_VARIABLES))

  tf.logging.info("all vars")
  for v in tf.all_variables():
    tf.logging.info("   %s" % str(v))
  global_step = tf.train.get_global_step()
  accumulate_global_step = global_step.assign_add(1)
  reset_global_step = global_step.assign(0)

  train_op = tf.group(
      train_one_step_op, accumulate_global_step, name="train_op")

  summary_op = tf.summary.merge_all()

  file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"])
  if checkpoint:
    str_var_list = checkpoint_utils.list_variables(checkpoint)
    name_to_v_map = {v.op.name: v for v in tf.all_variables()}
    var_list = [
        name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map
    ]
    saver = tf.train.Saver(var_list)
    missed_variables = [
        v.op.name for v in set(
            snt.get_variables_in_scope("LocalWeightUpdateProcess",
                                       tf.GraphKeys.GLOBAL_VARIABLES)) -
        set(var_list)
    ]
    assert len(missed_variables) == 0, "Missed a theta variable."

  hooks = []

  with tf.train.SingularMonitoredSession(master="", hooks=hooks) as sess:

    # global step should be restored from the evals job checkpoint or zero for fresh.
    step = sess.run(global_step)

    if step == 0 and checkpoint:
      tf.logging.info("force restore")
      saver.restore(sess, checkpoint)
      tf.logging.info("force restore done")
      sess.run(reset_global_step)
      step = sess.run(global_step)

    while step < num_steps:
      if step % eval_every_n_steps == 0:
        s, _, step = sess.run([summary_op, train_op, global_step])
        file_writer.add_summary(s, step)
      else:
        _, step = sess.run([train_op, global_step])
                         args.image_size,
                         feature_name=args.feature_name,
                         annotation_threshold=args.annotation_threshold)

#initialize variables------------------
sess.run(tf.global_variables_initializer())

#set up savers------------
saver = tf.train.Saver(max_to_keep=20)
summary_writer = tf.summary.FileWriter(args.logs_dir, sess.graph)

#try to reload weights--------------------
if args.model_dir is not None:
    ckpt = tf.train.get_checkpoint_state(args.model_dir)
    if ckpt and ckpt.model_checkpoint_path:
        vars_stored = [var[0] for var in list_variables(args.model_dir)]
        vars_restore = [
            v for v in tf.global_variables() if v.name[0:-2] in vars_stored
        ]
        restore_saver = tf.train.Saver(vars_restore)
        restore_saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")
    else:
        print("Model restore failed...")

#start training-------------------------
for itr in range(args.max_iteration):
    if not args.weight_data:
        batch = train_dataset_reader.next_batch_in_seqs(
            args.batch_size, args.n_steps)
    else:
        serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(input_dict)
        save_path = nn_estimator.export_savedmodel(models_dir, serving_input_receiver_fn)
        
    return save_path

# GETTING VALIDATION SCORES
from tensorflow.contrib.framework.python.framework import checkpoint_utils
import re
from collections import defaultdict

files = [checkpoint_file]
eval_data_dict = pickle.load(open('data/eval_data_dict_%d_files_%s_rows.pkl' % (NUM_DATAFILES, str(ROWS_TO_EXTRACT)), 'rb'))

random_model_dir = files[0]
ckpt_var_list = [var_name for (var_name, shape) in checkpoint_utils.list_variables(random_model_dir)]

vars_to_restore = list()
for graph_var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    if graph_var.name[:-2] in ckpt_var_list:
        vars_to_restore.append(graph_var)

try:
    global_step = tf.train.create_global_step()
except:
    print("global_step already exists")
    
if global_step not in vars_to_restore:
    vars_to_restore.append(global_step)

init_global = tf.global_variables_initializer()
Exemplo n.º 26
0
def train(args):
    output_dir = args.output
    log_dir = args.log_dir if args.log_dir else os.path.join(output_dir, 'log')
    model_dir = os.path.join(output_dir, 'model')
    records_dir = args.records_dir if not args.data_dir else os.path.join(
        args.data_dir, args.records_dir)
    result_dir = os.path.join(output_dir, 'result')
    for dir in [output_dir, log_dir, model_dir, result_dir]:
        if not os.path.exists(dir):
            os.makedirs(dir)

    # save the args info to ouptut dir.
    with open(os.path.join(output_dir, 'args.json'), 'w') as p:
        json.dump(vars(args), p, indent=2)

    # load meta info
    with open(os.path.join(records_dir, 'train_meta.json'),
              'r',
              encoding='utf8') as p:
        train_total = json.load(p)['size']
        batch_num_per_epoch = int(np.ceil(train_total / args.batch_size))
        print(f'{time.asctime()} - batch num per epoch: {batch_num_per_epoch}')

    with open(os.path.join(records_dir, 'dev_meta.json'), 'r',
              encoding='utf8') as p:
        dev_total = json.load(p)['size']

    train_records_file = os.path.join(records_dir, 'train.tfrecords')
    dev_records_file = os.path.join(records_dir, 'dev.tfrecords')

    with tf.Graph().as_default() as graph, tf.device('/gpu:0'):

        parser = get_record_parser(args)
        train_dataset = get_batch_dataset(train_records_file, parser, args)
        dev_dataset = get_batch_dataset(dev_records_file, parser, args)

        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, train_dataset.output_types, train_dataset.output_shapes)
        train_iterator = train_dataset.make_one_shot_iterator()
        dev_iterator = dev_dataset.make_one_shot_iterator()

        model = CommentModel(args, iterator)

        session_config = tf.ConfigProto(allow_soft_placement=True)
        session_config.gpu_options.allow_growth = True
        sess = tf.Session(config=session_config)

        # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)

        writer = tf.summary.FileWriter(log_dir)
        best_ppl = tf.Variable(300,
                               trainable=False,
                               name='best_ppl',
                               dtype=tf.float32)

        saver = tf.train.Saver(max_to_keep=10000)
        if args.restore:
            model_file = args.restore_model or tf.train.latest_checkpoint(
                model_dir)
            print(f'{time.asctime()} - Restore model from {model_file}..')
            var_list = [
                _[0] for _ in checkpoint_utils.list_variables(model_file)
            ]
            saved_vars = [
                _ for _ in tf.global_variables()
                if _.name.split(':')[0] in var_list
            ]
            res_saver = tf.train.Saver(saved_vars)
            res_saver.restore(sess, model_file)

            left_vars = [
                _ for _ in tf.global_variables()
                if _.name.split(':')[0] not in var_list
            ]
            sess.run(tf.initialize_variables(left_vars))
            print(
                f'{time.asctime()} - Restore {len(var_list)} vars and initialize {len(left_vars)} vars.'
            )
            print(left_vars)
        else:
            print(f'{time.asctime()} - Initialize model..')
            sess.run(tf.global_variables_initializer())
            # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)

        train_handle = sess.run(train_iterator.string_handle())
        dev_handle = sess.run(dev_iterator.string_handle())

        sess.run(tf.assign(model.is_train, tf.constant(True,
                                                       dtype=tf.bool)))  #tmp

        patience = 0

        lr = sess.run(model.lr)
        b_ppl = sess.run(best_ppl)
        print(f'{time.asctime()} - lr: {lr:.3f}  best_ppl:{b_ppl:.3f}')

        t0 = datetime.now()

        while True:
            global_step = sess.run(model.global_step) + 1
            epoch = int(np.ceil(global_step / batch_num_per_epoch))

            loss, loss_gen, ppl, train_op, merge_sum, target, check_1 = sess.run(
                [
                    model.loss, model.loss_gen, model.ppl, model._train_op,
                    model._summaries, model.target, model.check_dec_outputs
                ],
                feed_dict={handle: train_handle})

            ela_time = str(datetime.now() - t0).split('.')[0]

            print(
                (f'{time.asctime()} - step/epoch:{global_step}/{epoch:<3d}   '
                 f'gen_loss:{loss_gen:<3.3f}  '
                 f'ppl:{ppl:<4.3f}  '
                 f'elapsed:{ela_time}\r'),
                end='')

            if global_step % args.period == 0:
                writer.add_summary(merge_sum, global_step)
                writer.flush()

            if global_step % args.checkpoint == 0:
                model_file = os.path.join(model_dir, 'model')
                saver.save(sess, model_file, global_step=global_step)

            # if  global_step % batch_num_per_epoch== 0:
            if global_step % args.checkpoint == 0 and not args.no_eval:
                sess.run(
                    tf.assign(model.is_train, tf.constant(False,
                                                          dtype=tf.bool)))
                metrics, summ = evaluate_batch(model,
                                               dev_total // args.batch_size,
                                               sess, handle, dev_handle,
                                               iterator)
                sess.run(
                    tf.assign(model.is_train, tf.constant(True,
                                                          dtype=tf.bool)))

                for s in summ:
                    writer.add_summary(s, global_step)

                dev_ppl = metrics['ppl']
                dev_gen_loss = metrics['gen_loss']

                tqdm.write(
                    f'{time.asctime()} - Evaluate after steps:{global_step}, '
                    f' gen_loss:{dev_gen_loss:.4f},  ppl:{dev_ppl:.3f}')

                if dev_ppl < b_ppl:
                    sess.run(tf.assign(best_ppl, dev_ppl))
                    saver.save(sess, save_path=os.path.join(model_dir, 'best'))
                    tqdm.write(
                        f'{time.asctime()} - the ppl is lower than current best ppl so saved the model.'
                    )
                    patience = 0
                else:
                    patience += 1

                if patience >= args.patience:
                    lr = lr / 2
                    sess.run(
                        tf.assign(model.lr, tf.constant(lr, dtype=tf.float32)))
                    patience = 0
                    tqdm.write(
                        f'{time.asctime()} - The lr is decayed form {lr*2} to {lr}.'
                    )
Exemplo n.º 27
0
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    ckpt = tf.train.get_checkpoint_state(ckpt_dir)
    start = 0

    if ckpt and ckpt.model_checkpoint_path:
        start = int(ckpt.model_checkpoint_path.split("-")[1])
        logger.info("start by iteration: %d" % (start))
        saver = tf.train.Saver()
        saver.restore(sess, ckpt.model_checkpoint_path)
        logger.info("model is restored using " + str(ckpt))
    elif pretraind_model:
        restore = {}
        from tensorflow.contrib.framework.python.framework.checkpoint_utils import list_variables
        slim = tf.contrib.slim
        for scope in list_variables(pretraind_model):
            if 'conv' in scope[0]:
                variables_to_restore = slim.get_variables(scope=scope[0])
                if variables_to_restore:
                    restore[scope[0]] = variables_to_restore[
                        0]  # variables_to_restore is list : [op]
        saver = tf.train.Saver(restore)
        saver.restore(sess, pretraind_model)
        logger.info("model is restored conv only using " +
                    str(pretraind_model))

    assign_op = global_step.assign(start)
    sess.run(assign_op)
    model_saver = tf.train.Saver(max_to_keep=30)

    # train
Exemplo n.º 28
0
            tf.constant(img, dtype=tf.float32),
            axis=0)  #tf.zeros([batch_size, img_size, img_size, 3])

        logits, endpoints = resnet_v1.resnet_v1_50(inputbatch,
                                                   1000,
                                                   is_training=False)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False
        sess = tf.Session(config=config)
        variables_to_restore = []

        a = [
            name for name, _ in checkpoint_utils.list_variables(
                'pretrained_model/resnet_v1_50.ckpt')
        ]
        # print a
        for var in slim.get_model_variables():
            if (var.op.name.startswith('resnet_v1_50')) and (
                    var.op.name in a) and ('logits' not in var.op.name):
                variables_to_restore.append(var)
        # print variables_to_restore
        # slim.assign_from_checkpoint_fn('pretrained_model/resnet_v1_50.ckpt', variables_to_restore, ignore_missing_vars=False)
        init = tf.global_variables_initializer()
        sess.run(init)

        saver = tf.train.Saver(variables_to_restore)
        saver.restore(sess, 'pretrained_model/resnet_v1_50.ckpt')
        # print a.keys()
Exemplo n.º 29
0
def checkpoint_quantization(in_checkpoint_file, out_checkpoint_file,
                            per_channel_quantization):
    var_list = checkpoint_utils.list_variables(tf.flags.FLAGS.init_checkpoint)

    def init_graph():
        restore_vars = []
        layer_num = 0
        regex = re.compile('layer_\d+')
        amaxTotalNum = 0
        for name, shape in var_list:
            var = checkpoint_utils.load_variable(
                tf.flags.FLAGS.init_checkpoint, name)
            if "intermediate/dense/kernel" in name and amaxTotalNum == 0:
                amaxTotalNum = ACTIVATION_AMAX_NUM + 9 * shape[
                    0] + INT8O_GEMM_NUM + TRT_FUSED_MHA_AMAX_NUM
                print(amaxTotalNum, shape[0])
            recon_dtype = var.dtype
            restore_vars.append(
                tf.get_variable(name, shape=shape, dtype=var.dtype))
            tmp = regex.findall(name)
            if len(tmp) < 1:
                continue
            num_tmp = int(tmp[0].replace("layer_", ""))
            if layer_num < num_tmp:
                layer_num = num_tmp
        layer_num = layer_num + 1
        #add new var for amax
        for i in range(layer_num):
            tf.get_variable("bert/encoder/layer_{}/amaxList".format(i),
                            shape=[amaxTotalNum],
                            dtype=tf.float32)
        return layer_num, amaxTotalNum, restore_vars

    layer_num, amaxTotalNum, restore_vars = init_graph()
    restorer = tf.train.Saver(restore_vars)
    saver = tf.train.Saver()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        restorer.restore(sess, in_checkpoint_file)
        kernel_name_list = [
            "attention/self/query", "attention/self/key",
            "attention/self/value", "attention/output/dense",
            "intermediate/dense", "output/dense"
        ]

        #input_scale, 0
        amax_name_list = [
            "attention/self/query/input_quantizer",
            #Q_aftergemm_scale, 1
            "attention/self/query/aftergemm_quantizer",
            #Qbias_scale, 2
            "attention/self/matmul_q_input_quantizer",
            #K_aftergemm_scale, 3
            "attention/self/key/aftergemm_quantizer",
            #Kbias_scale, 4
            "attention/self/matmul_k_input_quantizer",
            #V_aftergemm_scale, 5
            "attention/self/value/aftergemm_quantizer",
            #Vbias_scale, 6
            "attention/self/matmul_v_input_quantizer",
            #bmm1_scale, 7
            "attention/self/softmax_input_quantizer",
            #Softmax_scale, 8
            "attention/self/matmul_a_input_quantizer",
            #bmm2_scale, 9
            "attention/output/dense/input_quantizer",
            #Proj_aftergemm_scale, 10
            "attention/output/dense/aftergemm_quantizer",
            #ProjBiasNorm_scale, 11
            "intermediate/dense/input_quantizer",
            #FC1_aftergemm_scale, 12
            "intermediate/dense/aftergemm_quantizer",
            #F1Bias_scale, 13
            "output/dense/input_quantizer",
            #FC2_aftergemm_scale, 14
            "output/dense/aftergemm_quantizer",
            #F2Bias_scale, 15
            "special_F2Bias_scale",
        ]

        int8O_gemm_weight_amax_list = [0 for i in range(INT8O_GEMM_NUM)]
        #Q_aftergemm
        int8O_gemm_weight_list = [
            "attention/self/query",
            #K_aftergemm
            "attention/self/key",
            #V_aftergemm
            "attention/self/value",
            #bmm1_aftergemm
            "attention/self/matmul_k_input_quantizer",
            #bmm2_aftergemm
            "attention/self/matmul_v_input_quantizer",
            #Proj_aftergemm
            "attention/output/dense",
            #FC1_aftergemm
            "intermediate/dense",
            #FC2_aftergemm
            "output/dense"
        ]

        int8O_gemm_input_amax_list = [0 for i in range(INT8O_GEMM_NUM)]
        #Q_aftergemm
        int8O_gemm_input_list = [
            "attention/self/query/input_quantizer",
            #K_aftergemm
            "attention/self/key/input_quantizer",
            #V_aftergemm
            "attention/self/value/input_quantizer",
            #bmm1_aftergemm
            "attention/self/matmul_q_input_quantizer",
            #bmm2_aftergemm
            "attention/self/matmul_a_input_quantizer",
            #Proj_aftergemm
            "attention/output/dense/input_quantizer",
            #FC1_aftergemm
            "intermediate/dense/input_quantizer",
            #FC2_aftergemm
            "output/dense/input_quantizer"
        ]

        int8O_gemm_output_amax_list = [0 for i in range(INT8O_GEMM_NUM)]
        #Q_aftergemm
        int8O_gemm_output_list = [
            "attention/self/query/aftergemm_quantizer",
            #K_aftergemm
            "attention/self/key/aftergemm_quantizer",
            #V_aftergemm
            "attention/self/value/aftergemm_quantizer",
            #bmm1_aftergemm
            "attention/self/softmax_input_quantizer",
            #bmm2_aftergemm
            "attention/output/dense/input_quantizer",
            #Proj_aftergemm
            "attention/output/dense/aftergemm_quantizer",
            #FC1_aftergemm
            "intermediate/dense/aftergemm_quantizer",
            #FC2_aftergemm
            "output/dense/aftergemm_quantizer"
        ]

        factor = 1000000.0
        for i in range(layer_num):
            amaxList = np.zeros([amaxTotalNum])
            amax_id = 0
            for amax_name in amax_name_list:
                if amax_name == "special_F2Bias_scale":
                    if i != layer_num - 1:
                        name = "bert/encoder/layer_{}/{}/quant_max:0".format(
                            i + 1, amax_name_list[0])
                        quant_max = checkpoint_utils.load_variable(
                            tf.flags.FLAGS.init_checkpoint, name)
                        name = "bert/encoder/layer_{}/{}/quant_min:0".format(
                            i + 1, amax_name_list[0])
                        quant_min = checkpoint_utils.load_variable(
                            tf.flags.FLAGS.init_checkpoint, name)
                        if abs(quant_max) > abs(quant_min):
                            amax = abs(
                                quant_max)  #int(abs(quant_max)*factor)/factor
                        else:
                            amax = abs(
                                quant_min)  #int(abs(quant_min)*factor)/factor
                    else:

                        #not used, placeholder
                        amax = 1.0

                    amaxList[amax_id] = amax
                    amax_id += 1
                    amaxList[amax_id] = amax / 127.0
                    amax_id += 1
                    amaxList[amax_id] = amax / 127.0 / 127.0
                    amax_id += 1
                    amaxList[amax_id] = 127.0 / amax
                    amax_id += 1
                    continue

                name = "bert/encoder/layer_{}/{}/quant_max:0".format(
                    i, amax_name)
                quant_max = checkpoint_utils.load_variable(
                    tf.flags.FLAGS.init_checkpoint, name)
                name = "bert/encoder/layer_{}/{}/quant_min:0".format(
                    i, amax_name)
                quant_min = checkpoint_utils.load_variable(
                    tf.flags.FLAGS.init_checkpoint, name)

                if abs(quant_max) > abs(quant_min):
                    amax = abs(quant_max)  #int(abs(quant_max)*factor)/factor
                else:
                    amax = abs(quant_min)  #int(abs(quant_min)*factor)/factor

                if amax_name in int8O_gemm_input_list:
                    int8O_gemm_input_amax_list[int8O_gemm_input_list.index(
                        amax_name)] = amax
                    if amax_name == "attention/self/query/input_quantizer":
                        int8O_gemm_input_amax_list[int8O_gemm_input_list.index(
                            "attention/self/key/input_quantizer")] = amax
                        int8O_gemm_input_amax_list[int8O_gemm_input_list.index(
                            "attention/self/value/input_quantizer")] = amax

                if amax_name in int8O_gemm_output_list:
                    int8O_gemm_output_amax_list[int8O_gemm_output_list.index(
                        amax_name)] = amax

                if amax_name in int8O_gemm_weight_list:
                    int8O_gemm_weight_amax_list[int8O_gemm_weight_list.index(
                        amax_name)] = amax

                amaxList[amax_id] = amax
                amax_id += 1
                amaxList[amax_id] = amax / 127.0
                amax_id += 1
                amaxList[amax_id] = amax / 127.0 / 127.0
                amax_id += 1
                amaxList[amax_id] = 127.0 / amax
                amax_id += 1

            print("done process layer_{} activation amax".format(i))

            #kernel amax starts from ACTIVATION_AMAX_NUM
            amax_id = ACTIVATION_AMAX_NUM
            for kernel_id, kernel_name in enumerate(kernel_name_list):
                kernel = tf.get_default_graph().get_tensor_by_name(
                    "bert/encoder/layer_{}/{}/kernel:0".format(i, kernel_name))

                name = "bert/encoder/layer_{}/{}/kernel_quantizer/quant_max:0".format(
                    i, kernel_name)
                quant_max2 = tf.convert_to_tensor(
                    checkpoint_utils.load_variable(
                        tf.flags.FLAGS.init_checkpoint, name))

                name = "bert/encoder/layer_{}/{}/kernel_quantizer/quant_min:0".format(
                    i, kernel_name)
                quant_min2 = tf.convert_to_tensor(
                    checkpoint_utils.load_variable(
                        tf.flags.FLAGS.init_checkpoint, name))

                kernel_processed, quant_max_processed = transformer_op_module.weight_quantize(
                    kernel,
                    quant_max2,
                    quant_min2,
                    per_channel_quantization=per_channel_quantization)
                kernel_processed_, quant_max_processed_ = sess.run(
                    [kernel_processed, quant_max_processed])
                sess.run(tf.assign(kernel, kernel_processed_))
                if kernel_name in int8O_gemm_weight_list:
                    int8O_gemm_weight_amax_list[int8O_gemm_weight_list.index(
                        kernel_name)] = quant_max_processed_[0]
                for e in quant_max_processed_:
                    amaxList[amax_id] = e
                    amax_id += 1

            #for int8O gemm deQuant
            for j in range(INT8O_GEMM_NUM):
                amaxList[amax_id] = (int8O_gemm_input_amax_list[j] *
                                     int8O_gemm_weight_amax_list[j]) / (
                                         127.0 *
                                         int8O_gemm_output_amax_list[j])
                amax_id += 1

            #for trt fused MHA amax
            #### QKV_addBias_amax
            amaxList[amax_id] = np.maximum(
                np.maximum(amaxList[8], amaxList[16]), amaxList[24])
            amax_id += 1
            #### softmax amax
            amaxList[amax_id] = amaxList[32]
            amax_id += 1
            #### bmm2 amax
            amaxList[amax_id] = amaxList[36]
            amax_id += 1
            amaxL = tf.get_default_graph().get_tensor_by_name(
                "bert/encoder/layer_{}/amaxList:0".format(i))
            sess.run(tf.assign(amaxL, amaxList))

            print("done process layer_{} kernel weight".format(i))

        saver.save(sess, out_checkpoint_file)
Exemplo n.º 30
0
sess.run(init)
sess.run(init_local)

coord = tf.train.Coordinator()
# start the threads
tf.train.start_queue_runners(sess=sess, coord=coord)

### ckpt
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

restore = {}
from tensorflow.contrib.framework.python.framework.checkpoint_utils import list_variables

slim = tf.contrib.slim
for scope in list_variables(ckpt):
    if 'conv' or 'fc1_image' or 'fc2_image' or 'fc3_image' or 'fc4_image' in scope[
            0]:
        variables_to_restore = slim.get_variables(scope=scope[0])
        if variables_to_restore:
            restore[scope[0]] = variables_to_restore[
                0]  # variables_to_restore is list : [op]

for scope in list_variables(DET_ckpt):
    if 'fc1_adj' or 'fc2_adj' or 'fc3_adj' or 'fc4_adj' in scope[0]:
        variables_to_restore = slim.get_variables(scope=scope[0])
        if variables_to_restore:
            restore[scope[0]] = variables_to_restore[
                0]  # variables_to_restore is list : [op]

saver = tf.train.Saver(restore)
Exemplo n.º 31
0
 def list_vars_in_checkpoint(self, dirname):
     ''' Just for tf debugging.  '''
     from tensorflow.contrib.framework.python.framework.checkpoint_utils import list_variables
     abspath = os.path.abspath(dirname)
     return list_variables(abspath)
Exemplo n.º 32
0
def train(train_log_dir,
          checkpoint_dir,
          eval_every_n_steps=10,
          num_steps=3000):
    dataset_fn = datasets.mnist.TinyMnist
    w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner
    theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess

    meta_objectives = []
    meta_objectives.append(
        meta_objective.linear_regression.LinearRegressionMetaObjective)
    meta_objectives.append(meta_objective.sklearn.LogisticRegression)

    checkpoint_vars, train_one_step_op, (
        base_model, dataset) = evaluation.construct_evaluation_graph(
            theta_process_fn=theta_process_fn,
            w_learner_fn=w_learner_fn,
            dataset_fn=dataset_fn,
            meta_objectives=meta_objectives)
    batch = dataset()
    pre_logit, outputs = base_model(batch)

    global_step = tf.train.get_or_create_global_step()
    var_list = list(
        snt.get_variables_in_module(base_model,
                                    tf.GraphKeys.TRAINABLE_VARIABLES))

    tf.logging.info("all vars")
    for v in tf.all_variables():
        tf.logging.info("   %s" % str(v))
    global_step = tf.train.get_global_step()
    accumulate_global_step = global_step.assign_add(1)
    reset_global_step = global_step.assign(0)

    train_op = tf.group(train_one_step_op,
                        accumulate_global_step,
                        name="train_op")

    summary_op = tf.summary.merge_all()

    file_writer = summary_utils.LoggingFileWriter(train_log_dir,
                                                  regexes=[".*"])
    if checkpoint_dir:
        str_var_list = checkpoint_utils.list_variables(checkpoint_dir)
        name_to_v_map = {v.op.name: v for v in tf.all_variables()}
        var_list = [
            name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map
        ]
        saver = tf.train.Saver(var_list)
        missed_variables = [
            v.op.name for v in set(
                snt.get_variables_in_scope("LocalWeightUpdateProcess",
                                           tf.GraphKeys.GLOBAL_VARIABLES)) -
            set(var_list)
        ]
        assert len(missed_variables) == 0, "Missed a theta variable."

    hooks = []

    with tf.train.SingularMonitoredSession(master="", hooks=hooks) as sess:

        # global step should be restored from the evals job checkpoint or zero for fresh.
        step = sess.run(global_step)

        if step == 0 and checkpoint_dir:
            tf.logging.info("force restore")
            saver.restore(sess, checkpoint_dir)
            tf.logging.info("force restore done")
            sess.run(reset_global_step)
            step = sess.run(global_step)

        while step < num_steps:
            if step % eval_every_n_steps == 0:
                s, _, step = sess.run([summary_op, train_op, global_step])
                file_writer.add_summary(s, step)
            else:
                _, step = sess.run([train_op, global_step])
Exemplo n.º 33
0
def train():
    with tf.Graph().as_default():
        with tf.device('/gpu:0'):
            src_mesh = model.mesh_placeholder_inputs(BATCH_SIZE,
                                                     MAX_NVERTS,
                                                     MAX_NTRIS,
                                                     img_size=(IMG_SIZE,
                                                               IMG_SIZE),
                                                     scope='src')
            ref_mesh = model.mesh_placeholder_inputs(BATCH_SIZE,
                                                     MAX_NVERTS,
                                                     MAX_NTRIS,
                                                     img_size=(IMG_SIZE,
                                                               IMG_SIZE),
                                                     scope='ref')

            is_training_pl = tf.placeholder(tf.bool, shape=())
            print(is_training_pl)

            # Note the global_step=batch parameter to minimize.
            # That tells the optimizer to helpfully increment the 'batch' parameter for you every time it trains.
            batch = tf.Variable(0, name='batch')
            bn_decay = get_bn_decay(batch)
            tf.summary.scalar('bn_decay', bn_decay)

            print("--- Get model and loss")
            # Get model and loss

            end_points = model.get_model(src_mesh,
                                         ref_mesh,
                                         NUM_POINTS,
                                         is_training_pl,
                                         bn=False)
            loss, end_points = model.get_loss(end_points)
            tf.summary.scalar('loss', loss)

            print("--- Get training operator")
            # Get training operator
            learning_rate = get_learning_rate(batch)
            tf.summary.scalar('learning_rate', learning_rate)
            if OPTIMIZER == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                       momentum=MOMENTUM)
            elif OPTIMIZER == 'adam':
                optimizer = tf.train.AdamOptimizer(learning_rate)

            # Create a session
            config = tf.ConfigProto()
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.99)
            config = tf.ConfigProto(gpu_options=gpu_options)
            config.gpu_options.allow_growth = True
            config.allow_soft_placement = True
            config.log_device_placement = False
            sess = tf.Session(config=config)
            # sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))

            # Add summary writers
            merged = tf.summary.merge_all()
            train_writer = tf.summary.FileWriter(
                os.path.join(LOG_DIR, 'train'), sess.graph)
            test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test'),
                                                sess.graph)

            ##### all
            update_variables = [
                x for x in tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
            ]

            train_op = optimizer.minimize(loss,
                                          global_step=batch,
                                          var_list=update_variables)

            # Init variables
            init = tf.global_variables_initializer()
            sess.run(init)

            ######### Loading Checkpoint ###############
            # CNN(Pretrained from ImageNet)
            if not load_model(
                    sess, PRETRAINED_CNN_MODEL_FILE, 'vgg_16', strict=True):
                return

            # load weights from 3D deform net
            ckptstate = tf.train.get_checkpoint_state(PRETRAINED_DEFORM3D_PATH)
            if ckptstate is not None:
                PRETRAINED_DEFORM3D_MODEL_FILE = os.path.join(
                    PRETRAINED_DEFORM3D_PATH,
                    os.path.basename(ckptstate.model_checkpoint_path))
                load_model(sess, PRETRAINED_DEFORM3D_MODEL_FILE,
                           ['sharebiasnet', 'srcpc', 'refpc'])

            # Overall
            saver = tf.train.Saver([
                v for v in tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
                if ('lr' not in v.name) and ('batch' not in v.name)
            ])
            ckptstate = tf.train.get_checkpoint_state(PRETRAINED_MODEL_PATH)

            if ckptstate is not None:
                LOAD_MODEL_FILE = os.path.join(
                    PRETRAINED_MODEL_PATH,
                    os.path.basename(ckptstate.model_checkpoint_path))
                vars_in_pretrained_model = dict(
                    checkpoint_utils.list_variables(LOAD_MODEL_FILE))
                checkpoint_keys = set(vars_in_pretrained_model.keys())
                current_keys = set([
                    v.name.encode('ascii', 'ignore')
                    for v in tf.global_variables()
                ])
                try:
                    with NoStdStreams():
                        saver.restore(sess, LOAD_MODEL_FILE)
                    print("Model loaded in file: %s" % LOAD_MODEL_FILE)
                except:
                    print("Fail to load overall modelfile: %s" %
                          PRETRAINED_MODEL_PATH)

            ###########################################

            ops = {
                'src_mesh': src_mesh,
                'ref_mesh': ref_mesh,
                'is_training_pl': is_training_pl,
                'loss': loss,
                'train_op': train_op,
                'merged': merged,
                'step': batch,
                'end_points': end_points
            }

            best_loss = 1e20
            for epoch in range(MAX_EPOCH):
                log_string('**** EPOCH %03d ****' % (epoch))
                sys.stdout.flush()

                epoch_loss = train_one_epoch(sess, ops, train_writer, saver)
                if epoch_loss < best_loss:
                    best_loss = epoch_loss
                    save_path = saver.save(
                        sess,
                        os.path.join(LOG_DIR,
                                     "best_model_epoch_%03d.ckpt" % (epoch)))
                    log_string("Model saved in file: %s" % save_path)

                # Save the variables to disk.
                if epoch % 10 == 0:
                    save_path = saver.save(sess,
                                           os.path.join(LOG_DIR, "model.ckpt"))
                    log_string("Model saved in file: %s" % save_path)
Exemplo n.º 34
0
Arquivo: test.py Projeto: uniq10/nabu
def main(_):
    '''does everything for testing'''

    decoder_cfg_file = None

    #read the database config file
    parsed_database_cfg = configparser.ConfigParser()
    parsed_database_cfg.read(os.path.join(FLAGS.asr_expdir, 'database.cfg'))
    database_cfg = dict(parsed_database_cfg.items('database'))

    #read the features config file
    parsed_feat_cfg = configparser.ConfigParser()
    parsed_feat_cfg.read(
        os.path.join(FLAGS.asr_expdir, 'model', 'features.cfg'))
    feat_cfg = dict(parsed_feat_cfg.items('features'))

    #read the asr config file
    parsed_asr_cfg = configparser.ConfigParser()
    parsed_asr_cfg.read(os.path.join(FLAGS.asr_expdir, 'model', 'asr.cfg'))
    asr_cfg = dict(parsed_asr_cfg.items('asr'))

    #read the lm config file
    parsed_lm_cfg = configparser.ConfigParser()
    parsed_lm_cfg.read(os.path.join(FLAGS.lm_expdir, 'model', 'lm.cfg'))
    lm_cfg = dict(parsed_lm_cfg.items('lm'))

    #read the asr-lm config file
    parsed_asr_lm_cfg = configparser.ConfigParser()
    parsed_asr_lm_cfg.read('config/asr_lm.cfg')
    asr_lm_cfg = dict(parsed_asr_lm_cfg.items('asr-lm'))

    #read the decoder config file
    if decoder_cfg_file is None:
        decoder_cfg_file = os.path.join(FLAGS.asr_expdir, 'model',
                                        'decoder.cfg')
    parsed_decoder_cfg = configparser.ConfigParser()
    parsed_decoder_cfg.read(decoder_cfg_file)
    decoder_cfg = dict(parsed_decoder_cfg.items('decoder'))

    #create a feature reader
    featdir = os.path.join(database_cfg['test_dir'], feat_cfg['name'])

    with open(os.path.join(featdir, 'maxlength'), 'r') as fid:
        max_length = int(fid.read())

    reader = feature_reader.FeatureReader(
        scpfile=os.path.join(featdir, 'feats.scp'),
        cmvnfile=os.path.join(featdir, 'cmvn.scp'),
        utt2spkfile=os.path.join(featdir, 'utt2spk'),
        max_length=max_length)

    #read the feature dimension
    with open(
        os.path.join(database_cfg['train_dir'], feat_cfg['name'],
                     'dim'),
        'r') as fid:

        input_dim = int(fid.read())

    #create the coder
    with open(os.path.join(database_cfg['train_dir'], 'alphabet')) as fid:
        alphabet = fid.read().split(' ')
    coder = target_coder.TargetCoder(alphabet)


    #create the classifier
    classifier = asr_lm_classifier.AsrLmClassifier(
        conf=asr_lm_cfg,
        asr_conf=asr_cfg,
        lm_conf=lm_cfg,
        output_dim=coder.num_labels)

    #create a decoder
    graph = tf.Graph()
    with graph.as_default():
        decoder = decoder_factory.factory(
            conf=decoder_cfg,
            classifier=classifier,
            input_dim=input_dim,
            max_input_length=reader.max_length,
            coder=coder,
            expdir=FLAGS.asr_expdir)


        #create the lm saver
        varnames = zip(*checkpoint_utils.list_variables(os.path.join(
            FLAGS.lm_expdir, 'model', 'network.ckpt')))[0]
        variables = [v for v in tf.all_variables()
                     if v.name.split(':')[0] in varnames]
        lm_saver = tf.train.Saver(variables)

        #create the asr saver
        varnames = zip(*checkpoint_utils.list_variables(os.path.join(
            FLAGS.asr_expdir, 'model', 'network.ckpt')))[0]
        variables = [v for v in tf.all_variables()
                     if v.name.split(':')[0] in varnames]
        asr_saver = tf.train.Saver(variables)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True #pylint: disable=E1101
    config.allow_soft_placement = True

    with tf.Session(graph=graph, config=config) as sess:
        #load the lm model
        lm_saver.restore(
            sess, os.path.join(FLAGS.lm_expdir, 'model', 'network.ckpt'))

        #load the asr model
        asr_saver.restore(
            sess, os.path.join(FLAGS.asr_expdir, 'model', 'network.ckpt'))

        #decode with te neural net
        decoded = decoder.decode(reader, sess)

    #the path to the text file
    textfile = database_cfg['testtext']

    #read all the reference transcriptions
    with open(textfile) as fid:
        lines = fid.readlines()

    references = dict()
    for line in lines:
        splitline = line.strip().split(' ')
        references[splitline[0]] = ' '.join(splitline[1:])

    #compute the character error rate
    score = decoder.score(decoded, references)

    print 'score: %f' % score
Exemplo n.º 35
0
def list_variables(checkpoint_dir):
  """See `tf.contrib.framework.list_variables`."""
  return checkpoint_utils.list_variables(checkpoint_dir)