예제 #1
0
def inference(ops, iterator, model_path=None, sess=None):
    assert model_path
    if model_path:
        ids_list = []

    predictions_list = []

    id_, predicts_ = ops
    if not sess:
        sess = melt.get_session()

    # for prcurve
    sess.run(iterator.initializer)

    try:
        while True:
            ids, predictions = sess.run(ops)

            predictions_list.append(predictions)

            if model_path:
                ids_list.append(ids)
    except tf.errors.OutOfRangeError:
        predicts = np.concatenate(predictions_list)

        if model_path:
            write(np.concatenate(ids_list),
                  predicts,
                  model_path,
                  labels=None,
                  suffix='infer_info')
예제 #2
0
def inference(ops, iterator, num_steps, num_examples, model_path=None, num_gpus=1, sess=None):
  ids_list = []
  predictions_list = []

  if not sess:
    sess = melt.get_session()
  
  # for prcurve
  sess.run(iterator.initializer)
  
  for _ in range(num_steps):
    results = sess.run(ops)
    for i in range(num_gpus):
      ids, predictions = results[i]
      predictions_list.append(predictions)
      ids_list.append(ids)

  ids = np.concatenate(ids_list)[:num_examples]
  predicts = np.concatenate(predictions_list)[:num_examples]
  write(ids, 
        predicts,
        model_path,
        labels=None,
        suffix='infer_info'
        )
def main(_):
    #-----------init global resource
    melt.apps.train.init()

    FLAGS.vocab = FLAGS.vocab or os.path.join(os.path.dirname(FLAGS.model_dir),
                                              'vocab.txt')

    image_util.init()

    vocabulary.init()
    text2ids.init()

    ## TODO FIXME if evaluator init before main graph(assistant predictor with image model) then will wrong for finetune later,
    ## image scope as not defined, to set reuse = None? though assistant in different scope graph still scope reused??
    ## so right now just let evaluator lazy init, init when used after main graph build

    # try:
    #   evaluator.init()
    # except Exception:
    #   print(traceback.format_exc(), file=sys.stderr)
    #   print('evaluator init fail will not do metric eval')
    #   FLAGS.metric_eval = False

    logging.info('algo:{}'.format(FLAGS.algo))
    logging.info('monitor_level:{}'.format(FLAGS.monitor_level))

    global sess
    sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)

    train()
예제 #4
0
def main(_):
    FLAGS.train_input = '../input/train.small'
    FLAGS.valid_input = '../input/train.small'
    FLAGS.batch_size = 4
    FLAGS.feat_file_path = '../input/feature_index'
    FLAGS.field_file_path = '../input/feat_fields.old'
    melt.init()

    dataset = Dataset('train')
    #dataset = Dataset('valid')

    iter = dataset.make_batch()
    op = iter.get_next()

    print('---batch_size', dataset.batch_size, FLAGS.batch_size)

    sess = melt.get_session()

    print('----sess', sess)

    if not FLAGS.use_horovod:
        for epoch in range(2):
            for i in range(3):
                batch = sess.run(op)
                print(epoch, i, len(batch[0]['id']), batch[0]['id'])
    else:
        for epoch in range(2):
            for i in range(3):
                batch = sess.run(op)
                print(epoch, i, len(batch[0]['id']), batch[0]['id'])
예제 #5
0
  def __init__(self, 
               weight_op='learning_rate_weight', 
               patience=3, 
               decay=0.8,  
               cmp=None,
               min_weight=None,
               min_learning_rate=None,
               initial_learning_rate=None,
               decay_start_epoch=0,
               sess=None):
    import melt.utils.logging as logging
    if not tf.executing_eagerly():
      self.sess = sess or melt.get_session()
    if isinstance(weight_op, str):
      try:
        # by default melt.apps.train will generate weight op Var named 'learning_rate_weight' TODO may be hold it just here
        # so currently graph model will go here
        self.weight_op = tf.get_collection(weight_op)[-1]
        self.name = weight_op
      except Exception:
        raise 'TODO..'
        # print('-------------------------Weight Decay change!')
        # so currently eager mode will go here
        #learning_rate_weight = tf.get_variable('learning_rate_weight', initializer= tf.ones(shape=(), dtype=tf.float32))
        #learning_rate_weight = tf.Variable(tf.ones(shape=(), dtype=tf.float32), name='learning_rate_weight')
        # TODO tfe.Var should only be used in keras.Model init ? notice eager mode can not use tf.Variable
        # learning_rate_weight = tfe.Variable(tf.ones(shape=(), dtype=tf.float32), name='learning_rate_weight')
        # tf.add_to_collection('learning_rate_weight', learning_rate_weight)
        # self.weight_op = learning_rate_weight
    else:
      self.weight_op = weight_op
      self.name = 'weight'

    if cmp == 'less':
      self.cmp = lambda x, y: x < y
    elif cmp== 'greater':
      self.cmp = lambda x, y: x > y  
    else:
      self.cmp = cmp
    self.score = None

    self.max_patience = patience
    self.decay = decay
    self.patience = 0
    self.count = 0
    self.min_weight = min_weight

    if not self.min_weight:
      self.min_weight = min_learning_rate / (initial_learning_rate or FLAGS.learning_rate)

    # This is done in melt.flow
    # weight = self.sess.run(self.weight_op)
    # if 'learning_rate' in self.name:
    #   melt.set_learning_rate(tf.constant(weight, dtype=tf.float32), self.sess)

    self.decay_start_epoch = decay_start_epoch
예제 #6
0
def evaluate(eval_ops,
             iterator,
             num_steps,
             num_examples,
             model_path=None,
             num_gpus=1,
             sess=None):
    #timer = gezi.Timer('evaluate')
    if model_path:
        ids_list = []

    predictions_list = []
    labels_list = []
    losses = []

    top_preds_list = []

    if not sess:
        sess = melt.get_session()

    # for prcurve
    sess.run(iterator.initializer)

    for _ in range(num_steps):
        results = sess.run(eval_ops)
        for i in range(num_gpus):
            ids, loss, predictions, top_preds, labels = results[i]
            ids = gezi.decode(ids)
            #images = images.astype(np.uint8)
            losses.append(loss)
            predictions_list.append(predictions)
            top_preds_list.append(top_preds)
            labels_list.append(labels)

            if model_path:
                ids_list.append(ids)

    # notice loss might be not so accurate due to final batch padding but that's not big problem
    loss = np.mean(losses)
    if model_path:
        ids = np.concatenate(ids_list)[:num_examples]
    predicts = np.concatenate(predictions_list)[:num_examples]
    top_preds = np.concatenate(top_preds_list)[:num_examples]
    labels = np.concatenate(labels_list)[:num_examples]

    acc = np.mean(np.equal(predicts, labels))
    results = [loss, acc]
    names = ['metric/valid/loss', 'metric/valid/acc']

    if model_path:
        write(ids, predicts, model_path, labels, suffix='valid_info')

    #timer.print()
    #print(len(predicts))
    return results, names
예제 #7
0
def evaluate_score():
  predictor = algos_factory.gen_predictor(FLAGS.algo)
  score = predictor.init_predict(TEXT_MAX_WORDS)
  tf.add_to_collection('score', score)
  predictor.load(FLAGS.model_dir)
  step = melt.get_model_step_from_dir(FLAGS.model_dir) 
  model_dir, _ = melt.get_model_dir_and_path(FLAGS.model_dir)
  print('step', step, file=sys.stderr)
  print('model_dir', model_dir)
  #melt.save_model(melt.get_session(), FLAGS.model_dir, step + 1)
  melt.save_model(melt.get_session(), model_dir, step + 1)
예제 #8
0
def evaluate(eval_ops, iterator, model_path=None, sess=None):
    if model_path:
        ids_list = []

    predictions_list = []
    labels_list = []
    losses = []

    images_list = []

    id_, loss_, predicts_, labels_, images_ = eval_ops
    if not sess:
        sess = melt.get_session()

    # for prcurve
    sess.run(iterator.initializer)

    try:
        while True:
            ids, loss, predictions, labels, images = sess.run(eval_ops)
            images = images.astype(np.uint8)
            losses.append(loss)
            predictions_list.append(predictions)
            labels_list.append(labels)

            images_list.append(images)

            if model_path:
                ids_list.append(ids)
    except tf.errors.OutOfRangeError:
        loss = np.mean(losses)
        predicts = np.concatenate(predictions_list)
        labels = np.concatenate(labels_list)

        images = np.concatenate(images_list)

        acc = np.mean(np.equal(predicts, labels))

        results = [loss, acc]
        names = ['metric/valid/loss/avg', 'metric/valid/acc']

        if model_path:
            write(np.concatenate(ids_list),
                  predicts,
                  model_path,
                  labels,
                  images,
                  suffix='valid_info')

        return results, names
예제 #9
0
  def __init__(self):
    self.input_train_name = 'input_train'
    self.input_valid_name = 'input_valid'
    self.fixed_input_valid_name = 'fixed_input_valid'
    #-----------common for all app inputs may be 
    #for evaluate show small results, not for evaluate show cost
    self.num_records = None
    self.num_evaluate_examples = None
    self.num_steps_per_epoch = None

    self.sess = melt.get_session()

    self.train_with_validation = None
    self.eval_fixed = None
예제 #10
0
def main(_):
    #-----------init global resource
    logging.set_logging_path(gezi.get_dir(FLAGS.model_dir))

    vocabulary.init()
    text2ids.init()

    #evaluator.init()

    logging.info('algo:{}'.format(FLAGS.algo))
    logging.info('monitor_level:{}'.format(FLAGS.monitor_level))

    global_scope = ''
    if FLAGS.add_global_scope:
        global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo

    global sess
    sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)
    with tf.variable_scope(global_scope):
        train()
예제 #11
0
def load_constant(data_npy,
                  sess=None,
                  trainable=False,
                  dtype=None,
                  shape=None,
                  name=None):
    """
  tf.constant only can be used for small data
  so melt.constant means melt.large_constant and have more general usage
  https://stackoverflow.com/questions/35687678/using-a-pre-trained-word-embedding-word2vec-or-glove-in-tensorflow
  """

    #or if isinstance(data_npy, str)
    if type(data_npy) is str:
        data_npy = np.load(data_npy)

    if dtype is None:
        dtype = npdtype2tfdtype(data_npy)
    #dtype = tf.float32
    if shape is None:
        shape = data_npy.shape

    if name is None:
        data_init = tf.placeholder(dtype, shape)
        #@TODO getvariable?
        data = tf.Variable(data_init,
                           trainable=trainable,
                           collections=[],
                           name=name)
        if sess is None:
            sess = melt.get_session()
        sess.run(data.initializer, feed_dict={data_init: data_npy})
        return data
    else:
        data = tf.get_variable(name,
                               shape=shape,
                               initializer=tf.constant_initializer(data_npy),
                               trainable=trainable)

    return data
