Esempio n. 1
0
def main(argv=None):
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    try:
        os.makedirs(FLAGS.output_dir)
    except OSError as e:
        if e.errno != 17:
            raise
    """
    if FLAGS.use_vacab and os.path.exists("./vocab.txt"):
        bk_tree = BKTree(levenshtein, list_words('./vocab.txt'))
        # bk_tree = bktree.Tree()
    """
    with tf.get_default_graph().as_default():
        input_images = tf.placeholder(tf.float32,
                                      shape=[None, None, None, 3],
                                      name='input_images')
        input_feature_map = tf.placeholder(tf.float32,
                                           shape=[None, None, None, 32],
                                           name='input_feature_map')
        input_transform_matrix = tf.placeholder(tf.float32,
                                                shape=[None, 6],
                                                name='input_transform_matrix')
        input_box_mask = []
        input_box_mask.append(
            tf.placeholder(tf.int32, shape=[None], name='input_box_masks_0'))
        input_box_widths = tf.placeholder(tf.int32,
                                          shape=[None],
                                          name='input_box_widths')

        input_seq_len = input_box_widths[tf.argmax(
            input_box_widths, 0)] * tf.ones_like(input_box_widths)
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        shared_feature, f_geometry = detect_part.model(input_images)
        pad_rois = roi_rotate_part.roi_rotate_tensor_pad(
            input_feature_map, input_transform_matrix, input_box_mask,
            input_box_widths)
        recognition_logits = recognize_part.build_graph(
            pad_rois, input_box_widths)
        _, dense_decode = recognize_part.decode(recognition_logits,
                                                input_box_widths)

        variable_averages = tf.train.ExponentialMovingAverage(
            0.997, global_step)
        saver = tf.train.Saver(variable_averages.variables_to_restore())

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
            model_path = os.path.join(
                FLAGS.checkpoint_path,
                os.path.basename(ckpt_state.model_checkpoint_path))
            print('Restore from {}'.format(model_path))
            saver.restore(sess, model_path)

            im_fn_list = get_images()
            #用于测试中断后继续当前测试 n=10000-index
            im_fn_list = im_fn_list[0:]
            index = len(im_fn_list)
            for im_fn in im_fn_list:

                print('still have {} pictures left'.format(index))
                print(im_fn)

                #image for draw quad
                im = cv2.imread(im_fn)[:, :, ::-1]
                im_resized, (ratio_h, ratio_w) = resize_image(im)

                #method 2
                #im = image.load_img(im_fn)
                #im_draw = im.copy
                #d_wight, d_height = resize_image2(im)
                #im_resized = im.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
                #im_draw = cv2.resize(im_draw, (int(d_wight), int(d_height)))
                #ratio_w = d_wight / im.width
                #ratio_h = d_height / im.height

                img = image.img_to_array(im_resized)
                img = preprocess_input(img, mode='tf')

                shared_feature_map, geometry = sess.run(
                    [shared_feature, f_geometry],
                    feed_dict={input_images: [img]})

                geometry = np.squeeze(geometry, axis=0)
                print('geometry is : ', (np.array(geometry)).shape)
                geometry[:, :, :3] = sigmoid(geometry[:, :, :3])
                #print('geometry : ' ,geometry[40:50,40:50,0:7])

                pixel_threshold = 0.9
                cond = np.greater_equal(geometry[:, :, 0], pixel_threshold)
                activation_pixels = np.where(cond)
                print('activation_pixels is : ',
                      (np.array(activation_pixels)).shape)

                #for i,j in zip(activation_pixels[0], activation_pixels[1]):
                #print(geometry[i, j, 0:7])

                #boxes, timer = detect(score_map=score, geo_map=geometry, timer=timer)
                quad_scores, boxes = nms(geometry, activation_pixels)
                #print('boxes : ',boxes)

                input_roi_boxes = []
                for score, box in zip(quad_scores, boxes):
                    if np.amin(score) > 0:
                        #print(type(box))
                        box = box[[0, 3, 2, 1]]
                        input_roi_boxes.append(box)
                input_roi_boxes = np.array(input_roi_boxes)
                #print('input_roi_boxes : ',input_roi_boxes[:])

                #im_txt = None
                if input_roi_boxes is not None and input_roi_boxes.shape[
                        0] != 0:

                    #for ICDAR
                    #res_file_path = os.path.join(FLAGS.output_dir, 'res_' + '{}.txt'.format(os.path.basename(im_fn).split('.')[0]))

                    #for MTWI
                    res_file_path = os.path.join(
                        FLAGS.output_dir,
                        '{}.txt'.format(im_fn[:-4].split('/')[-1]))

                    input_roi_boxes = input_roi_boxes[:, :8].reshape(-1, 8)
                    recog_decode_list = []
                    # Here avoid too many text area leading to OOM
                    for batch_index in range(input_roi_boxes.shape[0] // 32 +
                                             1):  # test roi batch size is 32
                        start_slice_index = batch_index * 32
                        end_slice_index = (
                            batch_index + 1
                        ) * 32 if input_roi_boxes.shape[0] >= (
                            batch_index + 1) * 32 else input_roi_boxes.shape[0]
                        tmp_roi_boxes = input_roi_boxes[
                            start_slice_index:end_slice_index]

                        boxes_masks = [0] * tmp_roi_boxes.shape[0]

                        transform_matrixes, box_widths = get_project_matrix_and_width(
                            tmp_roi_boxes)
                        # max_box_widths = max_width * np.ones(boxes_masks.shape[0]) # seq_len

                        # Run end to end
                        try:
                            recog_decode = sess.run(dense_decode,
                                                    feed_dict={
                                                        input_feature_map:
                                                        shared_feature_map,
                                                        input_transform_matrix:
                                                        transform_matrixes,
                                                        input_box_mask[0]:
                                                        boxes_masks,
                                                        input_box_widths:
                                                        box_widths
                                                    })
                            recog_decode_list.extend([r for r in recog_decode])
                        except:
                            recog_decode_list.extend([None])

                    if len(recog_decode_list) != input_roi_boxes.shape[0]:
                        print(
                            "detection and recognition result are not equal!")
                        with open(
                                '/home/wsw/deeplearning/advancedFOTS/somepicwrong.txt',
                                'a') as f:
                            f.write('{}\r\n'.format(im_fn))
                        wrong_file_path = os.path.join(
                            '/home/wsw/deeplearning/advancedFOTS/picwrong/',
                            os.path.basename(im_fn))
                        cv2.imwrite(wrong_file_path, im[:, :, ::-1])
                        index -= 1
                        print('------------------------------------')
                        continue
                        #exit(-1)

                    # Preparing for draw boxes
                    #boxes = boxes[:, :8].reshape((-1, 4, 2))
                    #boxes[:, :, 0] /= ratio_w
                    #boxes[:, :, 1] /= ratio_h
                    input_roi_boxes = input_roi_boxes[:, :8].reshape(
                        (-1, 4, 2))
                    #input_roi_boxes[:, :, 0] /= ratio_w
                    #input_roi_boxes[:, :, 0] /= ratio_w

                    with open(res_file_path, 'w') as f:
                        #for score, box, i in zip(quad_scores, boxes, range(len(input_roi_boxes))):
                        for i, box in enumerate(input_roi_boxes):
                            #if np.amin(score) > 0:
                            if True:

                                # for ICDAR to avoid submitting errors
                                #box = sort_poly(box.astype(np.int32))

                                # for MTWI
                                box = box[[0, 3, 2, 1]]
                                #box = box.astype(np.int32)

                                box = box / [ratio_w, ratio_h]
                                box = box.astype(np.int32)

                                recognition_result = ground_truth_to_word(
                                    recog_decode_list[i])
                                print('recognition_result : ',
                                      recognition_result)

                                f.write(
                                    '{},{},{},{},{},{},{},{},{}\r\n'.format(
                                        box[0, 0], box[0, 1], box[1, 0],
                                        box[1, 1], box[2, 0], box[2, 1],
                                        box[3, 0], box[3,
                                                       1], recognition_result))

                                # Draw bounding box
                                im_txt = cv2.polylines(
                                    im[:, :, ::-1],
                                    [box.astype(np.int32).reshape((-1, 1, 2))],
                                    True,
                                    color=(0, 0, 255),
                                    thickness=2)
                                '''
                                # Draw recognition results area
                                text_area = box.copy()
                                text_area[2, 1] = text_area[1, 1]
                                text_area[3, 1] = text_area[0, 1]
                                text_area[0, 1] = text_area[0, 1] - 15
                                text_area[1, 1] = text_area[1, 1] - 15
                                cv2.fillPoly(im[:, :, ::-1], [text_area.astype(np.int32).reshape((-1, 1, 2))], color=(255, 255, 0))
                                im_txt = cv2.putText(im[:, :, ::-1], recognition_result, (box[0, 0], box[0, 1]), font, 0.5, (0, 0, 255), 1)
                                '''
                else:
                    #for ICDAR
                    #res_file = os.path.join(FLAGS.output_dir, 'res_' + '{}.txt'.format(os.path.basename(im_fn).split('.')[0]))

                    #for MTWI
                    res_file = os.path.join(
                        FLAGS.output_dir,
                        '{}.txt'.format(im_fn[:-4].split('/')[-1]))

                    f = open(res_file, "w")
                    im_txt = None
                    f.close()

                #to show how many pictures left
                index -= 1
                print('------------------------------------')

                if not FLAGS.no_write_images:
                    img_path = os.path.join(FLAGS.output_dir,
                                            os.path.basename(im_fn))
                    # cv2.imwrite(img_path, im[:, :, ::-1])
                    if im_txt is not None:
                        cv2.imwrite(img_path, im_txt)
