示例#1
0
def test(img_num):
    with tf.Graph().as_default():

        img_name_batch, img_batch, gtboxes_and_label_batch, num_objects_batch = \
            next_batch(dataset_name=cfgs.DATASET_NAME,
                       batch_size=cfgs.BATCH_SIZE,
                       shortside_len=cfgs.SHORT_SIDE_LEN,
                       is_training=False)

        gtboxes_and_label, head = get_head(
            tf.squeeze(gtboxes_and_label_batch, 0))
        gtboxes_and_label = tf.py_func(back_forward_convert,
                                       inp=[gtboxes_and_label],
                                       Tout=tf.float32)
        gtboxes_and_label = tf.reshape(gtboxes_and_label, [-1, 6])
        head_quadrant = tf.py_func(get_head_quadrant,
                                   inp=[head, gtboxes_and_label],
                                   Tout=tf.float32)
        head_quadrant = tf.reshape(head_quadrant, [-1, 1])

        gtboxes_and_label_minAreaRectangle = get_horizen_minAreaRectangle(
            gtboxes_and_label)

        gtboxes_and_label_minAreaRectangle = tf.reshape(
            gtboxes_and_label_minAreaRectangle, [-1, 5])

        # ***********************************************************************************************
        # *                                         share net                                           *
        # ***********************************************************************************************
        _, share_net = get_network_byname(net_name=cfgs.NET_NAME,
                                          inputs=img_batch,
                                          num_classes=None,
                                          is_training=True,
                                          output_stride=None,
                                          global_pool=False,
                                          spatial_squeeze=False)

        # ***********************************************************************************************
        # *                                            RPN                                              *
        # ***********************************************************************************************
        rpn = build_rpn.RPN(
            net_name=cfgs.NET_NAME,
            inputs=img_batch,
            gtboxes_and_label=None,
            is_training=False,
            share_head=cfgs.SHARE_HEAD,
            share_net=share_net,
            stride=cfgs.STRIDE,
            anchor_ratios=cfgs.ANCHOR_RATIOS,
            anchor_scales=cfgs.ANCHOR_SCALES,
            scale_factors=cfgs.SCALE_FACTORS,
            base_anchor_size_list=cfgs.
            BASE_ANCHOR_SIZE_LIST,  # P2, P3, P4, P5, P6
            level=cfgs.LEVEL,
            top_k_nms=cfgs.RPN_TOP_K_NMS,
            rpn_nms_iou_threshold=cfgs.RPN_NMS_IOU_THRESHOLD,
            max_proposals_num=cfgs.MAX_PROPOSAL_NUM,
            rpn_iou_positive_threshold=cfgs.RPN_IOU_POSITIVE_THRESHOLD,
            rpn_iou_negative_threshold=cfgs.RPN_IOU_NEGATIVE_THRESHOLD,
            rpn_mini_batch_size=cfgs.RPN_MINIBATCH_SIZE,
            rpn_positives_ratio=cfgs.RPN_POSITIVE_RATE,
            remove_outside_anchors=False,  # whether remove anchors outside
            rpn_weight_decay=cfgs.WEIGHT_DECAY[cfgs.NET_NAME])

        # rpn predict proposals
        rpn_proposals_boxes, rpn_proposals_scores = rpn.rpn_proposals(
        )  # rpn_score shape: [300, ]

        # ***********************************************************************************************
        # *                                         Fast RCNN                                           *
        # ***********************************************************************************************
        fast_rcnn = build_fast_rcnn.FastRCNN(
            feature_pyramid=rpn.feature_pyramid,
            rpn_proposals_boxes=rpn_proposals_boxes,
            rpn_proposals_scores=rpn_proposals_scores,
            img_shape=tf.shape(img_batch),
            img_batch=img_batch,
            roi_size=cfgs.ROI_SIZE,
            roi_pool_kernel_size=cfgs.ROI_POOL_KERNEL_SIZE,
            scale_factors=cfgs.SCALE_FACTORS,
            gtboxes_and_label=None,
            gtboxes_and_label_minAreaRectangle=
            gtboxes_and_label_minAreaRectangle,
            fast_rcnn_nms_iou_threshold=cfgs.FAST_RCNN_NMS_IOU_THRESHOLD,
            fast_rcnn_maximum_boxes_per_img=100,
            fast_rcnn_nms_max_boxes_per_class=cfgs.
            FAST_RCNN_NMS_MAX_BOXES_PER_CLASS,
            show_detections_score_threshold=cfgs.FINAL_SCORE_THRESHOLD,
            # show detections which score >= 0.6
            num_classes=cfgs.CLASS_NUM,
            fast_rcnn_minibatch_size=cfgs.FAST_RCNN_MINIBATCH_SIZE,
            fast_rcnn_positives_ratio=cfgs.FAST_RCNN_POSITIVE_RATE,
            fast_rcnn_positives_iou_threshold=cfgs.
            FAST_RCNN_IOU_POSITIVE_THRESHOLD,
            # iou>0.5 is positive, iou<0.5 is negative
            use_dropout=cfgs.USE_DROPOUT,
            weight_decay=cfgs.WEIGHT_DECAY[cfgs.NET_NAME],
            is_training=False,
            level=cfgs.LEVEL,
            head_quadrant=head_quadrant)

        fast_rcnn_decode_boxes, fast_rcnn_score, num_of_objects, detection_category, \
        fast_rcnn_decode_boxes_rotate, fast_rcnn_score_rotate, fast_rcnn_head_quadrant, \
        num_of_objects_rotate, detection_category_rotate = fast_rcnn.fast_rcnn_predict()

        # train
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        restorer, restore_ckpt = restore_model.get_restorer()

        config = tf.ConfigProto()
        # config.gpu_options.per_process_gpu_memory_fraction = 0.5
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            sess.run(init_op)
            if not restorer is None:
                restorer.restore(sess, restore_ckpt)
                print('restore model')

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess, coord)

            for i in range(img_num):

                start = time.time()

                _img_name_batch, _img_batch, _gtboxes_and_label, _gtboxes_and_label_minAreaRectangle, _head_quadrant,  \
                _fast_rcnn_decode_boxes, _fast_rcnn_score, _detection_category, _fast_rcnn_decode_boxes_rotate, \
                _fast_rcnn_score_rotate, _fast_rcnn_head_quadrant, _detection_category_rotate \
                    = sess.run([img_name_batch, img_batch, gtboxes_and_label, gtboxes_and_label_minAreaRectangle, head_quadrant,
                                fast_rcnn_decode_boxes, fast_rcnn_score, detection_category, fast_rcnn_decode_boxes_rotate,
                                fast_rcnn_score_rotate, fast_rcnn_head_quadrant, detection_category_rotate])
                end = time.time()

                _img_batch = np.squeeze(_img_batch, axis=0)

                _img_batch_fpn_horizonal = help_utils.draw_box_cv(
                    _img_batch,
                    boxes=_fast_rcnn_decode_boxes,
                    labels=_detection_category,
                    scores=_fast_rcnn_score)

                _img_batch_fpn_rotate = help_utils.draw_rotate_box_cv(
                    _img_batch,
                    boxes=_fast_rcnn_decode_boxes_rotate,
                    labels=_detection_category_rotate,
                    scores=_fast_rcnn_score_rotate,
                    head=np.argmax(_fast_rcnn_head_quadrant, axis=1))
                mkdir(cfgs.TEST_SAVE_PATH)
                cv2.imwrite(
                    cfgs.TEST_SAVE_PATH +
                    '/{}_horizontal_fpn.jpg'.format(str(_img_name_batch[0])),
                    _img_batch_fpn_horizonal)
                cv2.imwrite(
                    cfgs.TEST_SAVE_PATH +
                    '/{}_rotate_fpn.jpg'.format(str(_img_name_batch[0])),
                    _img_batch_fpn_rotate)

                temp_label_horizontal = np.reshape(_gtboxes_and_label[:, -1:],
                                                   [
                                                       -1,
                                                   ]).astype(np.int64)
                temp_label_rotate = np.reshape(_gtboxes_and_label[:, -1:], [
                    -1,
                ]).astype(np.int64)

                _img_batch_gt_horizontal = help_utils.draw_box_cv(
                    _img_batch,
                    boxes=_gtboxes_and_label_minAreaRectangle[:, :-1],
                    labels=temp_label_horizontal,
                    scores=None)

                _img_batch_gt_rotate = help_utils.draw_rotate_box_cv(
                    _img_batch,
                    boxes=_gtboxes_and_label[:, :-1],
                    labels=temp_label_rotate,
                    scores=None,
                    head=np.reshape(_head_quadrant, [
                        -1,
                    ]))

                cv2.imwrite(
                    cfgs.TEST_SAVE_PATH +
                    '/{}_horizontal_gt.jpg'.format(str(_img_name_batch[0])),
                    _img_batch_gt_horizontal)
                cv2.imwrite(
                    cfgs.TEST_SAVE_PATH +
                    '/{}_rotate_gt.jpg'.format(str(_img_name_batch[0])),
                    _img_batch_gt_rotate)

                view_bar(
                    '{} image cost {}s'.format(str(_img_name_batch[0]),
                                               (end - start)), i + 1, img_num)

            coord.request_stop()
            coord.join(threads)
