Ejemplo n.º 1
0
    def save_rank_result(self):
        save_image_path = os.path.join(self.rankings_dir, 'result_images')
        check_path(save_image_path)
        
        save_rank_path = os.path.join(self.rankings_dir, 'rank_order')
        rank_files = os.listdir(save_rank_path)
        for i in range(len(rank_files)):
            ranks = open(os.path.join(save_rank_path, 'rank_'+str(i)+'.txt'),'r').readlines()
            
            #def mergeImages(name, files, box, size=(224,224), axis=0):
            files = []
            files.append(self.query_db['paths'][i])
            box = self.query_db['query_boxes'][i]
            name = os.path.join(save_image_path, 'result_'+str(i)+'.jpg')
            
            for i, order in enumerate(ranks):
                if i == self.top_k:
                    break
                image_path = order.strip().split(' ')[1]
                files.append(image_path)
            
            
            mergeImages(name, files, box, size=(128,128), axis=1)
            
            
        name_list = []
        name_all = os.path.join(save_image_path, '0000_10.jpg')
        
        for i in range(len(self.query_db['query_feats'])):
            
#            if i == 10:
#                break
            name_list.append(os.path.join(save_image_path, 'result_'+str(i)+'.jpg'))
            
        mergeImages(name_all, name_list, box=None, size=(128*(self.top_k+1),128), axis=0)
Ejemplo n.º 2
0
    def rerank_once(self, query_idx, rank):
        print('rerank_once', query_idx)
        query_feat = self.query_db['query_feats'][query_idx]
        bboxes = []
        distances = []
        frames = []
        for one in rank:
            bbox_info = self.db_feats['bbox_infos'][one[0]]
            frames.append(one[1])
            bbox_feats = []
            for bbox in bbox_info:
                bbox_feats.append(bbox['feat'])
            bbox_feats = np.array(bbox_feats)
            if len(bbox_info) == 2:
#                print(' ')
                pass
            # Compute distances
            dist_array = pairwise_distances(query_feat.reshape(1, -1), bbox_feats, self.dist_type, n_jobs=-1)

            # Select minimum distance
            distances.append(np.min(dist_array))

            # Array of boxes with min distance
            idx = np.argmin(dist_array)

            # Select array of locations with minimum distance
            best_box_array = bbox_info[idx]['bbox']
            
            bboxes.append(best_box_array)
        
        
        dist = list(zip(distances, range(len(distances))))
        dist.sort(key = lambda x : x[0])
        best_idxs = [x[1] for x in dist[:self.top_k]]
        #choose the best 10 pictures
        best_boxes = list(map(lambda x: bboxes[x], best_idxs))
        best_distances = list(map(lambda x: distances[x], best_idxs))
#        frames
        best_frames = list(map(lambda x: frames[x], best_idxs))
        
        
        rerankings_dir_info = os.path.join(self.rerankings_dir, 'rerank_order')
        check_path(rerankings_dir_info)
        with open(os.path.join(rerankings_dir_info, str(query_idx)) + '.pkl' ,'wb') as f:
            pickle.dump(best_distances, f)
            pickle.dump(best_boxes, f)
            pickle.dump(best_frames, f)
        
        
        return best_idxs, best_boxes, best_distances, best_frames