Esempio n. 2
0
def main(argv=None):
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    try:
        os.makedirs(FLAGS.output_dir)
    except OSError as e:
        if e.errno != 17:
            raise

    bk_tree = BKTree(levenshtein, list_words(FLAGS.vocab))
    # bk_tree = bktree.Tree()

    with tf.get_default_graph().as_default():
        input_images = tf.placeholder(tf.float32,
                                      shape=[None, None, None, 3],
                                      name='input_images')
        input_feature_map = tf.placeholder(tf.float32,
                                           shape=[None, None, None, 32],
                                           name='input_feature_map')
        input_transform_matrix = tf.placeholder(tf.float32,
                                                shape=[None, 6],
                                                name='input_transform_matrix')
        input_box_mask = []
        input_box_mask.append(
            tf.placeholder(tf.int32, shape=[None], name='input_box_masks_0'))
        input_box_widths = tf.placeholder(tf.int32,
                                          shape=[None],
                                          name='input_box_widths')

        input_seq_len = input_box_widths[tf.argmax(
            input_box_widths, 0)] * tf.ones_like(input_box_widths)
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        shared_feature, f_score, f_geometry = detect_part.model(input_images)
        pad_rois = roi_rotate_part.roi_rotate_tensor_pad(
            input_feature_map, input_transform_matrix, input_box_mask,
            input_box_widths)
        recognition_logits = recognize_part.build_graph(
            pad_rois, input_box_widths)
        _, dense_decode = recognize_part.decode(recognition_logits,
                                                input_box_widths)

        variable_averages = tf.train.ExponentialMovingAverage(
            0.997, global_step)
        saver = tf.train.Saver(variable_averages.variables_to_restore())

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
            model_path = os.path.join(
                FLAGS.checkpoint_path,
                os.path.basename(ckpt_state.model_checkpoint_path))
            print('Restore from {}'.format(model_path))
            saver.restore(sess, model_path)

            im_fn_list = get_images()
            for im_fn in im_fn_list:
                #im = cv2.imread(im_fn)[:, :, ::-1]
                im = cv2.imread(im_fn)
                im = cv2.resize(im, (960, 540))

                start_time = time.time()
                im_resized, (ratio_h, ratio_w) = resize_image(im)
                # im_resized_d, (ratio_h_d, ratio_w_d) = resize_image_detection(im)

                timer = {'detect': 0, 'restore': 0, 'nms': 0, 'recog': 0}
                start = time.time()
                shared_feature_map, score, geometry = sess.run(
                    [shared_feature, f_score, f_geometry],
                    feed_dict={input_images: [im_resized]})

                boxes, timer = detect(score_map=score,
                                      geo_map=geometry,
                                      timer=timer)
                timer['detect'] = time.time() - start
                start = time.time()  # reset for recognition
                if boxes is not None and boxes.shape[0] != 0:
                    #res_file_path = os.path.join(FLAGS.output_dir,'res_' + '{}.txt'.format(os.path.basename(im_fn).split('.')[0]))
                    res_file_path = os.path.join(
                        FLAGS.output_dir,
                        '{}.txt'.format(os.path.basename(im_fn)))

                    input_roi_boxes = boxes[:, :8].reshape(-1, 8)
                    recog_decode_list = []
                    # Here avoid too many text area leading to OOM
                    for batch_index in range(input_roi_boxes.shape[0] // 32 +
                                             1):  # test roi batch size is 32
                        start_slice_index = batch_index * 32
                        end_slice_index = (
                            batch_index + 1
                        ) * 32 if input_roi_boxes.shape[0] >= (
                            batch_index + 1) * 32 else input_roi_boxes.shape[0]
                        tmp_roi_boxes = input_roi_boxes[
                            start_slice_index:end_slice_index]

                        boxes_masks = [0] * tmp_roi_boxes.shape[0]
                        transform_matrixes, box_widths = get_project_matrix_and_width(
                            tmp_roi_boxes)
                        #max_box_widths = max_width * np.ones(boxes_masks.shape[0]) # seq_len

                        # Run end to end
                        recog_decode = sess.run(dense_decode,
                                                feed_dict={
                                                    input_feature_map:
                                                    shared_feature_map,
                                                    input_transform_matrix:
                                                    transform_matrixes,
                                                    input_box_mask[0]:
                                                    boxes_masks,
                                                    input_box_widths:
                                                    box_widths
                                                })
                        recog_decode_list.extend([r for r in recog_decode])

                    timer['recog'] = time.time() - start
                    # Preparing for draw boxes
                    boxes = boxes[:, :8].reshape((-1, 4, 2))
                    boxes[:, :, 0] /= ratio_w
                    boxes[:, :, 1] /= ratio_h

                    if len(recog_decode_list) != boxes.shape[0]:
                        print(
                            "detection and recognition result are not equal!")
                        exit(-1)

                    with open(res_file_path, 'w') as f:
                        for i, box in enumerate(boxes):
                            # to avoid submitting errors
                            box = sort_poly(box.astype(np.int32))
                            if np.linalg.norm(box[0] -
                                              box[1]) < 5 or np.linalg.norm(
                                                  box[3] - box[0]) < 5:
                                continue
                            recognition_result = ground_truth_to_word(
                                recog_decode_list[i])

                            if contain_eng(recognition_result):
                                print(recognition_result)
                                fix_result = bktree_search(
                                    bk_tree, recognition_result.lower())
                                print(fix_result)
                                if len(fix_result) != 0:
                                    recognition_result = fix_result[0][1]
                                # print(recognition_result)
                            else:
                                recognition_result = recognition_result

                            f.write('{},{},{},{},{},{},{},{},{}\r\n'.format(
                                box[0, 0], box[0, 1], box[1, 0], box[1, 1],
                                box[2, 0], box[2, 1], box[3, 0], box[3, 1],
                                recognition_result))

                            # Draw bounding box
                            cv2.polylines(
                                im, [box.astype(np.int32).reshape((-1, 1, 2))],
                                True,
                                color=(255, 255, 0),
                                thickness=1)
                            # Draw recognition results area
                            text_area = box.copy()
                            text_area[2, 1] = text_area[1, 1]
                            text_area[3, 1] = text_area[0, 1]
                            text_area[0, 1] = text_area[0, 1] - 15
                            text_area[1, 1] = text_area[1, 1] - 15
                            cv2.fillPoly(im, [
                                text_area.astype(np.int32).reshape((-1, 1, 2))
                            ],
                                         color=(255, 255, 0))
                            im_txt = cv2.putText(im, recognition_result,
                                                 (box[0, 0], box[0, 1]), font,
                                                 0.5, (0, 0, 255), 1)
                            # 中文文字添加:
                            # im_txt = cv2ImgAddText(im, recognition_result, box[0, 0], box[0, 1], (0, 0, 149), 20)
                else:
                    #res_file = os.path.join(FLAGS.output_dir, 'res_' + '{}.txt'.format(os.path.basename(im_fn).split('.')[0]))
                    res_file = os.path.join(
                        FLAGS.output_dir,
                        '{}.txt'.format(os.path.basename(im_fn)))
                    f = open(res_file, "w")
                    im_txt = None
                    f.close()

                print(
                    '{} : detect {:.0f}ms, restore {:.0f}ms, nms {:.0f}ms, recog {:.0f}ms'
                    .format(im_fn, timer['detect'] * 1000,
                            timer['restore'] * 1000, timer['nms'] * 1000,
                            timer['recog'] * 1000))

                duration = time.time() - start_time
                print('[timing] {}'.format(duration))

                if not FLAGS.no_write_images:
                    img_path = os.path.join(FLAGS.output_dir,
                                            os.path.basename(im_fn))
                    #cv2.imwrite(img_path, im[:, :, ::-1])
                    if im_txt is not None:
                        cv2.imwrite(img_path, im_txt)