예제 #12
0
파일: train.py 프로젝트: buptpriswang/hasky
def main(_):
    #-----------init global resource
    logging.set_logging_path(gezi.get_dir(FLAGS.model_dir))
    melt.apps.train.init()

    has_image_model = FLAGS.image_checkpoint_file and os.path.exists(
        FLAGS.image_checkpoint_file)
    if has_image_model:
        print('image_endpoint_feature_name:',
              FLAGS.image_endpoint_feature_name)
        melt.apps.image_processing.init(
            FLAGS.image_model_name,
            feature_name=FLAGS.image_endpoint_feature_name)

    FLAGS.pre_calc_image_feature = FLAGS.pre_calc_image_feature or (
        not has_image_model)

    vocabulary.init()
    text2ids.init()

    ## TODO FIXME if evaluator init before main graph(assistant preidictor with image model) then will wrong for finetune later,
    ## image scope as not defined, to set reuse = None? though assistant in different scope graph still scope reused??
    ## so right now just let evaluator lazy init, init when used after main graph build

    # try:
    #   evaluator.init()
    # except Exception:
    #   print(traceback.format_exc(), file=sys.stderr)
    #   print('evaluator init fail will not do metric eval')
    #   FLAGS.metric_eval = False

    logging.info('algo:{}'.format(FLAGS.algo))
    logging.info('monitor_level:{}'.format(FLAGS.monitor_level))

    global sess
    sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)

    global_scope = melt.apps.train.get_global_scope()
    with tf.variable_scope(global_scope):
        train()
예제 #13
0
def main(_):
  text2ids.init()
  global_scope = ''
  if FLAGS.add_global_scope:
    global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo
 
  global sess
  sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)
  with tf.variable_scope(global_scope):
    predictor =  algos_factory.gen_predictor(FLAGS.algo)
    with tf.variable_scope(FLAGS.main_scope) as scope:
      text, score, beam_text, beam_score = gen_predict_graph(predictor, scope)

  predictor.load(FLAGS.model_dir) 
  #input_text = "������������_��������ǰ��Ա���Ƭ"
  input_texts = ['���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���',
                 '����̫����ô����',
                 '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 '����ף�Ŀǰ4����1�굶']

  for input_text in input_texts:
    word_ids = _text2ids(input_text, INPUT_TEXT_MAX_WORDS)

    print(word_ids)
    print(text2ids.ids2text(word_ids))

    timer = gezi.Timer()
    text_, score_ = sess.run([text, score], {predictor.input_text_place : [word_ids]})
    print(text_[0], text2ids.ids2text(text_[0]), score_[0], 'time(ms):', timer.elapsed_ms())

    timer = gezi.Timer()
    texts, scores = sess.run([beam_text, beam_score], {predictor.input_text_place : [word_ids]})

    texts = texts[0]
    scores = scores[0]
    for text_, score_ in zip(texts, scores):
      print(text_, text2ids.ids2text(text_), score_)

    print('beam_search using time(ms):', timer.elapsed_ms())
예제 #14
0
def main(_):
    #-----------init global resource
    logging.set_logging_path(gezi.get_dir(FLAGS.model_dir))

    if not FLAGS.pre_calc_image_feature:
        melt.apps.image_processing.init()

    InputApp.init()
    vocabulary.init()
    text2ids.init()
    evaluator.init()

    logging.info('algo:{}'.format(FLAGS.algo))
    logging.info('monitor_level:{}'.format(FLAGS.monitor_level))

    global sess
    sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)

    global_scope = ''
    if FLAGS.add_global_scope:
        global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo
    with tf.variable_scope(global_scope):
        train()
