示例#1
0
    def __getitem__(self, idx):
        img_path, gt = self.lines[idx]

        img = cv2.imread(img_path, 0)

        if img is None:
            return None

        if img.shape[0] != self.img_height:
            if img.shape[0] < self.img_height and not self.warning:
                self.warning = True
                print("WARNING: upsampling image to fit size")
            percent = float(self.img_height) / img.shape[0]
            img = cv2.resize(img, (0, 0),
                             fx=percent,
                             fy=percent,
                             interpolation=cv2.INTER_CUBIC)

        if img is None:
            return None

        if self.augmentation:
            img = augmentation.apply_random_color_rotation(img)
            img = augmentation.apply_tensmeyer_brightness(img)
            img = grid_distortion.warp_image(img)

        img = img.astype(np.float32)
        img = img / 128.0 - 1.0
        img = img[..., None]

        if len(gt) == 0:
            return None
        gt_label = string_utils.str2label_single(gt, self.char_to_idx)

        return {"line_img": img, "gt": gt, "gt_label": gt_label}
    def __getitem__(self, idx):

        gt_json_path, img_path = self.ids[idx]

        gt_json = safe_load.json_state(gt_json_path)
        if gt_json is None:
            return None

        # print('img_path: {}'.format(img_path))
        org_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        # print('img.size: {}'.format(org_img.shape))
        # median = np.median(org_img, axis=(0,1))
        # org_img = cv2.copyMakeBorder(org_img,100,100,100,100,cv2.BORDER_CONSTANT,value=median)
        target_dim1 = int(np.random.uniform(self.rescale_range[0], self.rescale_range[1]))

        s = target_dim1 / float(org_img.shape[1])
        target_dim0 = int(org_img.shape[0]/float(org_img.shape[1]) * target_dim1)
        org_img = cv2.resize(org_img,(target_dim1, target_dim0), interpolation = cv2.INTER_CUBIC)

        gt = np.zeros((1,len(gt_json['corners']), 4), dtype=np.float32)

        for j, gt_item in enumerate(gt_json['corners']):

            x0 = gt_item[0]
            x1 = gt_item[0]
            y0 = gt_item[1]
            y1 = gt_item[1]

            gt[:,j,0] = x0 * s
            gt[:,j,1] = y0 * s
            gt[:,j,2] = x1 * s
            gt[:,j,3] = y1 * s

        if self.transform is not None:
            out = self.transform({
                "img": org_img,
                "sol_gt": gt
            })
            org_img = out['img']
            gt = out['sol_gt']

            org_img = augmentation.apply_random_color_rotation(org_img)
            org_img = augmentation.apply_tensmeyer_brightness(org_img)
            org_img = augmentation.apply_random_blur(org_img)


        img = org_img.transpose([2,1,0])[None,...]
        img = img.astype(np.float32)
        img = torch.from_numpy(img)
        img = img / 128.0 - 1.0

        if gt.shape[1] == 0:
            gt = None
        else:
            gt = torch.from_numpy(gt)

        return {
            "img": img,
            "sol_gt": gt
        }
示例#3
0
    def __getitem__(self, idx):

        gt_json_path, img_path = self.ids[idx]

        gt_json = safe_load.json_state(gt_json_path)
        if gt_json is None:
            return None

        org_img = cv2.imread(img_path)
        target_dim1 = int(np.random.uniform(self.rescale_range[0], self.rescale_range[1]))

        s = target_dim1 / float(org_img.shape[1])
        target_dim0 = int(org_img.shape[0]/float(org_img.shape[1]) * target_dim1)
        org_img = cv2.resize(org_img,(target_dim1, target_dim0), interpolation = cv2.INTER_CUBIC)

        gt = np.zeros((1,len(gt_json), 4), dtype=np.float32)

        for j, gt_item in enumerate(gt_json):
            if 'sol' not in gt_item:
                continue

            x0 = gt_item['sol']['x0']
            x1 = gt_item['sol']['x1']
            y0 = gt_item['sol']['y0']
            y1 = gt_item['sol']['y1']

            gt[:,j,0] = x0 * s
            gt[:,j,1] = y0 * s
            gt[:,j,2] = x1 * s
            gt[:,j,3] = y1 * s

        if self.transform is not None:
            out = self.transform({
                "img": org_img,
                "sol_gt": gt
            })
            org_img = out['img']
            gt = out['sol_gt']


            org_img = augmentation.apply_random_color_rotation(org_img)
            org_img = augmentation.apply_tensmeyer_brightness(org_img)


        img = org_img.transpose([2,1,0])[None,...]
        img = img.astype(np.float32)
        img = torch.from_numpy(img)
        img = img / 128.0 - 1.0

        if gt.shape[1] == 0:
            gt = None
        else:
            gt = torch.from_numpy(gt)

        return {
            "img": img,
            "sol_gt": gt
        }
