def main(): """ Main freezing function. """ args = parse_args() image_width = 120 image_height = 32 _, _, num_classes = Dataset.create_character_maps() model = TextRecognition(is_training=False, num_classes=num_classes) model_out = model(inputdata=tf.placeholder( tf.float32, [1, image_height, image_width, 1])) output_dir = args.output_dir if args.output_dir else os.path.join( os.path.dirname(args.checkpoint), 'export') saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess=sess, save_path=args.checkpoint) graph_file = os.path.join(output_dir, 'graph.pb') frozen_graph = freezing_graph(sess, graph_file, output_node_names=[model_out.name[:-2]]) mo_params = { 'model_name': 'text_recognition', 'data_type': args.data_type, } export_ir_dir = os.path.join(output_dir, 'IR', args.data_type) execute_mo(mo_params, frozen_graph, export_ir_dir)
def test_training_loss(self): """ Test for checking that training loss decreases. """ model = TextRecognition(is_training=True, num_classes=self.dataset.num_classes) next_sample = self.dataset().make_one_shot_iterator().get_next() model_out = model(inputdata=next_sample[0]) ctc_loss = tf.reduce_mean( tf.nn.ctc_loss(labels=next_sample[1], inputs=model_out, sequence_length=self.seq_length * np.ones(self.batch_size))) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer = tf.train.AdadeltaOptimizer(0.1).minimize(loss=ctc_loss) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) l0 = sess.run(ctc_loss) print('loss before', l0) for _ in range(10): sess.run(optimizer) l1 = sess.run(ctc_loss) print('loss after', l1) assert l1 < l0
def main(): """ Main testing funciton. """ args = parse_args() sequence_length = 30 image_width = 120 image_height = 32 dataset = Dataset(args.annotation_path, image_width, image_height, repeat=1) next_sample = dataset().make_one_shot_iterator().get_next() model = TextRecognition(is_training=False, num_classes=dataset.num_classes) images_ph = tf.placeholder(tf.float32, [1, image_height, image_width, 1]) model_out = model(inputdata=images_ph) decoded, _ = tf.nn.ctc_beam_search_decoder(model_out, sequence_length * np.ones(1), merge_repeated=False) saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess=sess, save_path=args.weights_path) correct = 0.0 dataset_len = len(dataset) for _ in tqdm(range(dataset_len)): images_batch, labels_batch = sess.run(next_sample) preds, _ = sess.run([decoded, model_out], feed_dict={images_ph: images_batch}) try: predicted = Dataset.sparse_tensor_to_str( preds[0], dataset.int_to_char)[0] expected = Dataset.sparse_tensor_to_str( labels_batch, dataset.int_to_char)[0].lower() except: print('Could not find a word') continue correct += 1 if predicted == expected else 0 if args.show and predicted != expected: image = np.reshape(images_batch, [image_height, image_width, -1]).astype( np.uint8) cv2.imshow('image', image) print('pr, gt', predicted, expected) k = cv2.waitKey(0) if k == 27: sess.close() return print('accuracy', correct / dataset_len) return
sess_r_h=tf.Session() #init = tf.global_variables_initializer() sess_r_h.run(init) with tf.Graph().as_default(): recogniton_graph_def = tf.GraphDef() with open(recognition_model_v_path, "rb") as f: recogniton_graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(recogniton_graph_def, name="") sess_r_v=tf.Session() #init = tf.global_variables_initializer() sess_r_v.run(init) ''' bs = 4 model = TextRecognition(is_training=False, num_classes=37) images_ph_h = tf.placeholder(tf.float32, [bs, 32, 240, 1]) model_out_h = model(inputdata=images_ph_h) saver_h = tf.train.Saver() sess_r_h = tf.Session() saver_h.restore(sess=sess_r_h, save_path=recognition_model_h_path) decoded_h, _ = tf.nn.ctc_beam_search_decoder(model_out_h, 60 * np.ones(bs), merge_repeated=False) images_ph_v = tf.placeholder(tf.float32, [bs, 32, 320, 1]) model_out_v = model(inputdata=images_ph_v) saver_v = tf.train.Saver() sess_r_v = tf.Session() saver_v.restore(sess=sess_r_v, save_path=recognition_model_v_path)
def main(): """ Main training function. """ args = parse_args() seq_length = 30 batch_size = 64 image_width, image_height = 120, 32 handle = tf.placeholder(tf.string, shape=[]) dataset_train = Dataset(args.annotation_path, image_width, image_height, batch_size=batch_size, shuffle=True) iterator_train = dataset_train().make_initializable_iterator() if args.annotation_path_test != '': dataset_test = Dataset(args.annotation_path_test, image_width, image_height, batch_size=batch_size, shuffle=False, repeat=1) iterator_test = dataset_test().make_initializable_iterator() iterator = tf.data.Iterator.from_string_handle( handle, dataset_train().output_types, dataset_train().output_shapes, dataset_train().output_classes) next_sample = iterator.get_next() is_training_ph = tf.placeholder(tf.bool) model = TextRecognition(is_training=is_training_ph, num_classes=dataset_train.num_classes, backbone_dropout=args.backbone_dropout) model_out = model(inputdata=next_sample[0]) ctc_loss = tf.reduce_mean(tf.nn.ctc_loss(labels=next_sample[1], inputs=model_out, sequence_length=seq_length * np.ones(batch_size))) reg_loss = tf.losses.get_regularization_loss() loss = ctc_loss if args.reg: loss += reg_loss decoded, _ = tf.nn.ctc_beam_search_decoder(model_out, seq_length * np.ones(batch_size), merge_repeated=False) edit_dist = tf.edit_distance(tf.cast(decoded[0], tf.int32), next_sample[1]) crw = tf.nn.zero_fraction(edit_dist) global_step = tf.Variable(0, name='global_step', trainable=False) starter_learning_rate = args.learning_rate learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, 1000000, 0.1, staircase=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer = tf.train.AdadeltaOptimizer(learning_rate).minimize(loss=loss, global_step=global_step) # Set tf summary train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) model_descr = str(train_start_time) tboard_save_path = 'tboard/' + model_descr if not os.path.exists(tboard_save_path): os.makedirs(tboard_save_path) tf.summary.scalar(name='ctc_loss', tensor=ctc_loss) tf.summary.scalar(name='reg_loss', tensor=reg_loss) tf.summary.scalar(name='total_loss', tensor=loss) tf.summary.scalar(name='Learning_Rate', tensor=learning_rate) merge_summary_op = tf.summary.merge_all() test_acc_ph = tf.placeholder(dtype=np.float32) test_acc_summary = tf.summary.scalar(name='test_acc_ph', tensor=test_acc_ph) # Set saver configuration saver = tf.train.Saver(max_to_keep=1000) model_save_dir = 'model/' + model_descr if not os.path.exists(model_save_dir): os.makedirs(model_save_dir) model_name = 'model_' + model_descr + '.ckpt' model_save_path = os.path.join(model_save_dir, model_name) summary_writer = tf.summary.FileWriter(tboard_save_path) with tf.Session() as sess: sess.run(iterator_train.initializer) if args.weights_path is None: print('Training from scratch') init = tf.global_variables_initializer() sess.run(init) else: print('Restore model from {:s}'.format(args.weights_path)) saver.restore(sess=sess, save_path=args.weights_path) training_handle = sess.run(iterator_train.string_handle()) if args.annotation_path_test != '': test_handle = sess.run(iterator_test.string_handle()) for _ in range(args.num_steps): _, c, step, summary = sess.run([optimizer, ctc_loss, global_step, merge_summary_op], feed_dict={is_training_ph: True, handle: training_handle}) if step % 100 == 0: summary_writer.add_summary(summary=summary, global_step=step) print('Iter: {:d} cost= {:9f}'.format(step, c)) if step % 1000 == 0: saver.save(sess=sess, save_path=model_save_path, global_step=global_step) if args.annotation_path_test: sess.run(iterator_test.initializer) correct = 0.0 for _ in range(len(dataset_test) // batch_size): correct += sess.run(crw, feed_dict={is_training_ph: False, handle: test_handle}) test_accuracy = correct / (len(dataset_test) // batch_size) print('Iter: {:d} cost= {:9f} TEST accuracy= {:9f}'.format(step, c, test_accuracy)) summary = sess.run(test_acc_summary, feed_dict={test_acc_ph: test_accuracy}) summary_writer.add_summary(summary=summary, global_step=step)