Пример #1
0
def get_image_and_label():
    image_label_train = {}
    for line in open('../data/train.txt').readlines():
        line_idx, pic_path, boxes, labels, img_width, img_height = parse_line(line)
        # print(line_idx, pic_path, boxes, labels, img_width, img_height)
        # image_name = pic_path.strip().split('/')[-1]
        image_label_train[line_idx] = [pic_path, labels]
    image_label_val = {}
    for line in open('../data/val.txt').readlines():
        line_idx, pic_path, boxes, labels, img_width, img_height = parse_line(line)
        # print(line_idx, pic_path, boxes, labels, img_width, img_height)
        # image_name = pic_path.strip().split('/')[-1]
        image_label_train[line_idx+30000] = [pic_path, labels]
        image_label_val[line_idx] = [pic_path, labels]
    return image_label_train, image_label_val
 def generate_rot90_txt(self, annotation_txt, txt_dst):
     contents = open(annotation_txt, 'r').readlines()
     dstfile = open(txt_dst, 'w')
     for line in contents:
         pic_path, boxes, labels = parse_line(line, is_str=True)
         img = get_img(pic_path)
         try:
             h, w, _ = img.shape
         except:
             print('wrong', pic_path)
             continue
         # rimg = np.rot90(img).copy()
         img_dstdir, imgname = os.path.split(pic_path)
         savepath = os.path.join(img_dstdir, 'rot90_' + imgname)
         aline = savepath
         for i in range(len(labels)):
             b = boxes[i]
             x1, y1, x2, y2 = b[1], w - b[2], b[3], w - b[0]
             aline = aline + ' {} {} {} {} {}'.format(
                 labels[i], x1, y1, x2, y2)
             # plot_one_box(img, [b[0], b[1], b[2], b[3]], label='test')
             # cv2.imshow('ori', img)
             # plot_one_box(rimg, [x1, y1, x2, y2], label='test')
             # cv2.imshow('test', rimg)
             # cv2.waitKey(0)
             # aline = aline + ' ' + str(labels[i]) + ' ' + str(b[0]) + ' ' + str(b[1]) + ' ' + str(b[2]) + ' ' + str(b[3])
         dstfile.write(aline)
         dstfile.write('\n')
     dstfile.close()
     self.concat_txts(annotation_txt, txt_dst)
Пример #3
0
def parse_gt_rec(gt_filename, resize_img_size):
    '''
    parse and re-organize the gt info.
    return:
        gt_dict: dict. Each key is a img_id, the value is the gt bboxes in the corresponding img.
    '''

    global gt_dict

    if not gt_dict:
        resize_w, resize_h = resize_img_size
        with open(gt_filename, 'r') as f:
            for line in f:
                img_id, pic_path, boxes, labels = parse_line(line)

                ori_img_size = cv2.imread(pic_path).shape
                ori_w, ori_h = ori_img_size[1], ori_img_size[0]

                objects = []
                for i in range(len(labels)):
                    x_min, y_min, x_max, y_max = boxes[i]
                    label = labels[i]
                    objects.append([
                        x_min * resize_w / ori_w, y_min * resize_h / ori_h,
                        x_max * resize_w / ori_w, y_max * resize_h / ori_h,
                        label
                    ])
                gt_dict[img_id] = objects
    return gt_dict
Пример #4
0
def get_image_and_label(phase):
    image_label_dict = {}
    if phase == 'train':
        for line in open('../data/train.txt').readlines():
            line_idx, pic_path, boxes, labels, img_width, img_height = parse_line(
                line)
            # print(line_idx, pic_path, boxes, labels, img_width, img_height)
            # image_name = pic_path.strip().split('/')[-1]
            image_label_dict[line_idx] = [pic_path, labels]
    else:
        for line in open('../data/val.txt').readlines():
            line_idx, pic_path, boxes, labels, img_width, img_height = parse_line(
                line)
            # print(line_idx, pic_path, boxes, labels, img_width, img_height)
            # image_name = pic_path.strip().split('/')[-1]
            image_label_dict[line_idx] = [pic_path, labels]
    return image_label_dict