示例#4
0
    def __getitem__(self, idx):

        ids_idx, line_idx = self.detailed_ids[idx]
        gt_json_path, img_path = self.ids[ids_idx]
        gt_json = safe_load.json_state(gt_json_path)

        positions = []
        positions_xy = []

        if 'lf' not in gt_json[line_idx]:
            return None

        for step in gt_json[line_idx]['lf']:
            x0 = step['x0']
            x1 = step['x1']
            y0 = step['y0']
            y1 = step['y1']

            positions_xy.append((torch.Tensor([[x1, x0], [y1, y0]])))

            dx = x0 - x1
            dy = y0 - y1

            d = math.sqrt(dx**2 + dy**2)

            mx = (x0 + x1) / 2.0
            my = (y0 + y1) / 2.0

            #Not sure if this is right...
            theta = -math.atan2(dx, -dy)

            positions.append(torch.Tensor([mx, my, theta, d / 2, 1.0]))

        img = cv2.imread(img_path)
        if self.augmentation:
            img = augmentation.apply_random_color_rotation(img)
            img = augmentation.apply_tensmeyer_brightness(img)

        img = img.astype(np.float32)
        img = img.transpose()
        img = img / 128.0 - 1.0
        img = torch.from_numpy(img)

        gt = gt_json[line_idx]['gt']

        result = {
            "img": img,
            "lf_xyrs": positions,
            "lf_xyxy": positions_xy,
            "gt": gt
        }
        return result
    def __getitem__(self, idx):
        ids_idx, line_idx = self.detailed_ids[idx]
        gt_json_path, img_path = self.ids[ids_idx]
        gt_json = safe_load.json_state(gt_json_path)
        if gt_json is None:
            return None

        if 'hw_path' not in gt_json[line_idx]:
            return None

        hw_path = gt_json[line_idx]['hw_path']

        hw_path = hw_path.split("/")[-1:]
        hw_path = "/".join(hw_path)

        hw_folder = os.path.dirname(gt_json_path)

        img = cv2.imread(os.path.join(hw_folder, hw_path))

        if img is None:
            return None

        if img.shape[0] != self.img_height:
            if img.shape[0] < self.img_height and not self.warning:
                self.warning = True
                print "WARNING: upsampling image to fit size"
            percent = float(self.img_height) / img.shape[0]
            img = cv2.resize(img, (0,0), fx=percent, fy=percent, interpolation = cv2.INTER_CUBIC)

        if img is None:
            return None

        if self.augmentation:
            img = augmentation.apply_random_color_rotation(img)
            img = augmentation.apply_tensmeyer_brightness(img)
            img = grid_distortion.warp_image(img)

        img = img.astype(np.float32)
        img = img / 128.0 - 1.0

        gt = gt_json[line_idx]['gt']
        if len(gt) == 0:
            return None
        gt_label = string_utils.str2label_single(gt, self.char_to_idx)


        return {
            "line_img": img,
            "gt": gt,
            "gt_label": gt_label
        }
    def __getitem__(self, idx):

        img_path = os.path.join(self.directory, '{}.png'.format(idx))
        img = cv2.imread(img_path, 0)

        if len(img.shape) == 2:
            img = img[..., None]
        if self.augmentation is not None:
            #img = augmentation.apply_random_color_rotation(img)
            if 'brightness' in self.augmentation:
                img = augmentation.apply_tensmeyer_brightness(img)
            if 'warp' in self.augmentation:
                img = grid_distortion.warp_image(img)
        if len(img.shape) == 2:
            img = img[..., None]

        img = img.astype(np.float32)
        img = 1.0 - img / 128.0

        gt = self.labels[idx]
        if gt is None:
            #metadata = pyexiv2.ImageMetadata(img_path)
            #metadata.read()
            #metadata = piexif.load(img_path)
            #if 'gt' in metadata:
            #    gt = metadata['gt']
            #else:
            print('Error unknown label for image: {}'.format(img_path))
            return self.__getitem__((idx + 7) % self.set_size)

        gt_label = string_utils.str2label_single(gt, self.char_to_idx)

        return {
            "image": img,
            "gt": gt,
            "gt_label": gt_label,
            #"author": author
        }
    def __getitem__(self, idx):

        inst = self.lineIndex[idx]
        author = inst[0]
        words = inst[1]
        batch = []
        for word in words:
            if word >= len(self.w_authors[author]):
                word = (word + 37) % len(self.w_authors[author])
            img_path, lb, gt, id = self.w_authors[author][word]
            img = cv2.imread(img_path,
                             0)[lb[0]:lb[1],
                                lb[2]:lb[3]]  #read as grayscale, crop word
            if img.shape[0] == 0 or img.shape[1] == 0:
                return self.__getitem__((idx + 1) % self.__len__())

            if img is None:
                return None

            if img.shape[0] != self.img_height:
                if img.shape[0] < self.img_height and not self.warning:
                    self.warning = True
                    print("WARNING: upsampling image to fit size")
                percent = float(self.img_height) / img.shape[0]
                img = cv2.resize(img, (0, 0),
                                 fx=percent,
                                 fy=percent,
                                 interpolation=cv2.INTER_CUBIC)

            if img is None:
                return None

            if len(img.shape) == 2:
                img = img[..., None]
            if self.augmentation is not None:
                #img = augmentation.apply_random_color_rotation(img)
                img = augmentation.apply_tensmeyer_brightness(img)
                img = grid_distortion.warp_image(img)

            img = img.astype(np.float32)
            img = 1.0 - img / 128.0

            if len(gt) == 0:
                return None
            gt_label = string_utils.str2label_single(gt, self.char_to_idx)

            if self.styles:
                style_i = self.npr.choice(len(self.styles[author][id]))
                style = self.styles[author][id][style_i]
            else:
                style = None
            batch.append({
                "image": img,
                "gt": gt,
                "style": style,
                "gt_label": gt_label,
                "name": '{}_{}'.format(author, word),
                "author": author
            })
        #batch = [b for b in batch if b is not None]
        #These all should be the same size or error
        assert len(set([b['image'].shape[0] for b in batch])) == 1
        assert len(set([b['image'].shape[2] for b in batch])) == 1

        dim0 = batch[0]['image'].shape[0]
        dim1 = max([b['image'].shape[1] for b in batch])
        dim2 = batch[0]['image'].shape[2]

        all_labels = []
        label_lengths = []

        input_batch = np.full((len(batch), dim0, dim1, dim2),
                              PADDING_CONSTANT).astype(np.float32)
        for i in range(len(batch)):
            b_img = batch[i]['image']
            #toPad = (dim1-b_img.shape[1])
            input_batch[i, :, 0:b_img.shape[1], :] = b_img

            l = batch[i]['gt_label']
            all_labels.append(l)
            label_lengths.append(len(l))

        #all_labels = np.concatenate(all_labels)
        label_lengths = torch.IntTensor(label_lengths)
        max_len = label_lengths.max()
        all_labels = [
            np.pad(l, ((0, max_len - l.shape[0]), ), 'constant')
            for l in all_labels
        ]
        all_labels = np.stack(all_labels, axis=1)

        images = input_batch.transpose([0, 3, 1, 2])
        images = torch.from_numpy(images)
        labels = torch.from_numpy(all_labels.astype(np.int32))
        #label_lengths = torch.from_numpy(label_lengths.astype(np.int32))
        if batch[0]['style'] is not None:
            styles = np.stack([b['style'] for b in batch], axis=0)
            styles = torch.from_numpy(styles).float()
        else:
            styles = None

        return {
            "image": images,
            "mask": makeMask(images, self.mask_post),
            "label": labels,
            "style": styles,
            "label_lengths": label_lengths,
            "gt": [b['gt'] for b in batch],
            "name": [b['name'] for b in batch],
            "author": [b['author'] for b in batch]
        }
    def __getitem__(self, idx):
        if type( self.augmentation) is str and 'affine' in self.augmentation:
            strech = (self.max_strech*2)*np.random.random() - self.max_strech +1
            #self.max_rot_rad = self.max_rot_deg/180 * np.pi
            skew = (self.max_rot_rad*2)*np.random.random() - self.max_rot_rad
        if self.include_stroke_aug:
            thickness_change= np.random.randint(-4,5)
            fg_shade = np.random.random()*0.25 + 0.75
            bg_shade = np.random.random()*0.2
            blur_size = np.random.randint(2,4)
            noise_sigma = np.random.random()*0.02

        batch=[]

        if self.triplet=='hard':
            authors = random.sample(self.authors.keys(),self.triplet_author_size)
            alines=[]
            for author in authors:
                if len(self.authors[author])>=self.triplet_sample_size*self.batch_size:
                    lines = random.sample(range(len(self.authors[author])),self.triplet_sample_size*self.batch_size)
                else:
                    lines = list(range(len(self.authors[author])))
                    random.shuffle(lines)
                    dif = self.triplet_sample_size*self.batch_size-len(self.authors[author])
                    lines += lines[:dif]
                alines += [(author,l) for l in lines]
        else:


            inst = self.lineIndex[idx]
            author=inst[0]
            lines=inst[1]


            alines = [(author,l) for l in lines]
            used_lines = set(lines)
            if self.triplet:
                if len(self.authors[author])<=2*self.batch_size:
                    for l in range(len(self.authors[author])):
                        if l not in used_lines:
                            alines.append((author,l))
                    if len(alines)<2*self.batch_size:
                        dif = 2*self.batch_size - len(alines)
                        for i in range(dif):
                            alines.append(alines[self.batch_size+i])
                else:
                    unused_lines = set(range(len(self.authors[author])))-used_lines
                    for i in range(self.batch_size):
                        l = random.select(unused_lines)
                        unused_lines.remove(l)
                        alines.append((author,l))
                
                other_authors = set(range(len(self.authors)))
                other_authors.remove(author)
                author = random.select(other_authors)
                unused_lines = set(range(len(self.authors[author])))-used_lines
                for i in range(self.batch_size):
                    l = random.select(unused_lines)
                    unused_lines.remove(l)
                    alines.append((author,l))

            

        images=[]
        for author,line in alines:
            if line>=len(self.authors[author]):
                line = (line+37)%len(self.authors[author])
            img_path, lb, gt = self.authors[author][line]
            img_path = os.path.join(self.dirPath,'images_gray',img_path)

            if self.no_spaces:
                gt = gt.replace(' ','')
            if type(self.augmentation) is str and 'normalization' in  self.augmentation and self.normalized_dir is not None and os.path.exists(os.path.join(self.normalized_dir,'{}_{}.png'.format(author,line))):
                img = cv2.imread(os.path.join(self.normalized_dir,'{}_{}.png'.format(author,line)),0)
                readNorm=True
            else:
                img = cv2.imread(img_path,0)
                if img is None:
                    print('Error, could not read image: {}'.format(img_path))
                    return None
                lb[0] = max(lb[0],0)
                lb[2] = max(lb[2],0)
                lb[1] = min(lb[1],img.shape[0])
                lb[3] = min(lb[3],img.shape[1])
                img = img[lb[0]:lb[1],lb[2]:lb[3]] #read as grayscale, crop line
                readNorm=False


            if img.shape[0] != self.img_height:
                if img.shape[0] < self.img_height and not self.warning:
                    self.warning = True
                    print("WARNING: upsampling image to fit size")
                percent = float(self.img_height) / img.shape[0]
                if img.shape[1]*percent > self.max_width:
                    percent = self.max_width/img.shape[1]
                img = cv2.resize(img, (0,0), fx=percent, fy=percent, interpolation = cv2.INTER_CUBIC)
                if img.shape[0]<self.img_height:
                    diff = self.img_height-img.shape[0]
                    img = np.pad(img,((diff//2,diff//2+diff%2),(0,0)),'constant',constant_values=255)
            elif img.shape[1]> self.max_width:
                percent = self.max_width/img.shape[1]
                img = cv2.resize(img, (0,0), fx=percent, fy=percent, interpolation = cv2.INTER_CUBIC)
                if img.shape[0]<self.img_height:
                    diff = self.img_height-img.shape[0]
                    img = np.pad(img,((diff//2,diff//2+diff%2),(0,0)),'constant',constant_values=255)

            if self.augmentation=='affine':
                if img.shape[1]*strech > self.max_width:
                    strech = self.max_width/img.shape[1]
            images.append( (line,gt,img,author) )
            #we split the processing here so that strech will be adjusted for longest image in author batch


        for line,gt,img,author in images:
            if self.fg_masks_dir is not None:
                fg_path = os.path.join(self.fg_masks_dir,'{}_{}.png'.format(author,line))
                fg_mask = cv2.imread(fg_path,0)
                fg_mask = fg_mask/255
                if fg_mask.shape!=img[:,:].shape:
                    print('Error, fg_mask ({}, {}) not the same size as image ({})'.format(fg_path,fg_mask.shape,img[:,:,0].shape))
                    th,fg_mask = cv2.threshold(img,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
                    fg_mask = 255-fg_mask
                    ele = cv2.getStructuringElement(  cv2.MORPH_ELLIPSE, (9,9) )
                    fg_mask = cv2.dilate(fg_mask,ele)
                    fg_mask = fg_mask/255
            else:
                fg_mask=None

                    
            if type(self.augmentation) is str and 'normalization' in  self.augmentation and not readNorm:
                img = normalize_line.deskew(img)
                img = normalize_line.skeletonize(img)
                if self.normalized_dir is not None:
                    cv2.imwrite(os.path.join(self.normalized_dir,'{}_{}.png'.format(author,line)),img)
            if type(self.augmentation) is str and 'affine' in  self.augmentation:
                img,fg_mask = augmentation.affine_trans(img,fg_mask,skew,strech)
            elif self.augmentation is not None and (type(self.augmentation) is not None or 'warp' in self.augmentation):
                #img = augmentation.apply_random_color_rotation(img)
                img = augmentation.apply_tensmeyer_brightness(img)
                img = grid_distortion.warp_image(img)
                assert(fg_mask is None)

            if self.include_stroke_aug:
                new_img = augmentation.change_thickness(img,thickness_change,fg_shade,bg_shade,blur_size,noise_sigma)
                if len(new_img.shape)==2:
                    new_img = new_img[...,None]
                new_img = new_img*2 -1.0

            if len(img.shape)==2:
                img = img[...,None]

            img = img.astype(np.float32)
            if self.remove_bg:
                img = 1.0 - img / 256.0
                #kernel = torch.FloatTensor(7,7).fill_(1/49)
                #blurred_mask = F.conv2d(fg_mask,kernel,padding=3)
                blurred_mask = cv2.blur(fg_mask,(7,7))
                img *= blurred_mask[...,None]
                img = 2*img -1
            else:
                img = 1.0 - img / 128.0



            if len(gt) == 0:
                return None
            gt_label = string_utils.str2label_single(gt, self.char_to_idx)

            if self.styles:
                style_i = self.npr.choice(len(self.styles[author][id]))
                style = self.styles[author][id][style_i]
            else:
                style=None
            name = '{}_{}'.format(author,line)
            if self.identity_spaced:
                spaced_label = gt_label[:,None].astype(np.long)
            else:
                spaced_label = None if self.spaced_by_name is None else self.spaced_by_name[name]
                if spaced_label is not None:
                    assert(spaced_label.shape[1]==1)
            toAppend= {
                "image": img,
                "gt": gt,
                "style": style,
                "gt_label": gt_label,
                "spaced_label": spaced_label,
                "name": name,
                "center": self.center,
                "author": author,
                "author_idx": self.author_list.index(author)
                
                }
            if self.fg_masks_dir is not None:
                toAppend['fg_mask'] = fg_mask
            if self.include_stroke_aug:
                toAppend['changed_image'] = new_img
            batch.append(toAppend)
            
        #batch = [b for b in batch if b is not None]
        #These all should be the same size or error
        assert len(set([b['image'].shape[0] for b in batch])) == 1
        assert len(set([b['image'].shape[2] for b in batch])) == 1

        dim0 = batch[0]['image'].shape[0]
        dim1 = max([b['image'].shape[1] for b in batch])
        dim2 = batch[0]['image'].shape[2]

        all_labels = []
        label_lengths = []
        if self.spaced_by_name is not None or self.identity_spaced:
            spaced_labels = []
        else:
            spaced_labels = None
        max_spaced_len=0

        input_batch = np.full((len(batch), dim0, dim1, dim2), PADDING_CONSTANT).astype(np.float32)
        if self.fg_masks_dir is not None:
            fg_masks = np.full((len(batch), dim0, dim1, 1), 0).astype(np.float32)
        if self.include_stroke_aug:
            changed_batch = np.full((len(batch), dim0, dim1, dim2), PADDING_CONSTANT).astype(np.float32)
        for i in range(len(batch)):
            b_img = batch[i]['image']
            toPad = (dim1-b_img.shape[1])
            if 'center' in batch[0] and batch[0]['center']:
                toPad //=2
            else:
                toPad = 0
            input_batch[i,:,toPad:toPad+b_img.shape[1],:] = b_img
            if self.fg_masks_dir is not None:
                fg_masks[i,:,toPad:toPad+b_img.shape[1],0] = batch[i]['fg_mask']
            if self.include_stroke_aug:
                changed_batch[i,:,toPad:toPad+b_img.shape[1],:] = batch[i]['changed_image']

            l = batch[i]['gt_label']
            all_labels.append(l)
            label_lengths.append(len(l))

            if spaced_labels is not None:
                sl = batch[i]['spaced_label']
                spaced_labels.append(sl)
                max_spaced_len = max(max_spaced_len,sl.shape[0])

        #all_labels = np.concatenate(all_labels)
        label_lengths = torch.IntTensor(label_lengths)
        max_len = label_lengths.max()
        all_labels = [np.pad(l,((0,max_len-l.shape[0]),),'constant') for l in all_labels]
        all_labels = np.stack(all_labels,axis=1)
        if self.spaced_by_name is not None or self.identity_spaced:
            spaced_labels = [np.pad(l,((0,max_spaced_len-l.shape[0]),(0,0)),'constant') for l in spaced_labels]
            ddd = spaced_labels
            spaced_labels = np.concatenate(spaced_labels,axis=1)
            spaced_labels = torch.from_numpy(spaced_labels)
            assert(spaced_labels.size(1) == len(batch))


        images = input_batch.transpose([0,3,1,2])
        images = torch.from_numpy(images)
        labels = torch.from_numpy(all_labels.astype(np.int32))
        #label_lengths = torch.from_numpy(label_lengths.astype(np.int32))
        if self.fg_masks_dir is not None:
            fg_masks = fg_masks.transpose([0,3,1,2])
            fg_masks = torch.from_numpy(fg_masks)
        
        if batch[0]['style'] is not None:
            styles = np.stack([b['style'] for b in batch], axis=0)
            styles = torch.from_numpy(styles).float()
        else:
            styles=None
        mask, top_and_bottom, center_line = makeMask(images,self.mask_post, self.mask_random)
        ##DEBUG
        #for i in range(5):
        #    mask2, top_and_bottom2 = makeMask(images,self.mask_post, self.mask_random)
        #    #extra_masks.append(mask2)
        #    mask2 = ((mask2[0,0]+1)/2).numpy().astype(np.uint8)*255
        #    cv2.imshow('mask{}'.format(i),mask2)
        #mask = ((mask[0,0]+1)/2).numpy().astype(np.uint8)*255
        #cv2.imshow('mask'.format(i),mask)
        #cv2.waitKey()
        toRet= {
            "image": images,
            "mask": mask,
            "top_and_bottom": top_and_bottom,
            "center_line": center_line,
            "label": labels,
            "style": styles,
            "label_lengths": label_lengths,
            "gt": [b['gt'] for b in batch],
            "spaced_label": spaced_labels,
            "name": [b['name'] for b in batch],
            "author": [b['author'] for b in batch],
            "author_idx": [b['author_idx'] for b in batch],
        }
        if self.fg_masks_dir is not None:
            toRet['fg_mask'] = fg_masks
        if self.include_stroke_aug:
            changed_images = changed_batch.transpose([0,3,1,2])
            changed_images = torch.from_numpy(changed_images)
            toRet['changed_image']=changed_images
        return toRet
    def getitem(self, index, scaleP=None, cropPoint=None):
        if self.useRandomAugProb is not None and np.random.rand(
        ) < self.useRandomAugProb and scaleP is None and cropPoint is None:
            return self.getRandomImage()
        ##ticFull=timeit.default_timer()
        imagePath = self.images[index]['imagePath']
        imageName = self.images[index]['imageName']
        annotationPath = self.images[index]['annotationPath']
        #print(annotationPath)
        rescaled = self.images[index]['rescaled']
        with open(annotationPath) as annFile:
            annotations = json.loads(annFile.read())

        ##tic=timeit.default_timer()
        np_img = cv2.imread(imagePath, 1 if self.color else 0)  #/255.0
        if np_img is None or np_img.shape[0] == 0:
            print("ERROR, could not open " + imagePath)
            return self.__getitem__((index + 1) % self.__len__())

        if scaleP is None:
            s = np.random.uniform(self.rescale_range[0], self.rescale_range[1])
        else:
            s = scaleP
        partial_rescale = s / rescaled
        if self.transform is None:  #we're doing the whole image
            #this is a check to be sure we don't send too big images through
            pixel_count = partial_rescale * partial_rescale * np_img.shape[
                0] * np_img.shape[1]
            if pixel_count > self.pixel_count_thresh:
                partial_rescale = math.sqrt(partial_rescale * partial_rescale *
                                            self.pixel_count_thresh /
                                            pixel_count)
                print('{} exceed thresh: {}: {}, new {}: {}'.format(
                    imageName, s, pixel_count, rescaled * partial_rescale,
                    partial_rescale * partial_rescale * np_img.shape[0] *
                    np_img.shape[1]))
                s = rescaled * partial_rescale

            max_dim = partial_rescale * max(np_img.shape[0], np_img.shape[1])
            if max_dim > self.max_dim_thresh:
                partial_rescale = partial_rescale * (self.max_dim_thresh /
                                                     max_dim)
                print('{} exceed thresh: {}: {}, new {}: {}'.format(
                    imageName, s, max_dim, rescaled * partial_rescale,
                    partial_rescale * max(np_img.shape[0], np_img.shape[1])))
                s = rescaled * partial_rescale

        ##tic=timeit.default_timer()
        #np_img = cv2.resize(np_img,(target_dim1, target_dim0), interpolation = cv2.INTER_CUBIC)
        np_img = cv2.resize(np_img, (0, 0),
                            fx=partial_rescale,
                            fy=partial_rescale,
                            interpolation=cv2.INTER_CUBIC)
        if not self.color:
            np_img = np_img[..., None]  #add 'color' channel
        ##print('resize: {}  [{}, {}]'.format(timeit.default_timer()-tic,np_img.shape[0],np_img.shape[1]))

        bbs, line_gts, point_gts, pixel_gt, numClasses, numNeighbors, pairs = self.parseAnn(
            np_img, annotations, s, imagePath)

        if self.coordConv:  #add absolute position information
            xs = 255 * np.arange(np_img.shape[1]) / (np_img.shape[1])
            xs = np.repeat(xs[None, :, None], np_img.shape[0], axis=0)
            ys = 255 * np.arange(np_img.shape[0]) / (np_img.shape[0])
            ys = np.repeat(ys[:, None, None], np_img.shape[1], axis=1)
            np_img = np.concatenate(
                (np_img, xs.astype(np_img.dtype), ys.astype(np_img.dtype)),
                axis=2)

        ##ticTr=timeit.default_timer()
        if self.transform is not None:
            pairs = None
            out, cropPoint = self.transform(
                {
                    "img": np_img,
                    "bb_gt": bbs,
                    "bb_auxs": numNeighbors,
                    "line_gt": line_gts,
                    "point_gt": point_gts,
                    "pixel_gt": pixel_gt,
                }, cropPoint)
            np_img = out['img']
            bbs = out['bb_gt']
            numNeighbors = out['bb_auxs']
            #if 'table_points' in out['point_gt']:
            #    table_points = out['point_gt']['table_points']
            #else:
            #    table_points=None
            point_gts = out['point_gt']
            pixel_gt = out['pixel_gt']
            #start_of_line = out['line_gt']['start_of_line']
            #end_of_line = out['line_gt']['end_of_line']
            line_gts = out['line_gt']

            ##tic=timeit.default_timer()
            if self.color:
                np_img[:, :, :3] = augmentation.apply_random_color_rotation(
                    np_img[:, :, :3])
                np_img[:, :, :3] = augmentation.apply_tensmeyer_brightness(
                    np_img[:, :, :3])
            else:
                np_img[:, :, 0:1] = augmentation.apply_tensmeyer_brightness(
                    np_img[:, :, 0:1])
            ##print('augmentation: {}'.format(timeit.default_timer()-tic))
        ##print('transfrm: {}  [{}, {}]'.format(timeit.default_timer()-ticTr,org_img.shape[0],org_img.shape[1]))

        #if len(np_img.shape)==2:
        #    img=np_img[None,None,:,:] #add "color" channel and batch
        #else:
        img = np_img.transpose(
            [2, 0, 1])[None,
                       ...]  #from [row,col,color] to [batch,color,row,col]
        img = img.astype(np.float32)
        img = torch.from_numpy(img)
        img = 1.0 - img / 128.0  #ideally the median value would be 0
        #img = 1.0 - img / 255.0 #this way ink is on, page is off
        if pixel_gt is not None:
            pixel_gt = pixel_gt.transpose([2, 0, 1])[None, ...]
            pixel_gt = torch.from_numpy(pixel_gt)

        #start_of_line = None if start_of_line is None or start_of_line.shape[1] == 0 else torch.from_numpy(start_of_line)
        #end_of_line = None if end_of_line is None or end_of_line.shape[1] == 0 else torch.from_numpy(end_of_line)
        for name in line_gts:
            line_gts[name] = None if line_gts[name] is None or line_gts[
                name].shape[1] == 0 else torch.from_numpy(line_gts[name])

        #import pdb; pdb.set_trace()
        #bbs = None if bbs.shape[1] == 0 else torch.from_numpy(bbs)
        bbs = convertBBs(bbs, self.rotate, numClasses)
        if len(numNeighbors) > 0:
            numNeighbors = torch.tensor(numNeighbors)[None, :]  #add batch dim
        else:
            numNeighbors = None
            #start_of_line = convertLines(start_of_line,numClasses)
        #end_of_line = convertLines(end_of_line,numClasses)
        for name in point_gts:
            #if table_points is not None:
            #table_points = None if table_points.shape[1] == 0 else torch.from_numpy(table_points)
            if point_gts[name] is not None:
                point_gts[name] = None if point_gts[name].shape[
                    1] == 0 else torch.from_numpy(point_gts[name])

        ##print('__getitem__: '+str(timeit.default_timer()-ticFull))
        if self.only_types is None:
            return {
                "img": img,
                "bb_gt": bbs,
                "num_neighbors": numNeighbors,
                "line_gt": line_gts,
                "point_gt": point_gts,
                "pixel_gt": pixel_gt,
                "imgName": imageName,
                "scale": s,
                "cropPoint": cropPoint,
                "pairs": pairs
            }
        else:
            if 'boxes' not in self.only_types or not self.only_types['boxes']:
                bbs = None
            line_gt = {}
            if 'line' in self.only_types:
                for ent in self.only_types['line']:
                    if type(ent) == list:
                        toComb = []
                        for inst in ent[1:]:
                            einst = line_gts[inst]
                            if einst is not None:
                                toComb.append(einst)
                        if len(toComb) > 0:
                            comb = torch.cat(toComb, dim=1)
                            line_gt[ent[0]] = comb
                        else:
                            line_gt[ent[0]] = None
                    else:
                        line_gt[ent] = line_gts[ent]
            point_gt = {}
            if 'point' in self.only_types:
                for ent in self.only_types['point']:
                    if type(ent) == list:
                        toComb = []
                        for inst in ent[1:]:
                            einst = point_gts[inst]
                            if einst is not None:
                                toComb.append(einst)
                        if len(toComb) > 0:
                            comb = torch.cat(toComb, dim=1)
                            point_gt[ent[0]] = comb
                        else:
                            line_gt[ent[0]] = None
                    else:
                        point_gt[ent] = point_gts[ent]
            pixel_gtR = None
            #for ent in self.only_types['pixel']:
            #    if type(ent)==list:
            #        comb = ent[1]
            #        for inst in ent[2:]:
            #            comb = (comb + inst)==2 #:eq(2) #pixel-wise AND
            #        pixel_gt[ent[0]]=comb
            #    else:
            #        pixel_gt[ent]=eval(ent)
            if 'pixel' in self.only_types:  # and self.only_types['pixel'][0]=='table_pixels':
                pixel_gtR = pixel_gt

            return {
                "img": img,
                "bb_gt": bbs,
                "num_neighbors": numNeighbors,
                "line_gt": line_gt,
                "point_gt": point_gt,
                "pixel_gt": pixel_gtR,
                "imgName": imageName,
                "scale": s,
                "cropPoint": cropPoint,
                "pairs": pairs,
            }
    def getitem(self, index, scaleP=None, cropPoint=None):
        ##ticFull=timeit.default_timer()
        imagePath = self.images[index]['imagePath']
        imageName = self.images[index]['imageName']
        annotationPath = self.images[index]['annotationPath']
        #print(annotationPath)
        rescaled = self.images[index]['rescaled']
        with open(annotationPath) as annFile:
            annotations = json.loads(annFile.read())

        ##tic=timeit.default_timer()
        np_img = cv2.imread(imagePath, 1 if self.color else 0)  #/255.0
        if np_img is None or np_img.shape[0] == 0:
            print("ERROR, could not open " + imagePath)
            return self.__getitem__((index + 1) % self.__len__())
        if scaleP is None:
            s = np.random.uniform(self.rescale_range[0], self.rescale_range[1])
        else:
            s = scaleP
        partial_rescale = s / rescaled
        if self.transform is None:  #we're doing the whole image
            #this is a check to be sure we don't send too big images through
            pixel_count = partial_rescale * partial_rescale * np_img.shape[
                0] * np_img.shape[1]
            if pixel_count > self.pixel_count_thresh:
                partial_rescale = math.sqrt(partial_rescale * partial_rescale *
                                            self.pixel_count_thresh /
                                            pixel_count)
                print('{} exceed thresh: {}: {}, new {}: {}'.format(
                    imageName, s, pixel_count, rescaled * partial_rescale,
                    partial_rescale * partial_rescale * np_img.shape[0] *
                    np_img.shape[1]))
                s = rescaled * partial_rescale

            max_dim = partial_rescale * max(np_img.shape[0], np_img.shape[1])
            if max_dim > self.max_dim_thresh:
                partial_rescale = partial_rescale * (self.max_dim_thresh /
                                                     max_dim)
                print('{} exceed thresh: {}: {}, new {}: {}'.format(
                    imageName, s, max_dim, rescaled * partial_rescale,
                    partial_rescale * max(np_img.shape[0], np_img.shape[1])))
                s = rescaled * partial_rescale

        ##tic=timeit.default_timer()
        #np_img = cv2.resize(np_img,(target_dim1, target_dim0), interpolation = cv2.INTER_CUBIC)
        np_img = cv2.resize(np_img, (0, 0),
                            fx=partial_rescale,
                            fy=partial_rescale,
                            interpolation=cv2.INTER_CUBIC)
        if not self.color:
            np_img = np_img[..., None]  #add 'color' channel
        ##print('resize: {}  [{}, {}]'.format(timeit.default_timer()-tic,np_img.shape[0],np_img.shape[1]))

        ##tic=timeit.default_timer()

        bbs, ids, numClasses, trans = self.parseAnn(annotations, s)

        #start_of_line, end_of_line = getStartEndGT(annotations['byId'].values(),s)
        #Try:
        #    table_points, table_pixels = self.getTables(
        #            fieldBBs,
        #            s,
        #            np_img.shape[0],
        #            np_img.shape[1],
        #            annotations['samePairs'])
        #Except Exception as inst:
        #    if imageName not in self.errors:
        #        table_points=None
        #        table_pixels=None
        #        print(inst)
        #        print('Table error on: '+imagePath)
        #        self.errors.append(imageName)

        #pixel_gt = table_pixels

        ##ticTr=timeit.default_timer()
        if self.transform is not None:
            out, cropPoint = self.transform(
                {
                    "img": np_img,
                    "bb_gt": bbs,
                    'bb_auxs': ids,
                    #"line_gt": {
                    #    "start_of_line": start_of_line,
                    #    "end_of_line": end_of_line
                    #    },
                    #"point_gt": {
                    #        "table_points": table_points
                    #        },
                    #"pixel_gt": pixel_gt,
                },
                cropPoint)
            np_img = out['img']
            bbs = out['bb_gt']
            ids = out['bb_auxs']

            ##tic=timeit.default_timer()
            if np_img.shape[2] == 3:
                np_img = augmentation.apply_random_color_rotation(np_img)
                np_img = augmentation.apply_tensmeyer_brightness(np_img)
            else:
                np_img = augmentation.apply_tensmeyer_brightness(np_img)
            ##print('augmentation: {}'.format(timeit.default_timer()-tic))
        ##print('transfrm: {}  [{}, {}]'.format(timeit.default_timer()-ticTr,org_img.shape[0],org_img.shape[1]))
        pairs = set()
        #import pdb;pdb.set_trace()
        numNeighbors = [0] * len(ids)
        for index1, id in enumerate(ids):  #updated
            responseBBIdList = self.getResponseBBIdList(id, annotations)
            for bbId in responseBBIdList:
                try:
                    index2 = ids.index(bbId)
                    #adjMatrix[min(index1,index2),max(index1,index2)]=1
                    pairs.add((min(index1, index2), max(index1, index2)))
                    numNeighbors[index1] += 1
                except ValueError:
                    pass
        #ones = torch.ones(len(pairs))
        #if len(pairs)>0:
        #    pairs = torch.LongTensor(list(pairs)).t()
        #else:
        #    pairs = torch.LongTensor(pairs)
        #adjMatrix = torch.sparse.FloatTensor(pairs,ones,(len(ids),len(ids))) # This is an upper diagonal matrix as pairings are bi-directional

        #if len(np_img.shape)==2:
        #    img=np_img[None,None,:,:] #add "color" channel and batch
        #else:
        img = np_img.transpose(
            [2, 0, 1])[None,
                       ...]  #from [row,col,color] to [batch,color,row,col]
        img = img.astype(np.float32)
        img = torch.from_numpy(img)
        img = 1.0 - img / 128.0  #ideally the median value would be 0
        #if pixel_gt is not None:
        #    pixel_gt = pixel_gt.transpose([2,0,1])[None,...]
        #    pixel_gt = torch.from_numpy(pixel_gt)

        #start_of_line = None if start_of_line is None or start_of_line.shape[1] == 0 else torch.from_numpy(start_of_line)
        #end_of_line = None if end_of_line is None or end_of_line.shape[1] == 0 else torch.from_numpy(end_of_line)

        bbs = convertBBs(bbs, self.rotate, numClasses)
        if len(numNeighbors) > 0:
            numNeighbors = torch.tensor(numNeighbors)[None, :]  #add batch dim
        else:
            numNeighbors = None
        #if table_points is not None:
        #    table_points = None if table_points.shape[1] == 0 else torch.from_numpy(table_points)

        return {
            "img": img,
            "bb_gt": bbs,
            "num_neighbors": numNeighbors,
            "adj": pairs,  #adjMatrix,
            "imgName": imageName,
            "scale": s,
            "cropPoint": cropPoint,
            "transcription": [trans[id] for id in ids if id in trans]
        }
示例#11
0
    def getitem(self, index, scaleP=None, cropPoint=None):
        ##ticFull=timeit.default_timer()
        imagePath = self.images[index]['imagePath']
        imageName = self.images[index]['imageName']
        annotationPath = self.images[index]['annotationPath']
        #print(annotationPath)
        rescaled = self.images[index]['rescaled']
        with open(annotationPath) as annFile:
            annotations = json.loads(annFile.read())

        ##tic=timeit.default_timer()
        np_img = img_f.imread(imagePath, 1 if self.color else 0)  #*255.0
        if np_img.max() < 200:
            np_img *= 255
        if np_img is None or np_img.shape[0] == 0:
            print("ERROR, could not open " + imagePath)
            return self.__getitem__((index + 1) % self.__len__())
        if scaleP is None:
            s = np.random.uniform(self.rescale_range[0], self.rescale_range[1])
        else:
            s = scaleP
        partial_rescale = s / rescaled
        if self.transform is None:  #we're doing the whole image
            #this is a check to be sure we don't send too big images through
            pixel_count = partial_rescale * partial_rescale * np_img.shape[
                0] * np_img.shape[1]
            if pixel_count > self.pixel_count_thresh:
                partial_rescale = math.sqrt(partial_rescale * partial_rescale *
                                            self.pixel_count_thresh /
                                            pixel_count)
                print('{} exceed thresh: {}: {}, new {}: {}'.format(
                    imageName, s, pixel_count, rescaled * partial_rescale,
                    partial_rescale * partial_rescale * np_img.shape[0] *
                    np_img.shape[1]))
                s = rescaled * partial_rescale

            max_dim = partial_rescale * max(np_img.shape[0], np_img.shape[1])
            if max_dim > self.max_dim_thresh:
                partial_rescale = partial_rescale * (self.max_dim_thresh /
                                                     max_dim)
                print('{} exceed thresh: {}: {}, new {}: {}'.format(
                    imageName, s, max_dim, rescaled * partial_rescale,
                    partial_rescale * max(np_img.shape[0], np_img.shape[1])))
                s = rescaled * partial_rescale

        ##tic=timeit.default_timer()
        #np_img = img_f.resize(np_img,(target_dim1, target_dim0))
        np_img = img_f.resize(
            np_img,
            (0, 0),
            fx=partial_rescale,
            fy=partial_rescale,
        )
        if len(np_img.shape) == 2:
            np_img = np_img[..., None]  #add 'color' channel
        if self.color and np_img.shape[2] == 1:
            np_img = np.repeat(np_img, 3, axis=2)
        ##print('resize: {}  [{}, {}]'.format(timeit.default_timer()-tic,np_img.shape[0],np_img.shape[1]))

        ##tic=timeit.default_timer()

        bbs, ids, numClasses, trans, groups, metadata, form_metadata = self.parseAnn(
            annotations, s)
        #trans = {i:v for i,v in enumerate(trans)}
        #metadata = {i:v for i,v in enumerate(metadata)}

        #start_of_line, end_of_line = getStartEndGT(annotations['byId'].values(),s)
        #Try:
        #    table_points, table_pixels = self.getTables(
        #            fieldBBs,
        #            s,
        #            np_img.shape[0],
        #            np_img.shape[1],
        #            annotations['samePairs'])
        #Except Exception as inst:
        #    if imageName not in self.errors:
        #        table_points=None
        #        table_pixels=None
        #        print(inst)
        #        print('Table error on: '+imagePath)
        #        self.errors.append(imageName)

        #pixel_gt = table_pixels

        ##ticTr=timeit.default_timer()
        if self.questions:  #we need to do questions before crop to have full context
            #we have to relationships to get questions
            pairs = set()
            for index1, id in enumerate(ids):  #updated
                responseBBIdList = self.getResponseBBIdList(id, annotations)
                for bbId in responseBBIdList:
                    try:
                        index2 = ids.index(bbId)
                        pairs.add((min(index1, index2), max(index1, index2)))
                    except ValueError:
                        pass
            groups_adj = set()
            if groups is not None:
                for n0, n1 in pairs:
                    g0 = -1
                    g1 = -1
                    for i, ns in enumerate(groups):
                        if n0 in ns:
                            g0 = i
                            if g1 != -1:
                                break
                        if n1 in ns:
                            g1 = i
                            if g0 != -1:
                                break
                    if g0 != g1:
                        groups_adj.add((min(g0, g1), max(g0, g1)))
            questions_and_answers = self.makeQuestions(bbs, trans, groups,
                                                       groups_adj)
        else:
            questions_and_answers = None

        if self.transform is not None:
            if 'word_boxes' in form_metadata:
                word_bbs = form_metadata['word_boxes']
                dif_f = bbs.shape[2] - word_bbs.shape[1]
                blank = np.zeros([word_bbs.shape[0], dif_f])
                prep_word_bbs = np.concatenate([word_bbs, blank], axis=1)[None,
                                                                          ...]
                crop_bbs = np.concatenate([bbs, prep_word_bbs], axis=1)
                crop_ids = ids + [
                    'word{}'.format(i) for i in range(word_bbs.shape[0])
                ]
            else:
                crop_bbs = bbs
                crop_ids = ids
            out, cropPoint = self.transform(
                {
                    "img": np_img,
                    "bb_gt": crop_bbs,
                    'bb_auxs': crop_ids,
                    #'word_bbs':form_metadata['word_boxes'] if 'word_boxes' in form_metadata else None
                    #"line_gt": {
                    #    "start_of_line": start_of_line,
                    #    "end_of_line": end_of_line
                    #    },
                    #"point_gt": {
                    #        "table_points": table_points
                    #        },
                    #"pixel_gt": pixel_gt,
                },
                cropPoint)
            np_img = out['img']

            if 'word_boxes' in form_metadata:
                saw_word = False
                word_index = -1
                for i, ii in enumerate(out['bb_auxs']):
                    if not saw_word:
                        if type(ii) is str and 'word' in ii:
                            saw_word = True
                            word_index = i
                    else:
                        assert 'word' in ii
                bbs = out['bb_gt'][:, :word_index]
                ids = out['bb_auxs'][:word_index]
                form_metadata['word_boxes'] = out['bb_gt'][0, word_index:, :8]
                word_ids = out['bb_auxs'][word_index:]
                form_metadata['word_trans'] = [
                    form_metadata['word_trans'][int(id[4:])] for id in word_ids
                ]
            else:
                bbs = out['bb_gt']
                ids = out['bb_auxs']

            if questions_and_answers is not None:
                questions = []
                answers = []
                questions_and_answers = [
                    (q, a, qids) for q, a, qids in questions_and_answers
                    if all((i in ids) for i in qids)
                ]
        if questions_and_answers is not None:
            if len(questions_and_answers) > self.questions:
                questions_and_answers = random.sample(questions_and_answers,
                                                      k=self.questions)
            if len(questions_and_answers) > 0:
                questions, answers, _ = zip(*questions_and_answers)
            else:
                return self.getitem((index + 1) % len(self))
        else:
            questions = answers = None

            ##tic=timeit.default_timer()
            if np_img.shape[2] == 3:
                np_img = augmentation.apply_random_color_rotation(np_img)
                np_img = augmentation.apply_tensmeyer_brightness(
                    np_img, **self.aug_params)
            else:
                np_img = augmentation.apply_tensmeyer_brightness(
                    np_img, **self.aug_params)
            ##print('augmentation: {}'.format(timeit.default_timer()-tic))
        newGroups = []
        for group in groups:
            newGroup = [ids.index(bbId) for bbId in group if bbId in ids]
            if len(newGroup) > 0:
                newGroups.append(newGroup)
                #print(len(newGroups)-1,newGroup)
        groups = newGroups
        ##print('transfrm: {}  [{}, {}]'.format(timeit.default_timer()-ticTr,org_img.shape[0],org_img.shape[1]))
        pairs = set()
        #import pdb;pdb.set_trace()
        numNeighbors = [0] * len(ids)
        for index1, id in enumerate(ids):  #updated
            responseBBIdList = self.getResponseBBIdList(id, annotations)
            for bbId in responseBBIdList:
                try:
                    index2 = ids.index(bbId)
                    #adjMatrix[min(index1,index2),max(index1,index2)]=1
                    pairs.add((min(index1, index2), max(index1, index2)))
                    numNeighbors[index1] += 1
                except ValueError:
                    pass
        #ones = torch.ones(len(pairs))
        #if len(pairs)>0:
        #    pairs = torch.LongTensor(list(pairs)).t()
        #else:
        #    pairs = torch.LongTensor(pairs)
        #adjMatrix = torch.sparse.FloatTensor(pairs,ones,(len(ids),len(ids))) # This is an upper diagonal matrix as pairings are bi-directional

        #if len(np_img.shape)==2:
        #    img=np_img[None,None,:,:] #add "color" channel and batch
        #else:
        img = np_img.transpose(
            [2, 0, 1])[None,
                       ...]  #from [row,col,color] to [batch,color,row,col]
        img = img.astype(np.float32)
        img = torch.from_numpy(img)
        img = 1.0 - img / 128.0  #ideally the median value would be 0
        #if pixel_gt is not None:
        #    pixel_gt = pixel_gt.transpose([2,0,1])[None,...]
        #    pixel_gt = torch.from_numpy(pixel_gt)

        #start_of_line = None if start_of_line is None or start_of_line.shape[1] == 0 else torch.from_numpy(start_of_line)
        #end_of_line = None if end_of_line is None or end_of_line.shape[1] == 0 else torch.from_numpy(end_of_line)

        bbs = convertBBs(bbs, self.rotate, numClasses)
        if 'word_boxes' in form_metadata:
            form_metadata['word_boxes'] = convertBBs(
                form_metadata['word_boxes'][None, ...], self.rotate, 0)[0, ...]
        if len(numNeighbors) > 0:
            numNeighbors = torch.tensor(numNeighbors)[None, :]  #add batch dim
        else:
            numNeighbors = None
        #if table_points is not None:
        #    table_points = None if table_points.shape[1] == 0 else torch.from_numpy(table_points)
        groups_adj = set()
        if groups is not None:
            for n0, n1 in pairs:
                g0 = -1
                g1 = -1
                for i, ns in enumerate(groups):
                    if n0 in ns:
                        g0 = i
                        if g1 != -1:
                            break
                    if n1 in ns:
                        g1 = i
                        if g0 != -1:
                            break
                if g0 != g1:
                    groups_adj.add((min(g0, g1), max(g0, g1)))
            for group in groups:
                for i in group:
                    assert (i < bbs.shape[1])
            targetIndexToGroup = {}
            for groupId, bbIds in enumerate(groups):
                targetIndexToGroup.update({bbId: groupId for bbId in bbIds})

        transcription = [trans[id] for id in ids]

        return {
            "img": img,
            "bb_gt": bbs,
            "num_neighbors": numNeighbors,
            "adj": pairs,  #adjMatrix,
            "imgName": imageName,
            "scale": s,
            "cropPoint": cropPoint,
            "transcription": transcription,
            "metadata": [metadata[id] for id in ids if id in metadata],
            "form_metadata": form_metadata,
            "gt_groups": groups,
            "targetIndexToGroup": targetIndexToGroup,
            "gt_groups_adj": groups_adj,
            "questions": questions,
            "answers": answers
        }
示例#12
0
    def __getitem__(self, idx):

        gt_json_path, img_path = self.ids[idx]

        gt_json = safe_load.json_state(gt_json_path)
        if gt_json is None:
            return None

        org_img = cv2.imread(img_path)
        target_dim1 = int(
            np.random.uniform(self.rescale_range[0], self.rescale_range[1]))

        s = target_dim1 / float(org_img.shape[1])
        target_dim0 = int(org_img.shape[0] / float(org_img.shape[1]) *
                          target_dim1)
        org_img = cv2.resize(org_img, (target_dim1, target_dim0),
                             interpolation=cv2.INTER_CUBIC)

        gt = np.zeros((1, len(gt_json), 4), dtype=np.float32)

        positions = []
        positions_xy = []

        for j, gt_item in enumerate(gt_json):
            if 'sol' not in gt_item:
                continue

            x0 = gt_item['sol']['x0'] * s
            x1 = gt_item['sol']['x1'] * s
            y0 = gt_item['sol']['y0'] * s
            y1 = gt_item['sol']['y1'] * s

            positions_xy.append([(torch.Tensor([[x1, x0], [y1, y0]]))])
            dx = x0 - x1
            dy = y0 - y1
            d = math.sqrt(dx**2 + dy**2)
            mx = (x0 + x1) / 2.0
            my = (y0 + y1) / 2.0
            # Not sure if this is right...
            theta = -math.atan2(dx, -dy)
            positions.append([torch.Tensor([mx, my, theta, d / 2, 1.0])])

            gt[:, j, 0] = x0
            gt[:, j, 1] = y0
            gt[:, j, 2] = x1
            gt[:, j, 3] = y1

        if self.transform is not None:
            out = self.transform({"img": org_img, "sol_gt": gt})
            org_img = out['img']
            gt = out['sol_gt']
            org_img = augmentation.apply_random_color_rotation(org_img)
            org_img = augmentation.apply_tensmeyer_brightness(org_img)

        img = org_img.transpose([2, 1, 0])[None, ...]
        img = img.astype(np.float32)
        img = torch.from_numpy(img)
        img = img / 128.0 - 1.0

        if gt.shape[1] == 0:
            gt = None
        else:
            gt = torch.from_numpy(gt)

        return {
            "scale": s,
            "img_path": img_path,
            "img": img,
            "sol_gt": gt,
            "lf_xyrs": positions,
            "lf_xyxy": positions_xy,
        }
    def __getitem__(self, idx):

        author, line = self.lineIndex[idx]
        img_path, lb, gt = self.authors[author][line]
        if self.add_spaces:
            gt = ' ' + gt + ' '
        if type(
                self.augmentation
        ) is str and 'normalization' in self.augmentation and self.normalized_dir is not None and os.path.exists(
                os.path.join(self.normalized_dir, '{}_{}.png'.format(
                    author, line))):
            img = cv2.imread(
                os.path.join(self.normalized_dir,
                             '{}_{}.png'.format(author, line)), 0)
            readNorm = True
        else:
            img = cv2.imread(img_path,
                             0)[lb[0]:lb[1],
                                lb[2]:lb[3]]  #read as grayscale, crop line
            readNorm = False

        if img is None:
            return None

        if img.shape[0] != self.img_height:
            if img.shape[0] < self.img_height and not self.warning:
                self.warning = True
                print("WARNING: upsampling image to fit size")
            percent = float(self.img_height) / img.shape[0]
            img = cv2.resize(img, (0, 0),
                             fx=percent,
                             fy=percent,
                             interpolation=cv2.INTER_CUBIC)

        if img is None:
            return None

        if len(img.shape) == 2:
            img = img[..., None]
        if type(
                self.augmentation
        ) is str and 'normalization' in self.augmentation and not readNorm:
            img = normalize_line.deskew(img)
            img = normalize_line.skeletonize(img)
            if self.normalized_dir is not None:
                cv2.imwrite(
                    os.path.join(self.normalized_dir,
                                 '{}_{}.png'.format(author, line)), img)
        elif self.augmentation is not None and (type(
                self.augmentation) is not str or 'warp' in self.augmentation):
            #img = augmentation.apply_random_color_rotation(img)
            if type(self.augmentation) is str and "low" in self.augmentation:
                if random.random() > 0.1:
                    img = augmentation.apply_tensmeyer_brightness(img)
                if random.random() > 0.01:
                    img = grid_distortion.warp_image(img,
                                                     w_mesh_std=0.7,
                                                     h_mesh_std=0.7)
            else:
                img = augmentation.apply_tensmeyer_brightness(img)
                img = grid_distortion.warp_image(img)
        if len(img.shape) == 2:
            img = img[..., None]

        img = img.astype(np.float32)
        img = 1.0 - img / 128.0

        if len(gt) == 0:
            return None
        gt_label = string_utils.str2label_single(gt, self.char_to_idx)

        return {
            "image": img,
            "gt": gt,
            "gt_label": gt_label,
            "name": '{}_{}'.format(author, line),
            "center": self.center,
            "author": author
        }
示例#14
0
    def __getitem__(self, idx):

        if self.train and self.refresh_self:
            if self.used_instances >= self.set_size:
                self.refresh_data(None, None, logged=False)
                self.used_instances = 0
            self.used_instances += self.batch_size

        if self.augmentation is not None and 'affine' in self.augmentation:
            strech = (self.max_strech *
                      2) * np.random.random() - self.max_strech + 1
            skew = (self.max_rot_rad *
                    2) * np.random.random() - self.max_rot_rad
        if self.include_stroke_aug:
            thickness_change = np.random.randint(-4, 5)
            fg_shade = np.random.random() * 0.25 + 0.75
            bg_shade = np.random.random() * 0.2
            blur_size = np.random.randint(2, 4)
            noise_sigma = np.random.random() * 0.02

        batch = []
        for b in range(self.batch_size):
            img_path = os.path.join(self.directory, '{}.png'.format(idx + b))
            img = cv2.imread(img_path, 0)
            if img is None:
                print('Error, could not read {}'.format(img_path))
                return self[(idx + 1) % len(self)]

            if self.augmentation == 'affine':
                if img.shape[1] * strech > self.max_width:
                    strech = self.max_width / img.shape[1]
            if img.shape[1] > self.max_width:
                percent = float(self.max_width) / img.shape[1]
                img = cv2.resize(img, (0, 0),
                                 fx=percent,
                                 fy=1,
                                 interpolation=cv2.INTER_CUBIC)

            if img.shape[1] > self.clip_width:
                img = img[:, :self.clip_width]

            if self.use_fg_mask:
                th, fg_mask = cv2.threshold(
                    img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
                fg_mask = 255 - fg_mask
                ele = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9))
                fg_mask = cv2.dilate(fg_mask, ele)
                fg_mask = fg_mask / 255
            else:
                fg_mask = None

            if len(img.shape) == 2:
                img = img[..., None]
            if self.augmentation is not None:
                #img = augmentation.apply_random_color_rotation(img)
                if 'affine' in self.augmentation:
                    img, fg_mask = augmentation.affine_trans(
                        img, fg_mask, skew, strech)
                if 'brightness' in self.augmentation:
                    img = augmentation.apply_tensmeyer_brightness(img)
                    assert (fg_mask is None)
                if 'warp' in self.augmentation and random.random(
                ) < self.warp_freq:
                    try:
                        img = grid_distortion.warp_image(img)
                    except cv2.error as e:
                        print(e)
                        print(img.shape)
                    assert (fg_mask is None)
                if 'invert' in self.augmentation and random.random() < 0.25:
                    img = 1 - img

            if self.include_stroke_aug:
                new_img = augmentation.change_thickness(
                    img, thickness_change, fg_shade, bg_shade, blur_size,
                    noise_sigma)
                new_img = new_img * 2 - 1.0

            if len(img.shape) == 2:
                img = img[..., None]

            img = img.astype(np.float32)
            img = 1.0 - img / 128.0

            if self.train:
                gt = self.labels[idx]
            else:
                with open(self.gt_filename) as f:
                    for i in range(0, idx + 1):
                        gt = f.readline()
                gt = gt.strip()
            if gt is None:
                #metadata = pyexiv2.ImageMetadata(img_path)
                #metadata.read()
                #metadata = piexif.load(img_path)
                #if 'gt' in metadata:
                #    gt = metadata['gt']
                #else:
                print('Error unknown label for image: {}'.format(img_path))
                return self.__getitem__((idx + 7) % self.set_size)

            gt_label = string_utils.str2label_single(gt, self.char_to_idx)

            font_idx = '?'
            toRet = {
                "image": img,
                "gt": gt,
                "gt_label": gt_label,
                "author": font_idx,
                "name": '{}_{}'.format(idx + b, font_idx),
                "style": None,
                "spaced_label": None
            }
            if self.use_fg_mask:
                toRet['fg_mask'] = fg_mask
            if self.include_stroke_aug:
                toRet['changed_image'] = new_img
            batch.append(toRet)

        dim0 = batch[0]['image'].shape[0]
        dim1 = max([b['image'].shape[1] for b in batch])
        dim2 = batch[0]['image'].shape[2]

        all_labels = []
        label_lengths = []
        if self.spaced_by_name is not None:
            spaced_labels = []
        else:
            spaced_labels = None
        max_spaced_len = 0

        input_batch = np.full((len(batch), dim0, dim1, dim2),
                              PADDING_CONSTANT).astype(np.float32)
        if self.use_fg_mask:
            fg_masks = np.full((len(batch), dim0, dim1, 1),
                               0).astype(np.float32)
        if self.include_stroke_aug:
            changed_batch = np.full((len(batch), dim0, dim1, dim2),
                                    PADDING_CONSTANT).astype(np.float32)
        for i in range(len(batch)):
            b_img = batch[i]['image']
            toPad = (dim1 - b_img.shape[1])
            if 'center' in batch[0] and batch[0]['center']:
                toPad //= 2
            else:
                toPad = 0
            input_batch[i, :, toPad:toPad + b_img.shape[1], :] = b_img
            if self.use_fg_mask:
                fg_masks[i, :, toPad:toPad + b_img.shape[1],
                         0] = batch[i]['fg_mask']
            if self.include_stroke_aug:
                changed_batch[i, :, toPad:toPad + b_img.shape[1],
                              0] = batch[i]['changed_image']

            l = batch[i]['gt_label']
            all_labels.append(l)
            label_lengths.append(len(l))

            if spaced_labels is not None:
                sl = batch[i]['spaced_label']
                spaced_labels.append(sl)
                max_spaced_len = max(max_spaced_len, sl.shape[0])

        #all_labels = np.concatenate(all_labels)
        label_lengths = torch.IntTensor(label_lengths)
        max_len = label_lengths.max()
        all_labels = [
            np.pad(l, ((0, max_len - l.shape[0]), ), 'constant')
            for l in all_labels
        ]
        all_labels = np.stack(all_labels, axis=1)
        if self.spaced_by_name is not None:
            spaced_labels = [
                np.pad(l, ((0, max_spaced_len - l.shape[0]), (0, 0)),
                       'constant') for l in spaced_labels
            ]
            ddd = spaced_labels
            spaced_labels = np.concatenate(spaced_labels, axis=1)
            spaced_labels = torch.from_numpy(spaced_labels)
            assert (spaced_labels.size(1) == len(batch))

        images = input_batch.transpose([0, 3, 1, 2])
        images = torch.from_numpy(images)
        labels = torch.from_numpy(all_labels.astype(np.int32))
        #label_lengths = torch.from_numpy(label_lengths.astype(np.int32))
        if self.use_fg_mask:
            fg_masks = fg_masks.transpose([0, 3, 1, 2])
            fg_masks = torch.from_numpy(fg_masks)

        if batch[0]['style'] is not None:
            styles = np.stack([b['style'] for b in batch], axis=0)
            styles = torch.from_numpy(styles).float()
        else:
            styles = None
        mask, top_and_bottom, center_line = makeMask(images, self.mask_post,
                                                     self.mask_random)
        toRet = {
            "image": images,
            "mask": mask,
            "top_and_bottom": top_and_bottom,
            "center_line": center_line,
            "label": labels,
            "style": styles,
            "label_lengths": label_lengths,
            "gt": [b['gt'] for b in batch],
            "spaced_label": spaced_labels,
            "name": [b['name'] for b in batch],
            "author": [b['author'] for b in batch],
        }
        if self.use_fg_mask:
            toRet['fg_mask'] = fg_masks
        if self.include_stroke_aug:
            changed_images = changed_batch.transpose([0, 3, 1, 2])
            changed_images = torch.from_numpy(changed_images)
            toRet['changed_image'] = changed_images
        return toRet
    def __getitem__(self, idx):

        inst = self.lineIndex[idx]
        author = inst[0]
        lines = inst[1]
        batch = []
        for line in lines:
            if line >= len(self.authors[author]):
                line = (line + 37) % len(self.authors[author])
            img_path, gt, pad_above, pad_below = self.authors[author][line]
            img = cv2.imread(img_path, 0)  #read as grayscale
            if img is None:
                return None

            if pad_above < 0:
                img = img[-pad_above:, :]
                pad_above = 0
            if pad_below < 0:
                img = img[:pad_below, :]
                pad_below = 0
            #if pad_above>0 or pad_below>0:
            img = img = np.pad(img, ((pad_above, pad_below), (10, 10)),
                               'constant',
                               constant_values=255)
            #we also pad a bit on the sides
            #print('{}, {} {}'.format(img_path,pad_above,pad_below))

            if img.shape[0] != self.img_height:
                if img.shape[0] < self.img_height and not self.warning:
                    self.warning = True
                    print("WARNING: upsampling image to fit size")
                percent = float(self.img_height) / img.shape[0]
                if img.shape[1] * percent > self.max_width:
                    percent = self.max_width / img.shape[1]
                img = cv2.resize(img, (0, 0),
                                 fx=percent,
                                 fy=percent,
                                 interpolation=cv2.INTER_CUBIC)
                if img.shape[0] < self.img_height:
                    diff = self.img_height - img.shape[0]
                    img = np.pad(img,
                                 ((diff // 2, diff // 2 + diff % 2), (0, 0)),
                                 'constant',
                                 constant_values=255)

            if len(img.shape) == 2:
                img = img[..., None]
            if self.fg_masks_dir is not None:
                fg_path = os.path.join(self.fg_masks_dir,
                                       '{}_{}.png'.format(author, line))
                fg_mask = cv2.imread(fg_path, 0)
                fg_mask = fg_mask / 255
                if fg_mask.shape != img[:, :, 0].shape:
                    print(
                        'Error, fg_mask ({}, {}) not the same size as image ({})'
                        .format(fg_path, fg_mask.shape, img[:, :, 0].shape))
                    th, fg_mask = cv2.threshold(
                        img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
                    fg_mask = 255 - fg_mask
                    ele = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9))
                    fg_mask = cv2.dilate(fg_mask, ele)
                    fg_mask = fg_mask / 255

            if self.augmentation is not None:
                #img = augmentation.apply_random_color_rotation(img)
                img = augmentation.apply_tensmeyer_brightness(img)
                img = grid_distortion.warp_image(img)
                if len(img.shape) == 2:
                    img = img[..., None]

            img = img.astype(np.float32)
            img = 1.0 - img / 128.0

            if len(gt) == 0:
                return None
            gt_label = string_utils.str2label_single(gt, self.char_to_idx)

            if self.styles:
                style_i = self.npr.choice(len(self.styles[author][id]))
                style = self.styles[author][id][style_i]
            else:
                style = None
            name = img_path[img_path.rfind('/') + 1:img_path.rfind('.')]
            spaced_label = None if self.spaced_by_name is None else self.spaced_by_name[
                img_path]
            if spaced_label is not None:
                assert (spaced_label.shape[1] == 1)
            toAppend = {
                "image": img,
                "gt": gt,
                "style": style,
                "gt_label": gt_label,
                "spaced_label": spaced_label,
                "name": name,
                "center": self.center,
                "author": author
            }
            if self.fg_masks_dir is not None:
                toAppend['fg_mask'] = fg_mask
            batch.append(toAppend)
        #batch = [b for b in batch if b is not None]
        #These all should be the same size or error
        assert len(set([b['image'].shape[0] for b in batch])) == 1
        assert len(set([b['image'].shape[2] for b in batch])) == 1

        dim0 = batch[0]['image'].shape[0]
        dim1 = max([b['image'].shape[1] for b in batch])
        dim2 = batch[0]['image'].shape[2]

        all_labels = []
        label_lengths = []
        if self.spaced_by_name is not None:
            spaced_labels = []
        else:
            spaced_labels = None
        max_spaced_len = 0

        input_batch = np.full((len(batch), dim0, dim1, dim2),
                              PADDING_CONSTANT).astype(np.float32)
        if self.fg_masks_dir is not None:
            fg_masks = np.full((len(batch), dim0, dim1, 1),
                               0).astype(np.float32)
        for i in range(len(batch)):
            b_img = batch[i]['image']
            toPad = (dim1 - b_img.shape[1])
            if 'center' in batch[0] and batch[0]['center']:
                toPad //= 2
            else:
                toPad = 0
            input_batch[i, :, toPad:toPad + b_img.shape[1], :] = b_img
            if self.fg_masks_dir is not None:
                fg_masks[i, :, toPad:toPad + b_img.shape[1],
                         0] = batch[i]['fg_mask']

            l = batch[i]['gt_label']
            all_labels.append(l)
            label_lengths.append(len(l))

            if spaced_labels is not None:
                sl = batch[i]['spaced_label']
                spaced_labels.append(sl)
                max_spaced_len = max(max_spaced_len, sl.shape[0])

        #all_labels = np.concatenate(all_labels)
        label_lengths = torch.IntTensor(label_lengths)
        max_len = label_lengths.max()
        all_labels = [
            np.pad(l, ((0, max_len - l.shape[0]), ), 'constant')
            for l in all_labels
        ]
        all_labels = np.stack(all_labels, axis=1)
        if self.spaced_by_name is not None:
            spaced_labels = [
                np.pad(l, ((0, max_spaced_len - l.shape[0]), (0, 0)),
                       'constant') for l in spaced_labels
            ]
            ddd = spaced_labels
            spaced_labels = np.concatenate(spaced_labels, axis=1)
            spaced_labels = torch.from_numpy(spaced_labels)
            assert (spaced_labels.size(1) == len(batch))

        images = input_batch.transpose([0, 3, 1, 2])
        images = torch.from_numpy(images)
        labels = torch.from_numpy(all_labels.astype(np.int32))
        #label_lengths = torch.from_numpy(label_lengths.astype(np.int32))
        if self.fg_masks_dir is not None:
            fg_masks = fg_masks.transpose([0, 3, 1, 2])
            fg_masks = torch.from_numpy(fg_masks)

        if batch[0]['style'] is not None:
            styles = np.stack([b['style'] for b in batch], axis=0)
            styles = torch.from_numpy(styles).float()
        else:
            styles = None
        mask, top_and_bottom, center_line = makeMask(images, self.mask_post,
                                                     self.mask_random)
        ##DEBUG
        #for i in range(5):
        #    mask2, top_and_bottom2 = makeMask(images,self.mask_post, self.mask_random)
        #    #extra_masks.append(mask2)
        #    mask2 = ((mask2[0,0]+1)/2).numpy().astype(np.uint8)*255
        #    cv2.imshow('mask{}'.format(i),mask2)
        #mask = ((mask[0,0]+1)/2).numpy().astype(np.uint8)*255
        #cv2.imshow('mask'.format(i),mask)
        #cv2.waitKey()
        toRet = {
            "image": images,
            "mask": mask,
            "top_and_bottom": top_and_bottom,
            "center_line": center_line,
            "label": labels,
            "style": styles,
            "label_lengths": label_lengths,
            "gt": [b['gt'] for b in batch],
            "spaced_label": spaced_labels,
            "name": [b['name'] for b in batch],
            "author": [b['author'] for b in batch],
        }
        if self.fg_masks_dir is not None:
            toRet['fg_mask'] = fg_masks
        return toRet