Ejemplo n.º 3
0
 def save_rank_kpi(self):
     save_kpi_path = os.path.join(self.rankings_dir, 'result_kpi')
     check_path(save_kpi_path)
     save_rank_path = os.path.join(self.rankings_dir, 'rank_order')
     rank_files = os.listdir(save_rank_path)
     
     with open(os.path.join(save_kpi_path, 'rank_kpi.txt'), 'a+') as f:
         
         releated_num_total = 0
         right_num_total = 0
         wrong_num_total = 0
         recall = 0
         precise = 0
         for i in range(len(rank_files)):
             ranks = open(os.path.join(save_rank_path, 'rank_'+str(i)+'.txt'),'r').readlines()
             
             #def mergeImages(name, files, box, size=(224,224), axis=0):
             files = []
             query_full_path = self.query_db['paths'][i]
             query_base_name = get_query_basename(query_full_path)
             
             files.append(query_full_path)
             
             related_num = self.query_db['related_num'][i]
             releated_num_total += related_num
             for i, order in enumerate(ranks[: self.top_k]):
                 
                 image_full_path = order.strip().split(' ')[1]
                 image_base_name = get_query_basename(image_full_path)
                 
                 if image_base_name == query_base_name:
                     right_num_total += 1
                 else:
                     wrong_num_total += 1
         try:
             recall = right_num_total*1.0/releated_num_total
             precise = right_num_total*1.0/(right_num_total + wrong_num_total)
         except ZeroDivisionError:
             print('--illeage value top_k')
         print('--top_k = : ', self.top_k)
         print('---- rank recall : ', recall)
         print('---- rank precise: ', precise)
         f.writelines('top_k = : '+str(self.top_k)+'\n')
         f.writelines('recall:  '+str(recall)[:6]+'\n')
         f.writelines('precise: '+str(precise)[:6]+'\n')
         f.writelines('-------------------------''\n')
         f.writelines('-------------------------''\n')
             
         
     return 
Ejemplo n.º 4
0
    def write_rankings(self, final_scores):

        save_rank_path = os.path.join(self.rankings_dir, 'rank_order')
        check_path(save_rank_path)
        
        for i, query_info in enumerate(self.query_db['query_feats']):

            scores = final_scores[i,:]

            ranking = np.arange(len(self.db_feats['rmacs']))[np.argsort(scores)]
            
            savefile = open(os.path.join(save_rank_path, 'rank_'+str(i)) +'.txt','w')

            for res in ranking:
                savefile.write(str(res) +': ' + self.db_feats['paths'][res] + '\n')
                
            savefile.close()
Ejemplo n.º 5
0
    def __init__(self, mark, base):

        # Read image lists
        self.dimension = 256
        
        self.rank_top_k = 200
        self.top_k = 10
        self.feats_db_path = base+ '/feats_db_'+mark+'/feats'
        query_feats_path = base+ '/query_db/feats'
        
        # Distance type
        self.dist_type = 'cosine'
        # Load features
        self.query_db = read_query_feats(query_feats_path)
        

        # Where to store the rankings
        self.rankings_dir = base+ '/rank_'+mark
        check_path(self.rankings_dir)
        self.rerankings_dir = base+ '/rerank_'+mark
        check_path(self.rerankings_dir)
Ejemplo n.º 6
0
    def Annotation_Mission(self, annotation_info):
        """
        主要用于处理所有的文件路径
        @param annotation_info:
        @return:
        """
        #1:导入annotation_info的信息
        dataset_path = annotation_info["dataset_path"]  #读入数据集路径
        images_path = os.path.join(dataset_path, "images")
        annotations_path = os.path.join(dataset_path, "image_annotations")
        tools.check_path(images_path, annotations_path)  #确保文件存在

        #2:确定使用的识别方法等
        self.flag_trackcar = self.str_to_bool(annotation_info["flag_trackcar"])
        self.flag_detectcar = self.str_to_bool(
            annotation_info["flag_detectcar"])
        self.flag_detectarmor = self.str_to_bool(
            annotation_info["flag_detectarmor"])
        self.flag_show_annotation = self.str_to_bool(
            annotation_info["flag_show_annotation"])

        #3:进行标注
        images_list = os.listdir(images_path)
        try:
            images_list.sort(key=lambda x: int(x[:-4]))
        except:
            print("排序的文件中有非数字,因此直接sort,而不是采用数字进行sort")
            images_list.sort()  #如果有不是数字的,则进行第二种排序

        #3.1:如果有进行一开始就进行跟踪,则进行一下初始化
        if self.flag_trackcar:
            self.tracker = cv.MultiTracker_create()
            track_firstframe_path = os.path.join(images_path, images_list[0])
            track_firstframe = cv.imread(track_firstframe_path)
            self.tracker_init(track_firstframe)  #进行跟踪器跟踪

        #进行标注
        index_id = 0
        while index_id < len(images_list):
            filename = images_list[index_id][:-4]
            print(
                "*****************{}.jpg的处理**************************".format(
                    filename))
            #4.1:获取image_path和txt_path
            image_path = os.path.join(images_path, images_list[index_id])
            txt_path = os.path.join(annotations_path, filename + ".txt")

            return_state = self.Annotation_File(image_path, txt_path)
            if return_state == 'b':
                index_id = index_id - 2
            if return_state == 'j':
                jump_data = input("请输入要跳转的图片名称")
                name = str(jump_data) + ".jpg"
                index_id = images_list.index(name)
                print("即将跳转的id为:", index_id)

            if return_state == 'q':
                cv.waitKey(1)
                cv.destroyAllWindows()
                break

            index_id = index_id + 1
            cv.waitKey(1)

        cv.destroyAllWindows()
