예제 #1
0
 def num_examples_per_epoch(subset='train', dir=None):
     default_value = None
     if subset == 'train':
         file = (dir or gezi.dirname(
             FLAGS.train_input.split(',')[0])) + '/num_records.txt'
         return gezi.read_int_from(file, default_value)
     elif subset == 'valid':
         file = (dir
                 or gezi.dirname(FLAGS.valid_input)) + '/num_records.txt'
         return gezi.read_int_from(file, default_value)
     elif subset == 'test':
         file = (dir or gezi.dirname(FLAGS.test_input)) + '/num_records.txt'
         return gezi.read_int_from(file, default_value)
     else:
         raise ValueError('Invalid data subset "%s"' % subset)
예제 #2
0
def init(vocab_path_=None, append=None):
    global vocab, vocab_size, vocab_path
    if vocab is None:
        if not FLAGS.vocab_buckets:
            vocab_path = vocab_path_ or FLAGS.vocab or gezi.dirname(
                FLAGS.model_dir) + '/vocab.txt'
            FLAGS.vocab = vocab_path
            logging.info('vocab:{}'.format(vocab_path))
            logging.info('NUM_RESERVED_IDS:{}'.format(FLAGS.num_reserved_ids))
            if append is None:
                append = FLAGS.vocab_append
                if gezi.env_has('VOCAB_APPEND'):
                    append = True
            vocab = Vocabulary(vocab_path,
                               FLAGS.num_reserved_ids,
                               append=append,
                               max_words=FLAGS.vocab_max_words,
                               min_count=FLAGS.vocab_min_count)
        else:
            vocab = Vocabulary(buckets=FLAGS.vocab_buckets)
        vocab_size = vocab.size() if not FLAGS.vocab_size else min(
            vocab.size(), FLAGS.vocab_size)
        logging.info('vocab_size:{}'.format(vocab_size))
        assert vocab_size > FLAGS.num_reserved_ids, 'empty vocab, wrong vocab path? %s' % FLAGS.vocab
        logging.info('vocab_start:{} id:{}'.format(vocab.key(vocab.start_id()),
                                                   vocab.start_id()))
        logging.info('vocab_end:{} id:{}'.format(vocab.key(vocab.end_id()),
                                                 vocab.end_id()))
        logging.info('vocab_unk:{} id:{}'.format(vocab.key(vocab.unk_id()),
                                                 vocab.unk_id()))
예제 #3
0
파일: util.py 프로젝트: meng-jia/wenzheng
def get_model_dir(model_dir, model_name=None):
  model_path = model_dir
  ckpt = tf.train.get_checkpoint_state(model_dir)
  if ckpt and ckpt.model_checkpoint_path:
    #model_path = '%s/%s'%(model_dir, os.path.basename(ckpt.model_checkpoint_path)) 
    model_path = os.path.join(model_dir, os.path.basename(ckpt.model_checkpoint_path))
  else:
    model_path = model_dir if model_name is None else os.path.join(model_dir, model_name)
  #if not os.path.exists(model_path+'.index'):
  #  raise ValueError(model_path)
  return gezi.dirname(model_path)