Пример #1
0
    def __init__(self, batch_size=None):
        net_params, train_params = parser_cfg_file('./net.cfg')
        self._model_save_path = str(train_params['model_save_path'])
        self.input_img_height = int(net_params['input_height'])
        self.input_img_width = int(net_params['input_width'])
        if batch_size is None:
            self.test_batch_size = int(net_params['test_batch_size'])
        else:
            self.test_batch_size = batch_size

        # 加载label onehot
        f = open('./data/word_onehot.txt', 'r')
        data = f.read()
        words_onehot_dict = eval(data)
        self.words_list = list(words_onehot_dict.keys())
        self.words_onehot_list = [words_onehot_dict[self.words_list[i]] for i in range(len(self.words_list))]

        # 构建网络
        self.inputs_tensor = tf.placeholder(tf.float32, [self.test_batch_size, self.input_img_height, self.input_img_width, 1])
        self.seq_len_tensor = tf.placeholder(tf.int32, [None], name='seq_len')

        crnn_net = CRNN(net_params, self.inputs_tensor, self.seq_len_tensor, self.test_batch_size, True)
        net_output, decoded, self.max_char_count = crnn_net.construct_graph()
        self.dense_decoded = tf.sparse_tensor_to_dense(decoded[0], default_value=-1)

        self.sess = tf.Session()
        saver = tf.train.Saver()
        saver.restore(self.sess, "./model/ckpt")
    def __init__(self, pre_train=False):
        net_params, train_params = parser_cfg_file('./net.cfg')

        self.input_height = int(net_params['input_height'])
        self.input_width = int(net_params['input_width'])
        self.batch_size = int(train_params['batch_size'])
        self._learning_rate = float(train_params['learning_rate'])
        self._max_iterators = int(train_params['max_iterators'])
        self._train_logger_init()
        self._pre_train = pre_train
        self._model_save_path = str(train_params['model_save_path'])

        if self._pre_train:
            ckpt = tf.train.checkpoint_exists(self._model_save_path)
            if ckpt:
                print('Checkpoint is valid...')
                f = open('./model/train_step.txt', 'r')
                step = f.readline()
                self._start_step = int(step)
                f.close()
            else:
                assert 0, print('Checkpoint is invalid...')
        else:
            self._start_step = 0

        self._inputs = tf.placeholder(
            tf.float32, [self.batch_size, 32, self.input_width, 1])

        # label
        self._label = tf.sparse_placeholder(tf.int32, name='label')

        # The length of the sequence [32] * 64
        self._seq_len = tf.placeholder(tf.int32, [None], name='seq_len')

        crnn_net = CRNN(net_params, self._inputs, self._seq_len,
                        self.batch_size, True)
        self._net_output, self._decoded, self._max_char_count = crnn_net.construct_graph(
        )
        self.dense_decoded = tf.sparse_tensor_to_dense(self._decoded[0],
                                                       default_value=-1)