Пример #1
0
def generator(
        input_images_dir,
        input_gt_dir,
        input_size=512,
        batch_size=12,
        random_scale=np.array([0.8, 0.85, 0.9, 0.95, 1.0, 1.1, 1.2]),
):
    #data_loader = SynthTextLoader()
    data_loader = ICDARLoader(edition='13', shuffle=True)
    #image_list = np.array(data_loader.get_images(FLAGS.training_data_dir))
    image_list = np.array(data_loader.get_images(input_images_dir))
    print('{} training images in {} '.format(image_list.shape[0],
                                             input_images_dir))
    index = np.arange(0, image_list.shape[0])
    while True:
        np.random.shuffle(index)
        batch_images = []
        batch_image_fns = []
        batch_score_maps = []
        batch_geo_maps = []
        batch_training_masks = []

        batch_text_polyses = []
        batch_text_tagses = []
        batch_boxes_masks = []

        batch_text_labels = []
        count = 0
        for i in index:
            try:
                im_fn = image_list[i]
                #print(im_fn)
                # if im_fn.split(".")[0][-1] == '0' or im_fn.split(".")[0][-1] == '2':
                #     continue
                im = cv2.imread(im_fn)
                h, w, _ = im.shape
                #file_name = im_fn.replace(os.path.basename(im_fn).split('.')[1], 'txt').split('/')[-1]
                file_name = os.path.basename(im_fn) + '.txt'
                # file_name = im_fn.replace(im_fn.split('.')[1], 'txt') # using for synthtext
                # txt_fn = os.path.join(FLAGS.training_gt_data_dir, file_name)
                txt_fn = os.path.join(input_gt_dir, file_name)
                if not os.path.exists(txt_fn):
                    print('text file {} does not exists'.format(txt_fn))
                    continue
                #print(txt_fn)
                text_polys, text_tags, text_labels = data_loader.load_annotation(
                    txt_fn)  # Change for load text transiption

                if text_polys.shape[0] == 0:
                    continue

                text_polys, text_tags, text_labels = check_and_validate_polys(
                    text_polys, text_tags, text_labels, (h, w))

                ############################# Data Augmentation ##############################
                # input image:1920*1080 > 960*540
                im = cv2.resize(im, dsize=(960, 540))
                text_polys *= 0.5

                # color aug
                def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
                    x = (np.random.uniform(-1, 1, 3) *
                         np.array([hgain, sgain, vgain]) + 1).astype(
                             np.float32)  # random gains
                    img_hsv = (cv2.cvtColor(img, cv2.COLOR_BGR2HSV) *
                               x.reshape((1, 1, 3))).clip(None,
                                                          255).astype(np.uint8)
                    cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR,
                                 dst=img)  # no return needed

                #augment_hsv(im)

                # random scale this image
                rd_scale = np.random.choice(random_scale)
                im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
                #im = cv2.resize(im, dsize=(960, 540), fx=rd_scale, fy=rd_scale) #头两个resize版本(这样的话做不了数据增强
                text_polys *= rd_scale

                # rotate image from [-10, 10],因为nba都是水平框,数据增强旋转没必要
                #angle = random.randint(-10, 10)
                #im, text_polys = rotate_image(im, text_polys, angle)

                # 600×600 random samples are cropped.
                im, text_polys, text_tags, selected_poly = crop_area(
                    im, text_polys, text_tags, crop_background=False)
                # im, text_polys, text_tags, selected_poly = crop_area_fix(im, text_polys, text_tags, crop_size=(600, 600))
                text_labels = [text_labels[i] for i in selected_poly]
                if text_polys.shape[0] == 0 or len(text_labels) == 0:
                    continue

                # pad the image to the training input size or the longer side of image
                new_h, new_w, _ = im.shape
                max_h_w_i = np.max([new_h, new_w, input_size])
                im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8)
                im_padded[:new_h, :new_w, :] = im.copy()
                im = im_padded
                # resize the image to input size
                new_h, new_w, _ = im.shape
                resize_h = input_size
                resize_w = input_size
                im = cv2.resize(im, dsize=(resize_w, resize_h))
                resize_ratio_3_x = resize_w / float(new_w)
                resize_ratio_3_y = resize_h / float(new_h)
                text_polys[:, :, 0] *= resize_ratio_3_x
                text_polys[:, :, 1] *= resize_ratio_3_y
                new_h, new_w, _ = im.shape

                score_map, geo_map, training_mask, rectangles = generate_rbox(
                    (new_h, new_w), text_polys, text_tags)

                mask = [not (word == [-1]) for word in text_labels]
                text_labels = list(compress(text_labels, mask))
                rectangles = list(compress(rectangles, mask))

                assert len(text_labels) == len(
                    rectangles
                ), "rotate rectangles' num is not equal to text label"

                if len(text_labels) == 0:
                    continue

                boxes_mask = np.array([count] * len(rectangles))

                count += 1

                batch_images.append(im[:, :, ::-1].astype(np.float32))
                batch_image_fns.append(im_fn)
                batch_score_maps.append(score_map[::4, ::4, np.newaxis].astype(
                    np.float32))
                batch_geo_maps.append(geo_map[::4, ::4, :].astype(np.float32))
                batch_training_masks.append(
                    training_mask[::4, ::4, np.newaxis].astype(np.float32))

                batch_text_polyses.append(rectangles)
                batch_boxes_masks.append(boxes_mask)
                batch_text_labels.extend(text_labels)
                batch_text_tagses.append(text_tags)

                if len(batch_images) == batch_size:
                    batch_text_polyses = np.concatenate(batch_text_polyses)
                    batch_text_tagses = np.concatenate(batch_text_tagses)
                    batch_transform_matrixes, batch_box_widths = get_project_matrix_and_width(
                        batch_text_polyses, batch_text_tagses)
                    # TODO limit the batch size of recognition
                    batch_text_labels_sparse = sparse_tuple_from(
                        np.array(batch_text_labels))

                    # yield images, image_fns, score_maps, geo_maps, training_masks
                    yield batch_images, batch_image_fns, batch_score_maps, batch_geo_maps, batch_training_masks, batch_transform_matrixes, batch_boxes_masks, batch_box_widths, batch_text_labels_sparse, batch_text_polyses, batch_text_labels
                    batch_images = []
                    batch_image_fns = []
                    batch_score_maps = []
                    batch_geo_maps = []
                    batch_training_masks = []
                    batch_text_polyses = []
                    batch_text_tagses = []
                    batch_boxes_masks = []
                    batch_text_labels = []
                    count = 0
            except Exception as e:
                import traceback
                #print(im_fn)
                traceback.print_exc()
                continue
