def num_examples_per_epoch(subset='train', dir=None): default_value = None if subset == 'train': file = (dir or gezi.dirname( FLAGS.train_input.split(',')[0])) + '/num_records.txt' return gezi.read_int_from(file, default_value) elif subset == 'valid': file = (dir or gezi.dirname(FLAGS.valid_input)) + '/num_records.txt' return gezi.read_int_from(file, default_value) elif subset == 'test': file = (dir or gezi.dirname(FLAGS.test_input)) + '/num_records.txt' return gezi.read_int_from(file, default_value) else: raise ValueError('Invalid data subset "%s"' % subset)
def init(vocab_path_=None, append=None): global vocab, vocab_size, vocab_path if vocab is None: if not FLAGS.vocab_buckets: vocab_path = vocab_path_ or FLAGS.vocab or gezi.dirname( FLAGS.model_dir) + '/vocab.txt' FLAGS.vocab = vocab_path logging.info('vocab:{}'.format(vocab_path)) logging.info('NUM_RESERVED_IDS:{}'.format(FLAGS.num_reserved_ids)) if append is None: append = FLAGS.vocab_append if gezi.env_has('VOCAB_APPEND'): append = True vocab = Vocabulary(vocab_path, FLAGS.num_reserved_ids, append=append, max_words=FLAGS.vocab_max_words, min_count=FLAGS.vocab_min_count) else: vocab = Vocabulary(buckets=FLAGS.vocab_buckets) vocab_size = vocab.size() if not FLAGS.vocab_size else min( vocab.size(), FLAGS.vocab_size) logging.info('vocab_size:{}'.format(vocab_size)) assert vocab_size > FLAGS.num_reserved_ids, 'empty vocab, wrong vocab path? %s' % FLAGS.vocab logging.info('vocab_start:{} id:{}'.format(vocab.key(vocab.start_id()), vocab.start_id())) logging.info('vocab_end:{} id:{}'.format(vocab.key(vocab.end_id()), vocab.end_id())) logging.info('vocab_unk:{} id:{}'.format(vocab.key(vocab.unk_id()), vocab.unk_id()))
def get_model_dir(model_dir, model_name=None): model_path = model_dir ckpt = tf.train.get_checkpoint_state(model_dir) if ckpt and ckpt.model_checkpoint_path: #model_path = '%s/%s'%(model_dir, os.path.basename(ckpt.model_checkpoint_path)) model_path = os.path.join(model_dir, os.path.basename(ckpt.model_checkpoint_path)) else: model_path = model_dir if model_name is None else os.path.join(model_dir, model_name) #if not os.path.exists(model_path+'.index'): # raise ValueError(model_path) return gezi.dirname(model_path)