Ejemplo n.º 7
0
    def save_rerank_result(self):
        '''
        
        
        '''
        
        save_image_path = os.path.join(self.rerankings_dir, 'result_images')
        check_path(save_image_path)
        
        rerankings_dir_info = os.path.join(self.rerankings_dir, 'rerank_order')
        rerank_files = sorted(os.listdir(rerankings_dir_info))
        
        save_kpi_path = os.path.join(self.rerankings_dir, 'result_kpi')
        check_path(save_kpi_path)
        
        with open(os.path.join(save_kpi_path, 'rerank_kpi.txt'), 'a+') as kpi_f:
            releated_num_total = 0
            right_num_total = 0
            wrong_num_total = 0
            recall = 0
            precise = 0
            
            
            
            for i in range(len(rerank_files)):
                with open(os.path.join(rerankings_dir_info, str(i)+'.pkl') ,'rb') as f: 
                    best_distances = pickle.load(f)
                    best_boxes = pickle.load(f)
                    best_frames = pickle.load(f)
                
                #def mergeImages(name, files, box, size=(224,224), axis=0):
                files = []
                files.append(self.query_db['paths'][i])
                
                query_base_name = get_query_basename(self.query_db['paths'][i])
                related_num = self.query_db['related_num'][i]
                releated_num_total += related_num
                
                bboxes = []
                box = self.query_db['query_boxes'][i]
                bboxes.append(box)
                name = os.path.join(save_image_path, 'result_'+str(i)+'.jpg')
                
                for i, frame in enumerate(best_frames):
                    if i == self.top_k:
                        break
                    files.append(frame)
                    bboxes.append(best_boxes[i])
                    image_base_name = get_query_basename(frame)
                    if image_base_name == query_base_name:
                        right_num_total += 1
                    else:
                        wrong_num_total += 1
                
                mergeImages2(name, files, bboxes, size=(128,128), axis=1)
                
            try:
                recall = right_num_total*1.0/releated_num_total
                precise = right_num_total*1.0/(right_num_total + wrong_num_total)
            except ZeroDivisionError:
                print('--illeage value top_k')
                
            print('--top_k = : ', self.top_k)
            print('---- rank recall : ', recall)
            print('---- rank precise: ', precise)
            kpi_f.writelines('top_k = : '+str(self.top_k)+'\n')
            kpi_f.writelines('recall:  '+str(recall)[:6]+'\n')
            kpi_f.writelines('precise: '+str(precise)[:6]+'\n')
            kpi_f.writelines('-------------------------''\n')
            kpi_f.writelines('-------------------------''\n')
        name_list = []
        name_all = os.path.join(save_image_path, 'all.jpg')
        for i in range(len(rerank_files)):
            
