def __init__(self, batch_size=None): net_params, train_params = parser_cfg_file('./net.cfg') self._model_save_path = str(train_params['model_save_path']) self.input_img_height = int(net_params['input_height']) self.input_img_width = int(net_params['input_width']) if batch_size is None: self.test_batch_size = int(net_params['test_batch_size']) else: self.test_batch_size = batch_size # 加载label onehot f = open('./data/word_onehot.txt', 'r') data = f.read() words_onehot_dict = eval(data) self.words_list = list(words_onehot_dict.keys()) self.words_onehot_list = [words_onehot_dict[self.words_list[i]] for i in range(len(self.words_list))] # 构建网络 self.inputs_tensor = tf.placeholder(tf.float32, [self.test_batch_size, self.input_img_height, self.input_img_width, 1]) self.seq_len_tensor = tf.placeholder(tf.int32, [None], name='seq_len') crnn_net = CRNN(net_params, self.inputs_tensor, self.seq_len_tensor, self.test_batch_size, True) net_output, decoded, self.max_char_count = crnn_net.construct_graph() self.dense_decoded = tf.sparse_tensor_to_dense(decoded[0], default_value=-1) self.sess = tf.Session() saver = tf.train.Saver() saver.restore(self.sess, "./model/ckpt")
def __init__(self, pre_train=False): net_params, train_params = parser_cfg_file('./net.cfg') self.input_height = int(net_params['input_height']) self.input_width = int(net_params['input_width']) self.batch_size = int(train_params['batch_size']) self._learning_rate = float(train_params['learning_rate']) self._max_iterators = int(train_params['max_iterators']) self._train_logger_init() self._pre_train = pre_train self._model_save_path = str(train_params['model_save_path']) if self._pre_train: ckpt = tf.train.checkpoint_exists(self._model_save_path) if ckpt: print('Checkpoint is valid...') f = open('./model/train_step.txt', 'r') step = f.readline() self._start_step = int(step) f.close() else: assert 0, print('Checkpoint is invalid...') else: self._start_step = 0 self._inputs = tf.placeholder( tf.float32, [self.batch_size, 32, self.input_width, 1]) # label self._label = tf.sparse_placeholder(tf.int32, name='label') # The length of the sequence [32] * 64 self._seq_len = tf.placeholder(tf.int32, [None], name='seq_len') crnn_net = CRNN(net_params, self._inputs, self._seq_len, self.batch_size, True) self._net_output, self._decoded, self._max_char_count = crnn_net.construct_graph( ) self.dense_decoded = tf.sparse_tensor_to_dense(self._decoded[0], default_value=-1)