示例#2
0
def test_rotate():
    with tf.Graph().as_default():
        with tf.name_scope('get_batch'):
            img_name_batch, img_batch, gtboxes_and_label_batch, num_objects_batch = \
                next_batch(dataset_name=cfgs.DATASET_NAME,
                           batch_size=cfgs.BATCH_SIZE,
                           shortside_len=cfgs.SHORT_SIDE_LEN,
                           is_training=True)
            gtboxes_and_label, head = get_head(
                tf.squeeze(gtboxes_and_label_batch, 0))
            gtboxes_and_label = tf.py_func(back_forward_convert,
                                           inp=[gtboxes_and_label],
                                           Tout=tf.float32)
            gtboxes_and_label = tf.reshape(gtboxes_and_label, [-1, 6])
            head_quadrant = tf.py_func(get_head_quadrant,
                                       inp=[head, gtboxes_and_label],
                                       Tout=tf.float32)
            head_quadrant = tf.reshape(head_quadrant, [-1, 1])

            gtboxes_and_label_minAreaRectangle = get_horizen_minAreaRectangle(
                gtboxes_and_label)

            gtboxes_and_label_minAreaRectangle = tf.reshape(
                gtboxes_and_label_minAreaRectangle, [-1, 5])

        with tf.name_scope('draw_gtboxes'):
            gtboxes_in_img = draw_box_with_color(
                img_batch,
                tf.reshape(gtboxes_and_label_minAreaRectangle,
                           [-1, 5])[:, :-1],
                text=tf.shape(gtboxes_and_label_minAreaRectangle)[0])

            gtboxes_rotate_in_img = draw_box_with_color_rotate(
                img_batch,
                tf.reshape(gtboxes_and_label, [-1, 6])[:, :-1],
                text=tf.shape(gtboxes_and_label)[0],
                head=head_quadrant)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        config = tf.ConfigProto()
        # config.gpu_options.per_process_gpu_memory_fraction = 0.5
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            sess.run(init_op)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess, coord)

            for i in range(650):
                img_gtboxes, img_gtboxes_rotate, img_name = sess.run(
                    [gtboxes_in_img, gtboxes_rotate_in_img, img_name_batch])
                img_gtboxes = np.squeeze(img_gtboxes, axis=0)
                img_gtboxes_rotate = np.squeeze(img_gtboxes_rotate, axis=0)

                print(i)
                cv2.imwrite(
                    cfgs.INFERENCE_SAVE_PATH +
                    '/{}_horizontal_fpn.jpg'.format(str(img_name[0])),
                    img_gtboxes)
                cv2.imwrite(
                    cfgs.INFERENCE_SAVE_PATH +
                    '/{}_rotate_fpn.jpg'.format(str(img_name[0])),
                    img_gtboxes_rotate)

            coord.request_stop()
            coord.join(threads)