#            if i == 10:
#                break
            name_list.append(os.path.join(save_image_path, 'result_'+str(i)+'.jpg'))
            
            
        mergeImages(name_all, name_list, box=None, size=(128*(self.top_k+1),128), axis=0)
    def load_data_(self, config, img_dir, debug=False):
        """Load a subset of the Balloon dataset.
        annotations_path: file path of the annotation.
        class_path: file path of class
        image_dir: the dictionary of image
        subset: Subset to load: train or val or test
        class_ids: class IDs to load
        """
        f = open(config.annotion_path, encoding='utf-8')
        dataset = json.load(f)
        f.close()
        self.config = config
        self.prepare_class(dataset['categories'])
        
        # Add images  this step should be optimized to avoid applying too much memory
        print("Loading image!")
        json_path_list = scan_specified_files(img_dir, key='.json')
        
        f = open('dataset_log.txt', 'w')
        time0 = time.time()
        counter = 0
        image_id_repeat = 0
        for idx, json_full_path in enumerate(json_path_list):
            jf = open(json_full_path, encoding='utf-8')
            info = json.load(jf)
            jf.close()

            width=info['width']
            height=info['height']
            img_full_path = os.path.join(os.path.split(json_full_path)[0], info['file_name'])

            if 'need_check_per_image' == 'need_check_per_image':
                try:
                    img = image.load_img(img_full_path)
                except FileNotFoundError as e:
                    print(e)
#                    print(annotation.image_name)
                    f.writelines(str(idx) + ' : ' + img_full_path + '\n')
                    continue
                width_gt, height_gt = img.size  #TODO
                if [width, height] != [width_gt, height_gt]:
                    print('wrong width and height')
                    f.writelines(str(idx) + ': wrong width and height: '+img_full_path+'\n')
                    sys.exit()
                    continue
            
            re_category_ids = []
            re_bboxes = []
            if len(info['objects'])==0 and not config.USING_NEGATIVE_IMG:
#                print('ignore no NEGATIVE image')
                continue
            if len(info['objects'])> 0 and not config.USING_POSITIVE_IMG:
#                 print('ignore no POSITIVE image')
                continue
            
            for idx_, obj in enumerate(info['objects']):
                bbox = obj['bbox']
                
                x1 = min(bbox[0], bbox[2])
                y1 = min(bbox[1], bbox[3])
                x2 = max(bbox[0], bbox[2])
                y2 = max(bbox[1], bbox[3])
                if x1 >= x2 or y1 >= y2:
                    print('bbox_gt error ',bbox )
                    continue
                re_category_ids.append(obj['label'])
                rect = []
                rect.append(x1)
                rect.append(y1)
                rect.append(x2)
                rect.append(y2)
                re_bboxes.append(rect)
                
            if debug:
                save_path = 'train_data_virsual_fold'
                check_path(save_path)
                drew_detect_resualt(img_full_path, 
                                    os.path.join(save_path, img_full_path.split('/')[-1]), 
                                    re_bboxes, 
                                    re_category_ids, 
                                    self.class_info, 
                                    debug)
            
#            img_b = (np.transpose(img_all[image_id][:][:][:],(2,1,0))+img_mean)*255
#            img_0 = np.where(img_b > 0, img_b, 0)
#            img_1 = np.where(img_0 < 255, img_0, 255)
#            if False:
#                img_2 = Image.fromarray(img_1.astype(np.uint8))
#                img_2.show()
            
            repeat = 1
            if len(info['objects']) == 0:
                repeat = config.NEGATIVE_MULT
            if len(info['objects']) > 0:
                repeat = config.POSITIVE_MULT
            for i in range(repeat):
                self.add_image(
                        config.NAME,
                        image_id=image_id_repeat,
                        path=img_full_path,
                        width=width,
                        height=height,
                        category_ids = re_category_ids,
                        bboxes = re_bboxes
                    )
                image_id_repeat += 1
            counter += 1
            step=200
            if counter % step == 0:
                rest_time = (time.time()-time0)*((len(json_path_list)-counter)/(step))
                print('----Adding the image:', counter, 
                      'rest time(sec) = ', rest_time)
                time0 = time.time()