Пример #5
0
def get_data_from_web_and_save(data_batch_path, data_train_save_path, data_val_save_path, nfs_mount_path, config, logger, project_name="sjht"):
    data_batch_path = data_batch_path.strip()
    assert os.path.exists(data_batch_path), "data batch path is not exit:%s"%data_batch_path
    # assert os.path.exists(data_train_save_path), "data train save path is not exit"
    # assert os.path.exists(data_val_save_path), "data val save path is not exit"
    data_train_save = open(data_train_save_path, "w")
    data_val_save = open(data_val_save_path, "w")

    delete_labels = config.data_set["delete_labels"]
    delete_labels_background = config.data_set["delete_labels_background"]
    batch_lines = [line.strip() for line in open(data_batch_path).readlines()]

    data_train_batch_id = batch_lines.index("#train_data_batch")
    data_train_background_batch_id = batch_lines.index("#train_data_background_batch")
    val_data_batch_id = batch_lines.index("#val_data_batch")
    data_train_batch = batch_lines[data_train_batch_id: data_train_background_batch_id]
    data_train_background_batch = batch_lines[data_train_background_batch_id: val_data_batch_id]
    data_val_batch = batch_lines[val_data_batch_id:]

    data_train = get_data(logger, data_train_batch, project_name=project_name, delete_labels=delete_labels, nfs_mount_path=nfs_mount_path)
    data_train_background = get_data_background(logger, data_train_background_batch, project_name=project_name, delete_labels=delete_labels_background, nfs_mount_path=nfs_mount_path)
    data_val = get_data_val(logger, data_val_batch, project_name=project_name, delete_labels=delete_labels, nfs_mount_path=nfs_mount_path)

    logger.info_ai(meg="download data, train data num:%d, background data num:%d, val data num:%d"%(len(data_train), len(data_train_background), len(data_val)))
    #储存每个类别嵌件个数,以便后续裁剪图片时按概率裁剪
    class_num_dic = {}
    for line in data_train:
        pic_path, boxes, labels = parse_line(line)
        for label in labels:
            if label not in config.data_set["fill_zero_label_names"] and label != "gj":
                if label not in class_num_dic:
                    class_num_dic[label] = 1
                else:
                    class_num_dic[label] += 1

    data_train = data_train + data_train_background

    for line in data_train:
        data_train_save.write(line)
        data_train_save.write("\n")
    data_train_save.close()
    logger.info_ai(meg="train data save in:%s" % data_train_save_path)

    if len(data_val) > 0:
        for line in data_val:
            data_val_save.write(line)
            data_val_save.write("\n")
    data_val_save.close()
    logger.info_ai(meg="val data save in:%s" % data_val_save_path)

    return len(data_train), len(data_val), class_num_dic
Пример #6
0
def parse_gt_rec(gt_filename, target_img_size, letterbox_resize=True):
    '''
    parse and re-organize the gt info.
    return:
        gt_dict: dict. Each key is a img_id, the value is the gt bboxes in the corresponding img.
    '''

    global gt_dict

    if not gt_dict:
        new_width, new_height = target_img_size
        with open(gt_filename, 'r') as f:
            for line in f:
                img_id, pic_path, boxes, labels, ori_width, ori_height = parse_line(
                    line)

                objects = []
                for i in range(len(labels)):
                    x_min, y_min, x_max, y_max = boxes[i]
                    label = labels[i]

                    if letterbox_resize:
                        resize_ratio = min(new_width / ori_width,
                                           new_height / ori_height)

                        resize_w = int(resize_ratio * ori_width)
                        resize_h = int(resize_ratio * ori_height)

                        dw = int((new_width - resize_w) / 2)
                        dh = int((new_height - resize_h) / 2)

                        objects.append([
                            x_min * resize_ratio + dw,
                            y_min * resize_ratio + dh,
                            x_max * resize_ratio + dw,
                            y_max * resize_ratio + dh, label
                        ])
                    else:
                        objects.append([
                            x_min * new_width / ori_width,
                            y_min * new_height / ori_height,
                            x_max * new_width / ori_width,
                            y_max * new_height / ori_height, label
                        ])
                gt_dict[img_id] = objects
    return gt_dict