示例#3
0
def eval_ship(img_num):
    with tf.Graph().as_default():

        img_name_batch, img_batch, gtboxes_and_label_batch, num_objects_batch = \
            next_batch(dataset_name=cfgs.DATASET_NAME,
                       batch_size=cfgs.BATCH_SIZE,
                       shortside_len=cfgs.SHORT_SIDE_LEN,
                       is_training=True)

        gtboxes_and_label, head = get_head(
            tf.squeeze(gtboxes_and_label_batch, 0))
        gtboxes_and_label = tf.py_func(back_forward_convert,
                                       inp=[gtboxes_and_label],
                                       Tout=tf.float32)
        gtboxes_and_label = tf.reshape(gtboxes_and_label, [-1, 6])
        head_quadrant = tf.py_func(get_head_quadrant,
                                   inp=[head, gtboxes_and_label],
                                   Tout=tf.float32)
        head_quadrant = tf.reshape(head_quadrant, [-1, 1])

        gtboxes_and_label_minAreaRectangle = get_horizen_minAreaRectangle(
            gtboxes_and_label)

        gtboxes_and_label_minAreaRectangle = tf.reshape(
            gtboxes_and_label_minAreaRectangle, [-1, 5])

        # ***********************************************************************************************
        # *                                         share net                                           *
        # ***********************************************************************************************
        _, share_net = get_network_byname(net_name=cfgs.NET_NAME,
                                          inputs=img_batch,
                                          num_classes=None,
                                          is_training=True,
                                          output_stride=None,
                                          global_pool=False,
                                          spatial_squeeze=False)

        # ***********************************************************************************************
        # *                                            RPN                                              *
        # ***********************************************************************************************
        rpn = build_rpn.RPN(
            net_name=cfgs.NET_NAME,
            inputs=img_batch,
            gtboxes_and_label=None,
            is_training=False,
            share_head=cfgs.SHARE_HEAD,
            share_net=share_net,
            stride=cfgs.STRIDE,
            anchor_ratios=cfgs.ANCHOR_RATIOS,
            anchor_scales=cfgs.ANCHOR_SCALES,
            scale_factors=cfgs.SCALE_FACTORS,
            base_anchor_size_list=cfgs.
            BASE_ANCHOR_SIZE_LIST,  # P2, P3, P4, P5, P6
            level=cfgs.LEVEL,
            top_k_nms=cfgs.RPN_TOP_K_NMS,
            rpn_nms_iou_threshold=cfgs.RPN_NMS_IOU_THRESHOLD,
            max_proposals_num=cfgs.MAX_PROPOSAL_NUM,
            rpn_iou_positive_threshold=cfgs.RPN_IOU_POSITIVE_THRESHOLD,
            rpn_iou_negative_threshold=cfgs.RPN_IOU_NEGATIVE_THRESHOLD,
            rpn_mini_batch_size=cfgs.RPN_MINIBATCH_SIZE,
            rpn_positives_ratio=cfgs.RPN_POSITIVE_RATE,
            remove_outside_anchors=False,  # whether remove anchors outside
            rpn_weight_decay=cfgs.WEIGHT_DECAY[cfgs.NET_NAME])

        # rpn predict proposals
        rpn_proposals_boxes, rpn_proposals_scores = rpn.rpn_proposals(
        )  # rpn_score shape: [300, ]

        # ***********************************************************************************************
        # *                                         Fast RCNN                                           *
        # ***********************************************************************************************
        fast_rcnn = build_fast_rcnn.FastRCNN(
            feature_pyramid=rpn.feature_pyramid,
            rpn_proposals_boxes=rpn_proposals_boxes,
            rpn_proposals_scores=rpn_proposals_scores,
            img_shape=tf.shape(img_batch),
            img_batch=img_batch,
            roi_size=cfgs.ROI_SIZE,
            roi_pool_kernel_size=cfgs.ROI_POOL_KERNEL_SIZE,
            scale_factors=cfgs.SCALE_FACTORS,
            gtboxes_and_label=None,
            gtboxes_and_label_minAreaRectangle=
            gtboxes_and_label_minAreaRectangle,
            fast_rcnn_nms_iou_threshold=cfgs.FAST_RCNN_NMS_IOU_THRESHOLD,
            fast_rcnn_maximum_boxes_per_img=100,
            fast_rcnn_nms_max_boxes_per_class=cfgs.
            FAST_RCNN_NMS_MAX_BOXES_PER_CLASS,
            show_detections_score_threshold=cfgs.FINAL_SCORE_THRESHOLD,
            # show detections which score >= 0.6
            num_classes=cfgs.CLASS_NUM,
            fast_rcnn_minibatch_size=cfgs.FAST_RCNN_MINIBATCH_SIZE,
            fast_rcnn_positives_ratio=cfgs.FAST_RCNN_POSITIVE_RATE,
            fast_rcnn_positives_iou_threshold=cfgs.
            FAST_RCNN_IOU_POSITIVE_THRESHOLD,
            # iou>0.5 is positive, iou<0.5 is negative
            use_dropout=cfgs.USE_DROPOUT,
            weight_decay=cfgs.WEIGHT_DECAY[cfgs.NET_NAME],
            is_training=False,
            level=cfgs.LEVEL,
            head_quadrant=head_quadrant)

        fast_rcnn_decode_boxes, fast_rcnn_score, num_of_objects, detection_category, \
        fast_rcnn_decode_boxes_rotate, fast_rcnn_score_rotate, fast_rcnn_head_quadrant, \
        num_of_objects_rotate, detection_category_rotate = fast_rcnn.fast_rcnn_predict()

        # train
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        restorer, restore_ckpt = restore_model.get_restorer()

        config = tf.ConfigProto()
        # config.gpu_options.per_process_gpu_memory_fraction = 0.5
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            sess.run(init_op)
            if not restorer is None:
                restorer.restore(sess, restore_ckpt)
                print('restore model')

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess, coord)

            gtboxes_horizontal_dict = {}
            gtboxes_rotate_dict = {}

            all_boxes_h = []
            all_boxes_r = []
            all_img_names = []

            for i in range(img_num):

                start = time.time()

                _img_name_batch, _img_batch, _gtboxes_and_label, _gtboxes_and_label_minAreaRectangle, \
                _fast_rcnn_decode_boxes, _fast_rcnn_score, _detection_category, _fast_rcnn_decode_boxes_rotate, \
                _fast_rcnn_score_rotate, _detection_category_rotate \
                    = sess.run([img_name_batch, img_batch, gtboxes_and_label, gtboxes_and_label_minAreaRectangle,
                                fast_rcnn_decode_boxes, fast_rcnn_score, detection_category, fast_rcnn_decode_boxes_rotate,
                                fast_rcnn_score_rotate, detection_category_rotate])
                end = time.time()

                # gtboxes convert dict
                gtboxes_horizontal_dict[str(_img_name_batch[0])] = []
                gtboxes_rotate_dict[str(_img_name_batch[0])] = []

                gtbox_horizontal_list, gtbox_rotate_list = make_dict_packle(
                    _gtboxes_and_label, _gtboxes_and_label_minAreaRectangle)

                xmin, ymin, xmax, ymax = _fast_rcnn_decode_boxes[:,
                                                                 1], _fast_rcnn_decode_boxes[:,
                                                                                             0], _fast_rcnn_decode_boxes[:,
                                                                                                                         3], _fast_rcnn_decode_boxes[:,
                                                                                                                                                     2]
                x_c, y_c, w, h, theta = _fast_rcnn_decode_boxes_rotate[:, 1], _fast_rcnn_decode_boxes_rotate[:, 0], _fast_rcnn_decode_boxes_rotate[:, 3], \
                                        _fast_rcnn_decode_boxes_rotate[:, 2], _fast_rcnn_decode_boxes_rotate[:, 4]
                boxes_h = np.transpose(np.stack([xmin, ymin, xmax, ymax]))
                boxes_r = np.transpose(np.stack([x_c, y_c, w, h, theta]))
                dets_h = np.hstack((_detection_category.reshape(-1, 1),
                                    _fast_rcnn_score.reshape(-1, 1), boxes_h))
                dets_r = np.hstack(
                    (_detection_category_rotate.reshape(-1, 1),
                     _fast_rcnn_score_rotate.reshape(-1, 1), boxes_r))
                all_boxes_h.append(dets_h)
                all_boxes_r.append(dets_r)
                all_img_names.append(str(_img_name_batch[0]))

                gtboxes_horizontal_dict[str(
                    _img_name_batch[0])].extend(gtbox_horizontal_list)
                gtboxes_rotate_dict[str(
                    _img_name_batch[0])].extend(gtbox_rotate_list)

                print(str(_img_name_batch[0]))

                view_bar(
                    '{} image cost {}s'.format(str(_img_name_batch[0]),
                                               (end - start)), i + 1, img_num)

            write_voc_results_file(all_boxes_h, all_img_names,
                                   cfgs.EVALUATE_R_DIR, 0)
            write_voc_results_file(all_boxes_r, all_img_names,
                                   cfgs.EVALUATE_R_DIR, 1)

            fw1 = open('gtboxes_horizontal_dict.pkl', 'wb')
            fw2 = open('gtboxes_rotate_dict.pkl', 'wb')
            pickle.dump(gtboxes_horizontal_dict, fw1)
            pickle.dump(gtboxes_rotate_dict, fw2)
            fw1.close()
            fw2.close()
            coord.request_stop()
            coord.join(threads)
