Beispiel #1
0
    def parseAnn(self,np_img,annotations,s,imageName):
        fieldBBs = annotations['fieldBBs']
        fixAnnotations(self,annotations)

        full_bbs=annotations['byId'].values()

        bbs = getBBWithPoints(full_bbs,s,useBlankClass=(not self.no_blanks),usePairedClass=self.use_paired_class)
        numClasses = bbs.shape[2]-16
        #field_bbs = getBBWithPoints(annotations['fieldBBs'],s)
        #bbs = np.concatenate([text_bbs,field_bbs],axis=1) #has batch dim
        start_of_line, end_of_line = getStartEndGT(full_bbs,s)
        try:
            table_points, table_pixels = self.getTables(
                    fieldBBs,
                    s, 
                    np_img.shape[0], 
                    np_img.shape[1],
                    annotations['pairs'])
        except Exception as inst:
            table_points=None
            table_pixels=None
            if imageName not in self.errors:
                #print(inst)
                #print('Table error on: '+imageName)
                self.errors.append(imageName)

        ##print('getStartEndGt: '+str(timeit.default_timer()-tic))

        pixel_gt = table_pixels

        line_gts = {
                    "start_of_line": start_of_line,
                    "end_of_line": end_of_line
                    }
        point_gts = {
                        "table_points": table_points
                        }
        
        numNeighbors=defaultdict(lambda:0)
        for id,bb in annotations['byId'].items():
            if not self.onlyFormStuff or ('paired' in bb and bb['paired']):
                responseIds = getResponseBBIdList_(self,id,annotations)
                for id2,bb2 in annotations['byId'].items():
                    if id!=id2:
                        pair = id2 in responseIds
                        if pair:
                            numNeighbors[id]+=1
        numNeighbors = [numNeighbors[bb['id']] for bb in full_bbs]
        #if self.pred_neighbors:
        #    bbs = torch.cat(bbs,
        idToIndex={}
        for i,bb in enumerate(full_bbs):
            idToIndex[bb['id']]=i
        pairs=[ (idToIndex[id1],idToIndex[id2]) for id1,id2 in annotations['pairs'] ]
            



        return bbs,line_gts,point_gts,pixel_gt,numClasses,numNeighbors, pairs
Beispiel #2
0
    def parseAnn(self,annotations,scale):
        #fieldBBs = annotations['fieldBBs']
        fixAnnotations(self,annotations)
        bbsToUse=[]
        ids=[]
        trans={}
        for id,bb in annotations['byId'].items():
            if not self.onlyFormStuff or ('paired' in bb and bb['paired']):
                bbsToUse.append(bb)
                ids.append(bb['id'])
                if 'transcription' in annotations:
                    trans[bb['id']] = annotations['transcription'][bb['id']]


        
        bbs = getBBWithPoints(bbsToUse,scale,useBlankClass=(not self.no_blanks),usePairedClass=self.use_paired_class)
        numClasses = bbs.shape[2]-16
        return bbs,ids,numClasses, trans
Beispiel #3
0
    def parseAnn(self, annotations, scale):
        #fieldBBs = annotations['fieldBBs']
        fixAnnotations(self, annotations)
        bbsToUse = []
        ids = []
        trans = {}
        metadata = {}
        for id, bb in annotations['byId'].items():
            if not self.onlyFormStuff or ('paired' in bb and bb['paired']):
                bbsToUse.append(bb)
                ids.append(bb['id'])
                if 'transcriptions' in annotations and bb['id'] in annotations[
                        'transcriptions']:
                    trans[bb['id']] = annotations['transcriptions'][bb['id']]
                else:
                    trans[bb['id']] = None

            metadata[bb['id']] = {'type': bb['isBlank']}

        bbs = getBBWithPoints(bbsToUse,
                              scale,
                              useBlankClass=(not self.no_blanks),
                              usePairedClass=self.use_paired_class,
                              useAllClass=self.useClasses)
        #numClasses = bbs.shape[2]-16
        numClasses = len(self.classMap)

        #import pdb;pdb.set_trace()
        if self.no_groups:
            idGroups = [[bbid] for bbid in ids]
        else:
            idGroups = formGroups(annotations, self.group_only_same)
        #revIds = {bbId:n for n,bbId in enumerate(ids)}
        #groups = [ [revIds[bbId] for bbId in group] for group in idGroups]
        groups = idGroups
        assert (bbs is not None)
        #print(metadata)
        assert (bbs is not None)
        assert (len(groups) > 0)
        return bbs, ids, numClasses, trans, groups, metadata, {}