#            if counter >10:      #TODO
#                break

            
        f.close()
        print('-----------loaded total image ----------------:', counter)
        print('-----------after balance total----------------:', image_id_repeat)
class ThisConfig(Config):
    """Configuration for training on the toy  dataset.
    Derives from the base Config class and overrides some values.
    """
    # Give the configuration a recognizable name
    
    #%%   about path
    
    annotion_path = '../../data/train/jinnan2_round1_train_20190305/train_no_poly.json'
    train_img_dir = '../../data/train/train'
    val_img_dir = '../../data/train/val'
    restricted_img_dir = '../../data/train/jinnan2_round1_train_20190222/restricted'
    normal_img_dir = '../../data/train/jinnan2_round1_train_20190222/normal'
    
    test_img_dir = '../../data/test/jinnan2_round1_test_a_20190306'
    real_test_img_dir = '../../data/test/jinnan2_round1_test_a_20190306'
#    real_test_img_dir = '../../data/test/error'
#    real_test_img_dir = '../../data/test/single'
    real_test_img_dir = '/media/mosay/数据/jz/tianchi/data/test/final/jinnan2_round1_test_b_20190326'
    
    
    
    computer = 'jz'  #zy jz 426

    if computer == '426':
        COCO_WEIGHTS_PATH = '/raid/Guests/DaYea/clothes/mask_rcnn_coco.h5'
        IMG_DIR = '/raid/Guests/Jay/Jay/datasets/clothes/Img'
        annotations_path = os.path.join("/raid/Guests/zy/clothes/Anno/cloth_all.csv")
        class_path =  os.path.join("/raid/Guests/zy/clothes/Anno/list_class.csv")
        img_path = '/raid/Guests/zy/clothes/Anno/cloth.h5'
        mask_path = r'/raid/Guests/zy/clothes/Anno/mask.h5'
        
    elif computer == 'jz':
        COCO_WEIGHTS_PATH = 'models/mask_rcnn_coco.h5'
        IMG_DIR = r"/hdisk/Ubuntu/datasets/clothes/Img"
        annotations_path = r"/hdisk/Ubuntu/datasets/clothes/Anno/cloth_all.csv"
        class_path =  r"/hdisk/Ubuntu/datasets/clothes/Anno/list_class.csv"
        img_path = r"/hdisk/Ubuntu/datasets/clothes/Anno/cloth.h5"
        mask_path = r"/hdisk/Ubuntu/datasets/clothes/Anno/mask.h5"
        query_file = '/home/aaron/mydisk/datasets/from_zy/MVC/query_image'
        
    elif computer == 'zy':
        
        data_root = '/media/mosay/数据/jz/dataset'
    #    IMG_DIR = os.path.join(ROOT_DIR, "clothes/Img/")
        COCO_WEIGHTS_PATH = 'mask_rcnn_coco.h5'
        IMG_DIR =  data_root + '/Img'
        annotations_path = data_root + "/Anno/cloth_all.csv"
        class_path =   data_root + "/Anno/list_class.csv"
        img_path =  data_root + "/Anno/cloth.h5"
        mask_path =  data_root + "/Anno/mask.h5"
        query_file = data_root + '/MVC/query_image'
    
    
    
    #%%  about dataset
    
    
    
    NAME = "clothes"
    
    
    # Number of classes (including background)
    NUM_CLASSES = 46 + 1  # Background + class
    split_record_dir = 'split_record'
    VAL_DATA_RATE = 0.1
    
    # Path to trained weights file
    DEFAULT_LOGS_DIR = 'logs'
    #%% important args
    Mode='retrival' #'train' or 'evaluate' or 'retrival'