Esempio n. 3
0
def main(photo):
	import os  
	with tf.get_default_graph().as_default():
		input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images')
		input_feature_map = tf.placeholder(tf.float32, shape=[None, None, None, 32], name='input_feature_map')
		input_transform_matrix = tf.placeholder(tf.float32, shape=[None, 6], name='input_transform_matrix')
		input_box_mask = []
		input_box_mask.append(tf.placeholder(tf.int32, shape=[None], name='input_box_masks_0'))
		input_box_widths = tf.placeholder(tf.int32, shape=[None], name='input_box_widths')

		input_seq_len = input_box_widths[tf.argmax(input_box_widths, 0)] * tf.ones_like(input_box_widths)
		global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)

		shared_feature, f_score, f_geometry = detect_part.model(input_images)
		pad_rois = roi_rotate_part.roi_rotate_tensor_pad(input_feature_map, input_transform_matrix, input_box_mask, input_box_widths)
		recognition_logits = recognize_part.build_graph(pad_rois, input_box_widths)
		_, dense_decode = recognize_part.decode(recognition_logits, input_box_widths)

		variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
		saver = tf.train.Saver(variable_averages.variables_to_restore())

		with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
			ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
			model_path = os.path.join(FLAGS.checkpoint_path, os.path.basename(ckpt_state.model_checkpoint_path))
			print('Restore from {}'.format(model_path))
			saver.restore(sess, model_path)
			
			img = photo
			# Decode Python (Flask or Werkzeug) photo file uploaded via HTTP POST request in-memory to an OpenCV matrix
			in_memory_file = io.BytesIO()
			photo.save(in_memory_file)
			data = np.fromstring(in_memory_file.getvalue(), dtype=np.uint8)
			# im = cv2.imread(photo)[:, :, ::-1]
			color_image_flag = 1
			img = cv2.imdecode(data, color_image_flag)
			img = img[:, :, ::-1]
			#print('image',img) 
		   
			start_time = time.time()
			im_resized, (ratio_h, ratio_w) = resize_image(img)
			timer = {'detect': 0, 'restore': 0, 'nms': 0, 'recog': 0}
			start = time.time()
			shared_feature_map, score, geometry = sess.run([shared_feature, f_score, f_geometry], feed_dict={input_images: [im_resized]})
			
			boxes, timer = detect(score_map=score, geo_map=geometry, timer=timer)
			timer['detect'] = time.time() - start
			start = time.time() # reset for recognition

			text_result =[]
			if boxes is not None and boxes.shape[0] != 0:
				
				res_file_path = os.path.join(FLAGS.output_dir, 'res_' + '{}.txt'.format(photo.filename))
				# format(os.path.basename(image).split('.')[0]))

				input_roi_boxes = boxes[:, :8].reshape(-1, 8)
				recog_decode_list = []
				# Here avoid too many text area leading to OOM
				for batch_index in range(input_roi_boxes.shape[0] // 32 + 1): # test roi batch size is 32
					start_slice_index = batch_index * 32
					end_slice_index = (batch_index + 1) * 32 if input_roi_boxes.shape[0] >= (batch_index + 1) * 32 else input_roi_boxes.shape[0]
					tmp_roi_boxes = input_roi_boxes[start_slice_index:end_slice_index]

					boxes_masks = [0] * tmp_roi_boxes.shape[0]
					transform_matrixes, box_widths = get_project_matrix_and_width(tmp_roi_boxes)
					# max_box_widths = max_width * np.ones(boxes_masks.shape[0]) # seq_len
					
					# Run end to end
					recog_decode = sess.run(dense_decode, feed_dict={input_feature_map: shared_feature_map, input_transform_matrix: transform_matrixes, input_box_mask[0]: boxes_masks, input_box_widths: box_widths})
					recog_decode_list.extend([r for r in recog_decode])

				timer['recog'] = time.time() - start
				# Preparing for draw boxes
				boxes = boxes[:, :8].reshape((-1, 4, 2))
				boxes[:, :, 0] /= ratio_w
				boxes[:, :, 1] /= ratio_h

				if len(recog_decode_list) != boxes.shape[0]:
					print("detection and recognition result are not equal!")
					exit(-1)

				with open(res_file_path, 'w') as f:
					for i, box in enumerate(boxes):
						# to avoid submitting errors
						box = sort_poly(box.astype(np.int32))
						if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3]-box[0]) < 5:
							continue
						recognition_result = ground_truth_to_word(recog_decode_list[i])
						text_result.append(recognition_result)
						
						f.write('{},{},{},{},{},{},{},{},{}\r\n'.format(
							box[0, 0], box[0, 1], box[1, 0], box[1, 1], box[2, 0], box[2, 1], box[3, 0], box[3, 1], recognition_result
						))
							
						# Draw bounding box
						cv2.polylines(img[:, :, ::-1], [box.astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 255, 0), thickness=1)
						# Draw recognition results area
						text_area = box.copy()
						text_area[2, 1] = text_area[1, 1]
						text_area[3, 1] = text_area[0, 1]
						text_area[0, 1] = text_area[0, 1] - 15
						text_area[1, 1] = text_area[1, 1] - 15
						
						cv2.fillPoly(img[:, :, ::-1], [text_area.astype(np.int32).reshape((-1, 1, 2))], color=(255, 255, 0))
						im_txt = cv2.putText(img[:, :, ::-1], recognition_result, (box[0, 0], box[0, 1]), font, 0.5, (0, 0, 255), 1)
			else:
				res_file = os.path.join(FLAGS.output_dir, 'res_' + '{}.txt'.format(os.path.basename(photo.filename).split('.')[0]))
				f = open(res_file, "w")
				im_txt = None
				f.close()

			print('{} : detect {:.0f}ms, restore {:.0f}ms, nms {:.0f}ms, recog {:.0f}ms'.format(
				photo.filename, timer['detect']*1000, timer['restore']*1000, timer['nms']*1000, timer['recog']*1000))

			duration = time.time() - start_time
			print('[timing] {}'.format(duration))

			if not FLAGS.no_write_images:
				img_path = os.path.join(FLAGS.output_dir, os.path.basename(photo.filename))
				cv2.imwrite(img_path, img)
				if im_txt is not None:
					cv2.imwrite(img_path, im_txt)

	return text_result            