Beispiel #4
0
    def cluster(self,k,sample_count,outPath):
        def makePointsAndRects(h,w,r=None):
            if r is None:
                return np.array([-w/2.0,0,w/2.0,0,0,-h/2.0,0,h/2.0, 0,0, 0, h,w])
            else:
                lx= -math.cos(r)*w
                ly= -math.sin(r)*w
                rx= math.cos(r)*w
                ry= math.sin(r)*w
                tx= math.sin(r)*h
                ty= -math.cos(r)*h
                bx= -math.sin(r)*h
                by= math.cos(r)*h
                return np.array([lx,ly,rx,ry,tx,ty,bx,by, 0,0, r, h,w])
        meanH=62.42
        stdH=87.31
        meanW=393.03
        stdW=533.53
        ratios=[4.0,7.18,11.0,15.0,19.0,27.0]
        pointsAndRects=[]
        for inst in self.images:
            annotationPath = inst['annotationPath']
            #rescaled = inst['rescaled']
            with open(annotationPath) as annFile:
                annotations = json.loads(annFile.read())
            fixAnnotations(self,annotations)
            for i in range(sample_count):
                if i==0:
                    s = (self.rescale_range[0]+self.rescale_range[1])/2
                else:
                    s = np.random.uniform(self.rescale_range[0], self.rescale_range[1])
                #partial_rescale = s/rescaled
                bbs = getBBWithPoints(annotations['byId'].values(),s)
                #field_bbs = self.getBBGT(annotations['fieldBBs'],s,fields=True)
                #bbs = np.concatenate([text_bbs,field_bbs],axis=1)
                bbs = convertBBs(bbs,self.rotate,2).numpy()[0]
                cos_rot = np.cos(bbs[:,2])
                sin_rot = np.sin(bbs[:,2])
                p_left_x = -cos_rot*bbs[:,4]
                p_left_y = -sin_rot*bbs[:,4]
                p_right_x = cos_rot*bbs[:,4]
                p_right_y = sin_rot*bbs[:,4]
                p_top_x = sin_rot*bbs[:,3]
                p_top_y = -cos_rot*bbs[:,3]
                p_bot_x = -sin_rot*bbs[:,3]
                p_bot_y = cos_rot*bbs[:,3]
                points = np.stack([p_left_x,p_left_y,p_right_x,p_right_y,p_top_x,p_top_y,p_bot_x,p_bot_y],axis=1)
                pointsAndRects.append(np.concatenate([points,bbs[:,:5]],axis=1))
        pointsAndRects = np.concatenate(pointsAndRects,axis=0)
        #all_points = pointsAndRects[:,0:8]
        #all_heights = pointsAndRects[:,11]
        #all_widths = pointsAndRects[:,12]
        
        bestDistsFromMean=None
        for attempt in range(20 if k>0 else 1):
            if k>0:
                randomIndexes = np.random.randint(0,pointsAndRects.shape[0],(k))
                means=pointsAndRects[randomIndexes]
            else:
                #minH=5
                #minW=5
                means=[]

                ##smaller than mean
                #for step in range(5):
                #    height = minH + (meanH-minH)*(step/5.0)
                #    width = minW + (meanW-minW)*(step/5.0)
                #    for ratio in ratios:
                #        means.append(makePointsAndRects(height,ratio*height))
                #        means.append(makePointsAndRects(width/ratio,width))
                #for stddev in range(0,5):
                #    for step in range(5-stddev):
                #        height = meanH + stddev*stdH + stdH*(step/(5.0-stddev))
                #        width = meanW + stddev*stdW + stdW*(step/(5.0-stddev))
                #        for ratio in ratios:
                #            means.append(makePointsAndRects(height,ratio*height))
                #            means.append(makePointsAndRects(width/ratio,width))
                rots = [0,math.pi/2,math.pi,1.5*math.pi]
                if self.rotate:
                    for height in np.linspace(15,200,num=4):
                        for width in np.linspace(30,1200,num=4):
                            for rot in rots:
                                means.append(makePointsAndRects(height,width,rot))
                        #long boxes
                    for width in np.linspace(1600,4000,num=3):
                        #for height in np.linspace(30,100,num=3):
                        #    for rot in rots:
                        #        means.append(makePointsAndRects(height,width,rot))
                        for rot in rots:
                            means.append(makePointsAndRects(50,width,rot))
                else:
                    #rotated boxes
                    #for height in np.linspace(13,300,num=4):
                    for height in np.linspace(13,300,num=3):
                        means.append(makePointsAndRects(height,20))
                    #general boxes
                    #for height in np.linspace(15,200,num=4):
                        #for width in np.linspace(30,1200,num=4):
                    for height in np.linspace(15,200,num=2):
                        for width in np.linspace(30,1200,num=3):
                            means.append(makePointsAndRects(height,width))
                    #long boxes
                    for width in np.linspace(1600,4000,num=3):
                        #for height in np.linspace(30,100,num=3):
                        #    means.append(makePointsAndRects(height,width))
                        means.append(makePointsAndRects(50,width))

                k=len(means)
                print('K: {}'.format(k))
                means = np.stack(means,axis=0)
            #pointsAndRects [0:p_left_x, 1:p_left_y,2:p_right_x,3:p_right_y,4:p_top_x,5:p_top_y,6:p_bot_x,7:p_bot_y, 8:xc, 9:yc, 10:rot, 11:h, 12:w
            cluster_centers=means
            distsFromMean=None
            prevDistsFromMean=None
            for iteration in range(100000): #intended to break out
                print('attempt:{}, bestDistsFromMean:{}, iteration:{}, bestDistsFromMean:{}'.format(attempt,bestDistsFromMean,iteration,prevDistsFromMean), end='\r')
                #means_points = means[:,0:8]
                #means_heights = means[:,11]
                #means_widths = means[:,12]
                # = groups = assignGroups(means,pointsAndRects)
                expanded_all_points = pointsAndRects[:,None,0:8]
                expanded_all_heights = pointsAndRects[:,None,11]
                expanded_all_widths = pointsAndRects[:,None,12]

                expanded_means_points = means[None,:,0:8]
                expanded_means_heights = means[None,:,11]
                expanded_means_widths = means[None,:,12]

                #expanded_all_points = expanded_all_points.expand(all_points.shape[0], all_points.shape[1], means_points.shape[1], all_points.shape[2])
                expanded_all_points = np.tile(expanded_all_points,(1,means.shape[0],1))
                expanded_all_heights = np.tile(expanded_all_heights,(1,means.shape[0]))
                expanded_all_widths = np.tile(expanded_all_widths,(1,means.shape[0]))
                #expanded_means_points = expanded_means_points.expand(means_points.shape[0], all_points.shape[0], means_points.shape[0], means_points.shape[2])
                expanded_means_points = np.tile(expanded_means_points,(pointsAndRects.shape[0],1,1))
                expanded_means_heights = np.tile(expanded_means_heights,(pointsAndRects.shape[0],1))
                expanded_means_widths = np.tile(expanded_means_widths,(pointsAndRects.shape[0],1))

                point_deltas = (expanded_all_points - expanded_means_points)
                #avg_heights = ((expanded_means_heights+expanded_all_heights)/2)
                #avg_widths = ((expanded_means_widths+expanded_all_widths)/2)
                avg_heights=avg_widths = (expanded_means_heights+expanded_all_heights+expanded_means_widths+expanded_all_widths)/4
                #print point_deltas

                normed_difference = (
                    np.linalg.norm(point_deltas[:,:,0:2],2,2)/avg_widths +
                    np.linalg.norm(point_deltas[:,:,2:4],2,2)/avg_widths +
                    np.linalg.norm(point_deltas[:,:,4:6],2,2)/avg_heights +
                    np.linalg.norm(point_deltas[:,:,6:8],2,2)/avg_heights
                    )**2
                #print normed_difference
                #import pdb; pdb.set_trace()

                groups = normed_difference.argmin(1) #this should list the mean (index) for each element of all
                distsFromMean = normed_difference.min(1).mean()
                if prevDistsFromMean is not None and distsFromMean >= prevDistsFromMean:
                    break
                prevDistsFromMean = distsFromMean

                #means = computeMeans(groups,pointsAndRects)
                #means = np.zeros(k,13)
                for ki in range(k):
                    selected = (groups==ki)[:,None]
                    numSel = float(selected.sum())
                    if (numSel==0):
                        break
                    means[ki,:] = (pointsAndRects*np.tile(selected,(1,13))).sum(0)/numSel
            if bestDistsFromMean is None or distsFromMean<bestDistsFromMean:
                bestDistsFromMean = distsFromMean
                cluster_centers=means
        #cluster_centers=means
        dH=600
        dW=3000
        draw = np.zeros([dH,dW,3],dtype=np.float)
        toWrite = []
        final_k=k
        for ki in range(k):
            pop = (groups==ki).sum().item()
            if pop>2:
                color = np.random.uniform(0.2,1,3).tolist()
                #d=math.sqrt(mean[ki,11]**2 + mean[ki,12]**2)
                #theta = math.atan2(mean[ki,11],mean[ki,12]) + mean[ki,10]
                h=cluster_centers[ki,11]
                w=cluster_centers[ki,12]
                rot=cluster_centers[ki,10]
                toWrite.append({'height':h.item(),'width':w.item(),'rot':rot.item(),'popularity':pop})
                tr = ( int(math.cos(rot)*w-math.sin(rot)*h)+dW//2,   int(math.sin(rot)*w+math.cos(rot)*h)+dH//2 )
                tl = ( int(math.cos(rot)*-w-math.sin(rot)*h)+dW//2,  int(math.sin(rot)*-w+math.cos(rot)*h)+dH//2 )
                br = ( int(math.cos(rot)*w-math.sin(rot)*-h)+dW//2,  int(math.sin(rot)*w+math.cos(rot)*-h)+dH//2 )
                bl = ( int(math.cos(rot)*-w-math.sin(rot)*-h)+dW//2, int(math.sin(rot)*-w+math.cos(rot)*-h)+dH//2 )
                
                cv2.line(draw,tl,tr,color)
                cv2.line(draw,tr,br,color)
                cv2.line(draw,br,bl,color)
                cv2.line(draw,bl,tl,color,2)
            else:
                final_k-=1
        
        #print(toWrite)
        with open(outPath.format(final_k),'w') as out:
            out.write(json.dumps(toWrite))
            print('saved '+outPath.format(final_k))
        cv2.imshow('clusters',draw)
        cv2.waitKey()
def FormsGraphPair_printer(config,instance, model, gpu, metrics, outDir=None, startIndex=None, lossFunc=None):
    def __eval_metrics(data,target):
        acc_metrics = np.zeros((output.shape[0],len(metrics)))
        for ind in range(output.shape[0]):
            for i, metric in enumerate(metrics):
                acc_metrics[ind,i] += metric(output[ind:ind+1], target[ind:ind+1])
        return acc_metrics

    def __to_tensor(instance,gpu):
        image = instance['img']
        bbs = instance['bb_gt']
        adjaceny = instance['adj']
        num_neighbors = instance['num_neighbors']

        if gpu is not None:
            image = image.to(gpu)
            if bbs is not None:
                bbs = bbs.to(gpu)
            if num_neighbors is not None:
                num_neighbors = num_neighbors.to(gpu)
            #adjacenyMatrix = adjacenyMatrix.to(self.gpu)
        return image, bbs, adjaceny, num_neighbors

    rel_thresholds = [config['THRESH']] if 'THRESH' in config else [0.5]
    if ('sweep_threshold' in config and config['sweep_threshold']) or ('sweep_thresholds' in config and config['sweep_thresholds']):
        rel_thresholds = np.arange(0.1,1.0,0.05)
    if ('sweep_threshold_big' in config and config['sweep_threshold_big']) or ('sweep_thresholds_big' in config and config['sweep_thresholds_big']):
        rel_thresholds = np.arange(0,20.0,1)
    if ('sweep_threshold_small' in config and config['sweep_threshold_small']) or ('sweep_thresholds_small' in config and config['sweep_thresholds_small']):
        rel_thresholds = np.arange(0,0.1,0.01)
    draw_rel_thresh = config['draw_thresh'] if 'draw_thresh' in config else rel_thresholds[0]
    #print(type(instance['pixel_gt']))
    #if type(instance['pixel_gt']) == list:
    #    print(instance)
    #    print(startIndex)
    #data, targetBB, targetBBSizes = instance
    lossWeights = config['loss_weights'] if 'loss_weights' in config else {"box": 1, "rel":1}
    if lossFunc is None:
        yolo_loss = YoloLoss(model.numBBTypes,model.rotation,model.scale,model.anchors,**config['loss_params']['box'])
    else:
        yolo_loss = lossFunc
    data = instance['img']
    batchSize = data.shape[0]
    assert(batchSize==1)
    targetBoxes = instance['bb_gt']
    adjacency = instance['adj']
    adjacency = list(adjacency)
    imageName = instance['imgName']
    scale = instance['scale']
    target_num_neighbors = instance['num_neighbors']
    if not model.detector.predNumNeighbors:
        instance['num_neighbors']=None
    dataT, targetBoxesT, adjT, target_num_neighborsT = __to_tensor(instance,gpu)


    pretty = config['pretty'] if 'pretty' in config else False
    useDetections = config['useDetections'] if 'useDetections' in config else False
    if 'useDetect' in config:
        useDetections = config['useDetect']
    confThresh = config['conf_thresh'] if 'conf_thresh' in config else None


    numClasses=2 #TODO no hard code

    resultsDirName='results'
    #if outDir is not None and resultsDirName is not None:
        #rPath = os.path.join(outDir,resultsDirName)
        #if not os.path.exists(rPath):
        #    os.mkdir(rPath)
        #for name in targetBoxes:
        #    nPath = os.path.join(rPath,name)
        #    if not os.path.exists(nPath):
        #        os.mkdir(nPath)

    #dataT = __to_tensor(data,gpu)
    #print('{}: {} x {}'.format(imageName,data.shape[2],data.shape[3]))
    if useDetections=='gt':
        outputBoxes, outputOffsets, relPred, relIndexes, bbPred = model(dataT,targetBoxesT,target_num_neighborsT,True,
                otherThresh=confThresh,
                otherThreshIntur=1 if confThresh is not None else None,
                hard_detect_limit=600)
        outputBoxes=torch.cat((torch.ones(targetBoxes.size(1),1),targetBoxes[0,:,0:5],targetBoxes[0,:,-numClasses:]),dim=1) #add score
    elif type(useDetections) is str:
        dataset=config['DATASET']
        jsonPath = os.path.join(useDetections,imageName+'.json')
        with open(os.path.join(jsonPath)) as f:
            annotations = json.loads(f.read())
        fixAnnotations(dataset,annotations)
        savedBoxes = torch.FloatTensor(len(annotations['byId']),6+model.detector.predNumNeighbors+numClasses)
        for i,(id,bb) in enumerate(annotations['byId'].items()):
            qX, qY, qH, qW, qR, qIsText, qIsField, qIsBlank, qNN = getBBInfo(bb,dataset.rotate,useBlankClass=not dataset.no_blanks)
            savedBoxes[i,0]=1 #conf
            savedBoxes[i,1]=qX*scale #x-center, already scaled
            savedBoxes[i,2]=qY*scale #y-center
            savedBoxes[i,3]=qR #rotation
            savedBoxes[i,4]=qH*scale/2
            savedBoxes[i,5]=qW*scale/2
            if model.detector.predNumNeighbors:
                extra=1
                savedBoxes[i,6]=qNN
            else:
                extra=0
            savedBoxes[i,6+extra]=qIsText
            savedBoxes[i,7+extra]=qIsField
            
        if gpu is not None:
            savedBoxes=savedBoxes.to(gpu)
        outputBoxes, outputOffsets, relPred, relIndexes, bbPred = model(dataT,savedBoxes,None,"saved",
                otherThresh=confThresh,
                otherThreshIntur=1 if confThresh is not None else None,
                hard_detect_limit=600)
        outputBoxes=savedBoxes.cpu()
    elif useDetections:
        print('Unknown detection flag: '+useDetections)
        exit()
    else:
        outputBoxes, outputOffsets, relPred, relIndexes, bbPred = model(dataT,
                otherThresh=confThresh,
                otherThreshIntur=1 if confThresh is not None else None,
                hard_detect_limit=600)

    if model.predNN and bbPred is not None:
        predNN = bbPred[:,0]
    else:
        predNN=None
    if  model.detector.predNumNeighbors and not useDetections:
        #useOutputBBs=torch.cat((outputBoxes[:,0:6],outputBoxes[:,7:]),dim=1) #throw away NN pred
        extraPreds=1
        if not model.predNN:
            predNN = outputBoxes[:,6]
    else:
        extraPreds=0
        if not model.predNN:
            predNN = None
        #useOutputBBs=outputBoxes

    if targetBoxesT is not None:
        targetSize=targetBoxesT.size(1)
    else:
        targetSize=0
    lossThis, position_loss, conf_loss, class_loss, nn_loss, recall, precision = yolo_loss(outputOffsets,targetBoxesT,[targetSize], target_num_neighborsT)

    if 'rule' in config:
        if config['rule']=='closest':
            dists = torch.FloatTensor(relPred.size())
            differentClass = torch.FloatTensor(relPred.size())
            predClasses = torch.argmax(outputBoxes[:,extraPreds+6:extraPreds+6+numClasses],dim=1)
            for i,(bb1,bb2) in enumerate(relIndexes):
                dists[i] = math.sqrt((outputBoxes[bb1,1]-outputBoxes[bb2,1])**2 + (outputBoxes[bb1,2]-outputBoxes[bb2,2])**2)
                differentClass[i] = predClasses[bb1]!=predClasses[bb2]
            maxDist = torch.max(dists)
            minDist = torch.min(dists)
            relPred = 1-(dists-minDist)/(maxDist-minDist)
            relPred *= differentClass
        elif config['rule']=='icdar':
            height = torch.FloatTensor(relPred.size())
            dists = torch.FloatTensor(relPred.size())
            right = torch.FloatTensor(relPred.size())
            sameClass = torch.FloatTensor(relPred.size())
            predClasses = torch.argmax(outputBoxes[:,extraPreds+6:extraPreds+6+numClasses],dim=1)
            for i,(bb1,bb2) in enumerate(relIndexes):
                sameClass[i] = predClasses[bb1]==predClasses[bb2]
                
                #g4 of the paper
                height[i] = max(outputBoxes[bb1,4],outputBoxes[bb2,4])/min(outputBoxes[bb1,4],outputBoxes[bb2,4])

                #g5 of the paper
                if predClasses[bb1]==0:
                    widthLabel = outputBoxes[bb1,5]*2 #we predict half width
                    widthValue = outputBoxes[bb2,5]*2
                    dists[i] = math.sqrt(((outputBoxes[bb1,1]+widthLabel)-(outputBoxes[bb2,1]-widthValue))**2 + (outputBoxes[bb1,2]-outputBoxes[bb2,2])**2)
                else:
                    widthLabel = outputBoxes[bb2,5]*2 #we predict half width
                    widthValue = outputBoxes[bb1,5]*2
                    dists[i] = math.sqrt(((outputBoxes[bb1,1]-widthValue)-(outputBoxes[bb2,1]+widthLabel))**2 + (outputBoxes[bb1,2]-outputBoxes[bb2,2])**2)
                if dists[i]>2*widthLabel:
                    dists[i]/=widthLabel
                else: #undefined
                    dists[i] = min(1,dists[i]/widthLabel)
            
                #g6 of the paper
                if predClasses[bb1]==0:
                    widthValue = outputBoxes[bb2,5]*2
                    hDist = outputBoxes[bb1,1]-outputBoxes[bb2,1]
                else:
                    widthValue = outputBoxes[bb1,5]*2
                    hDist = outputBoxes[bb2,1]-outputBoxes[bb1,1]
                right[i] = hDist/widthValue

            relPred = 1-(height+dists+right + 10000*sameClass)
        else:
            print('ERROR, unknown rule {}'.format(config['rule']))
            exit()
    elif relPred is not None:
        relPred = torch.sigmoid(relPred)[:,0]




    relCand = relIndexes
    if relCand is None:
        relCand=[]

    if model.rotation:
        bbAlignment, bbFullHit = getTargIndexForPreds_dist(targetBoxes[0],outputBoxes,0.9,numClasses,extraPreds,hard_thresh=False)
    else:
        bbAlignment, bbFullHit = getTargIndexForPreds_iou(targetBoxes[0],outputBoxes,0.5,numClasses,extraPreds,hard_thresh=False)
    if targetBoxes is not None:
        target_for_b = targetBoxes[0,:,:]
    else:
        target_for_b = torch.empty(0)

    if outputBoxes.size(0)>0:
        maxConf = outputBoxes[:,0].max().item()
        minConf = outputBoxes[:,0].min().item()
        if useDetections:
            minConf=0
    #threshConf = max(maxConf*THRESH,0.5)
    #if model.rotation:
    #    outputBoxes = non_max_sup_dist(outputBoxes.cpu(),threshConf,3)
    #else:
    #    outputBoxes = non_max_sup_iou(outputBoxes.cpu(),threshConf,0.4)
    if model.rotation:
        ap_5, prec_5, recall_5 =AP_dist(target_for_b,outputBoxes,0.9,model.numBBTypes,beforeCls=extraPreds)
    else:
        ap_5, prec_5, recall_5 =AP_iou(target_for_b,outputBoxes,0.5,model.numBBTypes,beforeCls=extraPreds)

    #precisionHistory={}
    #precision=-1
    #minStepSize=0.025
    #targetPrecisions=[None]
    #for targetPrecision in targetPrecisions:
    #    if len(precisionHistory)>0:
    #        closestPrec=9999
    #        for prec in precisionHistory:
    #            if abs(targetPrecision-prec)<abs(closestPrec-targetPrecision):
    #                closestPrec=prec
    #        precision=prec
    #        stepSize=precisionHistory[prec][0]
    #    else:
    #        stepSize=0.1
    #
    #    while True: #abs(precision-targetPrecision)>0.001:
    toRet={}
    for rel_threshold in rel_thresholds:

            if 'optimize' in config and config['optimize']:
                if 'penalty' in config:
                    penalty = config['penalty']
                else:
                    penalty = 0.25
                print('optimizing with penalty {}'.format(penalty))
                thresh=0.15
                while thresh<0.45:
                    keep = relPred>thresh
                    newRelPred = relPred[keep]
                    if newRelPred.size(0)<700:
                        break
                if newRelPred.size(0)>0:
                    #newRelCand = [ cand for i,cand in enumerate(relCand) if keep[i] ]
                    usePredNN= predNN is not None and config['optimize']!='gt'
                    idMap={}
                    newId=0
                    newRelCand=[]
                    numNeighbors=[]
                    for index,(id1,id2) in enumerate(relCand):
                        if keep[index]:
                            if id1 not in idMap:
                                idMap[id1]=newId
                                if not usePredNN:
                                    numNeighbors.append(target_num_neighbors[0,bbAlignment[id1]])
                                else:
                                    numNeighbors.append(predNN[id1])
                                newId+=1
                            if id2 not in idMap:
                                idMap[id2]=newId
                                if not usePredNN:
                                    numNeighbors.append(target_num_neighbors[0,bbAlignment[id2]])
                                else:
                                    numNeighbors.append(predNN[id2])
                                newId+=1
                            newRelCand.append( [idMap[id1],idMap[id2]] )            


                    #if not usePredNN:
                        #    decision = optimizeRelationships(newRelPred,newRelCand,numNeighbors,penalty)
                    #else:
                    decision= optimizeRelationshipsSoft(newRelPred,newRelCand,numNeighbors,penalty, rel_threshold)
                    decision= torch.from_numpy( np.round_(decision).astype(int) )
                    decision=decision.to(relPred.device)
                    relPred[keep] = torch.where(0==decision,relPred[keep]-1,relPred[keep])
                    relPred[1-keep] -=1
                    rel_threshold_use=0#-0.5
                else:
                    rel_threshold_use=rel_threshold
            else:
                rel_threshold_use=rel_threshold

            #threshed in model
            #if len(precisionHistory)==0:
            if len(toRet)==0:
                #align bb predictions (final) with GT
                if bbPred is not None and bbPred.size(0)>0:
                    #create aligned GT
                    #this was wrong...
                        #first, remove unmatched predicitons that didn't overlap (weren't close) to any targets
                        #toKeep = 1-((bbNoIntersections==1) * (bbAlignment==-1))
                    #remove predictions that overlapped with GT, but not enough
                    if model.predNN:
                        start=1
                        toKeep = 1-((bbFullHit==0) * (bbAlignment!=-1)) #toKeep = not (incomplete_overlap and did_overlap)
                        if toKeep.any():
                            bbPredNN_use = bbPred[toKeep][:,0]
                            bbAlignment_use = bbAlignment[toKeep]
                            #becuase we used -1 to indicate no match (in bbAlignment), we add 0 as the last position in the GT, as unmatched 
                            if target_num_neighborsT is not None:
                                target_num_neighbors_use = torch.cat((target_num_neighborsT[0].float(),torch.zeros(1).to(target_num_neighborsT.device)),dim=0)
                            else:
                                target_num_neighbors_use = torch.zeros(1).to(bbPred.device)
                            alignedNN_use = target_num_neighbors_use[bbAlignment_use]
                        else:
                            bbPredNN_use=None
                            alignedNN_use=None
                    else:
                        start=0
                    if model.predClass:
                        #We really don't care about the class of non-overlapping instances
                        if targetBoxes is not None:
                            toKeep = bbFullHit==1
                            if toKeep.any():
                                bbPredClass_use = bbPred[toKeep][:,start:start+model.numBBTypes]
                                bbAlignment_use = bbAlignment[toKeep]
                                alignedClass_use =  targetBoxesT[0][bbAlignment_use][:,13:13+model.numBBTypes] #There should be no -1 indexes in hereS
                            else:
                                bbPredClass_use=None
                                alignedClass_use=None
                        else:
                            alignedClass_use = None
                else:
                    bbPredNN_use = None
                    bbPredClass_use = None
                if model.predNN and bbPredNN_use is not None and bbPredNN_use.size(0)>0:
                    nn_loss_final = F.mse_loss(bbPredNN_use,alignedNN_use)
                    #nn_loss_final *= self.lossWeights['nn']

                    #loss += nn_loss_final
                    nn_loss_final = nn_loss_final.item()
                else:
                    nn_loss_final=0
                if model.predNN and predNN is not None:
                    predNN_p=bbPred[:,0]
                    diffs=torch.abs(predNN_p-target_num_neighborsT[0][bbAlignment].float())
                    nn_acc = (diffs<0.5).sum().item()
                    nn_acc /= predNN.size(0)
                elif model.predNN:
                    nn_acc = 0 
                if model.detector.predNumNeighbors and not useDetections:
                    predNN_d = outputBoxes[:,6]
                    diffs=torch.abs(predNN_d-target_num_neighbors[0][bbAlignment].float())
                    nn_acc_d = (diffs<0.5).sum().item()
                    nn_acc_d /= predNN.size(0)

                if model.predClass and bbPredClass_use is not None and bbPredClass_use.size(0)>0:
                    class_loss_final = F.binary_cross_entropy_with_logits(bbPredClass_use,alignedClass_use)
                    #class_loss_final *= self.lossWeights['class']
                    #loss += class_loss_final
                    class_loss_final = class_loss_final.item()
                else:
                    class_loss_final = 0
            #class_acc=0
            useOutputBBs=None

            truePred=falsePred=badPred=0
            scores=[]
            matches=0
            i=0
            numMissedByHeur=0
            targGotHit=set()
            for i,(n0,n1) in enumerate(relCand):
                t0 = bbAlignment[n0].item()
                t1 = bbAlignment[n1].item()
                if t0>=0 and bbFullHit[n0]:
                    targGotHit.add(t0)
                if t1>=0 and bbFullHit[n1]:
                    targGotHit.add(t1)
                if t0>=0 and t1>=0 and bbFullHit[n0] and bbFullHit[n1]:
                    if (min(t0,t1),max(t0,t1)) in adjacency:
                        matches+=1
                        scores.append( (relPred[i],True) )
                        if relPred[i]>rel_threshold_use:
                            truePred+=1
                    else:
                        scores.append( (relPred[i],False) )
                        if relPred[i]>rel_threshold_use:
                            falsePred+=1
                else:
                    scores.append( (relPred[i],False) )
                    if relPred[i]>rel_threshold_use:
                        badPred+=1
            for i in range(len(adjacency)-matches):
                numMissedByHeur+=1
                scores.append( (float('nan'),True) )
            rel_ap=computeAP(scores)

            numMissedByDetect=0
            for t0,t1 in adjacency:
                if t0 not in targGotHit or t1 not in targGotHit:
                    numMissedByHeur-=1
                    numMissedByDetect+=1
            heurRecall = (len(adjacency)-numMissedByHeur)/len(adjacency)
            detectRecall = (len(adjacency)-numMissedByDetect)/len(adjacency)
            if len(adjacency)>0:
                relRecall = truePred/len(adjacency)
            else:
                relRecall = 1
            #if falsePred>0:
            #    relPrec = truePred/(truePred+falsePred)
            #else:
            #    relPrec = 1
            if falsePred+badPred>0:
                precision = truePred/(truePred+falsePred+badPred)
            else:
                precision = 1
    

            toRet['prec@{}'.format(rel_threshold)]=precision
            toRet['recall@{}'.format(rel_threshold)]=relRecall
            if relRecall+precision>0:
                toRet['F-M@{}'.format(rel_threshold)]=2*relRecall*precision/(relRecall+precision)
            else:
                toRet['F-M@{}'.format(rel_threshold)]=0
            toRet['rel_AP@{}'.format(rel_threshold)]=rel_ap
            #precisionHistory[precision]=(draw_rel_thresh,stepSize)
            #if targetPrecision is not None:
            #    if abs(precision-targetPrecision)<0.001:
            #        break
            #    elif stepSize<minStepSize:
            #        if precision<targetPrecision:
            #            draw_rel_thresh += stepSize*2
            #            continue
            #        else:
            #            break
            #    elif precision<targetPrecision:
            #        draw_rel_thresh += stepSize
            #        if not wasTooSmall:
            #            reverse=True
            #            wasTooSmall=True
            #        else:
            #            reverse=False
            #    else:
            #        draw_rel_thresh -= stepSize
            #        if wasTooSmall:
            #            reverse=True
            #            wasTooSmall=False
            #        else:
            #            reverse=False
            #    if reverse:
            #        stepSize *= 0.5
            #else:
            #    break


            #import pdb;pdb.set_trace()

            #for b in range(len(outputBoxes)):
            
            
            dists=defaultdict(list)
            dists_x=defaultdict(list)
            dists_y=defaultdict(list)
            scaleDiffs=defaultdict(list)
            rotDiffs=defaultdict(list)
            b=0
            #print('image {} has {} {}'.format(startIndex+b,targetBoxesSizes[name][b],name))
            #bbImage = np.ones_like(image):w

    if outDir is not None:
        outputBoxes = outputBoxes.data.numpy()
        data = data.numpy()

        image = (1-((1+np.transpose(data[b][:,:,:],(1,2,0)))/2.0)).copy()
        if image.shape[2]==1:
            image = cv2.cvtColor(image,cv2.COLOR_GRAY2RGB)
        #if name=='text_start_gt':

        #Draw GT bbs
        if not pretty:
            for j in range(targetSize):
                plotRect(image,(1,0.5,0),targetBoxes[0,j,0:5])
            #x=int(targetBoxes[b,j,0])
            #y=int(targetBoxes[b,j,1]+targetBoxes[b,j,3])
            #cv2.putText(image,'{:.2f}'.format(target_num_neighbors[b,j]),(x,y), cv2.FONT_HERSHEY_SIMPLEX, 0.5,(0.6,0.3,0),2,cv2.LINE_AA)
            #if alignmentBBs[b] is not None:
            #    aj=alignmentBBs[b][j]
            #    xc_gt = targetBoxes[b,j,0]
            #    yc_gt = targetBoxes[b,j,1]
            #    xc=outputBoxes[b,aj,1]
            #    yc=outputBoxes[b,aj,2]
            #    cv2.line(image,(xc,yc),(xc_gt,yc_gt),(0,1,0),1)
            #    shade = 0.0+(outputBoxes[b,aj,0]-threshConf)/(maxConf-threshConf)
            #    shade = max(0,shade)
            #    if outputBoxes[b,aj,6] > outputBoxes[b,aj,7]:
            #        color=(0,shade,shade) #text
            #    else:
            #        color=(shade,shade,0) #field
            #    plotRect(image,color,outputBoxes[b,aj,1:6])

        #bbs=[]
        #pred_points=[]
        #maxConf = outputBoxes[b,:,0].max()
        #threshConf = 0.5 
        #threshConf = max(maxConf*0.9,0.5)
        #print("threshConf:{}".format(threshConf))
        #for j in range(outputBoxes.shape[1]):
        #    conf = outputBoxes[b,j,0]
        #    if conf>threshConf:
        #        bbs.append((conf,j))
        #    #pred_points.append(
        #bbs.sort(key=lambda a: a[0]) #so most confident bbs are draw last (on top)
        #import pdb; pdb.set_trace()

        #Draw pred bbs
        bbs = outputBoxes
        for j in range(bbs.shape[0]):
            #circle aligned predictions
            conf = bbs[j,0]
            if outDir is not None:
                shade = 0.0+(conf-minConf)/(maxConf-minConf)
                #print(shade)
                #if name=='text_start_gt' or name=='field_end_gt':
                #    cv2.bb(bbImage[:,:,1],p1,p2,shade,2)
                #if name=='text_end_gt':
                #    cv2.bb(bbImage[:,:,2],p1,p2,shade,2)
                #elif name=='field_end_gt' or name=='field_start_gt':
                #    cv2.bb(bbImage[:,:,0],p1,p2,shade,2)
                if bbs[j,6+extraPreds] > bbs[j,7+extraPreds]:
                    color=(0,0,shade) #text
                else:
                    color=(0,shade,shade) #field
                if pretty=='light':
                    lineWidth=2
                else:
                    lineWidth=1
                plotRect(image,color,bbs[j,1:6],lineWidth)

                if predNN is not None and not pretty: #model.detector.predNumNeighbors:
                    x=int(bbs[j,1])
                    y=int(bbs[j,2])#-bbs[j,4])
                    targ_j = bbAlignment[j].item()
                    if targ_j>=0:
                        gtNN = target_num_neighbors[0,targ_j].item()
                    else:
                        gtNN = 0
                    pred_nn = predNN[j].item()
                    color = min(abs(pred_nn-gtNN),1)#*0.5
                    cv2.putText(image,'{:.2}/{}'.format(pred_nn,gtNN),(x,y), cv2.FONT_HERSHEY_SIMPLEX, 0.5,(color,0,0),2,cv2.LINE_AA)

        #for j in alignmentBBsTarg[name][b]:
        #    p1 = (targetBoxes[name][b,j,0], targetBoxes[name][b,j,1])
        #    p2 = (targetBoxes[name][b,j,0], targetBoxes[name][b,j,1])
        #    mid = ( int(round((p1[0]+p2[0])/2.0)), int(round((p1[1]+p2[1])/2.0)) )
        #    rad = round(math.sqrt((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2)/2.0)
        #    #print(mid)
        #    #print(rad)
        #    cv2.circle(image,mid,rad,(1,0,1),1)

        draw_rel_thresh = relPred.max() * draw_rel_thresh


        #Draw pred pairings
        numrelpred=0
        hits = [False]*len(adjacency)
        for i in range(len(relCand)):
            #print('{},{} : {}'.format(relCand[i][0],relCand[i][1],relPred[i]))
            if pretty:
                if relPred[i]>0 or pretty=='light':
                    score = relPred[i]
                    pruned=False
                    lineWidth=2
                else:
                    score = relPred[i]+1
                    pruned=True
                    lineWidth=1
                #else:
                #    score = (relPred[i]+1)/2
                #    pruned=False
                #    lineWidth=2
                #if pretty=='light':
                #    lineWidth=3
            else:
                lineWidth=1
            if relPred[i]>draw_rel_thresh or (pretty and score>draw_rel_thresh):
                ind1 = relCand[i][0]
                ind2 = relCand[i][1]
                x1 = round(bbs[ind1,1])
                y1 = round(bbs[ind1,2])
                x2 = round(bbs[ind2,1])
                y2 = round(bbs[ind2,2])

                if pretty:
                    targ1 = bbAlignment[ind1].item()
                    targ2 = bbAlignment[ind2].item()
                    aId=None
                    if bbFullHit[ind1] and bbFullHit[ind2]:
                        if (targ1,targ2) in adjacency:
                            aId = adjacency.index((targ1,targ2))
                        elif (targ2,targ1) in adjacency:
                            aId = adjacency.index((targ2,targ1))
                    if aId is None:
                        if pretty=='clean' and pruned:
                            color=np.array([1,1,0])
                        else:
                            color=np.array([1,0,0])
                    else:
                        if pretty=='clean' and pruned:
                            color=np.array([1,0,1])
                        else:
                            color=np.array([0,1,0])
                        hits[aId]=True
                    #if pruned:
                    #    color = color*0.7
                    cv2.line(image,(x1,y1),(x2,y2),color.tolist(),lineWidth)
                    #color=color/3
                    #x = int((x1+x2)/2)
                    #y = int((y1+y2)/2)
                    #if pruned:
                    #    cv2.putText(image,'[{:.2}]'.format(score),(x,y), cv2.FONT_HERSHEY_PLAIN, 0.6,color.tolist(),1)
                    #else:
                    #    cv2.putText(image,'{:.2}'.format(score),(x,y), cv2.FONT_HERSHEY_PLAIN,1.1,color.tolist(),1)
                else:
                    shade = (relPred[i].item()-draw_rel_thresh)/(1-draw_rel_thresh)

                    #print('draw {} {} {} {} '.format(x1,y1,x2,y2))
                    cv2.line(image,(x1,y1),(x2,y2),(0,shade,0),lineWidth)
                numrelpred+=1
        if pretty and pretty!="light" and pretty!="clean":
            for i in range(len(relCand)):
                #print('{},{} : {}'.format(relCand[i][0],relCand[i][1],relPred[i]))
                if relPred[i]>-1:
                    score = (relPred[i]+1)/2
                    pruned=False
                else:
                    score = (relPred[i]+2+1)/2
                    pruned=True
                if relPred[i]>draw_rel_thresh or (pretty and score>draw_rel_thresh):
                    ind1 = relCand[i][0]
                    ind2 = relCand[i][1]
                    x1 = round(bbs[ind1,1])
                    y1 = round(bbs[ind1,2])
                    x2 = round(bbs[ind2,1])
                    y2 = round(bbs[ind2,2])

                    targ1 = bbAlignment[ind1].item()
                    targ2 = bbAlignment[ind2].item()
                    aId=None
                    if bbFullHit[ind1] and bbFullHit[ind2]:
                        if (targ1,targ2) in adjacency:
                            aId = adjacency.index((targ1,targ2))
                        elif (targ2,targ1) in adjacency:
                            aId = adjacency.index((targ2,targ1))
                    if aId is None:
                        color=np.array([1,0,0])
                    else:
                        color=np.array([0,1,0])
                    color=color/2
                    x = int((x1+x2)/2)
                    y = int((y1+y2)/2)
                    if pruned:
                        cv2.putText(image,'[{:.2}]'.format(score),(x,y), cv2.FONT_HERSHEY_PLAIN, 0.6,color.tolist(),1)
                    else:
                        cv2.putText(image,'{:.2}'.format(score),(x,y), cv2.FONT_HERSHEY_PLAIN,1.1,color.tolist(),1)
        #print('number of pred rels: {}'.format(numrelpred))
        #Draw GT pairings
        if not pretty:
            gtcolor=(0.25,0,0.25)
            wth=3
        else:
            #gtcolor=(1,0,0.6)
            gtcolor=(1,0.6,0)
            wth=2
        for aId,(i,j) in enumerate(adjacency):
            if not pretty or not hits[aId]:
                x1 = round(targetBoxes[0,i,0].item())
                y1 = round(targetBoxes[0,i,1].item())
                x2 = round(targetBoxes[0,j,0].item())
                y2 = round(targetBoxes[0,j,1].item())
                cv2.line(image,(x1,y1),(x2,y2),gtcolor,wth)

        #Draw alginment between gt and pred bbs
        if not pretty:
            for predI in range(bbs.shape[0]):
                targI=bbAlignment[predI].item()
                x1 = int(round(bbs[predI,1]))
                y1 = int(round(bbs[predI,2]))
                if targI>0:

                    x2 = round(targetBoxes[0,targI,0].item())
                    y2 = round(targetBoxes[0,targI,1].item())
                    cv2.line(image,(x1,y1),(x2,y2),(1,0,1),1)
                else:
                    #draw 'x', indicating not match
                    cv2.line(image,(x1-5,y1-5),(x1+5,y1+5),(.1,0,.1),1)
                    cv2.line(image,(x1+5,y1-5),(x1-5,y1+5),(.1,0,.1),1)



        saveName = '{}_boxes_prec:{:.2f},{:.2f}_recall:{:.2f},{:.2f}_rels_AP:{:.3f}'.format(imageName,prec_5[0],prec_5[1],recall_5[0],recall_5[1],rel_ap)
        #for j in range(metricsOut.shape[1]):
        #    saveName+='_m:{0:.3f}'.format(metricsOut[i,j])
        saveName+='.png'
        io.imsave(os.path.join(outDir,saveName),image)
        #print('saved: '+os.path.join(outDir,saveName))

    print('\n{} ap:{}\tnumMissedByDetect:{}\tmissedByHuer:{}'.format(imageName,rel_ap,numMissedByDetect,numMissedByHeur))
    retData= { 'bb_ap':[ap_5],
               'bb_recall':[recall_5],
               'bb_prec':[prec_5],
               'bb_Fm': -1,#(recall_5[0]+recall_5[1]+prec_5[0]+prec_5[1])/4,
               'nn_loss': nn_loss,
               'rel_recall':relRecall,
               'rel_precision':precision,
               'rel_Fm':2*relRecall*precision/(relRecall+precision) if relRecall+precision>0 else 0,
               'relMissedByHeur':numMissedByHeur,
               'relMissedByDetect':numMissedByDetect,
               'heurRecall': heurRecall,
               'detectRecall': detectRecall,
               **toRet

             }
    if rel_ap is not None: #none ap if no relationships
        retData['rel_AP']=rel_ap
        retData['no_targs']=0
    else:
        retData['no_targs']=1
    if model.predNN:
        retData['nn_loss_final']=nn_loss_final
        retData['nn_loss_diff']=nn_loss_final-nn_loss
        retData['nn_acc_final'] = nn_acc
    if model.detector.predNumNeighbors and not useDetections:
        retData['nn_acc_detector'] = nn_acc_d
    if model.predClass:
        retData['class_loss_final']=class_loss_final
        retData['class_loss_diff']=class_loss_final-class_loss
    return (
             retData,
             (lossThis, position_loss, conf_loss, class_loss, recall, precision)
            )
    def __init__(self, dirPath=None, split=None, config=None, instances=None, test=False):
        if split=='valid':
            valid=True
            amountPer=0.25
        else:
            valid=False
        self.cache_resized=False
        if 'augmentation_params' in config:
            self.augmentation_params=config['augmentation_params']
        else:
            self.augmentation_params=None
        if 'no_blanks' in config:
            self.no_blanks = config['no_blanks']
        else:
            self.no_blanks = False
        if 'no_print_fields' in config:
            self.no_print_fields = config['no_print_fields']
        else:
            self.no_print_fields = False
        numFeats=10
        self.use_corners = config['corners'] if 'corners' in config else False
        self.no_graphics =  config['no_graphics'] if 'no_graphics' in config else False
        if self.use_corners=='xy':
            numFeats=18
        elif self.use_corners:
            numFeats=14
        self.swapCircle = config['swap_circle'] if 'swap_circle' in config else True
        self.onlyFormStuff = config['only_form_stuff'] if 'only_form_stuff' in config else False
        self.only_opposite_pairs = config['only_opposite_pairs'] if 'only_opposite_pairs' in config else False
        self.color = config['color'] if 'color' in config else True
        self.rotate = config['rotation'] if 'rotation' in config else True

        #self.simple_dataset = config['simple_dataset'] if 'simple_dataset' in config else False
        self.special_dataset = config['special_dataset'] if 'special_dataset' in config else None
        if 'simple_dataset' in config and config['simple_dataset']:
            self.special_dataset='simple'
        self.balance = config['balance'] if 'balance' in config else False

        self.eval = config['eval'] if 'eval' in config else False

        self.altJSONDir = config['alternate_json_dir'] if 'alternate_json_dir' in config else None
        

        #width_mean=400.006887263
        #height_mean=47.9102279201
        xScale=400
        yScale=50
        xyScale=(xScale+yScale)/2

        if instances is not None:
            self.instances=instances
        else:
            if self.special_dataset is not None:
                splitFile = self.special_dataset+'_train_valid_test_split.json'
            else:
                splitFile = 'train_valid_test_split.json'
            with open(os.path.join(dirPath,splitFile)) as f:
                readFile = json.loads(f.read())
                if type(split) is str:
                    groupsToUse = readFile[split]
                elif type(split) is list:
                    groupsToUse = {}
                    for spstr in split:
                        newGroups = readFile[spstr]
                        groupsToUse.update(newGroups)
                else:
                    print("Error, unknown split {}".format(split))
                    exit()
            groupNames = list(groupsToUse.keys())
            groupNames.sort()
            pair_instances=[]
            notpair_instances=[]
            for groupName in groupNames:
                imageNames=groupsToUse[groupName]
                if groupName in SKIP:
                    print('Skipped group {}'.format(groupName))
                    continue
                
                for imageName in imageNames:
                    org_path = os.path.join(dirPath,'groups',groupName,imageName)
                    #print(org_path)
                    if self.cache_resized:
                        path = os.path.join(self.cache_path,imageName)
                    else:
                        path = org_path
                    jsonPaths = [org_path[:org_path.rfind('.')]+'.json']
                    if self.altJSONDir is not None:
                        jsonPaths = [os.path.join(self.altJSONDir,imageName[:imageName.rfind('.')]+'.json')]
                    for jsonPath in jsonPaths:
                        annotations=None
                        if os.path.exists(jsonPath):
                            if annotations is None:
                                with open(os.path.join(jsonPath)) as f:
                                    annotations = json.loads(f.read())
                                #print(os.path.join(jsonPath))

                                #fix assumptions made in GTing
                                missedCount=fixAnnotations(self,annotations)

                            #print(path)
                            numNeighbors=defaultdict(lambda:0)
                            for id,bb in annotations['byId'].items():
                                if not self.onlyFormStuff or ('paired' in bb and bb['paired']):
                                    responseBBList = self.__getResponseBBList(id,annotations)
                                    responseIds = [bb['id'] for bb in responseBBList]
                                    for id2,bb2 in annotations['byId'].items():
                                        if id!=id2:
                                            pair = id2 in responseIds
                                            if pair:
                                                numNeighbors[id]+=1
                                                #well catch id2 on it's own pass
                            for id,bb in annotations['byId'].items():
                                if not self.onlyFormStuff or ('paired' in bb and bb['paired']):
                                    numN1 = numNeighbors[id]-1
                                    qX, qY, qH, qW, qR, qIsText, qIsField, qIsBlank, qNN = getBBInfo(bb,self.rotate,useBlankClass=not self.no_blanks)
                                    tlX = bb['poly_points'][0][0]
                                    tlY = bb['poly_points'][0][1]
                                    trX = bb['poly_points'][1][0]
                                    trY = bb['poly_points'][1][1]
                                    brX = bb['poly_points'][2][0]
                                    brY = bb['poly_points'][2][1]
                                    blX = bb['poly_points'][3][0]
                                    blY = bb['poly_points'][3][1]
                                    qH /= yScale #math.log( (qH+0.375*height_mean)/height_mean ) #rescaling so 0 height is -1, big height is 1+
                                    qW /= xScale #math.log( (qW+0.375*width_mean)/width_mean ) #rescaling so 0 width is -1, big width is 1+
                                    qR = qR/math.pi
                                    responseBBList = self.__getResponseBBList(id,annotations)
                                    responseIds = [bb['id'] for bb in responseBBList]
                                    for id2,bb2 in annotations['byId'].items():
                                        if id!=id2:
                                            numN2 = numNeighbors[id2]-1
                                            iX, iY, iH, iW, iR, iIsText, iIsField, iIsBlank, iNN  = getBBInfo(bb2,self.rotate,useBlankClass=not self.no_blanks)
                                            tlX2 = bb2['poly_points'][0][0]
                                            tlY2 = bb2['poly_points'][0][1]
                                            trX2 = bb2['poly_points'][1][0]
                                            trY2 = bb2['poly_points'][1][1]
                                            brX2 = bb2['poly_points'][2][0]
                                            brY2 = bb2['poly_points'][2][1]
                                            blX2 = bb2['poly_points'][3][0]
                                            blY2 = bb2['poly_points'][3][1]
                                            iH /=yScale #math.log( (iH+0.375*height_mean)/height_mean ) 
                                            iW /=xScale #math.log( (iW+0.375*width_mean)/width_mean ) 
                                            iR = iR/math.pi
                                            xDiff=iX-qX
                                            yDiff=iY-qY
                                            yDiff /= yScale #math.log( (yDiff+0.375*yDiffScale)/yDiffScale ) 
                                            xDiff /= xScale #math.log( (xDiff+0.375*xDiffScale)/xDiffScale ) 
                                            tlDiff = math.sqrt( (tlX-tlX2)**2 + (tlY-tlY2)**2 )/xyScale
                                            trDiff = math.sqrt( (trX-trX2)**2 + (trY-trY2)**2 )/xyScale
                                            brDiff = math.sqrt( (brX-brX2)**2 + (brY-brY2)**2 )/xyScale
                                            blDiff = math.sqrt( (blX-blX2)**2 + (blY-blY2)**2 )/xyScale
                                            tlXDiff = (tlX2-tlX)/xScale
                                            trXDiff = (trX2-trX)/xScale
                                            brXDiff = (brX2-brX)/xScale
                                            blXDiff = (blX2-blX)/xScale
                                            tlYDiff = (tlY2-tlY)/yScale
                                            trYDiff = (trY2-trY)/yScale
                                            brYDiff = (brY2-brY)/yScale
                                            blYDiff = (blY2-blY)/yScale
                                            pair = id2 in responseIds
                                            if pair or self.eval:
                                                instances = pair_instances
                                            else:
                                                instances = notpair_instances
                                            if self.altJSONDir is None:
                                                data=[qH,qW,qR,qIsText, iH,iW,iR,iIsText, xDiff, yDiff]
                                            else:
                                                data=[qH,qW,qR,qIsText,qIsField, iH,iW,iR,iIsText,iIsField, xDiff, yDiff]
                                            if self.use_corners=='xy':
                                                data+=[tlXDiff,trXDiff,brXDiff,blXDiff,tlYDiff,trYDiff,brYDiff,blYDiff]
                                            elif self.use_corners:
                                                data+=[tlDiff, trDiff, brDiff, blDiff]
                                            if qIsBlank is not None:
                                                data+=[qIsBlank,iIsBlank]
                                            if qNN is not None:
                                                data+=[qNN,iNN]
                                            instances.append( {
                                                'data': torch.tensor([ data ]),
                                                'label': pair,
                                                'imgName': imageName,
                                                'qXY' : (qX,qY),
                                                'iXY' : (iX,iY),
                                                'qHW' : (qH,qW),
                                                'iHW' : (iH,iW),
                                                'ids' : (id,id2),
                                                'numNeighbors': torch.tensor([ [numN1,numN2] ])
                                                } )
                            if self.eval:
                                #if evaluating, pack all instances for an image into a batch
                                datas=[]
                                labels=[]
                                qXYs=[]
                                iXYs=[]
                                qHWs=[]
                                iHWs=[]
                                nodeIds=[]
                                NNs=[]
                                numTrue=0
                                for inst in pair_instances:
                                    datas.append(inst['data'])
                                    labels.append(inst['label'])
                                    numTrue += inst['label']
                                    qXYs.append(inst['qXY'])
                                    iXYs.append(inst['iXY'])
                                    qHWs.append(inst['qHW'])
                                    iHWs.append(inst['iHW'])
                                    nodeIds.append(inst['ids'])
                                    NNs.append(inst['numNeighbors'])
                                if len(datas)>0:
                                    data = torch.cat(datas,dim=0),
                                else:
                                    data = torch.FloatTensor((0,numFeats))
                                if len(NNs)>0:
                                    NNs = torch.cat(NNs,dim=0)
                                else:
                                    NNs = torch.FloatTensor((0,2))
                                #missedCount=0
                                #for id1,id2 in annotations['pairs']:
                                #    if id1 not in annotations['byId'] or id2 not in annotations['byId']:
                                #        missedCount+=1
                                notpair_instances.append( {
                                    'data': data,
                                    'label': torch.ByteTensor(labels),
                                    'imgName': imageName,
                                    'imgPath' : path,
                                    'qXY' : qXYs,
                                    'iXY' : iXYs,
                                    'qHW' : qHWs,
                                    'iHW' : iHWs,
                                    'nodeIds' : nodeIds,
                                    'numNeighbors' : NNs,
                                    'missedRels': missedCount
                                    } )
                                pair_instances=[]
            self.instances = notpair_instances
            if self.balance and not self.eval:
                dif = len(notpair_instances)/float(len(pair_instances))
                print('not: {}, pair: {}. Adding {}x'.format(len(notpair_instances),len(pair_instances),math.floor(dif)))
                for i in range(math.floor(dif)):
                    self.instances += pair_instances
            else:
                self.instances += pair_instances