예제 #1
0
파일: embsim.py 프로젝트: tangqiqi123/hasky
    def __init__(self,
                 emb=None,
                 fixed_emb=None,
                 name=None,
                 fixed_name=None,
                 model_dir=None,
                 model_name=None,
                 sess=None):
        self._sess = sess or tf.InteractiveSession()

        if os.path.isdir(emb):
            model_dir = emb

        if model_dir is None:
            if isinstance(emb, str):
                emb = np.load(emb)
                emb = melt.load_constant(emb, name=name)
            if isinstance(fixed_emb, str):
                #fixed_emb is corpus embeddings, all sumed and normed already
                fixed_emb = melt.load(fixed_emb, name=fixed_emb)
        else:
            model_path = melt.get_model_path(model_dir, model_name)
            emb = tf.Variable(0., name=name, validate_shape=False)
            #emb = tf.Variable(0., name=name)
            #like word2vec the name is 'w_in'
            embedding_saver = tf.train.Saver({name: emb})
            embedding_saver.restore(self._sess, model_path)

        #assume 0 index not used, 0 for PAD
        mask = zero_first_row(emb)
        emb = tf.multiply(emb, mask)

        self._emb = emb
        self._fixed_emb = fixed_emb
        self._normed_emb = None
def get_imagenet_from_checkpoint(checkpoint_path):
    """
  net = get_net_from_checkpoint(checkpoint)
  net.func_name  # like inception_v4
  net.default_image_size # like 299
  """
    checkpoint = melt.get_model_path(checkpoint_path)
    if not checkpoint or \
     (not os.path.exists(checkpoint) \
        and not os.path.exists(checkpoint + '.index')):
        return None

    from tensorflow.python import pywrap_tensorflow
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint)
    var_to_shape_map = reader.get_variable_to_shape_map()
    name = None
    for key in var_to_shape_map.keys():
        name = key.split('/')[0]
        gnu_name = gezi.to_gnu_name(name)
        if gnu_name in nets_factory.networks_map:
            break
    if name is None:
        return None
    else:
        nets_factory.networks_map[gnu_name].name = name
        return nets_factory.networks_map[gnu_name]
예제 #3
0
 def load(self, model_dir, var_list=None, model_name=None):
   """
   only load varaibels from checkpoint file, you need to 
   create the graph before calling load
   """
   self.model_path = melt.get_model_path(model_dir, model_name)
   saver = melt.restore_from_path(self.sess, self.model_path, var_list)
   return self.sess
예제 #4
0
 def restore(self, model_dir, model_name=None):
   """
   do not need to create graph
   restore graph from meta file then restore values from checkpoint file
   """
   self.model_path = model_path = melt.get_model_path(model_dir, model_name)
   meta_filename = '%s.meta'%model_path
   saver = tf.train.import_meta_graph(meta_filename)
   self.restore_from_graph()
   saver.restore(self.sess, model_path)
   return self.sess
 def load(self, model_dir, var_list=None, model_name=None, sess = None):
   """
   only load varaibels from checkpoint file, you need to 
   create the graph before calling load
   """
   if sess is not None:
     self.sess = sess
   self.model_path = melt.get_model_path(model_dir, model_name)
   timer = gezi.Timer('load model ok %s' % self.model_path)
   saver = melt.restore_from_path(self.sess, self.model_path, var_list)
   timer.print()
   return self.sess
 def restore(self, model_dir, model_name=None, sess=None):
   """
   do not need to create graph
   restore graph from meta file then restore values from checkpoint file
   """
   if sess is not None:
     self.sess = sess
   self.model_path = model_path = melt.get_model_path(model_dir, model_name)
   timer = gezi.Timer('restore meta grpah and model ok %s' % model_path)
   meta_filename = '%s.meta'%model_path
   saver = tf.train.import_meta_graph(meta_filename)
   self.restore_from_graph()
   saver.restore(self.sess, model_path)
   #---TODO not work remove can run but hang  FIXME add predictor + exact_predictor during train will face
   #@gauravsindhwani , can you still run the code successfully after you remove these two collections since they are actually part of the graph. 
   #I try your way but find the program is stuck after restoring."
   #https://github.com/tensorflow/tensorflow/issues/9747
   #tf.get_default_graph().clear_collection("queue_runners")
   #tf.get_default_graph().clear_collection("local_variables")
   #--for num_epochs not 0
   #self.sess.run(tf.local_variables_initializer())
   timer.print()
   return self.sess
예제 #7
0
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys, os
import melt, gezi

assert melt.get_num_gpus() == 1

try:
    model_dir = sys.argv[1]
except Exception:
    model_dir = './'
model_path = melt.get_model_path(model_dir)

input_file = sys.argv[2]

result_file = model_path + '.full-evaluate-inference.txt'

batch_size = 50
if len(sys.argv) > 3:
    batch_size = int(sys.argv[3])

feature_name = 'attention'
if len(sys.argv) > 4:
    feature_name = sys.argv[4]

if not gezi.non_empty(result_file) or len(
        open(result_file).readlines()) != 30000:
예제 #8
0
#!/usr/bin/env python
# ==============================================================================
#          \file   show-var-of-model.py
#        \author   chenghuige
#          \date   2017-09-06 07:52:34.258312
#   \Description
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys, os
import melt

from tensorflow.python import pywrap_tensorflow

model_dir = sys.argv[1]
var_name = sys.argv[2]
checkpoint_path = melt.get_model_path(model_dir)
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    if var_name in key:
        print("tensor_name: ", key)
        print(reader.get_tensor(key))