Esempio n. 4
0
def predictmtwi(image_path, return_dic):
    checkpoint_path = '/home/wsw/workplace/deeplearning/DRWord/algorithm/DRWord_ch/checkpointmtwi/'

    #是否使用GPU:-1不是用;0使用GPU0
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'

    with tf.Graph().as_default():  #tf.get_default_graph().as_default():
        input_images = tf.placeholder(tf.float32,
                                      shape=[None, None, None, 3],
                                      name='input_images')
        input_feature_map = tf.placeholder(tf.float32,
                                           shape=[None, None, None, 32],
                                           name='input_feature_map')
        input_transform_matrix = tf.placeholder(tf.float32,
                                                shape=[None, 6],
                                                name='input_transform_matrix')
        input_box_mask = []
        input_box_mask.append(
            tf.placeholder(tf.int32, shape=[None], name='input_box_masks_0'))
        input_box_widths = tf.placeholder(tf.int32,
                                          shape=[None],
                                          name='input_box_widths')

        input_seq_len = input_box_widths[tf.argmax(
            input_box_widths, 0)] * tf.ones_like(input_box_widths)
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        shared_feature, f_geometry = detect_part.model(input_images)
        pad_rois = roi_rotate_part.roi_rotate_tensor_pad(
            input_feature_map, input_transform_matrix, input_box_mask,
            input_box_widths)
        recognition_logits = recognize_part.build_graph(
            pad_rois, input_box_widths)
        _, dense_decode = recognize_part.decode(recognition_logits,
                                                input_box_widths)

        variable_averages = tf.train.ExponentialMovingAverage(
            0.997, global_step)
        saver = tf.train.Saver(variable_averages.variables_to_restore())

        im_fn_list = []
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            ckpt_state = tf.train.get_checkpoint_state(checkpoint_path)
            model_path = os.path.join(
                checkpoint_path,
                os.path.basename(ckpt_state.model_checkpoint_path))
            print('Restore from {}'.format(model_path))
            saver.restore(sess, model_path)

            im_fn_list.append(image_path)
            index = len(im_fn_list)
            for im_fn in im_fn_list:

                #image for draw quad
                im = cv2.imread(im_fn)[:, :, ::-1]
                im_resized, (ratio_h, ratio_w) = resize_image(im)

                img = image.img_to_array(im_resized)
                img = preprocess_input(img, mode='tf')

                shared_feature_map, geometry = sess.run(
                    [shared_feature, f_geometry],
                    feed_dict={input_images: [img]})

                geometry = np.squeeze(geometry, axis=0)
                geometry[:, :, :3] = sigmoid(geometry[:, :, :3])

                pixel_threshold = 0.9
                cond = np.greater_equal(geometry[:, :, 0], pixel_threshold)
                activation_pixels = np.where(cond)

                #boxes, timer = detect(score_map=score, geo_map=geometry, timer=timer)
                quad_scores, boxes = nms(geometry, activation_pixels)
                #print('boxes : ',boxes)

                input_roi_boxes = []
                for score, box in zip(quad_scores, boxes):
                    if np.amin(score) > 0:
                        #print(type(box))
                        box = box[[0, 3, 2, 1]]
                        input_roi_boxes.append(box)
                input_roi_boxes = np.array(input_roi_boxes)
                #print('input_roi_boxes : ',input_roi_boxes[:])

                txt_result = []
                im_result = im

                #im_txt = None
                if input_roi_boxes is not None and input_roi_boxes.shape[
                        0] != 0:

                    input_roi_boxes = input_roi_boxes[:, :8].reshape(-1, 8)
                    recog_decode_list = []
                    # Here avoid too many text area leading to OOM
                    for batch_index in range(input_roi_boxes.shape[0] // 32 +
                                             1):  # test roi batch size is 32
                        start_slice_index = batch_index * 32
                        end_slice_index = (
                            batch_index + 1
                        ) * 32 if input_roi_boxes.shape[0] >= (
                            batch_index + 1) * 32 else input_roi_boxes.shape[0]
                        tmp_roi_boxes = input_roi_boxes[
                            start_slice_index:end_slice_index]

                        boxes_masks = [0] * tmp_roi_boxes.shape[0]

                        transform_matrixes, box_widths = get_project_matrix_and_width(
                            tmp_roi_boxes)
                        # max_box_widths = max_width * np.ones(boxes_masks.shape[0]) # seq_len

                        # Run end to end
                        try:
                            recog_decode = sess.run(dense_decode,
                                                    feed_dict={
                                                        input_feature_map:
                                                        shared_feature_map,
                                                        input_transform_matrix:
                                                        transform_matrixes,
                                                        input_box_mask[0]:
                                                        boxes_masks,
                                                        input_box_widths:
                                                        box_widths
                                                    })
                            recog_decode_list.extend([r for r in recog_decode])
                        except:
                            recog_decode_list.extend([None])

                    if len(recog_decode_list) != input_roi_boxes.shape[0]:
                        return txt_result, im_result

                    input_roi_boxes = input_roi_boxes[:, :8].reshape(
                        (-1, 4, 2))

                    for i, box in enumerate(input_roi_boxes):
                        #if np.amin(score) > 0:
                        if True:

                            box = box[[0, 3, 2, 1]]
                            #box = box.astype(np.int32)

                            box = box / [ratio_w, ratio_h]
                            box = box.astype(np.int32)

                            recognition_result = ground_truth_to_word(
                                recog_decode_list[i])
                            txt_result.append(recognition_result)
                            # Draw bounding box
                            im_result = cv2.polylines(
                                im[:, :, ::-1],
                                [box.astype(np.int32).reshape((-1, 1, 2))],
                                True,
                                color=(0, 0, 255),
                                thickness=2)

                else:
                    txt_result = []
                    im_result = im

    #tf.reset_default_graph()
            return_dic['im_result_mtwi'] = im_result
            return_dic['txt_result_mtwi'] = txt_result