예제 #15
0
def main(_):
  text2ids.init()
  global_scope = ''
  if FLAGS.add_global_scope:
    global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo
 
  sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)
  with tf.variable_scope(global_scope):
    predictor =  algos_factory.gen_predictor(FLAGS.algo)
    with tf.variable_scope(FLAGS.main_scope) as scope:
      ##--notice if not add below len(tf.get_collection('encode_state') is 1, add below will be 2
      ## even though in init_predict_text(decode_method=SeqDecodeMethod.beam) will call generate_sequence_greedy
      #text, score = predictor.init_predict_text(decode_method=SeqDecodeMethod.greedy, 
      #                                          beam_size=FLAGS.beam_size,
      #                                          convert_unk=False)   
      #scope.reuse_variables()
      beam_text, beam_score = predictor.init_predict_text(decode_method=SeqDecodeMethod.beam, 
                                                          beam_size=FLAGS.beam_size,
                                                          convert_unk=False)  

  predictor.load(FLAGS.model_dir, sess=sess) 

  for item in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    print(item)
  #input_text = "������������_��������ǰ��Ա���Ƭ"
  input_texts = [
                 #'���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���',
                 #'����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 #'����̫����ô����',
                 #'����ף�Ŀǰ4����1�굶',
                 '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 ]

  for input_text in input_texts:
    word_ids = _text2ids(input_text, INPUT_TEXT_MAX_WORDS)

    print(word_ids)
    print(text2ids.ids2text(word_ids))

    #timer = gezi.Timer()
    #text_, score_ = sess.run([text, score], {predictor.input_text_feed : [word_ids]})
    #print(text_[0], text2ids.ids2text(text_[0]), score_[0], 'time(ms):', timer.elapsed_ms())

    timer = gezi.Timer()
    print(tf.get_collection('encode_state'), len(tf.get_collection('encode_state')))
    texts, scores,  atkeys, atvals  = sess.run([beam_text, beam_score, 
                                             tf.get_collection('attention_keys')[0],
                                             tf.get_collection('attention_values')[0],
                                             ], 
                                            {predictor.input_text_feed : [word_ids]})

    print(atkeys)
    print(atvals)
    print(np.shape(atkeys), np.shape(atvals))

    texts = texts[0]
    scores = scores[0]
    for text_, score_ in zip(texts, scores):
      print(text_, text2ids.ids2text(text_), score_)

    print('beam_search using time(ms):', timer.elapsed_ms())

  input_texts = [
                 '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 '���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���',
                 #'���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ����', #same length as lajiao sentence 15
                 #"����̫����ô����",
                 #'����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 #'����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 #'�޺콨�ǰ���˹��',
                 ]

  word_ids_list = [_text2ids(input_text, INPUT_TEXT_MAX_WORDS) for input_text in input_texts]
  timer = gezi.Timer()
  texts_list, scores_list, atkeys2, atvals2, weights = sess.run([beam_text, beam_score, 
                                             tf.get_collection('attention_keys')[0],
                                             tf.get_collection('attention_values')[0],
                                             'seq2seq/main/decode/attention_keys/weights:0'
                                             ], 
                             feed_dict={predictor.input_text_feed: word_ids_list})
  
  print(atkeys2)
  print(atvals2[0])
  print(np.shape(atkeys2), np.shape(atvals2))

  print('beam_search using time(ms):', timer.elapsed_ms())

  weights = tf.constant(weights)
  values = tf.get_collection('attention_values')[0]

  values = tf.squeeze(values[0])
  zeros = tf.zeros_like(values[:5,:])

  values2 = tf.concat([values, zeros], 0)

  values = tf.cast(values, tf.float64)
  values2 = tf.cast(values2, tf.float64)
  weights = tf.cast(weights, tf.float64)
  result1 = tf.matmul(values, weights)
  result2 = tf.matmul(values2, weights)

  result1_, result2_ = sess.run([result1, result2], feed_dict={predictor.input_text_feed: [word_ids]})
  #result2 = sess.run(result, feed_dict={predictor.input_text_feed: word_ids_list})

  print(result1_)
  print(result2_)
예제 #16
0
파일: util.py 프로젝트: meng-jia/wenzheng
def multiply_learning_rate(lr, sess=None, name='learning_rate'):
  if not sess:
    sess = melt.get_session()
  sess.run(tf.assign(tf.get_collection(name)[-1], tf.get_collection(name)[-1] * lr))
예제 #17
0
 def __init__(self):
   super(PredictorBase, self).__init__()
   self.sess = melt.get_session()
예제 #18
0
    def make_batch(self, batch_size, filenames, bptt=None, **kwargs):
        bptt = bptt or FLAGS.bptt
        self.batch_size = batch_size
        if self.data is None:
            data_npy = np.load(filenames[0])
            self.data = data_npy
            self.data_npy = data_npy
            self.epoch_size = (
                (len(np.concatenate(self.data_npy)) // batch_size) - 1) // bptt
        else:
            data_npy = self.data_npy

        with tf.device('/cpu:0'):
            if not tf.executing_eagerly():
                if self.data is None:
                    self.data_npy_ori = data_npy
                    self.data_npy = np.concatenate(data_npy)
                    data_npy = self.data_npy

                data_npy = self.data_npy
                if self.data is None:
                    #self.data = tf.get_variable('input_%s' % self.subset, dtype=tf.int32, shape=data_npy.shape, initializer=tf.constant_initializer(data_npy), trainable=False)
                    self.data = tf.get_variable('input_%s' % self.subset,
                                                dtype=tf.int32,
                                                shape=data_npy.shape,
                                                trainable=False)
                    data_placeholder = tf.placeholder(tf.int32, data_npy.shape)
                    data_init = self.data.assign(data_placeholder)
                    sess = melt.get_session()
                    sess.run(data_init, feed_dict={data_placeholder: data_npy})

                data = self.data

                data_len = tf.size(data)
                batch_len = data_len // batch_size
                data = tf.reshape(data[:batch_size * batch_len],
                                  [batch_size, batch_len])

                epoch_size = (batch_len - 1) // bptt
                assertion = tf.assert_positive(
                    epoch_size,
                    message="epoch_size == 0, decrease batch_size or bptt")
                with tf.control_dependencies([assertion]):
                    epoch_size = tf.identity(epoch_size, name="epoch_size")

                i = tf.train.range_input_producer(epoch_size,
                                                  shuffle=False).dequeue()
                x = tf.strided_slice(data, [0, i * bptt],
                                     [batch_size, (i + 1) * bptt])
                x.set_shape([batch_size, bptt])
                y = tf.strided_slice(data, [0, i * bptt + 1],
                                     [batch_size, (i + 1) * bptt + 1])
                y.set_shape([batch_size, bptt])

                class Iter(object):
                    def __init__(self, x, y):
                        self.x = x
                        self.y = y

                    def __iter__(self):
                        return self

                    def get_next(self):
                        return self.x, self.y

                iter = Iter(x, y)
                return iter
            else:
                # in eager mode if tf.get_variable will be very slow...
                # epoch:0.02/1024 step:8600 elapsed:[1.312] batch_size:[32] batches/s:[76.23] insts/s:[2439] 1epoch:[1.40h] lr:[0.0010000] train_loss:[5.0710] valid_loss:[5.0503]
                class Iter():
                    def __init__(self, data):
                        self.ori_data = data
                        self.reset()

                    def reset(self):
                        self.i = 0
                        np.random.shuffle(self.ori_data)
                        self.data = np.concatenate(self.ori_data)
                        data_len = len(self.data)
                        batch_len = data_len // batch_size
                        self.data = self.data[:batch_size * batch_len].reshape(
                            [batch_size, batch_len])

                    def __iter__(self):
                        return self

                    def __next__(self):
                        i = self.i
                        data = self.data

                        if i < data.shape[1]:
                            slen = min(bptt, data.shape[1] - 1 - i)
                            x = data[:, i:i + slen]
                            y = data[:, i + 1:i + 1 + slen]
                            self.i += bptt
                            return x, y
                        else:
                            self.reset()
                            raise StopIteration()

                return Iter(data_npy)
 def __init__(self, sess=None):
   super(PredictorBase, self).__init__()
   if sess is None:
     self.sess = melt.get_session()
   else:
     self.sess = sess
예제 #20
0
def tf_train_flow(train_once,
                  model_dir='./model',
                  max_models_keep=1,
                  save_interval_seconds=600,
                  save_interval_steps=1000,
                  num_epochs=None,
                  num_steps=None,
                  save_model=True,
                  save_interval_epochs=1,
                  num_steps_per_epoch=0,
                  restore_from_latest=True,
                  metric_eval_function=None,
                  sess=None):
    """
  similary flow as tf_flow, but add model try reload and save
  """
    if sess is None:
        #@TODO may have mutliple session ?
        sess = melt.get_session()
    logging.info('tf_train_flow start')
    print('max_models_keep:', max_models_keep)
    print('save_interval_seconds:', save_interval_seconds)

    saver = tf.train.Saver(
        max_to_keep=max_models_keep,
        keep_checkpoint_every_n_hours=save_interval_seconds / 3600.0)

    epoch_saver = tf.train.Saver(max_to_keep=max_models_keep,
                                 keep_checkpoint_every_n_hours=24)  # TODO

    #pre_step means the step last saved, train without pretrained,then -1
    pre_step = -1
    model_path = _get_model_path(model_dir, save_model)
    model_dir = gezi.get_dir(
        model_dir)  #incase you pass ./model/model-ckpt1000 -> ./model
    if model_path is not None:
        if not restore_from_latest:
            print('using recent but not latest model', file=sys.stderr)
            model_path = melt.recent_checkpoint(model_dir)
        model_name = os.path.basename(model_path)
        timer = gezi.Timer('Loading and training from existing model [%s]' %
                           model_path)
        saver.restore(sess, model_path)
        timer.print()
        pre_step = melt.get_model_step(model_path)
        if 'epoch' in model_name:
            pre_step *= num_steps_per_epoch
        #for non 0 eopochs  without this will be
        #Attempting to use uninitialized value input/input_producer/limit_epochs/epochs
        try:
            sess.run(tf.local_variables_initializer())
        except Exception:
            sess.run(tf.initialize_local_variables())
    else:
        print('Train all start step 0', file=sys.stderr)
        try:
            init_op = tf.group(tf.global_variables_initializer(),
                               tf.local_variables_initializer())
        except Exception:
            init_op = tf.group(tf.initialize_all_variables(),
                               tf.initialize_local_variables())

        sess.run(init_op)

    if save_interval_epochs and num_steps_per_epoch:
        epoch_dir = os.path.join(model_dir, 'epoch')
        gezi.try_mkdir(epoch_dir)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    checkpoint_path = os.path.join(model_dir, 'model.ckpt')
    try:
        step = start = pre_step + 1

        #hack just for save one model after load
        if num_steps and num_steps < step:
            print('just load and resave then exit', file=sys.stderr)
            saver.save(sess, checkpoint_path, global_step=step)
            sess.close()
            exit(0)

        while not coord.should_stop():
            stop = train_once(sess, step, is_start=(step == start))
            if save_model and step:
                #step 0 is also saved! actually train one step and save
                if step % save_interval_steps == 0:
                    timer = gezi.Timer('save model step %d to %s' %
                                       (step, checkpoint_path))
                    saver.save(sess, checkpoint_path, global_step=step)
                    timer.print()
                if save_interval_epochs and num_steps_per_epoch and step % (
                        num_steps_per_epoch * save_interval_epochs) == 0:
                    epoch_saver.save(sess,
                                     os.path.join(epoch_dir, 'model.epoch'),
                                     global_step=step)
            if stop is True:
                print('Early stop running %d stpes' % (step), file=sys.stderr)
                raise tf.errors.OutOfRangeError(
                    None, None, 'Early stop running %d stpes' % (step))
            if num_steps and (step + 1) == start + num_steps:
                raise tf.errors.OutOfRangeError(None, None,
                                                'Reached max num steps')
            max_num_epochs = 1000
            if num_steps_per_epoch and step // num_steps_per_epoch == max_num_epochs:
                raise tf.errors.OutOfRangeError(
                    None, None,
                    'Reached max num epochs of %d' % max_num_epochs)
            step += 1
    except tf.errors.OutOfRangeError, e:
        if not (step
                == start) and save_model and step % save_interval_steps != 0:
            saver.save(sess, checkpoint_path, global_step=step)
        if metric_eval_function is not None:
            metric_eval_function()
        if (num_epochs and step / num_steps_per_epoch >= num_epochs) or (
                num_steps and (step + 1) == start + num_steps):
            print('Done training for %d steps.' % (step), file=sys.stderr)
            #FIXME becase coord.join seems not work,  RuntimeError: Coordinator stopped with threads still running: Thread-9
            exit(0)
        else:
            print('Should not stop, but stopped at epoch: %.3f' %
                  (step / num_steps_per_epoch),
                  file=sys.stderr)
            raise e
예제 #21
0
import tensorflow as tf
import horovod.tensorflow as hvd 
from mpi4py import MPI
#import horovod.keras as hvd
import numpy as np
import melt
# Split COMM_WORLD into subcommunicators
#subcomm = MPI.COMM_WORLD.Split(color=MPI.COMM_WORLD.rank % 2,
#                               key=MPI.COMM_WORLD.rank)

# Initialize Horovod
#hvd.init(comm=subcomm)
hvd.init()
hvd_r=int(hvd.rank())
assert hvd.size() == 2
sess = melt.get_session()
sess.run(tf.global_variables_initializer())
#each process compute a small part of something and then compute the average etc.
test_array= np.array(range(100))
#compute a small part
span = int(100 / hvd.size())
x=test_array[hvd_r * span: (hvd_r + 1) * span]
if hvd_r == 0:
  x = list(x)
  x.append(2019)
  #x = np.array(x)
x = list(x) 
#x = [[1, a] for a in x] 
x = ['abc' for a in x]
#compute the average for all processes
#y=hvd.allgather(x, name='a')
예제 #22
0
파일: predict.py 프로젝트: fword/hasky
def main(_):
  text2ids.init()
  global_scope = ''
  if FLAGS.add_global_scope:
    global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo
 
  sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)
  with tf.variable_scope(global_scope):
    predictor =  algos_factory.gen_predictor(FLAGS.algo)
    with tf.variable_scope(FLAGS.main_scope) as scope:
      ##--notice if not add below len(tf.get_collection('encode_state') is 1, add below will be 2
      ## even though in init_predict_text(decode_method=SeqDecodeMethod.beam) will call generate_sequence_greedy
      #text, score = predictor.init_predict_text(decode_method=SeqDecodeMethod.greedy, 
      #                                          beam_size=FLAGS.beam_size,
      #                                          convert_unk=False)   
      #scope.reuse_variables()
      beam_text, beam_score = predictor.init_predict_text(decode_method=SeqDecodeMethod.beam, 
                                                          beam_size=FLAGS.beam_size,
                                                          convert_unk=False)  

  predictor.load(FLAGS.model_dir, sess=sess) 

  for item in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    print(item)
  #input_text = "������������_��������ǰ��Ա���Ƭ"
  input_texts = [
                 #'���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���',
                 #'����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 #'����̫����ô����',
                 '����ף�Ŀǰ4����1�굶',
                 '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 ]

  for input_text in input_texts:
    word_ids = _text2ids(input_text, INPUT_TEXT_MAX_WORDS)

    print(word_ids)
    print(text2ids.ids2text(word_ids))

    #timer = gezi.Timer()
    #text_, score_ = sess.run([text, score], {predictor.input_text_feed : [word_ids]})
    #print(text_[0], text2ids.ids2text(text_[0]), score_[0], 'time(ms):', timer.elapsed_ms())

    timer = gezi.Timer()
    texts, scores = sess.run([beam_text, beam_score], 
                                            {predictor.input_text_feed : [word_ids]})


    texts = texts[0]
    scores = scores[0]
    for text_, score_ in zip(texts, scores):
      print(text_, text2ids.ids2text(text_), score_, math.log(score_))

    print('beam_search using time(ms):', timer.elapsed_ms())

  input_texts = [
                 '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 '���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���',
                 #'���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ����', #same length as lajiao sentence 15
                 #"����̫����ô����",
                 #'����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 #'����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 #'�޺콨�ǰ���˹��',
                 ]

  word_ids_list = [_text2ids(input_text, INPUT_TEXT_MAX_WORDS) for input_text in input_texts]
  timer = gezi.Timer()
  texts_list, scores_list = sess.run([beam_text, beam_score], 
                             feed_dict={predictor.input_text_feed: word_ids_list})
  

  for texts, scores in zip(texts_list, scores_list):
    for text, score in zip(texts, scores):
      print(text, text2ids.ids2text(text), score, math.log(score))

  print('beam_search using time(ms):', timer.elapsed_ms())
예제 #23
0
def train(Dataset, 
          model, 
          loss_fn, 
          evaluate_fn=None, 
          inference_fn=None,
          eval_fn=None,
          write_valid=True,
          valid_names=None,
          infer_names=None,
          infer_debug_names=None,
          valid_write_fn=None,
          infer_write_fn=None,
          valid_suffix='.valid',
          infer_suffix='.infer',
          write_streaming=False,
          sep=','):
  if FLAGS.torch:
    if torch.cuda.is_available():
      model.cuda()
  
  input_ =  FLAGS.train_input 
  inputs = gezi.list_files(input_)
  inputs.sort()

  all_inputs = inputs

  batch_size = FLAGS.batch_size

  num_gpus = melt.num_gpus()
  if num_gpus > 1:
    assert False, 'Eager mode train currently not support for num gpus > 1'

  #batch_size_ = batch_size if not FLAGS.batch_sizes else int(FLAGS.batch_sizes.split(',')[-1])
  batch_size_ = batch_size

  if FLAGS.fold is not None:
    inputs = [x for x in inputs if not x.endswith('%d.record' % FLAGS.fold)]

  logging.info('inputs', inputs)

  dataset = Dataset('train')
  num_examples = dataset.num_examples_per_epoch('train') 
  num_all_examples = num_examples

  # if FLAGS.fold is not None:
  #   valid_inputs = [x for x in all_inputs if x not in inputs]
  # else:
  #   valid_inputs = gezi.list_files(FLAGS.valid_input)
  
  # logging.info('valid_inputs', valid_inputs)

  # if valid_inputs:
  #   valid_dataset_ = Dataset('valid')
  #   valid_dataset = valid_dataset_.make_batch(batch_size_, valid_inputs)
  #   valid_dataset2 = valid_dataset_.make_batch(batch_size_, valid_inputs, repeat=True)
  # else:
  #   valid_datsset = None
  #   valid_dataset2 = None

  if num_examples:
    if FLAGS.fold is not None:
      num_examples = int(num_examples * (len(inputs) / (len(inputs) + 1)))
    num_steps_per_epoch = -(-num_examples // batch_size)
  else:
    num_steps_per_epoch = None

  # if FLAGS.fold is not None:
  #   if num_examples:
  #     num_valid_examples = int(num_all_examples * (1 / (len(inputs) + 1)))
  #     num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_)
  #   else:
  #     num_valid_steps_per_epoch = None
  # else:
  #   num_valid_examples = valid_dataset_.num_examples_per_epoch('valid')
  #   num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) if num_valid_examples else None

  # test_inputs = gezi.list_files(FLAGS.test_input)
  # logging.info('test_inputs', test_inputs)
  
  # if test_inputs:
  #   test_dataset_ = Dataset('test')
  #   test_dataset = test_dataset_.make_batch(batch_size_, test_inputs) 
  #   num_test_examples = test_dataset_.num_examples_per_epoch('test')
  #   num_test_steps_per_epoch = -(-num_test_examples // batch_size_) if num_test_examples else None
  # else:
  #   test_dataset = None
  
  summary = tf.contrib.summary
  # writer = summary.create_file_writer(FLAGS.model_dir + '/epoch')
  # writer_train = summary.create_file_writer(FLAGS.model_dir + '/train')
  # writer_valid = summary.create_file_writer(FLAGS.model_dir + '/valid')
  writer = summary.create_file_writer(FLAGS.model_dir)
  writer_train = summary.create_file_writer(FLAGS.model_dir)
  writer_valid = summary.create_file_writer(FLAGS.model_dir)
  global_step = tf.train.get_or_create_global_step()

  learning_rate = tfe.Variable(FLAGS.learning_rate, name="learning_rate")
  tf.add_to_collection('learning_rate', learning_rate)

  learning_rate_weight = tf.get_collection('learning_rate_weight')[-1]
  try:
    learning_rate_weights = tf.get_collection('learning_rate_weights')[-1]
  except Exception:
    learning_rate_weights = None

  ckpt_dir = FLAGS.model_dir + '/ckpt'

  #TODO FIXME now I just changed tf code so to not by default save only latest 5
  # refer to https://github.com/tensorflow/tensorflow/issues/22036
    # manager = tf.contrib.checkpoint.CheckpointManager(
  #     checkpoint, directory=ckpt_dir, max_to_keep=5)
  # latest_checkpoint = manager.latest_checkpoint

  latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
  logging.info('Latest checkpoint:', latest_checkpoint)
  checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt')

  if not FLAGS.torch:
    optimizer = melt.get_optimizer(FLAGS.optimizer)(learning_rate)
    
    # TODO...
    if  learning_rate_weights is None:
      checkpoint = tf.train.Checkpoint(
            learning_rate=learning_rate, 
            learning_rate_weight=learning_rate_weight,
            model=model,
            optimizer=optimizer,
            global_step=global_step)
    else:
      checkpoint = tf.train.Checkpoint(
            learning_rate=learning_rate, 
            learning_rate_weight=learning_rate_weight,
            learning_rate_weights=learning_rate_weights,
            model=model,
            optimizer=optimizer,
            global_step=global_step)
      
    if os.path.exists(FLAGS.model_dir + '.index'):
      latest_checkpoint = FLAGS.model_dir   

    checkpoint.restore(latest_checkpoint)

    start_epoch = int(latest_checkpoint.split('-')[-1]) if latest_checkpoint else 0
  else:
    # TODO torch with learning rate adjust
    optimizer = torch.optim.Adamax(model.parameters(), lr=FLAGS.learning_rate)

    if latest_checkpoint:
      checkpoint = torch.load(latest_checkpoint + '.pyt')
      start_epoch = checkpoint['epoch']
      model.load_state_dict(checkpoint['state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      model.eval()
    else:
      start_epoch = 0

    if learning_rate_weights is None:
      checkpoint = tf.train.Checkpoint(
          learning_rate=learning_rate, 
          learning_rate_weight=learning_rate_weight,
          global_step=global_step)
    else:
      checkpoint = tf.train.Checkpoint(
            learning_rate=learning_rate, 
            learning_rate_weight=learning_rate_weight,
            learning_rate_weights=learning_rate_weights,
            global_step=global_step)

  #model.load_weights(os.path.join(ckpt_dir, 'ckpt-1'))
  #model.save('./weight3.hd5')

  # TODO currently not support 0.1 epoch.. like this
  num_epochs = FLAGS.num_epochs
  
 
  class PytObj(object):
    def __init__(self, x):
      self.x = x
    def numpy(self):
      return self.x

  class PytMean(object):
    def __init__(self):
      self._val = 0. 
      self.count = 0

      self.is_call = True

    def clear(self):
      self._val = 0
      self.count = 0

    def __call__(self, val):
      if not self.is_call:
        self.clear()
        self.is_call = True
      self._val += val.item()
      self.count += 1

    def result(self):
      if self.is_call:
        self.is_call = False
      if not self.count:
        val = 0
      else:
        val = self._val / self.count
      # TODO just for compact with tf ..
      return PytObj(val)
      
  # TODO consider multiple gpu for torch 

  iter = dataset.make_batch(batch_size, inputs, repeat=False, initializable=False)
  batch = iter.get_next()
  #x, y = melt.split_batch(batch, batch_size, num_gpus)
  x_, y_ = batch
  
  Mean =  tfe.metrics.Mean if not FLAGS.torch else PytMean
  epoch_loss_avg = Mean()
  epoch_valid_loss_avg = Mean()

  sess = melt.get_session(device_count={'GPU': 0})
  global_step = 0
  for epoch in range(start_epoch, num_epochs):
    melt.set_global('epoch', '%.4f' % (epoch))
    sess.run(iter.initializer)

    model.train()

    #..... still OOM... FIXME TODO
    try:
      for _ in tqdm(range(num_steps_per_epoch), total=num_steps_per_epoch, ascii=True):
        x, y = sess.run([x_, y_])
        x, y = to_torch(x, y)
        
        optimizer.zero_grad()
        loss = loss_fn(model, x, y)
        loss.backward()
        optimizer.step()

        epoch_loss_avg(loss) 

        if global_step % FLAGS.interval_steps == 0:
          print(global_step, epoch_loss_avg.result().numpy())

        global_step += 1
    except tf.errors.OutOfRangeError:
      print('epoch:%d loss:%f' % (epoch, epoch_loss_avg.result().numpy()))
예제 #24
0
파일: flow.py 프로젝트: tangqiqi123/hasky
def tf_train_flow(train_once_fn, 
                  model_dir='./model', 
                  max_models_keep=1, 
                  save_interval_seconds=600, 
                  save_interval_steps=1000, 
                  num_epochs=None,
                  num_steps=None, 
                  save_model=True,
                  save_interval_epochs=1, 
                  num_steps_per_epoch=0,
                  restore_from_latest=True,
                  metric_eval_fn=None,
                  init_fn=None,
                  sess=None):
  """
  similary flow as tf_flow, but add model try reload and save
  """
  if sess is None:
    #TODO melt.get_session is global session but may cause
    sess = melt.get_session()
  logging.info('tf_train_flow start')
  print('max_models_keep:', max_models_keep)
  print('save_interval_seconds:', save_interval_seconds)
  
  saver = tf.train.Saver(
    max_to_keep=max_models_keep, 
    keep_checkpoint_every_n_hours=save_interval_seconds / 3600.0)
  
  epoch_saver = tf.train.Saver()
  best_epoch_saver = tf.train.Saver() 
  
  #pre_step means the step last saved, train without pretrained,then -1
  pre_step = -1;
  model_path = _get_model_path(model_dir, save_model)
  model_dir = gezi.get_dir(model_dir) #incase you pass ./model/model-ckpt1000 -> ./model
  if model_path is not None:
    if not restore_from_latest:
      print('using recent but not latest model', file=sys.stderr)
      model_path = melt.recent_checkpoint(model_dir)
    model_name = os.path.basename(model_path)
    timer = gezi.Timer('Loading and training from existing model [%s]'%model_path)
    saver.restore(sess, model_path)
    timer.print()
    pre_step = melt.get_model_step(model_path)
    if 'epoch' in model_name:
      pre_step *= num_steps_per_epoch
    #for non 0 eopochs  without this will be
    #Attempting to use uninitialized value input/input_producer/limit_epochs/epochs
    sess.run(tf.local_variables_initializer())
  else:
    print('Train all start step 0', file=sys.stderr)
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    if init_fn is not None:
      init_fn(sess)
  
  if save_interval_epochs and num_steps_per_epoch:
    epoch_dir = os.path.join(model_dir, 'epoch')
    gezi.try_mkdir(epoch_dir)
  
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  checkpoint_path = os.path.join(model_dir, 'model.ckpt')

  tf.train.write_graph(sess.graph_def, model_dir, 'train.pbtxt')
  try:
    step = start = pre_step +  1
    #hack just for save one model after load
    if num_steps and num_steps < step:
      print('just load and resave then exit', file=sys.stderr)
      saver.save(sess, 
                 _get_checkpoint_path(checkpoint_path, step, num_steps_per_epoch), 
                 global_step=step)
      sess.close()
      exit(0)

    early_stop = True #TODO allow config
    num_bad_epochs = 0
    pre_epoch_eval_loss = 1e20
    best_epoch_eval_loss = 1e20
    num_allowed_bad_epochs = 4 #allow 5 non decrease eval loss epochs  before stop
    while not coord.should_stop():
      stop = train_once_fn(sess, step, is_start=(step==start))
      if save_model and step:
        #step 0 is also saved! actually train one step and save
        if step % save_interval_steps == 0:
          timer = gezi.Timer('save model step %d to %s'%(step, checkpoint_path))
          saver.save(sess, 
                     _get_checkpoint_path(checkpoint_path, step, num_steps_per_epoch), 
                     global_step=step)
          timer.print()
        #if save_interval_epochs and num_steps_per_epoch and step % (num_steps_per_epoch * save_interval_epochs) == 0:
        if save_interval_epochs and num_steps_per_epoch and step % num_steps_per_epoch == 0:
          epoch = step // num_steps_per_epoch
          eval_loss = melt.eval_loss()
          if eval_loss:
            #['eval_loss:3.2','eal_accuracy:4.3']
            eval_loss = float(eval_loss.strip('[]').split(',')[0].strip("'").split(':')[-1])
            if os.path.exists(os.path.join(epoch_dir, 'best_eval_loss.txt')):
              with open(os.path.join(epoch_dir, 'best_eval_loss.txt')) as f:
                best_epoch_eval_loss = float(f.readline().split()[-1].strip())
            if eval_loss < best_epoch_eval_loss:
              best_epoch_eval_loss = eval_loss
              logging.info('Now best eval loss is epoch %d eval_loss:%f' % (epoch, eval_loss))
              with open(os.path.join(epoch_dir, 'best_eval_loss.txt'), 'w') as f:
                f.write('%d %d %f\n'%(epoch, step, best_epoch_eval_loss))
              best_epoch_saver.save(sess, 
                                    os.path.join(epoch_dir,'model.cpkt-best'))

            with open(os.path.join(epoch_dir, 'eval_loss.txt'), 'a') as f:
               f.write('%d %d %f\n'%(epoch, step, eval_loss))
            if eval_loss >= pre_epoch_eval_loss:
              num_bad_epochs += 1
              if num_bad_epochs > num_allowed_bad_epochs:
                logging.warning('Evaluate loss not decrease for last %d epochs'% (num_allowed_bad_epochs + 1))
                if not os.path.exists(os.path.join(epoch_dir,'model.cpkt-noimprove')):
                  best_epoch_saver.save(sess, os.path.join(epoch_dir,'model.cpkt-noimprove'))
                ##-------well remove it since 
                #if early_stop:
                #  stop = True 
            else:
              num_bad_epochs = 0
            pre_epoch_eval_loss = eval_loss
          if step % (num_steps_per_epoch * save_interval_epochs) == 0:
            epoch_saver.save(sess, 
                            os.path.join(epoch_dir,'model.cpkt-%d'%epoch), 
                            global_step=step)
          #--------do not add step
          # epoch_saver.save(sess, 
          #        os.path.join(epoch_dir,'model.cpkt-%d'%epoch))
      if stop is True:
        print('Early stop running %d stpes'%(step), file=sys.stderr)
        raise tf.errors.OutOfRangeError(None, None,'Early stop running %d stpes'%(step))
      if num_steps and (step + 1) == start + num_steps:
        raise tf.errors.OutOfRangeError(None, None,'Reached max num steps')
      #max_num_epochs = 1000
      max_num_epochs = num_epochs
      if num_steps_per_epoch and step // num_steps_per_epoch >= max_num_epochs:
        raise tf.errors.OutOfRangeError(None, None,'Reached max num epochs of %d'%max_num_epochs)
      step += 1
  except tf.errors.OutOfRangeError, e:
    if not (step==start) and save_model and step % save_interval_steps != 0:
      saver.save(sess, 
                 _get_checkpoint_path(checkpoint_path, step, num_steps_per_epoch), 
                 global_step=step)
    if metric_eval_fn is not None:
      metric_eval_fn()
    if (num_epochs and step / num_steps_per_epoch >= num_epochs) or (num_steps and (step + 1) == start + num_steps) :
      print('Done training for %.3f epochs, %d steps.' % (step / num_steps_per_epoch, step + 1), file=sys.stderr)
      #FIXME becase coord.join seems not work,  RuntimeError: Coordinator stopped with threads still running: Thread-9
      exit(0)
    else:
      print('Should not stop, but stopped at epoch: %.3f'%(step / num_steps_per_epoch), file=sys.stderr)
      print(traceback.format_exc(), file=sys.stderr)
      raise e
예제 #25
0
def main(_):
  text2ids.init()
  global_scope = ''
  if FLAGS.add_global_scope:
    global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo
 
  sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)
  with tf.variable_scope(global_scope):
    predictor =  algos_factory.gen_predictor(FLAGS.algo)
    with tf.variable_scope(FLAGS.main_scope) as scope:
      ##--notice if not add below len(tf.get_collection('encode_state') is 1, add below will be 2
      ## even though in init_predict_text(decode_method=SeqDecodeMethod.beam) will call generate_sequence_greedy
      #text, score = predictor.init_predict_text(decode_method=SeqDecodeMethod.greedy, 
      #                                          beam_size=FLAGS.beam_size,
      #                                          convert_unk=False)   
      #scope.reuse_variables()
      #score = predictor.init_predict(exact_loss=True)
      #score = predictor.init_predict(exact_prob=True)
      score = predictor.init_predict()
      scope.reuse_variables()
      beam_text, beam_score = predictor.init_predict_text(decode_method=SeqDecodeMethod.beam, 
                                                          beam_size=FLAGS.beam_size,
                                                          convert_unk=False)  

  predictor.load(FLAGS.model_dir, sess=sess) 

  for item in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    print(item)
  #input_text = "王凯整容了吗_王凯整容前后对比照片"
  input_texts = [
                 #'包邮买二送一性感女内裤低腰诱惑透视蕾丝露臀大蝴蝶三角内裤女夏-淘宝网',
                 #'大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施',
                 #'宝宝太胖怎么办呢',
                 #'蛋龟缸,目前4虎纹1剃刀',
                 #'大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施',
                 #'2015羊年中国风年会晚会签到板设计',
                 '完美 玛丽艳脱角质霜'
                 ]

  for input_text in input_texts:
    word_ids = _text2ids(input_text, INPUT_TEXT_MAX_WORDS)

    print(word_ids)
    print(text2ids.ids2text(word_ids))

    #timer = gezi.Timer()
    #text_, score_ = sess.run([text, score], {predictor.input_text_feed : [word_ids]})
    #print(text_[0], text2ids.ids2text(text_[0]), score_[0], 'time(ms):', timer.elapsed_ms())

    timer = gezi.Timer()
    #texts, scores, preids, paids, seqlens = sess.run([beam_text, beam_score, 
    #                         tf.get_collection('preids')[-1], 
    #                         tf.get_collection('paids')[-1],
    #                         tf.get_collection('seqlens')[-1]],
    #                                        {predictor.input_text_feed : [word_ids]})

    #print(preids)
    #print(paids)
    #print(seqlens)

    score = sess.run(score, {predictor.input_text_feed: [word_ids], predictor.text_feed: [word_ids]})
    print(score)

    texts, scores = sess.run([beam_text, beam_score],
                                            {predictor.input_text_feed : [word_ids]})

    texts = texts[0]
    scores = scores[0]
    for text_, score_ in zip(texts, scores):
      print(text_, text2ids.ids2text(text_), score_, math.log(score_))

    print('beam_search using time(ms):', timer.elapsed_ms())
예제 #26
0
def tf_train_flow(
        train_once_fn,
        model_dir=None,
        log_dir=None,
        max_models_keep=1,
        save_interval_seconds=600,
        save_interval_steps=1000,
        num_epochs=None,
        num_steps=None,
        save_model=True,
        save_interval_epochs=None,
        freeze_graph=False,
        num_steps_per_epoch=0,
        restore_from_latest=True,
        metric_eval_fn=None,
        valid_interval_epochs=0,
        inference_fn=None,
        inference_interval_epochs=0,
        init_fn=None,
        restore_fn=None,
        restore_include=None,
        restore_exclude=None,
        save_all_scope=False,  #TODO save load from restore scope only but svae all
        variables_to_restore=None,
        variables_to_save=None,  #by default will be the same as variables_to_restore
        output_collection_names=None,
        output_node_names=None,
        learning_rate=None,  #not use yet, just use in train_once
        learning_rate_patience=None,
        learning_rate_decay_factor=None,
        write_during_train=True,
        model=None,
        sess=None):
    """
  similary flow as tf_flow, but add model try reload and save
  """
    use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ

    model_dir_ = model_dir
    if use_horovod and hvd.rank() != 0:
        model_dir = None

    if sess is None:
        #TODO melt.get_session is global session but may cause non close at last
        sess = melt.get_session()

    if FLAGS.use_tpu:
        sess.run(tpu.initialize_system())
    #logging.info('tf_train_flow start')
    #logging.info('max_models_keep:', max_models_keep)
    #logging.info('save_interval_seconds:', save_interval_seconds)

    if model_dir:
        if model:
            checkpoint = tf.train.Checkpoint(model=model)
            ckpt_dir = model_dir + '/ckpt'
            checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt')

        #this is usefull for you use another model with another scope, and just load and restore/save initalize your scope vars!
        #this is not for finetune but mainly for like using another model as in predict like this introducing graph other model scope and ignore here

        # var_list = None if not restore_scope else tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)
        # #logging.info('-------------var_list', var_list)

        # if not variables_to_restore:
        #   variables_to_restore = var_list

        if not variables_to_restore:
            variables_to_restore = slim.get_variables_to_restore(
                include=restore_include, exclude=restore_exclude)

        if not variables_to_save:
            variables_to_save = variables_to_restore
        if save_all_scope:
            variables_to_save = None

        #if variables_to_restore is None:
        logging.info('variables_to_restore from %s' % model_dir)
        #load all var in checkpoint try to save all var(might more then original checkpoint) if not specifiy variables_to_save
        varnames_in_checkpoint = melt.get_checkpoint_varnames(model_dir)
        #logging.info('varnames_in_checkpoint: {}'.format(varnames_in_checkpoint))

        # TODO has someproblem say  tf.Variable 'r_net/text_encoder/cudnn_rnn/cu_dnngru/recurrent_kernel/adam_v:0' even though in checkpoint I have renated it as ignore/rnet
        variables_to_restore_from_model = slim.get_variables_to_restore(
            include=varnames_in_checkpoint)
        #logging.info('variables_to_restore_from_model: {}'.format(variables_to_restore_from_model))
        if not variables_to_restore:
            variables_to_restore = variables_to_restore_from_model
        else:
            variables_to_restore = [
                v for v in variables_to_restore
                if v in variables_to_restore_from_model
            ]
        if restore_exclude:
            for excl in restore_exclude:
                variables_to_restore = [
                    v for v in variables_to_restore if not excl in v.name
                ]
        #--tf 1.6 adadelta will have same vars...
        variables_to_restore = list(set(variables_to_restore))
        #logging.info('variables_to_restore', variables_to_restore[:100])
        logging.info('variables_to_restore', [
            x for x in variables_to_restore if not 'OptimizeLoss' in x.name
        ][:100])

    ##finally remove global_step since melt.apps.train will handle it!
    global_step = tf.train.get_or_create_global_step()

    #variables_to_restore = [v for v in variables_to_restore if not tf.GraphKeys.GLOBAL_STEP in v.name]
    #variables_to_restore = [v for v in variables_to_restore if not 'learning_rate' in v.name]

    # TODO fixme if step, step2.. and in checkpoint step then here will be step2...
    #print('------------', [v for v in variables_to_restore if 'step' in v.name])
    loader = tf.train.Saver(var_list=variables_to_restore)

    logging.info('max models to keep {}, keep every {} hours'.format(
        max_models_keep, save_interval_seconds / 3600.0))
    saver = tf.train.Saver(
        max_to_keep=max_models_keep,
        keep_checkpoint_every_n_hours=save_interval_seconds / 3600.0,
        var_list=variables_to_save)
    epoch_saver = tf.train.Saver(var_list=variables_to_save, max_to_keep=1000)
    best_epoch_saver = tf.train.Saver(var_list=variables_to_save)
    #logging.info('variables_to_save:{}'.format(variables_to_save))

    # # #TODO for safe restore all init will be ok ?
    # # if variables_to_restore is None:
    init_op = tf.group(
        tf.global_variables_initializer(
        ),  #variables_initializer(global_variables())
        tf.local_variables_initializer()
    )  #variables_initializer(local_variables())
    # # else:
    # #   init_op = tf.group(tf.variables_initializer(variables_to_restore),
    # #                      tf.local_variables_initializer())

    ##--mostly this will be fine except for using assistant predictor, initialize again! will make assistant predictor wrong
    ##so assume to all run init op! if using assistant predictor, make sure it use another session

    # https://stackoverflow.com/questions/35164529/in-tensorflow-is-there-any-way-to-just-initialize-uninitialised-variables
    # def guarantee_initialized_variables(session, list_of_variables = None):
    #   if list_of_variables is None:
    #       list_of_variables = tf.global_variables()
    #   uninitialized_variables = list(tf.get_variable(name) for name in
    #                                  session.run(tf.report_uninitialized_variables(list_of_variables)))
    #   return unintialized_variables

    # unintialized_variables = guarantee_initialized_variables(sess)
    # init_op = tf.group(tf.initialize_variables(uninitialized_vars), tf.local_variables_initializer())

    timer = gezi.Timer('sess run init_op in melt.tf_train_flow')
    #model.save('./weights')

    # notice
    sess.run(init_op)

    timer.print_elapsed()

    #melt.init_uninitialized_variables(sess)

    #pre_step means the step last saved, train without pretrained,then -1
    pre_step = -1
    fixed_pre_step = -1  #fixed pre step is for epoch num to be correct if you change batch size
    #print(model_dir)
    pre_epoch = None
    if model_dir:
        model_path = _get_model_path(model_dir, save_model)
        # if not model_path:
        #   model_path = _get_model_path(os.path.join(model_dir, 'epoch'))
        #print(model_path)
        model_dir = gezi.get_dir(
            model_dir)  #incase you pass ./model/model-ckpt1000 -> ./model

        if model_path is not None:
            if not restore_from_latest:
                logging.info('using recent but not latest model')
                model_path = melt.recent_checkpoint(model_dir)
            model_name = os.path.basename(model_path)
            timer = gezi.Timer(
                'Loading and training from existing model [%s]' % model_path)
            if restore_fn is not None:
                restore_fn(sess)
            loader.restore(sess, model_path)
            ## not supported
            #model.save()
            #model.save_weights('./weights')
            timer.print()
            #pre_step = melt.get_model_step(model_path) - 1 if FLAGS.global_step is None else FLAGS.global_step -1
            # TODO check ..
            pre_step = sess.run(tf.train.get_global_step()) - 1
            pre_epoch = melt.get_model_epoch(
                model_path
            ) if FLAGS.global_epoch is None else FLAGS.global_epoch
            fixed_pre_step = pre_step
            # if pre_epoch is not None:
            #   #like using batch size 32, then reload train using batch size 64
            #   if abs(pre_step / num_steps_per_epoch - pre_epoch) > 0.1:
            #     fixed_pre_step = int(pre_epoch * num_steps_per_epoch)
            #     logging.info('Warning, epoch is diff with pre_step / num_steps_per_epoch:{}, pre_epoch:{},maybe you change batch size and we will adjust to set pre_step as {}'\
            #       .format(pre_step / num_steps_per_epoch, pre_epoch, fixed_pre_step))
        else:
            latest_checkpoint = None
            if not use_horovod:  #now will hang
                try:
                    latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
                    if latest_checkpoint:
                        logging.info(
                            'Try start from eager trained mode, latest checkpoint:',
                            latest_checkpoint)
                        checkpoint.restore(latest_checkpoint).run_restore_ops(
                            session=sess)

                        pre_epoch = int(latest_checkpoint.split('-')[-1])
                        #pre_step = pre_epoch * num_steps_per_epoch - 1
                        # TODO check
                        pre_step = sess.run(tf.train.get_global_step()) - 1
                        fixed_pre_step = pre_step
                        logging.info('Start step is:', pre_step)
                except Exception:
                    logging.info(
                        'Something wrong with restore from eager trained model'
                    )
                if latest_checkpoint is None:
                    logging.info('Train all start step 0')
                    #https://stackoverflow.com/questions/40220201/tensorflow-tf-initialize-all-variables-vs-tf-initialize-local-variables
                    #tf.initialize_all_variables() is a shortcut to tf.initialize_variables(tf.all_variables()),
                    #tf.initialize_local_variables() is a shortcut to tf.initialize_variables(tf.local_variables()),
                    #which initializes variables in GraphKeys.VARIABLES and GraphKeys.LOCAL_VARIABLE collections, respectively.
                    #init_op = tf.group(tf.global_variables_initializer(),
                    #                   tf.local_variables_initializer())
                    #[var for var in tf.all_variables() if var.op.name.startswith(restore_scope)] will be the same as tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)

                    #sess.run(init_op)

                    #like use image model, build image graph, reload first train, and then will go to same checkpoint all varaible just restore will ok
                    #for finetune from loading other model init
                    if init_fn is not None:
                        init_fn(sess)

    if gezi.env_has('METRIC'):
        l = metric_eval_fn(model_path)
        print(list(zip(l[1], l[0])))
        exit(0)

    #sess.run(tf.assign(global_step, tf.constant(global_step_val, dtype=tf.int64)))
    try:
        learning_rate = tf.get_collection('learning_rate')[-1]
        learning_rate_weight = tf.get_collection('learning_rate_weight')[-1]
        sess.run(tf.assign(learning_rate,
                           learning_rate * learning_rate_weight))
    except Exception:
        # if not using weight_decay but using optimizer decay then will go here as learning rate is a tensor can not assign
        pass

    try:
        logging.info('Actual start global step:',
                     sess.run(global_step), 'learning rate:',
                     sess.run(learning_rate), 'learning_rate_weight:',
                     sess.run(learning_rate_weight))
    except Exception:
        pass

    if model_dir_:
        #if save_interval_epochs and num_steps_per_epoch and num_steps >= 0:
        epoch_dir = os.path.join(model_dir_, 'epoch')
        gezi.try_mkdir(epoch_dir)
        checkpoint_path = os.path.join(model_dir_, 'model.ckpt')

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if use_horovod:
        bcast = hvd.broadcast_global_variables(0)
        sess.run(bcast)

    #tf.train.write_graph(sess.graph_def, model_dir, 'train.pbtxt')
    only_one_step = False
    try:
        if use_horovod:
            ## TODO FIXME why bcast here not work ? simple test work see tests/bcast.py
            #comm.bcast(pre_step, root=0)
            temp = np.array([pre_step, fixed_pre_step])
            comm.Bcast(temp, root=0)
            pre_step = temp[0]
            fixed_pre_step = temp[1]

        step = start = pre_step + 1
        fixed_step = fixed_pre_step + 1

        #first = True

        #hack just for save one model after load
        if num_steps < 0 or (num_steps and num_steps < step):
            logging.info('just load and resave then exit')
            model_path_ = _get_checkpoint_path(checkpoint_path,
                                               step,
                                               num_steps_per_epoch,
                                               epoch=pre_epoch)
            saver.save(sess, model_path_, global_step=step + 1)
            if freeze_graph:
                melt.freeze_graph(sess, model_path_, step + 1,
                                  output_collection_names, output_node_names)
            sess.close()
            exit(0)

        if num_epochs < 0:
            only_one_step = True
            logging.info('just run one step')

        if FLAGS.work_mode != 'train':
            assert not os.path.isdir(FLAGS.model_dir), FLAGS.model_dir
            if 'valid' in FLAGS.work_mode:
                vals, names = metric_eval_fn(FLAGS.model_dir)
                logging.info(list(zip(names, vals)))
            if 'test' in FLAGS.work_mode:
                inference_fn(FLAGS.model_dir)
            exit(0)

        #early_stop = True #TODO allow config
        num_bad_epochs = 0
        pre_epoch_eval_loss = 1e20
        best_epoch_eval_loss = 1e20
        num_allowed_bad_epochs = 4  #allow 5 non decrease eval loss epochs  before stop
        epoch_saved_step = 0
        while not coord.should_stop():
            model_step_path = None
            if model_dir_:
                model_path_ = os.path.join(
                    epoch_dir, 'model.ckpt-%.2f' %
                    (fixed_step / float(num_steps_per_epoch)))
                model_step_path_ = model_path_ + '-' + str(step)
                if (write_during_train and metric_eval_fn is not None
                        and valid_interval_epochs and fixed_step %
                        int(num_steps_per_epoch * valid_interval_epochs) == 0):
                    model_step_path = model_step_path_
                else:
                    model_step_path = None

            if step == 0:
                model_step_path = None

            #print('--------------------step', step)
            stop = train_once_fn(
                sess,
                step,
                is_start=(step == start),
                fixed_step=fixed_step,
                num_epochs=num_epochs,
                model_path=model_step_path,
                use_horovod=use_horovod,
                ## TODO FIXME this line will cause   tensorflow.python.framework.errors_impl.NotFoundError: Resource localhost/save_counter/N10tensorflow3VarE does not exist.
            )

            #first = False

            if only_one_step:
                stop = True

            step += 1
            fixed_step += 1

            if save_model and step and model_dir:
                #step 0 is also saved! actually train one step and save
                if step % save_interval_steps == 0:
                    timer = gezi.Timer(
                        'save model step %d to %s' % (step, checkpoint_path),
                        False)
                    model_path_ = _get_checkpoint_path(checkpoint_path,
                                                       fixed_step,
                                                       num_steps_per_epoch)
                    saver.save(sess, model_path_, global_step=step)
                    if freeze_graph:
                        melt.freeze_graph(sess, model_path_, step,
                                          output_collection_names,
                                          output_node_names)
                    #if log_dir != model_dir:
                    #  assert log_dir
                    #  command = 'rsync -l -r -t %s/* %s' % (log_dir, model_dir)
                    #  print(command, file=sys.stderr)
                    #  os.system(command)
                    timer.print_elapsed()

                if save_interval_steps and num_steps_per_epoch and fixed_step % int(
                        num_steps_per_epoch * save_interval_epochs) == 0:
                    # TODO only epoch in name not sep ?
                    epoch_saved_step = step
                    model_path_ = os.path.join(
                        epoch_dir, 'model.ckpt-%.2f' %
                        (fixed_step / float(num_steps_per_epoch)))
                    model_step_path = model_path_ + '-' + str(step)
                    epoch_saver.save(sess, model_path_, global_step=step)
                    #epoch_saver.save(sess, model_path_)

                    ## TODO FIXME do not support tf.keras save currently with horovod
                    # if model:
                    #   #model.save_weights(epoch_dir + '/ckpt-%.2f' % (fixed_step / float(num_steps_per_epoch)))
                    #   # TODO FIXME if restart will save from 1... again..
                    #   checkpoint.save(checkpoint_prefix, session=sess)
                    #   #print(sess.run(checkpoint.save_counter))

                    if freeze_graph:
                        melt.freeze_graph(sess, model_path_, step,
                                          output_collection_names,
                                          output_node_names)

                if write_during_train:
                    if inference_fn is not None and inference_interval_epochs and fixed_step % int(
                            num_steps_per_epoch *
                            inference_interval_epochs) == 0:
                        model_step_path = model_path_ + '-' + str(step)
                        try:
                            #print('--------------inference fn')
                            inference_fn(model_path=model_step_path)
                        except Exception:
                            logging.info(traceback.format_exc())

                    # if metric_eval_fn is not None and valid_interval_epochs and fixed_step % int(num_steps_per_epoch * valid_interval_epochs) == 0:
                    #   model_step_path = model_path_ + '-' + str(step)
                    #   try:
                    #     metric_eval_fn(model_path=model_step_path)
                    #   except Exception:
                    #     logging.info(traceback.format_exc())

            if stop is True:
                print('Early stop running %d stpes' % (step), file=sys.stderr)
                raise tf.errors.OutOfRangeError(
                    None, None, 'Early stop running %d stpes' % (step))
            if num_steps and (step + 1) == start + num_steps:
                raise tf.errors.OutOfRangeError(None, None,
                                                'Reached max num steps')
            #max_num_epochs = 1000
            max_num_epochs = num_epochs
            #if max_num_epochs and num_steps_per_epoch and fixed_step // num_steps_per_epoch >= max_num_epochs:
            if max_num_epochs and num_steps_per_epoch and fixed_step / num_steps_per_epoch > max_num_epochs:
                raise tf.errors.OutOfRangeError(
                    None, None,
                    'Reached max num epochs of %d' % max_num_epochs)
    #except tf.errors.OutOfRangeError, e:
    except tf.errors.OutOfRangeError:
        # if run 2 epoch and we have just epoch saved, do not need to save only 1 step more model
        if (step - epoch_saved_step > 1) and not (
                step == start
        ) and save_model and step % save_interval_steps != 0 and model_dir:
            model_path_ = _get_checkpoint_path(checkpoint_path, step,
                                               num_steps_per_epoch)
            saver.save(sess, model_path_, global_step=step)
            if freeze_graph:
                melt.freeze_graph(sess, model_path_, step,
                                  output_collection_names, output_node_names)
            if log_dir != model_dir:
                assert log_dir
                command = 'rsync -l -r -t %s/* %s' % (log_dir, model_dir)
                print(command, file=sys.stderr)
                os.system(command)
        if only_one_step:
            logging.info('Done one step')
            exit(0)

        # if (step - epoch_saved_step > 1) and metric_eval_fn is not None:
        #   metric_eval_fn(model_path=model_step_path)

        if (num_epochs and fixed_step / num_steps_per_epoch >= num_epochs) or (
                num_steps and step == start + num_steps):
            logging.info('Done training for %.3f epochs, %d steps.' %
                         (fixed_step / num_steps_per_epoch, step))
            #FIXME becase coord.join seems not work,  RuntimeError: Coordinator stopped with threads still running: Thread-9
            exit(0)
        else:
            logging.info('Should not stop, but stopped at epoch: %.3f' %
                         (fixed_step / num_steps_per_epoch))
            logging.info(traceback.format_exc())
            #raise e
    finally:
        coord.request_stop()

    coord.join(threads, stop_grace_period_secs=5)
    #FIMXE due to use melt.get_session(global not handle del well)
    #Done training for 3090020 steps.
    #Exception TypeError: "'NoneType' object is not callable" in <bound method Session.__del__ of <tensorflow.python.client.session.Session object at 0x7f6cf33cd450>> ignored
    if FLAGS.use_tpu:
        sess.run(tpu.shutdown_system())
    sess.close()
예제 #27
0
def tf_train_flow(
        train_once_fn,
        model_dir='./model',
        max_models_keep=1,
        save_interval_seconds=600,
        save_interval_steps=1000,
        num_epochs=None,
        num_steps=None,
        save_model=True,
        save_interval_epochs=1,
        num_steps_per_epoch=0,
        restore_from_latest=True,
        metric_eval_fn=None,
        init_fn=None,
        restore_fn=None,
        restore_scope=None,
        save_all_scope=False,  #TODO save load from restore scope only but svae all
        variables_to_restore=None,
        variables_to_save=None,  #by default will be the same as variables_to_restore
        sess=None):
    """
  similary flow as tf_flow, but add model try reload and save
  """
    if sess is None:
        #TODO melt.get_session is global session but may cause non close at last
        sess = melt.get_session()
    logging.info('tf_train_flow start')
    print('max_models_keep:', max_models_keep, file=sys.stderr)
    print('save_interval_seconds:', save_interval_seconds, file=sys.stderr)

    #this is usefull for you use another model with another scope, and just load and restore/save initalize your scope vars!
    #this is not for finetune but mainly for like using another model as in predict like this introducing graph other model scope and ignore here

    var_list = None if not restore_scope else tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)
    if not variables_to_restore:
        variables_to_restore = var_list
    if not variables_to_save:
        variables_to_save = variables_to_restore
    if save_all_scope:
        variables_to_save = None

    if variables_to_restore is None:
        #load all var in checkpoint try to save all var(might more then original checkpoint) if not specifiy variables_to_save
        varnames_in_checkpoint = melt.get_checkpoint_varnames(model_dir)
        #print(varnames_in_checkpoint)
        variables_to_restore = slim.get_variables_to_restore(
            include=varnames_in_checkpoint)

    #logging.info('variables_to_restore:{}'.format(variables_to_restore))
    loader = tf.train.Saver(var_list=variables_to_restore)

    saver = tf.train.Saver(
        max_to_keep=max_models_keep,
        keep_checkpoint_every_n_hours=save_interval_seconds / 3600.0,
        var_list=variables_to_save)
    epoch_saver = tf.train.Saver(var_list=variables_to_save, max_to_keep=1000)
    best_epoch_saver = tf.train.Saver(var_list=variables_to_save)

    ##TODO for safe restore all init will be ok ?
    #if variables_to_restore is None:
    init_op = tf.group(
        tf.global_variables_initializer(
        ),  #variables_initializer(global_variables())
        tf.local_variables_initializer()
    )  #variables_initializer(local_variables())
    # else:
    #   init_op = tf.group(tf.variables_initializer(variables_to_restore),
    #                      tf.local_variables_initializer())

    ##--mostly this will be fine except for using assistant predictor, initialize again! will make assistant predictor wrong
    ##so assume to all run init op! if using assistant predictor, make sure it use another session

    sess.run(init_op)

    #melt.init_uninitialized_variables(sess)

    #pre_step means the step last saved, train without pretrained,then -1
    pre_step = -1
    fixed_pre_step = -1  #fixed pre step is for epoch num to be correct if yu change batch size
    model_path = _get_model_path(model_dir, save_model)
    model_dir = gezi.get_dir(
        model_dir)  #incase you pass ./model/model-ckpt1000 -> ./model
    if model_path is not None:
        if not restore_from_latest:
            print('using recent but not latest model', file=sys.stderr)
            model_path = melt.recent_checkpoint(model_dir)
        model_name = os.path.basename(model_path)
        timer = gezi.Timer('Loading and training from existing model [%s]' %
                           model_path)
        if restore_fn is not None:
            restore_fn(sess)
        loader.restore(sess, model_path)
        timer.print()
        pre_step = melt.get_model_step(model_path)
        pre_epoch = melt.get_model_epoch(model_path)
        fixed_pre_step = pre_step
        if pre_epoch is not None:
            #like using batch size 32, then reload train using batch size 64
            if abs(pre_step / num_steps_per_epoch - pre_epoch) > 0.1:
                fixed_pre_step = int(pre_epoch * num_steps_per_epoch)
                logging.info('Warning, epoch is diff with pre_step / num_steps_per_epoch:{}, pre_epoch:{},maybe you change batch size and we will adjust to set pre_step as {}'\
                  .format(pre_step / num_steps_per_epoch, pre_epoch, fixed_pre_step))
    else:
        print('Train all start step 0', file=sys.stderr)
        #https://stackoverflow.com/questions/40220201/tensorflow-tf-initialize-all-variables-vs-tf-initialize-local-variables
        #tf.initialize_all_variables() is a shortcut to tf.initialize_variables(tf.all_variables()),
        #tf.initialize_local_variables() is a shortcut to tf.initialize_variables(tf.local_variables()),
        #which initializes variables in GraphKeys.VARIABLES and GraphKeys.LOCAL_VARIABLE collections, respectively.
        #init_op = tf.group(tf.global_variables_initializer(),
        #                   tf.local_variables_initializer())
        #[var for var in tf.all_variables() if var.op.name.startswith(restore_scope)] will be the same as tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)

        #sess.run(init_op)

        #like use image model, build image graph, reload first train, and then will go to same checkpoint all varaible just restore will ok
        #for finetune from loading other model init
        if init_fn is not None:
            init_fn(sess)

    if save_interval_epochs and num_steps_per_epoch:
        epoch_dir = os.path.join(model_dir, 'epoch')
        gezi.try_mkdir(epoch_dir)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    checkpoint_path = os.path.join(model_dir, 'model.ckpt')

    tf.train.write_graph(sess.graph_def, model_dir, 'train.pbtxt')
    only_one_step = False
    try:
        step = start = pre_step + 1
        fixed_step = fixed_pre_step + 1
        #hack just for save one model after load
        if num_steps < 0 or (num_steps and num_steps < step):
            print('just load and resave then exit', file=sys.stderr)
            saver.save(sess,
                       _get_checkpoint_path(checkpoint_path, step,
                                            num_steps_per_epoch),
                       global_step=step)
            sess.close()
            exit(0)

        if num_epochs < 0:
            only_one_step = True
            print('just run one step', file=sys.stderr)

        early_stop = True  #TODO allow config
        num_bad_epochs = 0
        pre_epoch_eval_loss = 1e20
        best_epoch_eval_loss = 1e20
        num_allowed_bad_epochs = 4  #allow 5 non decrease eval loss epochs  before stop
        while not coord.should_stop():
            stop = train_once_fn(sess,
                                 step,
                                 is_start=(step == start),
                                 fixed_step=fixed_step)
            if only_one_step:
                stop = True
            if save_model and step:
                #step 0 is also saved! actually train one step and save
                if step % save_interval_steps == 0:
                    timer = gezi.Timer('save model step %d to %s' %
                                       (step, checkpoint_path))
                    saver.save(sess,
                               _get_checkpoint_path(checkpoint_path,
                                                    fixed_step,
                                                    num_steps_per_epoch),
                               global_step=step)
                    timer.print()
                #if save_interval_epochs and num_steps_per_epoch and step % (num_steps_per_epoch * save_interval_epochs) == 0:
                #if save_interval_epochs and num_steps_per_epoch and step % num_steps_per_epoch == 0:
                if save_interval_epochs and num_steps_per_epoch and fixed_step % num_steps_per_epoch == 0:
                    #epoch = step // num_steps_per_epoch
                    epoch = fixed_step // num_steps_per_epoch
                    eval_loss = melt.eval_loss()
                    if eval_loss:
                        #['eval_loss:3.2','eal_accuracy:4.3']
                        eval_loss = float(
                            eval_loss.strip('[]').split(',')[0].strip(
                                "'").split(':')[-1])
                        if os.path.exists(
                                os.path.join(epoch_dir, 'best_eval_loss.txt')):
                            with open(
                                    os.path.join(epoch_dir,
                                                 'best_eval_loss.txt')) as f:
                                best_epoch_eval_loss = float(
                                    f.readline().split()[-1].strip())
                        if eval_loss < best_epoch_eval_loss:
                            best_epoch_eval_loss = eval_loss
                            logging.info(
                                'Now best eval loss is epoch %d eval_loss:%f' %
                                (epoch, eval_loss))
                            with open(
                                    os.path.join(epoch_dir,
                                                 'best_eval_loss.txt'),
                                    'w') as f:
                                f.write('%d %d %f\n' %
                                        (epoch, step, best_epoch_eval_loss))
                            best_epoch_saver.save(
                                sess, os.path.join(epoch_dir,
                                                   'model.ckpt-best'))

                        with open(os.path.join(epoch_dir, 'eval_loss.txt'),
                                  'a') as f:
                            f.write('%d %d %f\n' % (epoch, step, eval_loss))
                        if eval_loss >= pre_epoch_eval_loss:
                            num_bad_epochs += 1
                            if num_bad_epochs > num_allowed_bad_epochs:
                                logging.warning(
                                    'Evaluate loss not decrease for last %d epochs'
                                    % (num_allowed_bad_epochs + 1))
                                if not os.path.exists(
                                        os.path.join(epoch_dir,
                                                     'model.ckpt-noimprove')):
                                    best_epoch_saver.save(
                                        sess,
                                        os.path.join(epoch_dir,
                                                     'model.ckpt-noimprove'))
                                ##-------well remove it since
                                #if early_stop:
                                #  stop = True
                        else:
                            num_bad_epochs = 0
                        pre_epoch_eval_loss = eval_loss
                    if step % (num_steps_per_epoch *
                               save_interval_epochs) == 0:
                        epoch_saver.save(sess,
                                         os.path.join(epoch_dir,
                                                      'model.ckpt-%d' % epoch),
                                         global_step=step)
                    #--------do not add step
                    # epoch_saver.save(sess,
                    #        os.path.join(epoch_dir,'model.ckpt-%d'%epoch))
            if stop is True:
                print('Early stop running %d stpes' % (step), file=sys.stderr)
                raise tf.errors.OutOfRangeError(
                    None, None, 'Early stop running %d stpes' % (step))
            if num_steps and (step + 1) == start + num_steps:
                raise tf.errors.OutOfRangeError(None, None,
                                                'Reached max num steps')
            #max_num_epochs = 1000
            max_num_epochs = num_epochs
            if max_num_epochs and num_steps_per_epoch and step // num_steps_per_epoch >= max_num_epochs:
                raise tf.errors.OutOfRangeError(
                    None, None,
                    'Reached max num epochs of %d' % max_num_epochs)
            step += 1
            fixed_step += 1
    except tf.errors.OutOfRangeError, e:
        if not (step
                == start) and save_model and step % save_interval_steps != 0:
            saver.save(sess,
                       _get_checkpoint_path(checkpoint_path, step,
                                            num_steps_per_epoch),
                       global_step=step)
        if only_one_step:
            print('Done one step', file=sys.stderr)
            exit(0)
        if metric_eval_fn is not None:
            metric_eval_fn()
        if (num_epochs and step / num_steps_per_epoch >= num_epochs) or (
                num_steps and (step + 1) == start + num_steps):
            print('Done training for %.3f epochs, %d steps.' %
                  (step / num_steps_per_epoch, step + 1),
                  file=sys.stderr)
            #FIXME becase coord.join seems not work,  RuntimeError: Coordinator stopped with threads still running: Thread-9
            exit(0)
        else:
            print('Should not stop, but stopped at epoch: %.3f' %
                  (step / num_steps_per_epoch),
                  file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
            raise e
예제 #28
0
파일: train.py 프로젝트: tangqiqi123/hasky
def train_flow(ops,
               names=None,
               gen_feed_dict_fn=None,
               deal_results_fn=melt.print_results,
               eval_ops=None,
               eval_names=None,
               gen_eval_feed_dict_fn=None,
               deal_eval_results_fn=melt.print_results,
               optimizer=None,
               learning_rate=0.1,
               num_steps_per_epoch=None,
               model_dir=None,
               metric_eval_fn=None,
               debug=False,
               summary_excls=None,
               init_fn=None,
               sess=None):

    if sess is None:
        sess = melt.get_session()
    if debug:
        sess = tf_debug.LocalCLIDebugWrapperSession(sess)

    logging.info('learning_rate:{}'.format(FLAGS.learning_rate))
    #batch size right now not define here, but in app code like input_app.py
    melt.set_global('batch_size', FLAGS.batch_size)
    melt.set_global('num_gpus', max(FLAGS.num_gpus, 1))

    #NOTICE since melt.__init__.py with from melt.flow import * then you can not
    #use melt.flow.train.train_flow but you can always use
    #from melt.flow.train.train_flow import train_flow

    if optimizer is None:
        optimizer = FLAGS.optimizer
    # Set up the training ops.
    #notice '' only works in tf >= 0.11, for 0.10 will always add OptimeizeLoss scope
    #the diff is 0.10 use variable_op_scope and 0.11 use variable_scope
    optimize_scope = None if FLAGS.optimize_has_scope else ''
    #or judge by FLAGS.num_gpus
    if not isinstance(ops[0], (list, tuple)):
        learning_rate, learning_rate_decay_fn = gen_learning_rate()
        train_op = tf.contrib.layers.optimize_loss(
            loss=ops[0],
            global_step=None,
            learning_rate=learning_rate,
            optimizer=melt.util.get_optimizer(optimizer),
            clip_gradients=FLAGS.clip_gradients,
            learning_rate_decay_fn=learning_rate_decay_fn,
            name=optimize_scope)
    else:
        #---as in cifa10 example, put all but tower loss on cpu, wiki say, that will be faster,
        #but here I find without setting to cpu will be faster..
        #https://github.com/tensorflow/tensorflow/issues/4881
        #I've noticed same thing on cirrascale GPU machines - putting parameters on gpu:0 and using gpu->gpu transfer was a bit faster. I suppose this depends on particular details of hardware -- if you don't have p2p connectivity between your video cards then keeping parameters on CPU:0 gives faster training.
        #err but for my pc no p2p, with PHB connection nvidia-smi topo -m, still hurt by set cpu.. may be should not put cpu here
        #with tf.device('/cpu:0'):
        learning_rate, learning_rate_decay_fn = gen_learning_rate()
        train_op = melt.layers.optimize_loss(
            losses=ops[0],
            num_gpus=FLAGS.num_gpus,
            global_step=None,
            learning_rate=learning_rate,
            optimizer=melt.util.get_optimizer(optimizer),
            clip_gradients=FLAGS.clip_gradients,
            learning_rate_decay_fn=learning_rate_decay_fn,
            name=optimize_scope)
        #set the last tower loss as loss in ops
        ops[0] = ops[0][-1]

    ops.insert(0, train_op)

    #-----------post deal
    save_interval_seconds = FLAGS.save_interval_seconds if FLAGS.save_interval_seconds > 0 \
       else FLAGS.save_interval_hours * 3600

    interval_steps = FLAGS.interval_steps
    eval_interval_steps = FLAGS.eval_interval_steps
    metric_eval_interval_steps = FLAGS.metric_eval_interval_steps
    save_model = FLAGS.save_model
    save_interval_steps = FLAGS.save_interval_steps
    if not save_interval_steps:
        save_interval_steps = 1000000000000

    if FLAGS.work_mode == 'train':
        eval_ops = None
        metric_eval_fn = None
        logging.info('running train only mode')
    elif FLAGS.work_mode == 'train_metric':
        eval_ops = None
        assert metric_eval_fn is not None, 'set metric_eval to 1'
        logging.info('running train+metric mode')
    elif FLAGS.work_mode == 'train_valid':
        metric_eval_fn = None
        logging.info('running train+valid mode')
    elif FLAGS.work_mode == 'test':
        ops = None
        logging.info('running test only mode')
        interval_steps = 0
        eval_interval_steps = 1
        metric_eval_interval_steps /= FLAGS.eval_interval_steps
        save_model = False

    return melt.flow.train_flow(
        ops,
        names=names,
        gen_feed_dict_fn=gen_feed_dict_fn,
        deal_results_fn=deal_results_fn,
        eval_ops=eval_ops,
        eval_names=eval_names,
        gen_eval_feed_dict_fn=gen_eval_feed_dict_fn,
        deal_eval_results_fn=deal_eval_results_fn,
        interval_steps=interval_steps,
        eval_interval_steps=eval_interval_steps,
        num_epochs=FLAGS.num_epochs,
        num_steps=FLAGS.num_steps,
        save_interval_seconds=save_interval_seconds,
        save_interval_steps=save_interval_steps,
        save_model=save_model,
        save_interval_epochs=FLAGS.save_interval_epochs,
        #optimizer=optimizer,
        optimizer=
        None,  #must set None since here we have done choosing optimizer
        learning_rate=learning_rate,
        num_steps_per_epoch=num_steps_per_epoch,
        max_models_keep=FLAGS.max_models_keep,
        model_dir=model_dir,
        restore_from_latest=FLAGS.restore_from_latest,
        metric_eval_fn=metric_eval_fn,
        metric_eval_interval_steps=metric_eval_interval_steps,
        no_log=FLAGS.no_log,
        summary_excls=summary_excls,
        init_fn=init_fn,
        sess=sess)
예제 #29
0
  def __init__(self, 
               weights_op='learning_rate_weights', 
               patience=3, 
               decay=0.8, 
               cmp=None,
               names=None,
               num_weights=None, 
               min_weight=None,
               min_learning_rate=None,
               initial_learning_rate=None,
               initial_score=None,
               decay_start_epoch=0,
               sess=None):
    import melt.utils.logging as logging
    if not tf.executing_eagerly():
      self.sess = sess or melt.get_session()

    if num_weights is None:
      assert names
      num_weights = len(names)

    logging.info('decay:', decay, 'cmp:', cmp)
    assert cmp == 'less' or cmp == 'greater'

    if cmp == 'less':
      self.cmp = lambda x, y: x < y
      self.scores = np.ones([num_weights]) * 1e10
    elif cmp == 'greater':
      self.cmp = lambda x, y: x > y  
      self.scores = np.ones([num_weights]) * -1e10
    else:
      # TODO...
      self.cmp = cmp
      assert initial_score
      self.scores = [initial_score] * num_weights

    #self.scores = None

    self.max_patience = patience
    self.decay = decay

    # TODO patience also varaible so can save and restore ?
    self.patience = [0] * num_weights
    self.count = [0] * num_weights
    self.names = names or list(map(str, range(num_weights)))

    self.min_weight = min_weight

    self.decay_start_epoch = decay_start_epoch

    if not self.min_weight:
      self.min_weight = min_learning_rate / (initial_learning_rate or FLAGS.learning_rate)

    if isinstance(weights_op, str):
      try:
        self.weights_op = tf.get_collection(weights_op)[-1]
      except Exception:
        #self.weights_op = tf.get_variable('lr_ratios', initializer=tf.ones([num_classes], dtype=tf.float32))
        #tf.add_to_collection('lr_ratios', lr_ratios)
        raise 'TODO..'
    else:
      self.weights_op = weights_op