def cut_box_signal_process(data_lines,
                           process_id,
                           save_path,
                           do_extension_boxes,
                           extension_ratio=1.3,
                           process_over=None):
    globalvar.globalvar.logger.info_ai(meg="start process:%d" % process_id)
    save_id = 0
    for line in data_lines:
        try:
            save_id = save_id + 1
            pic_path_ori, boxes, labels = parse_line(line)
            if len(boxes) == 0:
                continue

            img = cv2.imread(pic_path_ori)
            if img is None:
                print("img is not exits:%d" % pic_path_ori)
            if do_extension_boxes:
                boxes = extension_boxes(img,
                                        boxes,
                                        extension_ratio=extension_ratio)

            for i in range(len(boxes)):
                save_id = save_id + 1
                x0, y0, x1, y1 = boxes[i]
                img_save = img[y0:y1, x0:x1, :]
                if not os.path.exists(save_path + "/%s" % labels[i]):
                    os.mkdir(save_path + "/%s" % labels[i])
                cut_img_save_path = save_path + "/%s/%s_%d_%d.jpg" % (
                    labels[i], pic_path_ori.split("/")[-1], process_id,
                    save_id)
                cv2.imwrite(cut_img_save_path, img_save)
                globalvar.globalvar.logger.info_ai(
                    meg="cut box img save to:%s" % cut_img_save_path)
        except Exception as ex:
            globalvar.globalvar.logger.info_ai(
                meg="cut box img find Exception:%s" % ex)

    process_over[process_id] = 1
    globalvar.globalvar.logger.info_ai(meg="this process id over:%d" %
                                       process_id)
                config=config_common,
                logger=logger,
                project_name="sjht")
        else:
            logger.info_ai(
                meg="data batch is None, use train file and val file")
            data_train = open(
                config_common.data_set["train_file_path"]).readlines()
            train_num = len(data_train)
            data_val = open(
                config_common.data_set["val_file_path"]).readlines()
            val_num = len(data_val)
            # 储存每个类别嵌件个数,以便后续裁剪图片时按概率裁剪
            class_num_dic = {}
            for line in data_train:
                pic_path, boxes, labels = parse_line(line)
                for label in labels:
                    if label not in config_common.data_set[
                            "fill_zero_label_names"] and label != "gj":
                        if label not in class_num_dic:
                            class_num_dic[label] = 1
                        else:
                            class_num_dic[label] += 1

        logger.info_ai(
            meg="prepare data over, get all class name from data file",
            get_ins={"class_num_dic": class_num_dic})

        lr_decay_freq = int(
            config_common.model_set["lr_decay_freq_epoch"] * train_num /
            (len(gpu_device) * config_common.model_set["batch_size"]))
 def generate_crop_txt(self, annotation_txt, txt_dst):
     # crop center region
     contents = open(annotation_txt, 'r').readlines()
     dstfile = open(txt_dst, 'w')
     for line in contents:
         pic_path, boxes, labels = parse_line(line, is_str=True)
         img = get_img(pic_path)
         try:
             h, w, _ = img.shape
         except:
             print('wrong', pic_path)
             continue
         # rimg = np.rot90(img).copy()
         img_dstdir, imgname = os.path.split(pic_path)
         savepath = os.path.join(img_dstdir, 'crop_' + imgname)
         aline = savepath
         flag = 0
         if h > w:
             st = int(h / 2 - w / 2)
             rimg = img[st:st + w, :, :]
             for i in range(len(labels)):
                 b = boxes[i]
                 x1, y1, x2, y2 = b[0], max(1, b[1] - st), b[2], min(
                     b[3] - st, w - 2)
                 if (y2 - y1) / (b[3] - b[1]) < 0.2 or (y2 - y1) / w < 0.05:
                     continue
                 aline = aline + ' {} {} {} {} {}'.format(
                     labels[i], x1, y1, x2, y2)
                 flag = 1
                 # plot_one_box(img, [b[0], b[1], b[2], b[3]], label='test')
                 # cv2.imshow('ori', img)
             #     plot_one_box(rimg, [x1, y1, x2, y2], label='test')
             # cv2.imshow('test', rimg)
             # cv2.waitKey(0)
             # aline = aline + ' ' + str(labels[i]) + ' ' + str(b[0]) + ' ' + str(b[1]) + ' ' + str(b[2]) + ' ' + str(b[3])
         elif h < w:
             st = int(w / 2 - h / 2)
             rimg = img[:, st:st + h, :]
             for i in range(len(labels)):
                 b = boxes[i]
                 x1, y1, x2, y2 = max(1, b[0] - st), b[1], min(
                     b[2] - st, h - 2), b[3]
                 if (x2 - x1) / (b[2] - b[0]) < 0.2 or (x2 - x1) / h < 0.05:
                     continue
                 aline = aline + ' {} {} {} {} {}'.format(
                     labels[i], x1, y1, x2, y2)
                 flag = 1
                 # plot_one_box(img, [b[0], b[1], b[2], b[3]], label='test')
                 # cv2.imshow('ori', img)
                 # plot_one_box(rimg, [x1, y1, x2, y2], label='test')
             # cv2.imshow('test', rimg)
             # cv2.waitKey(0)
             # aline = aline + ' ' + str(labels[i]) + ' ' + str(b[0]) + ' ' + str(b[1]) + ' ' + str(b[2]) + ' ' + str(b[3])
         else:
             print('no crop', pic_path)
             continue
         if flag == 0:
             continue
         dstfile.write(aline)
         dstfile.write('\n')
     dstfile.close()
     self.concat_txts(annotation_txt, txt_dst)
Пример #10
0
                config=config_common,
                logger=logger,
                project_name="sjht")
        else:
            logger.info_ai(
                meg="data batch is None, use train file and val file")
            data_train = open(
                config_common.data_set["train_file_path"]).readlines()
            train_num = len(data_train)
            data_val = open(
                config_common.data_set["val_file_path"]).readlines()
            val_num = len(data_val)
            # 储存每个类别嵌件个数,以便后续裁剪图片时按概率裁剪
            class_num_dic = {}
            for line in data_val:
                pic_path, boxes, labels = parse_line(line)
                for label in labels:
                    if label not in config_common.data_set[
                            "fill_zero_label_names"] and label != "gj":
                        if label not in class_num_dic:
                            class_num_dic[label] = 1
                        else:
                            class_num_dic[label] += 1

            logger.info_ai(
                meg="prepare data over, get all class name from data file",
                get_ins={"class_num_dic": class_num_dic})

        detection_scores = list(np.arange(0.05, 0.99, 0.05))
        class_scores = list(np.arange(0.05, 0.99, 0.05))
        if config_common.test_other_info_set["show_wrong_data"]: