示例#1
0
    def cluster(self,k,sample_count,outPath):
        def makePointsAndRects(h,w,r=None):
            if r is None:
                return np.array([-w/2.0,0,w/2.0,0,0,-h/2.0,0,h/2.0, 0,0, 0, h,w])
            else:
                lx= -math.cos(r)*w
                ly= -math.sin(r)*w
                rx= math.cos(r)*w
                ry= math.sin(r)*w
                tx= math.sin(r)*h
                ty= -math.cos(r)*h
                bx= -math.sin(r)*h
                by= math.cos(r)*h
                return np.array([lx,ly,rx,ry,tx,ty,bx,by, 0,0, r, h,w])
        meanH=62.42
        stdH=87.31
        meanW=393.03
        stdW=533.53
        ratios=[4.0,7.18,11.0,15.0,19.0,27.0]
        pointsAndRects=[]
        for inst in self.images:
            annotationPath = inst['annotationPath']
            #rescaled = inst['rescaled']
            with open(annotationPath) as annFile:
                annotations = json.loads(annFile.read())
            fixAnnotations(self,annotations)
            for i in range(sample_count):
                if i==0:
                    s = (self.rescale_range[0]+self.rescale_range[1])/2
                else:
                    s = np.random.uniform(self.rescale_range[0], self.rescale_range[1])
                #partial_rescale = s/rescaled
                bbs = getBBWithPoints(annotations['byId'].values(),s)
                #field_bbs = self.getBBGT(annotations['fieldBBs'],s,fields=True)
                #bbs = np.concatenate([text_bbs,field_bbs],axis=1)
                bbs = convertBBs(bbs,self.rotate,2).numpy()[0]
                cos_rot = np.cos(bbs[:,2])
                sin_rot = np.sin(bbs[:,2])
                p_left_x = -cos_rot*bbs[:,4]
                p_left_y = -sin_rot*bbs[:,4]
                p_right_x = cos_rot*bbs[:,4]
                p_right_y = sin_rot*bbs[:,4]
                p_top_x = sin_rot*bbs[:,3]
                p_top_y = -cos_rot*bbs[:,3]
                p_bot_x = -sin_rot*bbs[:,3]
                p_bot_y = cos_rot*bbs[:,3]
                points = np.stack([p_left_x,p_left_y,p_right_x,p_right_y,p_top_x,p_top_y,p_bot_x,p_bot_y],axis=1)
                pointsAndRects.append(np.concatenate([points,bbs[:,:5]],axis=1))
        pointsAndRects = np.concatenate(pointsAndRects,axis=0)
        #all_points = pointsAndRects[:,0:8]
        #all_heights = pointsAndRects[:,11]
        #all_widths = pointsAndRects[:,12]
        
        bestDistsFromMean=None
        for attempt in range(20 if k>0 else 1):
            if k>0:
                randomIndexes = np.random.randint(0,pointsAndRects.shape[0],(k))
                means=pointsAndRects[randomIndexes]
            else:
                #minH=5
                #minW=5
                means=[]

                ##smaller than mean
                #for step in range(5):
                #    height = minH + (meanH-minH)*(step/5.0)
                #    width = minW + (meanW-minW)*(step/5.0)
                #    for ratio in ratios:
                #        means.append(makePointsAndRects(height,ratio*height))
                #        means.append(makePointsAndRects(width/ratio,width))
                #for stddev in range(0,5):
                #    for step in range(5-stddev):
                #        height = meanH + stddev*stdH + stdH*(step/(5.0-stddev))
                #        width = meanW + stddev*stdW + stdW*(step/(5.0-stddev))
                #        for ratio in ratios:
                #            means.append(makePointsAndRects(height,ratio*height))
                #            means.append(makePointsAndRects(width/ratio,width))
                rots = [0,math.pi/2,math.pi,1.5*math.pi]
                if self.rotate:
                    for height in np.linspace(15,200,num=4):
                        for width in np.linspace(30,1200,num=4):
                            for rot in rots:
                                means.append(makePointsAndRects(height,width,rot))
                        #long boxes
                    for width in np.linspace(1600,4000,num=3):
                        #for height in np.linspace(30,100,num=3):
                        #    for rot in rots:
                        #        means.append(makePointsAndRects(height,width,rot))
                        for rot in rots:
                            means.append(makePointsAndRects(50,width,rot))
                else:
                    #rotated boxes
                    #for height in np.linspace(13,300,num=4):
                    for height in np.linspace(13,300,num=3):
                        means.append(makePointsAndRects(height,20))
                    #general boxes
                    #for height in np.linspace(15,200,num=4):
                        #for width in np.linspace(30,1200,num=4):
                    for height in np.linspace(15,200,num=2):
                        for width in np.linspace(30,1200,num=3):
                            means.append(makePointsAndRects(height,width))
                    #long boxes
                    for width in np.linspace(1600,4000,num=3):
                        #for height in np.linspace(30,100,num=3):
                        #    means.append(makePointsAndRects(height,width))
                        means.append(makePointsAndRects(50,width))

                k=len(means)
                print('K: {}'.format(k))
                means = np.stack(means,axis=0)
            #pointsAndRects [0:p_left_x, 1:p_left_y,2:p_right_x,3:p_right_y,4:p_top_x,5:p_top_y,6:p_bot_x,7:p_bot_y, 8:xc, 9:yc, 10:rot, 11:h, 12:w
            cluster_centers=means
            distsFromMean=None
            prevDistsFromMean=None
            for iteration in range(100000): #intended to break out
                print('attempt:{}, bestDistsFromMean:{}, iteration:{}, bestDistsFromMean:{}'.format(attempt,bestDistsFromMean,iteration,prevDistsFromMean), end='\r')
                #means_points = means[:,0:8]
                #means_heights = means[:,11]
                #means_widths = means[:,12]
                # = groups = assignGroups(means,pointsAndRects)
                expanded_all_points = pointsAndRects[:,None,0:8]
                expanded_all_heights = pointsAndRects[:,None,11]
                expanded_all_widths = pointsAndRects[:,None,12]

                expanded_means_points = means[None,:,0:8]
                expanded_means_heights = means[None,:,11]
                expanded_means_widths = means[None,:,12]

                #expanded_all_points = expanded_all_points.expand(all_points.shape[0], all_points.shape[1], means_points.shape[1], all_points.shape[2])
                expanded_all_points = np.tile(expanded_all_points,(1,means.shape[0],1))
                expanded_all_heights = np.tile(expanded_all_heights,(1,means.shape[0]))
                expanded_all_widths = np.tile(expanded_all_widths,(1,means.shape[0]))
                #expanded_means_points = expanded_means_points.expand(means_points.shape[0], all_points.shape[0], means_points.shape[0], means_points.shape[2])
                expanded_means_points = np.tile(expanded_means_points,(pointsAndRects.shape[0],1,1))
                expanded_means_heights = np.tile(expanded_means_heights,(pointsAndRects.shape[0],1))
                expanded_means_widths = np.tile(expanded_means_widths,(pointsAndRects.shape[0],1))

                point_deltas = (expanded_all_points - expanded_means_points)
                #avg_heights = ((expanded_means_heights+expanded_all_heights)/2)
                #avg_widths = ((expanded_means_widths+expanded_all_widths)/2)
                avg_heights=avg_widths = (expanded_means_heights+expanded_all_heights+expanded_means_widths+expanded_all_widths)/4
                #print point_deltas

                normed_difference = (
                    np.linalg.norm(point_deltas[:,:,0:2],2,2)/avg_widths +
                    np.linalg.norm(point_deltas[:,:,2:4],2,2)/avg_widths +
                    np.linalg.norm(point_deltas[:,:,4:6],2,2)/avg_heights +
                    np.linalg.norm(point_deltas[:,:,6:8],2,2)/avg_heights
                    )**2
                #print normed_difference
                #import pdb; pdb.set_trace()

                groups = normed_difference.argmin(1) #this should list the mean (index) for each element of all
                distsFromMean = normed_difference.min(1).mean()
                if prevDistsFromMean is not None and distsFromMean >= prevDistsFromMean:
                    break
                prevDistsFromMean = distsFromMean

                #means = computeMeans(groups,pointsAndRects)
                #means = np.zeros(k,13)
                for ki in range(k):
                    selected = (groups==ki)[:,None]
                    numSel = float(selected.sum())
                    if (numSel==0):
                        break
                    means[ki,:] = (pointsAndRects*np.tile(selected,(1,13))).sum(0)/numSel
            if bestDistsFromMean is None or distsFromMean<bestDistsFromMean:
                bestDistsFromMean = distsFromMean
                cluster_centers=means
        #cluster_centers=means
        dH=600
        dW=3000
        draw = np.zeros([dH,dW,3],dtype=np.float)
        toWrite = []
        final_k=k
        for ki in range(k):
            pop = (groups==ki).sum().item()
            if pop>2:
                color = np.random.uniform(0.2,1,3).tolist()
                #d=math.sqrt(mean[ki,11]**2 + mean[ki,12]**2)
                #theta = math.atan2(mean[ki,11],mean[ki,12]) + mean[ki,10]
                h=cluster_centers[ki,11]
                w=cluster_centers[ki,12]
                rot=cluster_centers[ki,10]
                toWrite.append({'height':h.item(),'width':w.item(),'rot':rot.item(),'popularity':pop})
                tr = ( int(math.cos(rot)*w-math.sin(rot)*h)+dW//2,   int(math.sin(rot)*w+math.cos(rot)*h)+dH//2 )
                tl = ( int(math.cos(rot)*-w-math.sin(rot)*h)+dW//2,  int(math.sin(rot)*-w+math.cos(rot)*h)+dH//2 )
                br = ( int(math.cos(rot)*w-math.sin(rot)*-h)+dW//2,  int(math.sin(rot)*w+math.cos(rot)*-h)+dH//2 )
                bl = ( int(math.cos(rot)*-w-math.sin(rot)*-h)+dW//2, int(math.sin(rot)*-w+math.cos(rot)*-h)+dH//2 )
                
                cv2.line(draw,tl,tr,color)
                cv2.line(draw,tr,br,color)
                cv2.line(draw,br,bl,color)
                cv2.line(draw,bl,tl,color,2)
            else:
                final_k-=1
        
        #print(toWrite)
        with open(outPath.format(final_k),'w') as out:
            out.write(json.dumps(toWrite))
            print('saved '+outPath.format(final_k))
        cv2.imshow('clusters',draw)
        cv2.waitKey()
    def getitem(self, index, scaleP=None, cropPoint=None):
        ##ticFull=timeit.default_timer()
        imagePath = self.images[index]['imagePath']
        imageName = self.images[index]['imageName']
        annotationPath = self.images[index]['annotationPath']
        #print(annotationPath)
        rescaled = self.images[index]['rescaled']
        with open(annotationPath) as annFile:
            annotations = json.loads(annFile.read())

        ##tic=timeit.default_timer()
        np_img = cv2.imread(imagePath, 1 if self.color else 0)  #/255.0
        if np_img is None or np_img.shape[0] == 0:
            print("ERROR, could not open " + imagePath)
            return self.__getitem__((index + 1) % self.__len__())
        if scaleP is None:
            s = np.random.uniform(self.rescale_range[0], self.rescale_range[1])
        else:
            s = scaleP
        partial_rescale = s / rescaled
        if self.transform is None:  #we're doing the whole image
            #this is a check to be sure we don't send too big images through
            pixel_count = partial_rescale * partial_rescale * np_img.shape[
                0] * np_img.shape[1]
            if pixel_count > self.pixel_count_thresh:
                partial_rescale = math.sqrt(partial_rescale * partial_rescale *
                                            self.pixel_count_thresh /
                                            pixel_count)
                print('{} exceed thresh: {}: {}, new {}: {}'.format(
                    imageName, s, pixel_count, rescaled * partial_rescale,
                    partial_rescale * partial_rescale * np_img.shape[0] *
                    np_img.shape[1]))
                s = rescaled * partial_rescale

            max_dim = partial_rescale * max(np_img.shape[0], np_img.shape[1])
            if max_dim > self.max_dim_thresh:
                partial_rescale = partial_rescale * (self.max_dim_thresh /
                                                     max_dim)
                print('{} exceed thresh: {}: {}, new {}: {}'.format(
                    imageName, s, max_dim, rescaled * partial_rescale,
                    partial_rescale * max(np_img.shape[0], np_img.shape[1])))
                s = rescaled * partial_rescale

        ##tic=timeit.default_timer()
        #np_img = cv2.resize(np_img,(target_dim1, target_dim0), interpolation = cv2.INTER_CUBIC)
        np_img = cv2.resize(np_img, (0, 0),
                            fx=partial_rescale,
                            fy=partial_rescale,
                            interpolation=cv2.INTER_CUBIC)
        if not self.color:
            np_img = np_img[..., None]  #add 'color' channel
        ##print('resize: {}  [{}, {}]'.format(timeit.default_timer()-tic,np_img.shape[0],np_img.shape[1]))

        ##tic=timeit.default_timer()

        bbs, ids, numClasses, trans = self.parseAnn(annotations, s)

        #start_of_line, end_of_line = getStartEndGT(annotations['byId'].values(),s)
        #Try:
        #    table_points, table_pixels = self.getTables(
        #            fieldBBs,
        #            s,
        #            np_img.shape[0],
        #            np_img.shape[1],
        #            annotations['samePairs'])
        #Except Exception as inst:
        #    if imageName not in self.errors:
        #        table_points=None
        #        table_pixels=None
        #        print(inst)
        #        print('Table error on: '+imagePath)
        #        self.errors.append(imageName)

        #pixel_gt = table_pixels

        ##ticTr=timeit.default_timer()
        if self.transform is not None:
            out, cropPoint = self.transform(
                {
                    "img": np_img,
                    "bb_gt": bbs,
                    'bb_auxs': ids,
                    #"line_gt": {
                    #    "start_of_line": start_of_line,
                    #    "end_of_line": end_of_line
                    #    },
                    #"point_gt": {
                    #        "table_points": table_points
                    #        },
                    #"pixel_gt": pixel_gt,
                },
                cropPoint)
            np_img = out['img']
            bbs = out['bb_gt']
            ids = out['bb_auxs']

            ##tic=timeit.default_timer()
            if np_img.shape[2] == 3:
                np_img = augmentation.apply_random_color_rotation(np_img)
                np_img = augmentation.apply_tensmeyer_brightness(np_img)
            else:
                np_img = augmentation.apply_tensmeyer_brightness(np_img)
            ##print('augmentation: {}'.format(timeit.default_timer()-tic))
        ##print('transfrm: {}  [{}, {}]'.format(timeit.default_timer()-ticTr,org_img.shape[0],org_img.shape[1]))
        pairs = set()
        #import pdb;pdb.set_trace()
        numNeighbors = [0] * len(ids)
        for index1, id in enumerate(ids):  #updated
            responseBBIdList = self.getResponseBBIdList(id, annotations)
            for bbId in responseBBIdList:
                try:
                    index2 = ids.index(bbId)
                    #adjMatrix[min(index1,index2),max(index1,index2)]=1
                    pairs.add((min(index1, index2), max(index1, index2)))
                    numNeighbors[index1] += 1
                except ValueError:
                    pass
        #ones = torch.ones(len(pairs))
        #if len(pairs)>0:
        #    pairs = torch.LongTensor(list(pairs)).t()
        #else:
        #    pairs = torch.LongTensor(pairs)
        #adjMatrix = torch.sparse.FloatTensor(pairs,ones,(len(ids),len(ids))) # This is an upper diagonal matrix as pairings are bi-directional

        #if len(np_img.shape)==2:
        #    img=np_img[None,None,:,:] #add "color" channel and batch
        #else:
        img = np_img.transpose(
            [2, 0, 1])[None,
                       ...]  #from [row,col,color] to [batch,color,row,col]
        img = img.astype(np.float32)
        img = torch.from_numpy(img)
        img = 1.0 - img / 128.0  #ideally the median value would be 0
        #if pixel_gt is not None:
        #    pixel_gt = pixel_gt.transpose([2,0,1])[None,...]
        #    pixel_gt = torch.from_numpy(pixel_gt)

        #start_of_line = None if start_of_line is None or start_of_line.shape[1] == 0 else torch.from_numpy(start_of_line)
        #end_of_line = None if end_of_line is None or end_of_line.shape[1] == 0 else torch.from_numpy(end_of_line)

        bbs = convertBBs(bbs, self.rotate, numClasses)
        if len(numNeighbors) > 0:
            numNeighbors = torch.tensor(numNeighbors)[None, :]  #add batch dim
        else:
            numNeighbors = None
        #if table_points is not None:
        #    table_points = None if table_points.shape[1] == 0 else torch.from_numpy(table_points)

        return {
            "img": img,
            "bb_gt": bbs,
            "num_neighbors": numNeighbors,
            "adj": pairs,  #adjMatrix,
            "imgName": imageName,
            "scale": s,
            "cropPoint": cropPoint,
            "transcription": [trans[id] for id in ids if id in trans]
        }
    def getitem(self, index, scaleP=None, cropPoint=None):
        if self.useRandomAugProb is not None and np.random.rand(
        ) < self.useRandomAugProb and scaleP is None and cropPoint is None:
            return self.getRandomImage()
        ##ticFull=timeit.default_timer()
        imagePath = self.images[index]['imagePath']
        imageName = self.images[index]['imageName']
        annotationPath = self.images[index]['annotationPath']
        #print(annotationPath)
        rescaled = self.images[index]['rescaled']
        with open(annotationPath) as annFile:
            annotations = json.loads(annFile.read())

        ##tic=timeit.default_timer()
        np_img = cv2.imread(imagePath, 1 if self.color else 0)  #/255.0
        if np_img is None or np_img.shape[0] == 0:
            print("ERROR, could not open " + imagePath)
            return self.__getitem__((index + 1) % self.__len__())

        if scaleP is None:
            s = np.random.uniform(self.rescale_range[0], self.rescale_range[1])
        else:
            s = scaleP
        partial_rescale = s / rescaled
        if self.transform is None:  #we're doing the whole image
            #this is a check to be sure we don't send too big images through
            pixel_count = partial_rescale * partial_rescale * np_img.shape[
                0] * np_img.shape[1]
            if pixel_count > self.pixel_count_thresh:
                partial_rescale = math.sqrt(partial_rescale * partial_rescale *
                                            self.pixel_count_thresh /
                                            pixel_count)
                print('{} exceed thresh: {}: {}, new {}: {}'.format(
                    imageName, s, pixel_count, rescaled * partial_rescale,
                    partial_rescale * partial_rescale * np_img.shape[0] *
                    np_img.shape[1]))
                s = rescaled * partial_rescale

            max_dim = partial_rescale * max(np_img.shape[0], np_img.shape[1])
            if max_dim > self.max_dim_thresh:
                partial_rescale = partial_rescale * (self.max_dim_thresh /
                                                     max_dim)
                print('{} exceed thresh: {}: {}, new {}: {}'.format(
                    imageName, s, max_dim, rescaled * partial_rescale,
                    partial_rescale * max(np_img.shape[0], np_img.shape[1])))
                s = rescaled * partial_rescale

        ##tic=timeit.default_timer()
        #np_img = cv2.resize(np_img,(target_dim1, target_dim0), interpolation = cv2.INTER_CUBIC)
        np_img = cv2.resize(np_img, (0, 0),
                            fx=partial_rescale,
                            fy=partial_rescale,
                            interpolation=cv2.INTER_CUBIC)
        if not self.color:
            np_img = np_img[..., None]  #add 'color' channel
        ##print('resize: {}  [{}, {}]'.format(timeit.default_timer()-tic,np_img.shape[0],np_img.shape[1]))

        bbs, line_gts, point_gts, pixel_gt, numClasses, numNeighbors, pairs = self.parseAnn(
            np_img, annotations, s, imagePath)

        if self.coordConv:  #add absolute position information
            xs = 255 * np.arange(np_img.shape[1]) / (np_img.shape[1])
            xs = np.repeat(xs[None, :, None], np_img.shape[0], axis=0)
            ys = 255 * np.arange(np_img.shape[0]) / (np_img.shape[0])
            ys = np.repeat(ys[:, None, None], np_img.shape[1], axis=1)
            np_img = np.concatenate(
                (np_img, xs.astype(np_img.dtype), ys.astype(np_img.dtype)),
                axis=2)

        ##ticTr=timeit.default_timer()
        if self.transform is not None:
            pairs = None
            out, cropPoint = self.transform(
                {
                    "img": np_img,
                    "bb_gt": bbs,
                    "bb_auxs": numNeighbors,
                    "line_gt": line_gts,
                    "point_gt": point_gts,
                    "pixel_gt": pixel_gt,
                }, cropPoint)
            np_img = out['img']
            bbs = out['bb_gt']
            numNeighbors = out['bb_auxs']
            #if 'table_points' in out['point_gt']:
            #    table_points = out['point_gt']['table_points']
            #else:
            #    table_points=None
            point_gts = out['point_gt']
            pixel_gt = out['pixel_gt']
            #start_of_line = out['line_gt']['start_of_line']
            #end_of_line = out['line_gt']['end_of_line']
            line_gts = out['line_gt']

            ##tic=timeit.default_timer()
            if self.color:
                np_img[:, :, :3] = augmentation.apply_random_color_rotation(
                    np_img[:, :, :3])
                np_img[:, :, :3] = augmentation.apply_tensmeyer_brightness(
                    np_img[:, :, :3])
            else:
                np_img[:, :, 0:1] = augmentation.apply_tensmeyer_brightness(
                    np_img[:, :, 0:1])
            ##print('augmentation: {}'.format(timeit.default_timer()-tic))
        ##print('transfrm: {}  [{}, {}]'.format(timeit.default_timer()-ticTr,org_img.shape[0],org_img.shape[1]))

        #if len(np_img.shape)==2:
        #    img=np_img[None,None,:,:] #add "color" channel and batch
        #else:
        img = np_img.transpose(
            [2, 0, 1])[None,
                       ...]  #from [row,col,color] to [batch,color,row,col]
        img = img.astype(np.float32)
        img = torch.from_numpy(img)
        img = 1.0 - img / 128.0  #ideally the median value would be 0
        #img = 1.0 - img / 255.0 #this way ink is on, page is off
        if pixel_gt is not None:
            pixel_gt = pixel_gt.transpose([2, 0, 1])[None, ...]
            pixel_gt = torch.from_numpy(pixel_gt)

        #start_of_line = None if start_of_line is None or start_of_line.shape[1] == 0 else torch.from_numpy(start_of_line)
        #end_of_line = None if end_of_line is None or end_of_line.shape[1] == 0 else torch.from_numpy(end_of_line)
        for name in line_gts:
            line_gts[name] = None if line_gts[name] is None or line_gts[
                name].shape[1] == 0 else torch.from_numpy(line_gts[name])

        #import pdb; pdb.set_trace()
        #bbs = None if bbs.shape[1] == 0 else torch.from_numpy(bbs)
        bbs = convertBBs(bbs, self.rotate, numClasses)
        if len(numNeighbors) > 0:
            numNeighbors = torch.tensor(numNeighbors)[None, :]  #add batch dim
        else:
            numNeighbors = None
            #start_of_line = convertLines(start_of_line,numClasses)
        #end_of_line = convertLines(end_of_line,numClasses)
        for name in point_gts:
            #if table_points is not None:
            #table_points = None if table_points.shape[1] == 0 else torch.from_numpy(table_points)
            if point_gts[name] is not None:
                point_gts[name] = None if point_gts[name].shape[
                    1] == 0 else torch.from_numpy(point_gts[name])

        ##print('__getitem__: '+str(timeit.default_timer()-ticFull))
        if self.only_types is None:
            return {
                "img": img,
                "bb_gt": bbs,
                "num_neighbors": numNeighbors,
                "line_gt": line_gts,
                "point_gt": point_gts,
                "pixel_gt": pixel_gt,
                "imgName": imageName,
                "scale": s,
                "cropPoint": cropPoint,
                "pairs": pairs
            }
        else:
            if 'boxes' not in self.only_types or not self.only_types['boxes']:
                bbs = None
            line_gt = {}
            if 'line' in self.only_types:
                for ent in self.only_types['line']:
                    if type(ent) == list:
                        toComb = []
                        for inst in ent[1:]:
                            einst = line_gts[inst]
                            if einst is not None:
                                toComb.append(einst)
                        if len(toComb) > 0:
                            comb = torch.cat(toComb, dim=1)
                            line_gt[ent[0]] = comb
                        else:
                            line_gt[ent[0]] = None
                    else:
                        line_gt[ent] = line_gts[ent]
            point_gt = {}
            if 'point' in self.only_types:
                for ent in self.only_types['point']:
                    if type(ent) == list:
                        toComb = []
                        for inst in ent[1:]:
                            einst = point_gts[inst]
                            if einst is not None:
                                toComb.append(einst)
                        if len(toComb) > 0:
                            comb = torch.cat(toComb, dim=1)
                            point_gt[ent[0]] = comb
                        else:
                            line_gt[ent[0]] = None
                    else:
                        point_gt[ent] = point_gts[ent]
            pixel_gtR = None
            #for ent in self.only_types['pixel']:
            #    if type(ent)==list:
            #        comb = ent[1]
            #        for inst in ent[2:]:
            #            comb = (comb + inst)==2 #:eq(2) #pixel-wise AND
            #        pixel_gt[ent[0]]=comb
            #    else:
            #        pixel_gt[ent]=eval(ent)
            if 'pixel' in self.only_types:  # and self.only_types['pixel'][0]=='table_pixels':
                pixel_gtR = pixel_gt

            return {
                "img": img,
                "bb_gt": bbs,
                "num_neighbors": numNeighbors,
                "line_gt": line_gt,
                "point_gt": point_gt,
                "pixel_gt": pixel_gtR,
                "imgName": imageName,
                "scale": s,
                "cropPoint": cropPoint,
                "pairs": pairs,
            }
示例#4
0
    def getitem(self, index, scaleP=None, cropPoint=None):
        ##ticFull=timeit.default_timer()
        imagePath = self.images[index]['imagePath']
        imageName = self.images[index]['imageName']
        annotationPath = self.images[index]['annotationPath']
        #print(annotationPath)
        rescaled = self.images[index]['rescaled']
        with open(annotationPath) as annFile:
            annotations = json.loads(annFile.read())

        ##tic=timeit.default_timer()
        np_img = img_f.imread(imagePath, 1 if self.color else 0)  #*255.0
        if np_img.max() < 200:
            np_img *= 255
        if np_img is None or np_img.shape[0] == 0:
            print("ERROR, could not open " + imagePath)
            return self.__getitem__((index + 1) % self.__len__())
        if scaleP is None:
            s = np.random.uniform(self.rescale_range[0], self.rescale_range[1])
        else:
            s = scaleP
        partial_rescale = s / rescaled
        if self.transform is None:  #we're doing the whole image
            #this is a check to be sure we don't send too big images through
            pixel_count = partial_rescale * partial_rescale * np_img.shape[
                0] * np_img.shape[1]
            if pixel_count > self.pixel_count_thresh:
                partial_rescale = math.sqrt(partial_rescale * partial_rescale *
                                            self.pixel_count_thresh /
                                            pixel_count)
                print('{} exceed thresh: {}: {}, new {}: {}'.format(
                    imageName, s, pixel_count, rescaled * partial_rescale,
                    partial_rescale * partial_rescale * np_img.shape[0] *
                    np_img.shape[1]))
                s = rescaled * partial_rescale

            max_dim = partial_rescale * max(np_img.shape[0], np_img.shape[1])
            if max_dim > self.max_dim_thresh:
                partial_rescale = partial_rescale * (self.max_dim_thresh /
                                                     max_dim)
                print('{} exceed thresh: {}: {}, new {}: {}'.format(
                    imageName, s, max_dim, rescaled * partial_rescale,
                    partial_rescale * max(np_img.shape[0], np_img.shape[1])))
                s = rescaled * partial_rescale

        ##tic=timeit.default_timer()
        #np_img = img_f.resize(np_img,(target_dim1, target_dim0))
        np_img = img_f.resize(
            np_img,
            (0, 0),
            fx=partial_rescale,
            fy=partial_rescale,
        )
        if len(np_img.shape) == 2:
            np_img = np_img[..., None]  #add 'color' channel
        if self.color and np_img.shape[2] == 1:
            np_img = np.repeat(np_img, 3, axis=2)
        ##print('resize: {}  [{}, {}]'.format(timeit.default_timer()-tic,np_img.shape[0],np_img.shape[1]))

        ##tic=timeit.default_timer()

        bbs, ids, numClasses, trans, groups, metadata, form_metadata = self.parseAnn(
            annotations, s)
        #trans = {i:v for i,v in enumerate(trans)}
        #metadata = {i:v for i,v in enumerate(metadata)}

        #start_of_line, end_of_line = getStartEndGT(annotations['byId'].values(),s)
        #Try:
        #    table_points, table_pixels = self.getTables(
        #            fieldBBs,
        #            s,
        #            np_img.shape[0],
        #            np_img.shape[1],
        #            annotations['samePairs'])
        #Except Exception as inst:
        #    if imageName not in self.errors:
        #        table_points=None
        #        table_pixels=None
        #        print(inst)
        #        print('Table error on: '+imagePath)
        #        self.errors.append(imageName)

        #pixel_gt = table_pixels

        ##ticTr=timeit.default_timer()
        if self.questions:  #we need to do questions before crop to have full context
            #we have to relationships to get questions
            pairs = set()
            for index1, id in enumerate(ids):  #updated
                responseBBIdList = self.getResponseBBIdList(id, annotations)
                for bbId in responseBBIdList:
                    try:
                        index2 = ids.index(bbId)
                        pairs.add((min(index1, index2), max(index1, index2)))
                    except ValueError:
                        pass
            groups_adj = set()
            if groups is not None:
                for n0, n1 in pairs:
                    g0 = -1
                    g1 = -1
                    for i, ns in enumerate(groups):
                        if n0 in ns:
                            g0 = i
                            if g1 != -1:
                                break
                        if n1 in ns:
                            g1 = i
                            if g0 != -1:
                                break
                    if g0 != g1:
                        groups_adj.add((min(g0, g1), max(g0, g1)))
            questions_and_answers = self.makeQuestions(bbs, trans, groups,
                                                       groups_adj)
        else:
            questions_and_answers = None

        if self.transform is not None:
            if 'word_boxes' in form_metadata:
                word_bbs = form_metadata['word_boxes']
                dif_f = bbs.shape[2] - word_bbs.shape[1]
                blank = np.zeros([word_bbs.shape[0], dif_f])
                prep_word_bbs = np.concatenate([word_bbs, blank], axis=1)[None,
                                                                          ...]
                crop_bbs = np.concatenate([bbs, prep_word_bbs], axis=1)
                crop_ids = ids + [
                    'word{}'.format(i) for i in range(word_bbs.shape[0])
                ]
            else:
                crop_bbs = bbs
                crop_ids = ids
            out, cropPoint = self.transform(
                {
                    "img": np_img,
                    "bb_gt": crop_bbs,
                    'bb_auxs': crop_ids,
                    #'word_bbs':form_metadata['word_boxes'] if 'word_boxes' in form_metadata else None
                    #"line_gt": {
                    #    "start_of_line": start_of_line,
                    #    "end_of_line": end_of_line
                    #    },
                    #"point_gt": {
                    #        "table_points": table_points
                    #        },
                    #"pixel_gt": pixel_gt,
                },
                cropPoint)
            np_img = out['img']

            if 'word_boxes' in form_metadata:
                saw_word = False
                word_index = -1
                for i, ii in enumerate(out['bb_auxs']):
                    if not saw_word:
                        if type(ii) is str and 'word' in ii:
                            saw_word = True
                            word_index = i
                    else:
                        assert 'word' in ii
                bbs = out['bb_gt'][:, :word_index]
                ids = out['bb_auxs'][:word_index]
                form_metadata['word_boxes'] = out['bb_gt'][0, word_index:, :8]
                word_ids = out['bb_auxs'][word_index:]
                form_metadata['word_trans'] = [
                    form_metadata['word_trans'][int(id[4:])] for id in word_ids
                ]
            else:
                bbs = out['bb_gt']
                ids = out['bb_auxs']

            if questions_and_answers is not None:
                questions = []
                answers = []
                questions_and_answers = [
                    (q, a, qids) for q, a, qids in questions_and_answers
                    if all((i in ids) for i in qids)
                ]
        if questions_and_answers is not None:
            if len(questions_and_answers) > self.questions:
                questions_and_answers = random.sample(questions_and_answers,
                                                      k=self.questions)
            if len(questions_and_answers) > 0:
                questions, answers, _ = zip(*questions_and_answers)
            else:
                return self.getitem((index + 1) % len(self))
        else:
            questions = answers = None

            ##tic=timeit.default_timer()
            if np_img.shape[2] == 3:
                np_img = augmentation.apply_random_color_rotation(np_img)
                np_img = augmentation.apply_tensmeyer_brightness(
                    np_img, **self.aug_params)
            else:
                np_img = augmentation.apply_tensmeyer_brightness(
                    np_img, **self.aug_params)
            ##print('augmentation: {}'.format(timeit.default_timer()-tic))
        newGroups = []
        for group in groups:
            newGroup = [ids.index(bbId) for bbId in group if bbId in ids]
            if len(newGroup) > 0:
                newGroups.append(newGroup)
                #print(len(newGroups)-1,newGroup)
        groups = newGroups
        ##print('transfrm: {}  [{}, {}]'.format(timeit.default_timer()-ticTr,org_img.shape[0],org_img.shape[1]))
        pairs = set()
        #import pdb;pdb.set_trace()
        numNeighbors = [0] * len(ids)
        for index1, id in enumerate(ids):  #updated
            responseBBIdList = self.getResponseBBIdList(id, annotations)
            for bbId in responseBBIdList:
                try:
                    index2 = ids.index(bbId)
                    #adjMatrix[min(index1,index2),max(index1,index2)]=1
                    pairs.add((min(index1, index2), max(index1, index2)))
                    numNeighbors[index1] += 1
                except ValueError:
                    pass
        #ones = torch.ones(len(pairs))
        #if len(pairs)>0:
        #    pairs = torch.LongTensor(list(pairs)).t()
        #else:
        #    pairs = torch.LongTensor(pairs)
        #adjMatrix = torch.sparse.FloatTensor(pairs,ones,(len(ids),len(ids))) # This is an upper diagonal matrix as pairings are bi-directional

        #if len(np_img.shape)==2:
        #    img=np_img[None,None,:,:] #add "color" channel and batch
        #else:
        img = np_img.transpose(
            [2, 0, 1])[None,
                       ...]  #from [row,col,color] to [batch,color,row,col]
        img = img.astype(np.float32)
        img = torch.from_numpy(img)
        img = 1.0 - img / 128.0  #ideally the median value would be 0
        #if pixel_gt is not None:
        #    pixel_gt = pixel_gt.transpose([2,0,1])[None,...]
        #    pixel_gt = torch.from_numpy(pixel_gt)

        #start_of_line = None if start_of_line is None or start_of_line.shape[1] == 0 else torch.from_numpy(start_of_line)
        #end_of_line = None if end_of_line is None or end_of_line.shape[1] == 0 else torch.from_numpy(end_of_line)

        bbs = convertBBs(bbs, self.rotate, numClasses)
        if 'word_boxes' in form_metadata:
            form_metadata['word_boxes'] = convertBBs(
                form_metadata['word_boxes'][None, ...], self.rotate, 0)[0, ...]
        if len(numNeighbors) > 0:
            numNeighbors = torch.tensor(numNeighbors)[None, :]  #add batch dim
        else:
            numNeighbors = None
        #if table_points is not None:
        #    table_points = None if table_points.shape[1] == 0 else torch.from_numpy(table_points)
        groups_adj = set()
        if groups is not None:
            for n0, n1 in pairs:
                g0 = -1
                g1 = -1
                for i, ns in enumerate(groups):
                    if n0 in ns:
                        g0 = i
                        if g1 != -1:
                            break
                    if n1 in ns:
                        g1 = i
                        if g0 != -1:
                            break
                if g0 != g1:
                    groups_adj.add((min(g0, g1), max(g0, g1)))
            for group in groups:
                for i in group:
                    assert (i < bbs.shape[1])
            targetIndexToGroup = {}
            for groupId, bbIds in enumerate(groups):
                targetIndexToGroup.update({bbId: groupId for bbId in bbIds})

        transcription = [trans[id] for id in ids]

        return {
            "img": img,
            "bb_gt": bbs,
            "num_neighbors": numNeighbors,
            "adj": pairs,  #adjMatrix,
            "imgName": imageName,
            "scale": s,
            "cropPoint": cropPoint,
            "transcription": transcription,
            "metadata": [metadata[id] for id in ids if id in metadata],
            "form_metadata": form_metadata,
            "gt_groups": groups,
            "targetIndexToGroup": targetIndexToGroup,
            "gt_groups_adj": groups_adj,
            "questions": questions,
            "answers": answers
        }