示例#1
0
    def __getitem__(self, idx):
        item = self.data[idx]

        img = cv2.imread(item['im'])

        if img is None:
            print("Warning: image is None:", item['im'])
            return None

        img = cv2.resize(img, (self.img_height, self.img_height),
                         interpolation=cv2.INTER_CUBIC
                         )  #Why are we resizing the image again?

        # image augmentation (this basically distorts the image so we can have more training data
        # without hand-generating it)
        if self.augmentation:
            img = grid_distortion.warp_image(img)

        img = img.astype(np.float32)
        img = img / 128.0 - 1.0

        gt = item['gt']

        # return a dict object with these elements
        return {"line_img": img, "gt": gt, 'filename': item['im']}
示例#2
0
    def __getitem__(self, idx):
        item = self.data[idx]
        # cache current image so we don't waste time reopening
        if item[0] != self.cur_img_name or self.cur_img is None:
            self.cur_img_name = item[0]
            self.cur_img = cv2.imread(
                os.path.join(self.root_path, item[0] + '.jpg'))

        # get coords for cropping
        coords = [int(coord) for coord in item[2:6]]
        img = self.cur_img[coords[2]:coords[3], coords[0]:coords[1]]
        gt = item[1]

        if img is None or img.shape[0] == 0:
            print("Warning: image is None:",
                  os.path.join(self.root_path, item[0]))
            return None

        percent = float(self.img_height) / img.shape[0]
        img = cv2.resize(img, (0, 0),
                         fx=percent,
                         fy=percent,
                         interpolation=cv2.INTER_CUBIC)

        if self.augmentation:
            img = grid_distortion.warp_image(img)

        img = img.astype(np.float32)
        img = img / 128.0 - 1.0

        gt_label = string_utils.str2label(gt, self.char_to_idx)

        return {"line_img": img, "gt_label": gt_label, "gt": gt}
示例#3
0
文件: hw_dataset.py 项目: jsw800/hwr
    def __getitem__(self, idx):
        item = self.data[idx]

        img = cv2.imread(os.path.join(self.root_path, item['image_path']))

        if img is None:
            print("Warning: image is None:",
                  os.path.join(self.root_path, item['image_path']))
            return None

        percent = float(self.img_height) / img.shape[0]
        img = cv2.resize(img, (0, 0),
                         fx=percent,
                         fy=percent,
                         interpolation=cv2.INTER_CUBIC)

        if self.augmentation:
            img = grid_distortion.warp_image(img)

        img = img.astype(np.float32)
        img = img / 128.0 - 1.0

        gt = item['gt']
        gt_label = string_utils.str2label(gt, self.char_to_idx)

        return {"line_img": img, "gt_label": gt_label, "gt": gt}