示例#4
0
def train():
    with tf.Graph().as_default():
        with tf.name_scope('get_batch'):
            img_name_batch, img_batch, gtboxes_and_label_batch, num_objects_batch = \
                next_batch(dataset_name=cfgs.DATASET_NAME,
                           batch_size=cfgs.BATCH_SIZE,
                           shortside_len=cfgs.SHORT_SIDE_LEN,
                           is_training=True)
            gtboxes_and_label, head = get_head(
                tf.squeeze(gtboxes_and_label_batch, 0))
            gtboxes_and_label = tf.py_func(back_forward_convert,
                                           inp=[gtboxes_and_label],
                                           Tout=tf.float32)
            gtboxes_and_label = tf.reshape(gtboxes_and_label, [-1, 6])
            head_quadrant = tf.py_func(get_head_quadrant,
                                       inp=[head, gtboxes_and_label],
                                       Tout=tf.float32)
            head_quadrant = tf.reshape(head_quadrant, [-1, 1])

            gtboxes_and_label_minAreaRectangle = get_horizen_minAreaRectangle(
                gtboxes_and_label)

            gtboxes_and_label_minAreaRectangle = tf.reshape(
                gtboxes_and_label_minAreaRectangle, [-1, 5])

        with tf.name_scope('draw_gtboxes'):
            gtboxes_in_img = draw_box_with_color(
                img_batch,
                tf.reshape(gtboxes_and_label_minAreaRectangle,
                           [-1, 5])[:, :-1],
                text=tf.shape(gtboxes_and_label_minAreaRectangle)[0])

            gtboxes_rotate_in_img = draw_box_with_color_rotate(
                img_batch,
                tf.reshape(gtboxes_and_label, [-1, 6])[:, :-1],
                text=tf.shape(gtboxes_and_label)[0],
                head=head_quadrant)

        # ***********************************************************************************************
        # *                                         share net                                           *
        # ***********************************************************************************************
        _, share_net = get_network_byname(net_name=cfgs.NET_NAME,
                                          inputs=img_batch,
                                          num_classes=None,
                                          is_training=True,
                                          output_stride=None,
                                          global_pool=False,
                                          spatial_squeeze=False)

        # ***********************************************************************************************
        # *                                            rpn                                              *
        # ***********************************************************************************************
        rpn = build_rpn.RPN(
            net_name=cfgs.NET_NAME,
            inputs=img_batch,
            gtboxes_and_label=gtboxes_and_label_minAreaRectangle,
            is_training=True,
            share_head=cfgs.SHARE_HEAD,
            share_net=share_net,
            stride=cfgs.STRIDE,
            anchor_ratios=cfgs.ANCHOR_RATIOS,
            anchor_scales=cfgs.ANCHOR_SCALES,
            scale_factors=cfgs.SCALE_FACTORS,
            base_anchor_size_list=cfgs.
            BASE_ANCHOR_SIZE_LIST,  # P2, P3, P4, P5, P6
            level=cfgs.LEVEL,
            top_k_nms=cfgs.RPN_TOP_K_NMS,
            rpn_nms_iou_threshold=cfgs.RPN_NMS_IOU_THRESHOLD,
            max_proposals_num=cfgs.MAX_PROPOSAL_NUM,
            rpn_iou_positive_threshold=cfgs.RPN_IOU_POSITIVE_THRESHOLD,
            rpn_iou_negative_threshold=cfgs.
            RPN_IOU_NEGATIVE_THRESHOLD,  # iou>=0.7 is positive box, iou< 0.3 is negative
            rpn_mini_batch_size=cfgs.RPN_MINIBATCH_SIZE,
            rpn_positives_ratio=cfgs.RPN_POSITIVE_RATE,
            remove_outside_anchors=False,  # whether remove anchors outside
            rpn_weight_decay=cfgs.WEIGHT_DECAY[cfgs.NET_NAME])

        rpn_proposals_boxes, rpn_proposals_scores = rpn.rpn_proposals(
        )  # rpn_score shape: [300, ]

        rpn_location_loss, rpn_classification_loss = rpn.rpn_losses()
        rpn_total_loss = rpn_classification_loss + rpn_location_loss

        with tf.name_scope('draw_proposals'):
            # score > 0.5 is object
            rpn_object_boxes_indices = tf.reshape(
                tf.where(tf.greater(rpn_proposals_scores, 0.5)), [-1])
            rpn_object_boxes = tf.gather(rpn_proposals_boxes,
                                         rpn_object_boxes_indices)

            rpn_proposals_objcet_boxes_in_img = draw_box_with_color(
                img_batch,
                rpn_object_boxes,
                text=tf.shape(rpn_object_boxes)[0])
            rpn_proposals_boxes_in_img = draw_box_with_color(
                img_batch,
                rpn_proposals_boxes,
                text=tf.shape(rpn_proposals_boxes)[0])
        # ***********************************************************************************************
        # *                                         Fast RCNN                                           *
        # ***********************************************************************************************

        fast_rcnn = build_fast_rcnn.FastRCNN(
            feature_pyramid=rpn.feature_pyramid,
            rpn_proposals_boxes=rpn_proposals_boxes,
            rpn_proposals_scores=rpn_proposals_scores,
            img_shape=tf.shape(img_batch),
            img_batch=img_batch,
            roi_size=cfgs.ROI_SIZE,
            roi_pool_kernel_size=cfgs.ROI_POOL_KERNEL_SIZE,
            scale_factors=cfgs.SCALE_FACTORS,
            gtboxes_and_label=gtboxes_and_label,
            gtboxes_and_label_minAreaRectangle=
            gtboxes_and_label_minAreaRectangle,
            fast_rcnn_nms_iou_threshold=cfgs.FAST_RCNN_NMS_IOU_THRESHOLD,
            fast_rcnn_maximum_boxes_per_img=100,
            fast_rcnn_nms_max_boxes_per_class=cfgs.
            FAST_RCNN_NMS_MAX_BOXES_PER_CLASS,
            show_detections_score_threshold=cfgs.
            FINAL_SCORE_THRESHOLD,  # show detections which score >= 0.6
            num_classes=cfgs.CLASS_NUM,
            fast_rcnn_minibatch_size=cfgs.FAST_RCNN_MINIBATCH_SIZE,
            fast_rcnn_positives_ratio=cfgs.FAST_RCNN_POSITIVE_RATE,
            fast_rcnn_positives_iou_threshold=cfgs.
            FAST_RCNN_IOU_POSITIVE_THRESHOLD,  # iou>0.5 is positive, iou<0.5 is negative
            use_dropout=cfgs.USE_DROPOUT,
            weight_decay=cfgs.WEIGHT_DECAY[cfgs.NET_NAME],
            is_training=True,
            level=cfgs.LEVEL,
            head_quadrant=head_quadrant)

        fast_rcnn_decode_boxes, fast_rcnn_score, num_of_objects, detection_category, \
        fast_rcnn_decode_boxes_rotate, fast_rcnn_score_rotate, fast_rcnn_head_quadrant, \
        num_of_objects_rotate, detection_category_rotate = fast_rcnn.fast_rcnn_predict()
        fast_rcnn_location_loss, fast_rcnn_classification_loss, \
        fast_rcnn_location_rotate_loss, fast_rcnn_classification_rotate_loss, \
        fast_rcnn_head_quadrant_loss = fast_rcnn.fast_rcnn_loss()

        fast_rcnn_total_loss = fast_rcnn_location_loss + fast_rcnn_classification_loss + \
                               fast_rcnn_location_rotate_loss + fast_rcnn_classification_rotate_loss

        with tf.name_scope('draw_boxes_with_categories'):
            fast_rcnn_predict_boxes_in_imgs = draw_boxes_with_categories(
                img_batch=img_batch,
                boxes=fast_rcnn_decode_boxes,
                labels=detection_category,
                scores=fast_rcnn_score)

            fast_rcnn_predict_rotate_boxes_in_imgs = draw_boxes_with_categories_rotate(
                img_batch=img_batch,
                boxes=fast_rcnn_decode_boxes_rotate,
                labels=detection_category_rotate,
                scores=fast_rcnn_score_rotate,
                head=fast_rcnn_head_quadrant)

        # train
        total_loss = slim.losses.get_total_loss()

        global_step = slim.get_or_create_global_step()

        lr = tf.train.piecewise_constant(
            global_step,
            boundaries=[np.int64(20000), np.int64(40000)],
            values=[0.001, 0.0001, 0.00001])
        tf.summary.scalar('lr', lr)
        optimizer = tf.train.MomentumOptimizer(lr, momentum=cfgs.MOMENTUM)

        train_op = slim.learning.create_train_op(total_loss, optimizer,
                                                 global_step)

        # ***********************************************************************************************
        # *                                          Summary                                            *
        # ***********************************************************************************************
        # ground truth and predict
        tf.summary.image('img/gtboxes', gtboxes_in_img)
        tf.summary.image('img/gtboxes_rotate', gtboxes_rotate_in_img)
        tf.summary.image('img/faster_rcnn_predict',
                         fast_rcnn_predict_boxes_in_imgs)
        tf.summary.image('img/faster_rcnn_predict_rotate',
                         fast_rcnn_predict_rotate_boxes_in_imgs)
        # rpn loss and image
        tf.summary.scalar('rpn/rpn_location_loss', rpn_location_loss)
        tf.summary.scalar('rpn/rpn_classification_loss',
                          rpn_classification_loss)
        tf.summary.scalar('rpn/rpn_total_loss', rpn_total_loss)

        tf.summary.scalar('fast_rcnn/fast_rcnn_location_loss',
                          fast_rcnn_location_loss)
        tf.summary.scalar('fast_rcnn/fast_rcnn_classification_loss',
                          fast_rcnn_classification_loss)
        tf.summary.scalar('fast_rcnn/fast_rcnn_location_rotate_loss',
                          fast_rcnn_location_rotate_loss)
        tf.summary.scalar('fast_rcnn/fast_rcnn_classification_rotate_loss',
                          fast_rcnn_classification_rotate_loss)
        tf.summary.scalar('fast_rcnn/fast_rcnn_head_quadrant_loss',
                          fast_rcnn_head_quadrant_loss)
        tf.summary.scalar('fast_rcnn/fast_rcnn_total_loss',
                          fast_rcnn_total_loss)

        tf.summary.scalar('loss/total_loss', total_loss)

        tf.summary.image('rpn/rpn_all_boxes', rpn_proposals_boxes_in_img)
        tf.summary.image('rpn/rpn_object_boxes',
                         rpn_proposals_objcet_boxes_in_img)
        # learning_rate
        tf.summary.scalar('learning_rate', lr)

        summary_op = tf.summary.merge_all()
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        restorer, restore_ckpt = restore_model.get_restorer()
        saver = tf.train.Saver(max_to_keep=10)

        config = tf.ConfigProto()
        # config.gpu_options.per_process_gpu_memory_fraction = 0.5
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            sess.run(init_op)
            if not restorer is None:
                restorer.restore(sess, restore_ckpt)
                print('restore model')
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess, coord)

            summary_path = os.path.join(FLAGS.summary_path, cfgs.VERSION)
            tools.mkdir(summary_path)
            summary_writer = tf.summary.FileWriter(summary_path,
                                                   graph=sess.graph)

            for step in range(cfgs.MAX_ITERATION):
                training_time = time.strftime('%Y-%m-%d %H:%M:%S',
                                              time.localtime(time.time()))
                start = time.time()

                _global_step, _img_name_batch, _rpn_location_loss, _rpn_classification_loss, \
                _rpn_total_loss, _fast_rcnn_location_loss, _fast_rcnn_classification_loss, \
                _fast_rcnn_location_rotate_loss, _fast_rcnn_classification_rotate_loss, \
                _fast_rcnn_total_loss, _total_loss, _ = \
                    sess.run([global_step, img_name_batch, rpn_location_loss, rpn_classification_loss,
                              rpn_total_loss, fast_rcnn_location_loss, fast_rcnn_classification_loss,
                              fast_rcnn_location_rotate_loss, fast_rcnn_classification_rotate_loss,
                              fast_rcnn_total_loss, total_loss, train_op])

                end = time.time()

                if step % 10 == 0:

                    print(""" {}: step{}    image_name:{} |\t
                                rpn_loc_loss:{} |\t rpn_cla_loss:{} |\t
                                rpn_total_loss:{} |
                                fast_rcnn_loc_loss:{} |\t fast_rcnn_cla_loss:{} |\t
                                fast_rcnn_loc_rotate_loss:{} |\t fast_rcnn_cla_rotate_loss:{} |\t
                                fast_rcnn_total_loss:{} |\t
                                total_loss:{} |\t pre_cost_time:{}s""" \
                          .format(training_time, _global_step, str(_img_name_batch[0]), _rpn_location_loss,
                                  _rpn_classification_loss, _rpn_total_loss, _fast_rcnn_location_loss,
                                  _fast_rcnn_classification_loss, _fast_rcnn_location_rotate_loss,
                                  _fast_rcnn_classification_rotate_loss,  _fast_rcnn_total_loss, _total_loss,
                                  (end - start)))

                if step % 50 == 0:
                    summary_str = sess.run(summary_op)
                    summary_writer.add_summary(summary_str, _global_step)
                    summary_writer.flush()

                if (step > 0
                        and step % 5000 == 0) or (step
                                                  == cfgs.MAX_ITERATION - 1):
                    save_dir = os.path.join(FLAGS.trained_checkpoint,
                                            cfgs.VERSION)
                    tools.mkdir(save_dir)

                    save_ckpt = os.path.join(
                        save_dir, 'voc_' + str(_global_step) + 'model.ckpt')
                    saver.save(sess, save_ckpt)
                    print(' weights had been saved')

            coord.request_stop()
            coord.join(threads)