def main(argv=None):
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    try:
        os.makedirs(FLAGS.output_dir)
    except OSError as e:
        if e.errno != 17:
            raise

    bk_tree = BKTree(levenshtein, list_words(FLAGS.vocab))
    # bk_tree = bktree.Tree()

    with tf.get_default_graph().as_default():
        input_images = tf.placeholder(tf.float32,
                                      shape=[None, None, None, 3],
                                      name='input_images')
        input_feature_map = tf.placeholder(tf.float32,
                                           shape=[None, None, None, 32],
                                           name='input_feature_map')
        input_transform_matrix = tf.placeholder(tf.float32,
                                                shape=[None, 6],
                                                name='input_transform_matrix')
        input_box_mask = []
        input_box_mask.append(
            tf.placeholder(tf.int32, shape=[None], name='input_box_masks_0'))
        input_box_widths = tf.placeholder(tf.int32,
                                          shape=[None],
                                          name='input_box_widths')

        input_seq_len = input_box_widths[tf.argmax(
            input_box_widths, 0)] * tf.ones_like(input_box_widths)
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        shared_feature, f_score, f_geometry = detect_part.model(input_images)
        pad_rois = roi_rotate_part.roi_rotate_tensor_pad(
            input_feature_map, input_transform_matrix, input_box_mask,
            input_box_widths)
        recognition_logits = recognize_part.build_graph(
            pad_rois, input_box_widths)
        _, dense_decode = recognize_part.decode(recognition_logits,
                                                input_box_widths)

        variable_averages = tf.train.ExponentialMovingAverage(
            0.997, global_step)
        saver = tf.train.Saver(variable_averages.variables_to_restore())

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
            model_path = os.path.join(
                FLAGS.checkpoint_path,
                os.path.basename(ckpt_state.model_checkpoint_path))
            print('Restore from {}'.format(model_path))
            saver.restore(sess, model_path)

            # im_fn_list = get_images()
            if FLAGS.just_infer:
                im_fn_list, _, _ = get_image_self(
                    "/data/ceph_11015/ssd/anhan/nba/video2image")
            else:
                im_fn_list, corridate_list, label_list = get_image_self(
                    "/data/ceph_11015/ssd/anhan/nba/video2image")
            wrong = 0
            total = 0
            for ind, im_fn in enumerate(im_fn_list):
                #print("im_fn:",im_fn)
                im = cv2.imread(im_fn)[:, :, ::-1]
                im = cv2.resize(im, (960, 540))

                start_time = time.time()
                im_resized, (ratio_h, ratio_w) = resize_image(im)
                # im_resized_d, (ratio_h_d, ratio_w_d) = resize_image_detection(im)

                timer = {'detect': 0, 'restore': 0, 'nms': 0, 'recog': 0}
                start = time.time()
                shared_feature_map, score, geometry = sess.run(
                    [shared_feature, f_score, f_geometry],
                    feed_dict={input_images: [im_resized]})

                boxes, timer = detect(score_map=score,
                                      geo_map=geometry,
                                      timer=timer)
                timer['detect'] = time.time() - start
                start = time.time()  # reset for recognition
                res = None
                str_list = []
                if boxes is not None and boxes.shape[0] != 0:
                    #res_file_path = os.path.join(FLAGS.output_dir,'res_' + '{}.txt'.format(os.path.basename(im_fn).split('.')[0]))
                    # res_file_path = os.path.join(FLAGS.output_dir, '{}.txt'.format(os.path.basename(im_fn)))

                    input_roi_boxes = boxes[:, :8].reshape(-1, 8)
                    recog_decode_list = []
                    # Here avoid too many text area leading to OOM
                    for batch_index in range(input_roi_boxes.shape[0] // 32 +
                                             1):  # test roi batch size is 32
                        start_slice_index = batch_index * 32
                        end_slice_index = (
                            batch_index + 1
                        ) * 32 if input_roi_boxes.shape[0] >= (
                            batch_index + 1) * 32 else input_roi_boxes.shape[0]
                        tmp_roi_boxes = input_roi_boxes[
                            start_slice_index:end_slice_index]

                        boxes_masks = [0] * tmp_roi_boxes.shape[0]
                        transform_matrixes, box_widths = get_project_matrix_and_width(
                            tmp_roi_boxes)
                        #max_box_widths = max_width * np.ones(boxes_masks.shape[0]) # seq_len

                        # Run end to end
                        recog_decode = sess.run(dense_decode,
                                                feed_dict={
                                                    input_feature_map:
                                                    shared_feature_map,
                                                    input_transform_matrix:
                                                    transform_matrixes,
                                                    input_box_mask[0]:
                                                    boxes_masks,
                                                    input_box_widths:
                                                    box_widths
                                                })
                        recog_decode_list.extend([r for r in recog_decode])

                    timer['recog'] = time.time() - start
                    # Preparing for draw boxes
                    boxes = boxes[:, :8].reshape((-1, 4, 2))
                    boxes[:, :, 0] /= ratio_w
                    boxes[:, :, 1] /= ratio_h

                    if len(recog_decode_list) != boxes.shape[0]:
                        print(
                            "detection and recognition result are not equal!")
                        exit(-1)

                    scores = {}
                    score_index = 0
                    time_left = {}
                    time_index = 0
                    team_name = {}
                    quarter_dict = {}
                    remainder_attack_time = {}
                    remainder_attack_time_index = 0
                    recognition_result_num = 0
                    points = {}
                    for i, box in enumerate(boxes):
                        # to avoid submitting errors
                        box = sort_poly(box.astype(np.int32))
                        if np.linalg.norm(box[0] -
                                          box[1]) < 5 or np.linalg.norm(
                                              box[3] - box[0]) < 5:
                            continue
                        recognition_result = ground_truth_to_word(
                            recog_decode_list[i])

                        if contain_eng(recognition_result):
                            #print(recognition_result)
                            fix_result = bktree_search(
                                bk_tree, recognition_result.lower())
                            #print(fix_result)
                            if len(fix_result) != 0:
                                recognition_result = fix_result[0][1]
                                #print(recognition_result)
                        else:
                            recognition_result = recognition_result

                        if recognition_result in all_team:
                            team_name[recognition_result] = [
                                (int(box[0, 0]) + int(box[2, 0])) / 2,
                                (int(box[0, 1]) + int(box[2, 1])) / 2
                            ]
                            points[recognition_result] = [box[0, 0], box[2, 0]]

                        if recognition_result in quarter:
                            quarter_dict[recognition_result] = [
                                (int(box[0, 0]) + int(box[2, 0])) / 2,
                                (int(box[0, 1]) + int(box[2, 1])) / 2
                            ]
                            points[recognition_result] = [box[0, 0], box[2, 0]]

                        if recognition_result.isdigit():
                            scores[recognition_result + "_" +
                                   str(score_index)] = [
                                       (int(box[0, 0]) + int(box[2, 0])) / 2,
                                       (int(box[0, 1]) + int(box[2, 1])) / 2
                                   ]
                            points[recognition_result + "_" +
                                   str(score_index)] = [box[0, 0], box[2, 0]]
                            score_index += 1

                        if ":" in recognition_result:
                            time_left[recognition_result + "_" +
                                      str(time_index)] = [
                                          (int(box[0, 0]) + int(box[2, 0])) /
                                          2,
                                          (int(box[0, 1]) + int(box[2, 1])) / 2
                                      ]
                            points[recognition_result + "_" +
                                   str(time_index)] = [box[0, 0], box[2, 0]]
                            time_index += 1

                        if "." in recognition_result and ":" not in recognition_result:
                            remainder_attack_time[
                                recognition_result + "_" +
                                str(remainder_attack_time_index)] = [
                                    (int(box[0, 0]) + int(box[2, 0])) / 2,
                                    (int(box[0, 1]) + int(box[2, 1])) / 2
                                ]
                            points[recognition_result + "_" +
                                   str(remainder_attack_time_index)] = [
                                       box[0, 0], box[2, 0]
                                   ]
                            remainder_attack_time_index += 1

                        recognition_result_num += 1
                        str_list.append(recognition_result)
                        # Draw bounding box
                        # cv2.polylines(im, [box.astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 255, 0), thickness=1)
                        # Draw recognition results area
                        # text_area = box.copy()
                        # text_area[2, 1] = text_area[1, 1]
                        # text_area[3, 1] = text_area[0, 1]
                        # text_area[0, 1] = text_area[0, 1] - 15
                        # text_area[1, 1] = text_area[1, 1] - 15
                        # cv2.fillPoly(im, [text_area.astype(np.int32).reshape((-1, 1, 2))], color=(255, 255, 0))
                        # im_txt = cv2.putText(im, recognition_result, (box[0, 0], box[0, 1]), font, 0.5, (0, 0, 255), 1)
                        # 中文文字添加:
                        # im_txt = cv2ImgAddText(im, recognition_result, box[0, 0], box[0, 1], (0, 0, 149), 20)

                    if recognition_result_num == 7 or recognition_result_num == 6 or recognition_result_num == 5 or recognition_result_num == 8:
                        res = get_content(remainder_attack_time, time_left,
                                          team_name, scores, quarter_dict)
                    elif recognition_result_num == 9:
                        sort_points = sorted(points.items(),
                                             key=lambda item: item[1][0])
                        x_coordiate = []
                        for pair in sort_points:
                            x_coordiate.append(pair[1][0])
                            x_coordiate.append(pair[1][1])
                        x_sort = sorted(x_coordiate)
                        if x_sort == x_coordiate:
                            drop1 = sort_points[1][0]
                            drop2 = sort_points[4][0]
                            if drop1 in remainder_attack_time:
                                remainder_attack_time = remove_key(
                                    remainder_attack_time, drop1)
                            if drop2 in remainder_attack_time:
                                remainder_attack_time = remove_key(
                                    remainder_attack_time, drop2)
                            if drop1 in time_left:
                                time_left = remove_key(time_left, drop1)
                            if drop2 in time_left:
                                time_left = remove_key(time_left, drop2)
                            if drop1 in scores:
                                scores = remove_key(scores, drop1)
                            if drop2 in scores:
                                scores = remove_key(scores, drop2)
                        res = get_content(remainder_attack_time, time_left,
                                          team_name, scores, quarter_dict)
                if not FLAGS.just_infer:
                    corridate_true = corridate_list[ind].split("_")[4:]
                    label_true = label_list[ind].split("_")
                    res_true = get_score_info_v2(corridate_true, label_true)
                    if res != res_true:
                        #print(im_fn.split("/")[-1],'wrong!!!')
                        wrong += 1
                        #print(im_fn.split("/")[-1],label_list[ind],res_true,res,("_").join(str_list))
                    total += 1
                    print(
                        im_fn.split("/")[-1], label_list[ind], res_true, res,
                        ("_").join(str_list))
                else:
                    print(im_fn.split("/")[-1], res, ("_").join(str_list))
                duration = time.time() - start_time
                #print('{} : detect {:.0f}ms, restore {:.0f}ms, nms {:.0f}ms, recog {:.0f}ms'.format(im_fn, timer['detect']*1000, timer['restore']*1000, timer['nms']*1000, timer['recog']*1000))
            print("wrong:{}".format(wrong))
            print("total:{}".format(total))
            print("precision:{}".format((total - wrong) / total))