示例#4
0
def combine_images(list_im, name, background, savefolder):
    rect_imgs = []
    for im in list_im:
        x, y, width, height = cv2.boundingRect(255 - np.asarray(im))
        rect_imgs.append(im[y:y + height, x:x + width])

    # pick the image which is the smallest, and resize the others to match it
    # min_shape = sorted([(np.sum(np.shape(i)), np.shape(i)) for i in rect_imgs])[0][1]
    min_shape = sorted([(np.shape(i)[0], np.shape(i))
                        for i in rect_imgs])[0][1]
    print("Min shape: {}".format(min_shape))
    rect_imgs_resized = [
        cv2.resize(img, (np.shape(img)[1], min_shape[0])) for img in rect_imgs
    ]

    stack_imgs = []
    prev = []
    for img in rect_imgs_resized:
        padding = np.ones((min_shape[0], np.random.randint(20))) * 255
        if (len(prev) != 0):
            if (np.random.randint(100) <= 80):
                tmp = np.hstack((np.asarray(img), padding))
            else:
                tmp = random_touching(img, prev)
        else:
            tmp = np.hstack((np.asarray(img), padding))

        stack_imgs.append(random_resize(tmp))
        prev = img

    imgs_comb = np.hstack((np.asarray(i)) for i in stack_imgs)
    imgs_comb = PIL.Image.fromarray((imgs_comb).astype(np.uint8))
    imgs_comb.save(os.path.join(savefolder, 'tmp.png'))

    im = cv2.imread(os.path.join(savefolder, 'tmp.png'))
    # im = cv2.cvtColor(imgs_comb, cv2.COLOR_GRAY2BGR)
    new_im = warp_image(im, draw_grid_lines=False)
    # cv2.imwrite(os.path.join(savefolder, "non_background_{}.png".format(name)), new_im)
    # print(np.shape(new_im))
    # plt.imshow(new_im)
    # plt.show()
    methods = ['gaussian', 'localvar', 'poisson', 'salt', 's&p', 'speckle']
    backgrounds = os.listdir(background)
    # print(backgrounds)
    bg_im = cv2.imread(
        os.path.join(background,
                     backgrounds[np.random.randint(len(backgrounds))]))
    bg_im = cv2.resize(bg_im, np.shape(new_im)[:2][::-1])
    merge = bg_im + new_im
    noise_img = add_bg_noise(merge, methods[np.random.randint(len(methods))])
    # plt.imshow(noise_img)
    # plt.show()
    cv2.imwrite(os.path.join(savefolder, '{}.png'.format(name)),
                (noise_img * 255).astype(np.uint8))
    os.remove(os.path.join(savefolder, 'tmp.png'))
    def __getitem__(self, idx):
        ids_idx, line_idx = self.detailed_ids[idx]
        gt_json_path, img_path = self.ids[ids_idx]
        gt_json = safe_load.json_state(gt_json_path)
        if gt_json is None:
            return None

        if 'hw_path' not in gt_json[line_idx]:
            return None

        hw_path = gt_json[line_idx]['hw_path']

        hw_path = hw_path.split("/")[-1:]
        hw_path = "/".join(hw_path)

        hw_folder = os.path.dirname(gt_json_path)

        img = cv2.imread(os.path.join(hw_folder, hw_path))

        if img is None:
            return None

        if img.shape[0] != self.img_height:
            if img.shape[0] < self.img_height and not self.warning:
                self.warning = True
                print "WARNING: upsampling image to fit size"
            percent = float(self.img_height) / img.shape[0]
            img = cv2.resize(img, (0,0), fx=percent, fy=percent, interpolation = cv2.INTER_CUBIC)

        if img is None:
            return None

        if self.augmentation:
            img = augmentation.apply_random_color_rotation(img)
            img = augmentation.apply_tensmeyer_brightness(img)
            img = grid_distortion.warp_image(img)

        img = img.astype(np.float32)
        img = img / 128.0 - 1.0

        gt = gt_json[line_idx]['gt']
        if len(gt) == 0:
            return None
        gt_label = string_utils.str2label_single(gt, self.char_to_idx)


        return {
            "line_img": img,
            "gt": gt,
            "gt_label": gt_label
        }
示例#6
0
    def __getitem__(self, idx):
        item = self.data[idx]

        image_path = os.path.join(self.root, item['image_path'])
        if self.num_of_channels == 3:
            img = cv2.imread(image_path)
        elif self.num_of_channels == 1:  # read grayscale
            img = cv2.imread(image_path, 0)
        else:
            raise Exception("Unexpected number of channels")
        if img is None:
            print("Warning: image is None:",
                  os.path.join(self.root, item['image_path']))
            return None

        percent = float(self.img_height) / img.shape[0]

        #if random.randint(0, 1):
        #    img = cv2.resize(img, (0,0), fx=percent, fy=percent, interpolation = cv2.INTER_CUBIC)
        #else:
        img = cv2.resize(img, (0, 0),
                         fx=percent,
                         fy=percent,
                         interpolation=cv2.INTER_CUBIC)

        if self.warp:
            img = grid_distortion.warp_image(img)

        # Add channel dimension, since resize and warp only keep non-trivial channel axis
        if self.num_of_channels == 1:
            img = img[:, :, np.newaxis]

        img = img.astype(np.float32)
        img = img / 128.0 - 1.0

        gt = item['gt']  # actual text
        gt_label = string_utils.str2label(
            gt, self.char_to_idx)  # character indices of text
        #online = item.get('online', False)
        # THIS IS A HACK, FIX THIS (below)
        online = int(item['actual_writer_id']) > 700

        return {
            "line_img": img,
            "gt_label": gt_label,
            "gt": gt,
            "actual_writer_id": int(item['actual_writer_id']),
            "writer_id": int(item['writer_id']),
            "path": image_path,
            "online": online
        }