#    Mode='evaluate'
    #%% evaluate
    if Mode == 'evaluate' or Mode == 'retrival':
        if computer =='jz':
            retrival_data_root = '/home/aaron/mydisk/datasets/from_zy'
        elif computer == 'zy':
            retrival_data_root = '/media/mosay/数据/jz/dataset'
        else:
            retrival_data_root = ''
        subset = 'mini_subset' #mini_subset  subset  
        if subset == 'subset':
            DEFAULT_DATASET = os.path.join(retrival_data_root, 'MVC/subset_annotation.txt')
        if subset == 'mini_subset':
            DEFAULT_DATASET = os.path.join(retrival_data_root, 'MVC/mini_subset_annotation.txt')#
        
        USING_NEGATIVE_IMG = True  # uesing img have no objects
        USING_POSITIVE_IMG = False  # uesing img have objects
        NEGATIVE_MULT = 1
        POSITIVE_MULT = 1
        real_test = True # if true means that we load image without gt
        init_with = "this"  # imagenet, coco, or last this
        EVA_LIMIT=100000
        
        model_version = 'mask_rcnn_clothes_0742'
        
        THIS_WEIGHT_PATH = 'models/mask_rcnn_clothes_0742.h5'  
        if init_with == 'coco':
            NUM_CLASSES = 80+1  # clothes has 80 classes
    
        # Adjust down if you use a smaller GPU.
        IMAGES_PER_GPU = 1
        GPU_COUNT = 1
        IMAGES_PER_GPU = 1
        # You can increase this during training to generate more propsals.
        RPN_NMS_THRESHOLD = 0.5
        # Skip detections with < 60% confidence
        DETECTION_MIN_CONFIDENCE =0.01
        # Non-maximum suppression threshold for detection
        DETECTION_NMS_DIFF_CLS = True   #!!!!!!!!! important
        DETECTION_NMS_THRESHOLD = 0.1   #!!!!!!!!! important
        POST_NMS_ROIS_INFERENCE = 1000   #!!!!!!!!! important
        map_iou_thr = 0.7
        arg_str = '_rn'+str(RPN_NMS_THRESHOLD)[2:4] +\
                  '_ds'+str(DETECTION_MIN_CONFIDENCE)[2:4] +\
				  '_dn'+str(DETECTION_NMS_THRESHOLD)[2:4]

        save_base_dir='test_' + model_version +'_'+ str(real_test)+'_' + arg_str
        check_path(save_base_dir)
        

        
    #%%train
    if Mode == 'train':
        USING_NEGATIVE_IMG = True
        USING_POSITIVE_IMG = True
        NEGATIVE_MULT = 1
        POSITIVE_MULT = 1
        # Which weights to start with?
        init_with = "coco"  # imagenet, coco, or last this
        THIS_WEIGHT_PATH = '/home/aaron/mydisk/aWS/ImgRetrival/logs/mask_rcnn_clothes_0464.h5'
        COCO_WEIGHTS_PATH = 'models/mask_rcnn_coco.h5'
        # Learning rate and momentum
        # The Mask RCNN paper uses lr=0.02, but on TensorFlow it causes
        # weights to explode. Likely due to differences in optimizer
        # implementation.
        LEARNING_RATE = 0.0001
        LEARNING_MOMENTUM = 0.9
        # Weight decay regularization
        WEIGHT_DECAY = 0.0001
        
        # Uncomment to train on 8 GPUs (default is 1)
        GPU_COUNT = 1
        # We use a GPU with 12GB memory, which can fit two images.
        # Adjust down if you use a smaller GPU.
        IMAGES_PER_GPU = 1
        # Number of training steps per epoch
        STEPS_PER_EPOCH = 300
        VALIDATION_STEPS = 50
        EPOCHS = 2000
        
        
        USE_RPN_ROIS = True
        rpn_fg_iou_thr = 0.5
        rpn_bg_iou_thr = 0.5
        # You can increase this during training to generate more propsals.
        RPN_NMS_THRESHOLD = 0.3
        RPN_TRAIN_ANCHORS_PER_IMAGE = 256
        POST_NMS_ROIS_TRAINING = 2000

    #%% stable args
    IMAGE_RESIZE_MODE = "square"
    IMAGE_MIN_DIM = 128
    IMAGE_MAX_DIM = 128
    BACKBONE = "resnet50" #TODO
    
    # Image mean (RGB)