Пример #2
0
def generator_all(input_images_dir, input_gt_dir, output_geo_dir, output_score_dir, input_size=512, class_num=230):
    geo_dir = "./"
    # data_loader = SynthTextLoader()
    data_loader = ICDARLoader(edition='13', shuffle=True)
    # image_list = np.array(data_loader.get_images(FLAGS.training_data_dir))
    image_list = np.array(data_loader.get_images(input_images_dir))
    # print('{} training images in {} '.format(image_list.shape[0], FLAGS.training_data_dir))
    index = np.arange(0, image_list.shape[0])
    # while True:
        # np.random.shuffle(index)
    for i in index:
            try:

                im_fn = image_list[i]
                # print(im_fn)
                # if im_fn.split(".")[0][-1] == '0' or im_fn.split(".")[0][-1] == '2':
                #     continue
                im = cv2.imread(os.path.join(input_images_dir, im_fn))
                h, w, _ = im.shape
                file_name = "gt_" + im_fn.replace(os.path.basename(im_fn).split('.')[1], 'txt').split('\\')[-1]
                # file_name = im_fn.replace(im_fn.split('.')[1], 'txt') # using for synthtext
                # txt_fn = os.path.join(FLAGS.training_gt_data_dir, file_name)
                txt_fn = os.path.join(input_gt_dir, file_name)
                if not os.path.exists(txt_fn):
                    print('text file {} does not exists'.format(txt_fn))
                    continue
                # print(txt_fn)
                text_polys, text_tags, text_labels = data_loader.load_annotation(
                    txt_fn)  # Change for load text transiption

                if text_polys.shape[0] == 0:
                    continue

                text_polys, text_tags, text_labels = check_and_validate_polys(text_polys, text_tags, text_labels,
                                                                              (h, w))

                # prepare_one_img_time = time.time()
                # pad the image to the training input size or the longer side of image
                new_h, new_w, _ = im.shape
                max_h_w_i = np.max([new_h, new_w, input_size])
                im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8)
                im_padded[:new_h, :new_w, :] = im.copy()
                im = im_padded
                # resize the image to input size
                new_h, new_w, _ = im.shape
                resize_h = input_size
                resize_w = input_size
                im = cv2.resize(im, dsize=(resize_w, resize_h))
                resize_ratio_3_x = resize_w / float(new_w)
                resize_ratio_3_y = resize_h / float(new_h)
                text_polys[:, :, 0] *= resize_ratio_3_x
                text_polys[:, :, 1] *= resize_ratio_3_y
                new_h, new_w, _ = im.shape


                score_map, geo_map, training_mask, rectangles = generate_maps((new_h, new_w), text_polys, text_tags)
                # rectangles_list = []
                # text_polys = text_polys.astype(np.float32)
                # for i in range(2):
                #     rectangles_list.append(text_polys[i].flatten())
                im_name = (im_fn.split("\\")[-1]).split(".")[0]
                print(im_name)
                with open("{}/{}.pickle".format(output_geo_dir, im_name), "wb") as f_out:
                    pickle.dump(geo_map, f_out)
                f_out.close()
                with open("{}/{}.pickle".format(output_score_dir, im_name), "wb") as f_out:
                    pickle.dump(score_map, f_out)
                f_out.close()

                # # turn the label to one hot representation
                # text_labels = [str(i) for i in text_labels[0]]
                # label = ''.join(text_labels)
                # one_hot_label = dense_to_one_hot(int(label), n_classes=class_num)
                # print(label)

                # start_time = time.time()
                # with open("{}/{}.pickle".format(output_geo_dir, im_name), "rb") as f_in:
                #     in_geo_map = pickle.load(f_in)
                #
                # print("\nprepare_one_img_cost:" + str(time.time() - start_time))
                # print(in_geo_map.shape)



            except Exception as e:
                import traceback
                print(im_fn)
                traceback.print_exc()
                continue