示例#7
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        index += 1
        with self.env.begin(write=False) as txn:
            img_key = 'image-%09d' % index
            if self.binarize:
                howe_imageKey = 'howe-image-%09d' % index
                simplebin_imageKey = 'simplebin-image-%09d' % index

            imgbuf = txn.get(img_key)
            buf = six.BytesIO()
            buf.write(imgbuf)
            buf.seek(0)
            try:
                img = Image.open(buf).convert('L')
            except IOError:
                print('Corrupted image for %d' % index)
                return self[index + 1]
            if self.binarize:
                imgbuf = txn.get(howe_imageKey)
                buf = six.BytesIO()
                buf.write(imgbuf)
                buf.seek(0)
                try:
                    img_howe = Image.open(buf).convert('L')
                except IOError:
                    print('Corrupted image for %d' % index)
                    return self[index + 1]
                imgbuf = txn.get(simplebin_imageKey)
                buf = six.BytesIO()
                buf.write(imgbuf)
                buf.seek(0)
                try:
                    img_simplebin = Image.open(buf).convert('L')
                except IOError:
                    print('Corrupted image for %d' % index)
                    return self[index + 1]

            label_key = 'label-%09d' % index
            label = unicode(
                txn.get(label_key), encoding=encoding
            ) if not self.test else u''  # Hopefully this still works with unicode

            file_key = 'file-%09d' % index
            file_name = str(txn.get(file_key))

            if self.target_transform is not None:
                label = self.target_transform(label)
            if self.binarize:
                if (img.size[0] != img_howe.size[0] or img.size[1] !=
                        img_howe.size[1]):  # need to resize the howe image
                    img_howe = pad_size(img_howe, img.size)

            final_image = Image.merge(
                "RGB",
                (img, img_howe, img_simplebin)) if self.binarize else img
            if self.augment:
                from grid_distortion import warp_image
                if self.dataset == 'READ':
                    #                        sets. We place
                    #the control points on intervals of 26 pixels (slightly larger than
                    #the average baseline height) and perturbed the points about a
                    #normal distribution with a standard deviation of 1.7 pixels.
                    #These parameters are for images with a height of 80 pixels
                    # params chosen based on BYU Data Augmentation Paper by Wigington et al.
                    _, h = final_image.size
                    mesh_i = h / 80.0 * 26
                    std = h / 80.0 * 1.7
                    final_image = Image.fromarray(
                        warp_image(np.array(final_image),
                                   w_mesh_interval=mesh_i,
                                   h_mesh_interval=mesh_i,
                                   w_mesh_std=std,
                                   h_mesh_std=std))
                else:
                    _, h = final_image.size
                    mesh_i = h / 80.0 * 26
                    std = h / 80.0 * 1.7
                    final_image = Image.fromarray(
                        warp_image(np.array(final_image),
                                   w_mesh_interval=mesh_i,
                                   h_mesh_interval=mesh_i,
                                   w_mesh_std=std,
                                   h_mesh_std=std))

            # Randomly resize the image

            if self.scale:
                s = random.uniform(1.0 / self.scale_dim, self.scale_dim)
                w, h = final_image.size
                ar = float(w) / h
                new_h = int(round(s * h))
                new_w = int(round(ar * new_h))
                final_image = final_image.resize((new_w, new_h),
                                                 resample=Image.BILINEAR)

            if self.transform is not None:
                final_image = self.transform(final_image)

            DEBUG = False  #self.debug
            if DEBUG:
                print("The image has shape:")
                print(np.array(final_image).shape)

            return (final_image, label, file_name)