#    MEAN_PIXEL = np.array([123.7, 116.8, 103.9])
    MEAN_PIXEL = np.array([218.37592,213.07745,211.30586])
    # Length of square anchor side in pixels
#   RPN_ANCHOR_SCALES = (32, 64, 128, 256, 512)
#    RPN_ANCHOR_SCALES = (16, 32, 64, 128, 256)
#    RPN_ANCHOR_SCALES = (16, 32, 32, 64, 128)
    RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128)
#    RPN_ANCHOR_RATIOS = [0.5, 1, 2]
    RPN_ANCHOR_RATIOS = [0.5, 1, 2]
    
    LOSS_WEIGHTS = {
        "rpn_class_loss": 1.,
        "rpn_bbox_loss": 1.,
        "mrcnn_class_loss": 1.,
        "mrcnn_bbox_loss": 1.
    }
Ejemplo n.º 10
0
def save_one_evaluate_result(config,
                             image_id,
                             dataset,
                             result,
                             avg_height,
                             debug=False):
    """
        result = {
                "image_id": image_id,
                "path":dataset.source_image_link(image_id),
                "category_id": dataset.get_source_class_id(class_id, "cloth"),
                "bbox": [bbox[1], bbox[0], bbox[3] - bbox[1], bbox[2] - bbox[0]],
                "score": score,
                "segmentation": maskUtils.encode(np.asfortranarray(mask))}
        results.append(result)
        """
    [box_results] = result
    base_dir = config.save_base_dir
    check_path(base_dir)

    image_info_ = dataset.image_info[image_id]

    path = dataset.source_image_link(image_id)

    #%% detection
    save_img_path = os.path.join(base_dir, 'detect_results')
    check_path(save_img_path)

    img = Image.open(path)
    gt_bboxes = []

    if 'bboxes' in image_info_.keys():
        gt_bboxes = image_info_['bboxes']
        gt_class_ids = image_info_['category_ids']

    if len(gt_bboxes) == 0:
        pass


#            print("image: no bbox", image_id)
    text_size = int(20)
    ttfont = ImageFont.truetype('lib/华文细黑.ttf', text_size)
    for idx, box in enumerate(gt_bboxes):

        category_id = gt_class_ids[idx]
        class_name = dataset.class_info[category_id]['name']
        draw = ImageDraw.Draw(img)
        draw.line([(box[0], box[1]), (box[2], box[1]), (box[2], box[3]),
                   (box[0], box[3]), (box[0], box[1])],
                  width=3,
                  fill='red')
        #            print('class_nameclass_nameclass_nameclass_name',class_name)
        #            unicode('杨','utf-8')
        draw.text((box[0] + 10, box[1]),
                  class_name.split('_')[0],
                  fill=(255, 0, 0),
                  font=ttfont)

    if len(box_results) == 0:
        print("image: no bbox", image_id)

    pure_bboxes = []
    pure_class_ids = []
    for bbox in box_results:

        box = bbox['bbox']
        pure_bboxes.append(box)
        category_id = bbox['category_id']
        pure_class_ids.append(category_id)
        class_name = dataset.class_info[category_id]['name']
        draw = ImageDraw.Draw(img)
        draw.line([(box[0], box[1]), (box[2], box[1]), (box[2], box[3]),
                   (box[0], box[3]), (box[0], box[1])],
                  width=3,
                  fill='blue')

        draw.text((box[0] + 10, box[1] + text_size),
                  class_name.split('_')[0],
                  fill=(0, 0, 255),
                  font=ttfont)
        draw.text((box[0] + 10, box[1] + text_size * 2),
                  str(bbox['score'])[:4],
                  fill=(0, 0, 255),
                  font=ttfont)
    if debug:
        img.show()
    img.save(save_img_path + '/' + path.split('/')[-1])