Пример #3
0
def generator(input_images_dir,
              input_gt_dir,
              input_size=512,
              batch_size=12,
              random_scale=np.array([0.8, 0.85, 0.9, 0.95, 1.0, 1.1, 1.2])):

    data_loader = ICDARLoader()
    image_list = np.array(data_loader.get_images(input_images_dir))
    print('{} training images in {} '.format(image_list.shape[0],
                                             input_images_dir))
    index = np.arange(0, image_list.shape[0])
    while True:
        np.random.shuffle(index)
        batch_images = []
        batch_image_fns = []
        batch_score_maps = []
        batch_geo_maps = []
        batch_training_masks = []

        batch_text_polyses = []
        batch_text_tagses = []
        batch_boxes_masks = []

        batch_text_labels = []
        count = 0
        for i in index:
            try:
                im_fn = image_list[i]
                # print(im_fn)
                # if im_fn.split(".")[0][-1] == '0' or im_fn.split(".")[0][-1] == '2':
                #     continue
                im = cv2.imread(im_fn)
                h, w, _ = im.shape
                file_name = "gt_" + im_fn.replace(
                    os.path.basename(im_fn).split('.')[1],
                    'txt').split('/')[-1]
                txt_fn = os.path.join(input_gt_dir, file_name)
                if not os.path.exists(txt_fn):
                    print('text file {} does not exists'.format(txt_fn))
                    continue
                #print(txt_fn)
                # Change for load text transiption
                text_polys, text_tags, text_labels = data_loader.load_annotation(
                    txt_fn)

                if text_polys.shape[0] == 0:
                    continue

                text_polys, text_tags, text_labels = check_and_validate_polys(
                    text_polys, text_tags, text_labels, (h, w))

                ############################# Data Augmentation ##############################
                # random scale this image
                rd_scale = np.random.choice(random_scale)
                im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
                text_polys *= rd_scale

                # rotate image from [-10, 10]
                angle = random.randint(-10, 10)
                im, text_polys = rotate_image(im, text_polys, angle)

                # 600×600 random samples are cropped.
                im, text_polys, text_tags, selected_poly = crop_area(
                    im, text_polys, text_tags, crop_background=False)
                # im, text_polys, text_tags, selected_poly = crop_area_fix(im,
                #                                                           text_polys,
                #                                                           text_tags,
                #                                                           crop_size=(600, 600))
                text_labels = [text_labels[i] for i in selected_poly]
                if text_polys.shape[0] == 0 or len(text_labels) == 0:
                    continue

                # pad the image to the training input size or the longer side of image
                new_h, new_w, _ = im.shape
                max_h_w_i = np.max([new_h, new_w, input_size])
                im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8)
                im_padded[:new_h, :new_w, :] = im.copy()
                im = im_padded
                # resize the image to input size
                new_h, new_w, _ = im.shape
                resize_h = input_size
                resize_w = input_size
                im = cv2.resize(im, dsize=(resize_w, resize_h))
                resize_ratio_3_x = resize_w / float(new_w)
                resize_ratio_3_y = resize_h / float(new_h)
                text_polys[:, :, 0] *= resize_ratio_3_x
                text_polys[:, :, 1] *= resize_ratio_3_y
                new_h, new_w, _ = im.shape

                score_map, geo_map, training_mask, rectangles = generate_rbox(
                    (new_h, new_w), text_polys, text_tags)

                mask = [not (word == [-1]) for word in text_labels]
                text_labels = list(compress(text_labels, mask))
                rectangles = list(compress(rectangles, mask))

                assert len(text_labels) == len(
                    rectangles
                ), "rotate rectangles' num is not equal to text label"

                if len(text_labels) == 0:
                    continue

                boxes_mask = np.array([count] * len(rectangles))

                count += 1

                batch_images.append(im[:, :, ::-1].astype(np.float32))
                batch_image_fns.append(im_fn)
                batch_score_maps.append(score_map[::4, ::4, np.newaxis].astype(
                    np.float32))
                batch_geo_maps.append(geo_map[::4, ::4, :].astype(np.float32))
                batch_training_masks.append(
                    training_mask[::4, ::4, np.newaxis].astype(np.float32))

                batch_text_polyses.append(rectangles)
                batch_boxes_masks.append(boxes_mask)
                batch_text_labels.extend(text_labels)
                batch_text_tagses.append(text_tags)

                if len(batch_images) == batch_size:
                    batch_text_polyses = np.concatenate(batch_text_polyses)
                    batch_text_tagses = np.concatenate(batch_text_tagses)
                    batch_transform_matrixes, batch_box_widths = get_project_matrix_and_width(
                        batch_text_polyses, batch_text_tagses)
                    # TODO limit the batch size of recognition
                    batch_text_labels_sparse = sparse_tuple_from(
                        np.array(batch_text_labels))

                    # yield images, image_fns, score_maps, geo_maps, training_masks
                    yield (batch_images, batch_image_fns, batch_score_maps,
                           batch_geo_maps, batch_training_masks,
                           batch_transform_matrixes, batch_boxes_masks,
                           batch_box_widths, batch_text_labels_sparse,
                           batch_text_polyses, batch_text_labels)

                    batch_images = []
                    batch_image_fns = []
                    batch_score_maps = []
                    batch_geo_maps = []
                    batch_training_masks = []
                    batch_text_polyses = []
                    batch_text_tagses = []
                    batch_boxes_masks = []
                    batch_text_labels = []
                    count = 0
            except Exception as e:
                import traceback
                print(im_fn)
                traceback.print_exc()
                continue
Пример #4
0
def generator(input_images_dir, input_gt_dir, input_size=512, batch_size=12, class_num=230,
              random_scale=np.array([0.8, 0.85, 0.9, 0.95, 1.0, 1.1, 1.2]), back_side=0):
    # data_loader = SynthTextLoader()
    data_loader = ICDARLoader(edition='13', shuffle=True)
    # image_list = np.array(data_loader.get_images(FLAGS.training_data_dir))
    image_list = np.array(data_loader.get_images(input_images_dir))
    # print('{} training images in {} '.format(image_list.shape[0], FLAGS.training_data_dir))
    index = np.arange(0, image_list.shape[0])
    while True:
        np.random.shuffle(index)
        batch_images = []
        batch_image_fns = []
        batch_score_maps = []
        batch_geo_maps = []
        batch_training_masks = []

        batch_text_polyses = []
        batch_text_tagses = []
        batch_boxes_masks = []

        batch_text_labels = []
        count = 0
        for i in index:
            try:

                start_time = time.time()
                im_fn = image_list[i]
                # print(im_fn)
                # if im_fn.split(".")[0][-1] == '0' or im_fn.split(".")[0][-1] == '2':
                #     continue
                im = cv2.imread(os.path.join(input_images_dir, im_fn))
                h, w, _ = im.shape
                file_name = "gt_" + im_fn.replace(os.path.basename(im_fn).split('.')[1], 'txt').split('/')[-1]
                # file_name = im_fn.replace(im_fn.split('.')[1], 'txt') # using for synthtext
                # txt_fn = os.path.join(FLAGS.training_gt_data_dir, file_name)
                txt_fn = os.path.join(input_gt_dir, file_name)
                if not os.path.exists(txt_fn):
                    print('text file {} does not exists'.format(txt_fn))
                    continue
                # print(txt_fn)
                text_polys, text_tags, text_labels = data_loader.load_annotation(
                    txt_fn)  # Change for load text transiption

                if text_polys.shape[0] == 0:
                    continue

                text_polys, text_tags, text_labels = check_and_validate_polys(text_polys, text_tags, text_labels,
                                                                              (h, w), back_side)

                ############################# Data Augmentation ##############################
                '''
                # random scale this image
                rd_scale = np.random.choice(random_scale)
                im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
                text_polys *= rd_scale

                # rotate image from [-10, 10]
                angle = random.randint(-10, 10)
                im, text_polys = rotate_image(im, text_polys, angle)

                # 600×600 random samples are cropped.
                im, text_polys, text_tags, selected_poly = crop_area(im, text_polys, text_tags, crop_background=False)
                # im, text_polys, text_tags, selected_poly = crop_area_fix(im, text_polys, text_tags, crop_size=(600, 600))
                text_labels = [text_labels[i] for i in selected_poly]
                if text_polys.shape[0] == 0 or len(text_labels) == 0:
                    continue
                '''



                ################################################################################
                # prepare_one_img_time = time.time()
                # pad the image to the training input size or the longer side of image
                new_h, new_w, _ = im.shape
                max_h_w_i = np.max([new_h, new_w, input_size])
                im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8)
                im_padded[:new_h, :new_w, :] = im.copy()
                im = im_padded
                # resize the image to input size
                new_h, new_w, _ = im.shape
                resize_h = input_size
                resize_w = input_size
                im = cv2.resize(im, dsize=(resize_w, resize_h))
                resize_ratio_3_x = resize_w / float(new_w)
                resize_ratio_3_y = resize_h / float(new_h)
                text_polys[:, :, 0] *= resize_ratio_3_x
                text_polys[:, :, 1] *= resize_ratio_3_y
                new_h, new_w, _ = im.shape


                training_mask, rectangles = generate_rbox((new_h, new_w), text_polys, text_tags)
                rectangles_list = []
                text_polys = text_polys.astype(np.float32)
                # for i in range(2):
                rectangles_list.append(text_polys.flatten())

                im_name = (im_fn.split("/")[-1]).split(".")[0]
                geo_dir = "{}/geo_map".format(im_fn.split("JPEGImages")[0])
                with open("{}/{}.pickle".format(geo_dir, im_name), "rb") as f_in:
                    geo_map = pickle.load(f_in)
                f_in.close()
                score_dir = "{}/score_map".format(im_fn.split("JPEGImages")[0])
                with open("{}/{}.pickle".format(score_dir, im_name), "rb") as f_in:
                    score_map = pickle.load(f_in)
                f_in.close()
                # print("\nprepare_one_img_cost:" + str(time.time() - start_time))

                mask = [not (word == [-1]) for word in text_labels]
                # remove the unreadable text
                text_labels = list(compress(text_labels, mask))
                rectangles = list(compress(rectangles_list, mask))

                assert len(text_labels) == len(rectangles), "rotate rectangles' num is not equal to text label"

                if len(text_labels) == 0:
                    continue

                # turn the label to one hot representation
                text_labels = [str(i) for i in text_labels[0]]
                label = ''.join(text_labels)
                one_hot_label = dense_to_one_hot(int(label), n_classes=class_num)

                boxes_mask = np.array([count] * len(rectangles))

                count += 1

                batch_images.append(im[:, :, ::-1].astype(np.float32))
                batch_image_fns.append(im_fn)
                batch_score_maps.append(score_map[::4, ::4, np.newaxis].astype(np.float32))
                batch_geo_maps.append(geo_map[::4, ::4, :].astype(np.float32))
                batch_training_masks.append(training_mask[::4, ::4, np.newaxis].astype(np.float32))

                batch_text_polyses.append(rectangles)
                batch_boxes_masks.append(boxes_mask)
                batch_text_labels.extend(one_hot_label)
                batch_text_tagses.append(text_tags)


                if len(batch_images) == batch_size:
                    batch_text_polyses = np.concatenate(batch_text_polyses)
                    batch_text_tagses = np.concatenate(batch_text_tagses)
                    batch_transform_matrixes, batch_box_widths = get_project_matrix(batch_text_polyses,
                                                                                    batch_text_tagses)
                    # TODO limit the batch size of recognition 
                    batch_text_labels_sparse = sparse_tuple_from(np.array(batch_text_labels))

                    # yield images, image_fns, score_maps, geo_maps, training_masks
                    yield batch_images, batch_image_fns, batch_score_maps, batch_geo_maps, batch_training_masks, batch_transform_matrixes, batch_boxes_masks, batch_box_widths, batch_text_labels_sparse, batch_text_polyses, batch_text_labels
                    batch_images = []
                    batch_image_fns = []
                    batch_score_maps = []
                    batch_geo_maps = []
                    batch_training_masks = []
                    batch_text_polyses = []
                    batch_text_tagses = []
                    batch_boxes_masks = []
                    batch_text_labels = []
                    count = 0
            except Exception as e:
                import traceback
                print(im_fn)
                traceback.print_exc()
                continue