def parseAnn(self,np_img,annotations,s,imageName): fieldBBs = annotations['fieldBBs'] fixAnnotations(self,annotations) full_bbs=annotations['byId'].values() bbs = getBBWithPoints(full_bbs,s,useBlankClass=(not self.no_blanks),usePairedClass=self.use_paired_class) numClasses = bbs.shape[2]-16 #field_bbs = getBBWithPoints(annotations['fieldBBs'],s) #bbs = np.concatenate([text_bbs,field_bbs],axis=1) #has batch dim start_of_line, end_of_line = getStartEndGT(full_bbs,s) try: table_points, table_pixels = self.getTables( fieldBBs, s, np_img.shape[0], np_img.shape[1], annotations['pairs']) except Exception as inst: table_points=None table_pixels=None if imageName not in self.errors: #print(inst) #print('Table error on: '+imageName) self.errors.append(imageName) ##print('getStartEndGt: '+str(timeit.default_timer()-tic)) pixel_gt = table_pixels line_gts = { "start_of_line": start_of_line, "end_of_line": end_of_line } point_gts = { "table_points": table_points } numNeighbors=defaultdict(lambda:0) for id,bb in annotations['byId'].items(): if not self.onlyFormStuff or ('paired' in bb and bb['paired']): responseIds = getResponseBBIdList_(self,id,annotations) for id2,bb2 in annotations['byId'].items(): if id!=id2: pair = id2 in responseIds if pair: numNeighbors[id]+=1 numNeighbors = [numNeighbors[bb['id']] for bb in full_bbs] #if self.pred_neighbors: # bbs = torch.cat(bbs, idToIndex={} for i,bb in enumerate(full_bbs): idToIndex[bb['id']]=i pairs=[ (idToIndex[id1],idToIndex[id2]) for id1,id2 in annotations['pairs'] ] return bbs,line_gts,point_gts,pixel_gt,numClasses,numNeighbors, pairs
def parseAnn(self,annotations,scale): #fieldBBs = annotations['fieldBBs'] fixAnnotations(self,annotations) bbsToUse=[] ids=[] trans={} for id,bb in annotations['byId'].items(): if not self.onlyFormStuff or ('paired' in bb and bb['paired']): bbsToUse.append(bb) ids.append(bb['id']) if 'transcription' in annotations: trans[bb['id']] = annotations['transcription'][bb['id']] bbs = getBBWithPoints(bbsToUse,scale,useBlankClass=(not self.no_blanks),usePairedClass=self.use_paired_class) numClasses = bbs.shape[2]-16 return bbs,ids,numClasses, trans
def parseAnn(self, annotations, scale): #fieldBBs = annotations['fieldBBs'] fixAnnotations(self, annotations) bbsToUse = [] ids = [] trans = {} metadata = {} for id, bb in annotations['byId'].items(): if not self.onlyFormStuff or ('paired' in bb and bb['paired']): bbsToUse.append(bb) ids.append(bb['id']) if 'transcriptions' in annotations and bb['id'] in annotations[ 'transcriptions']: trans[bb['id']] = annotations['transcriptions'][bb['id']] else: trans[bb['id']] = None metadata[bb['id']] = {'type': bb['isBlank']} bbs = getBBWithPoints(bbsToUse, scale, useBlankClass=(not self.no_blanks), usePairedClass=self.use_paired_class, useAllClass=self.useClasses) #numClasses = bbs.shape[2]-16 numClasses = len(self.classMap) #import pdb;pdb.set_trace() if self.no_groups: idGroups = [[bbid] for bbid in ids] else: idGroups = formGroups(annotations, self.group_only_same) #revIds = {bbId:n for n,bbId in enumerate(ids)} #groups = [ [revIds[bbId] for bbId in group] for group in idGroups] groups = idGroups assert (bbs is not None) #print(metadata) assert (bbs is not None) assert (len(groups) > 0) return bbs, ids, numClasses, trans, groups, metadata, {}
def cluster(self,k,sample_count,outPath): def makePointsAndRects(h,w,r=None): if r is None: return np.array([-w/2.0,0,w/2.0,0,0,-h/2.0,0,h/2.0, 0,0, 0, h,w]) else: lx= -math.cos(r)*w ly= -math.sin(r)*w rx= math.cos(r)*w ry= math.sin(r)*w tx= math.sin(r)*h ty= -math.cos(r)*h bx= -math.sin(r)*h by= math.cos(r)*h return np.array([lx,ly,rx,ry,tx,ty,bx,by, 0,0, r, h,w]) meanH=62.42 stdH=87.31 meanW=393.03 stdW=533.53 ratios=[4.0,7.18,11.0,15.0,19.0,27.0] pointsAndRects=[] for inst in self.images: annotationPath = inst['annotationPath'] #rescaled = inst['rescaled'] with open(annotationPath) as annFile: annotations = json.loads(annFile.read()) fixAnnotations(self,annotations) for i in range(sample_count): if i==0: s = (self.rescale_range[0]+self.rescale_range[1])/2 else: s = np.random.uniform(self.rescale_range[0], self.rescale_range[1]) #partial_rescale = s/rescaled bbs = getBBWithPoints(annotations['byId'].values(),s) #field_bbs = self.getBBGT(annotations['fieldBBs'],s,fields=True) #bbs = np.concatenate([text_bbs,field_bbs],axis=1) bbs = convertBBs(bbs,self.rotate,2).numpy()[0] cos_rot = np.cos(bbs[:,2]) sin_rot = np.sin(bbs[:,2]) p_left_x = -cos_rot*bbs[:,4] p_left_y = -sin_rot*bbs[:,4] p_right_x = cos_rot*bbs[:,4] p_right_y = sin_rot*bbs[:,4] p_top_x = sin_rot*bbs[:,3] p_top_y = -cos_rot*bbs[:,3] p_bot_x = -sin_rot*bbs[:,3] p_bot_y = cos_rot*bbs[:,3] points = np.stack([p_left_x,p_left_y,p_right_x,p_right_y,p_top_x,p_top_y,p_bot_x,p_bot_y],axis=1) pointsAndRects.append(np.concatenate([points,bbs[:,:5]],axis=1)) pointsAndRects = np.concatenate(pointsAndRects,axis=0) #all_points = pointsAndRects[:,0:8] #all_heights = pointsAndRects[:,11] #all_widths = pointsAndRects[:,12] bestDistsFromMean=None for attempt in range(20 if k>0 else 1): if k>0: randomIndexes = np.random.randint(0,pointsAndRects.shape[0],(k)) means=pointsAndRects[randomIndexes] else: #minH=5 #minW=5 means=[] ##smaller than mean #for step in range(5): # height = minH + (meanH-minH)*(step/5.0) # width = minW + (meanW-minW)*(step/5.0) # for ratio in ratios: # means.append(makePointsAndRects(height,ratio*height)) # means.append(makePointsAndRects(width/ratio,width)) #for stddev in range(0,5): # for step in range(5-stddev): # height = meanH + stddev*stdH + stdH*(step/(5.0-stddev)) # width = meanW + stddev*stdW + stdW*(step/(5.0-stddev)) # for ratio in ratios: # means.append(makePointsAndRects(height,ratio*height)) # means.append(makePointsAndRects(width/ratio,width)) rots = [0,math.pi/2,math.pi,1.5*math.pi] if self.rotate: for height in np.linspace(15,200,num=4): for width in np.linspace(30,1200,num=4): for rot in rots: means.append(makePointsAndRects(height,width,rot)) #long boxes for width in np.linspace(1600,4000,num=3): #for height in np.linspace(30,100,num=3): # for rot in rots: # means.append(makePointsAndRects(height,width,rot)) for rot in rots: means.append(makePointsAndRects(50,width,rot)) else: #rotated boxes #for height in np.linspace(13,300,num=4): for height in np.linspace(13,300,num=3): means.append(makePointsAndRects(height,20)) #general boxes #for height in np.linspace(15,200,num=4): #for width in np.linspace(30,1200,num=4): for height in np.linspace(15,200,num=2): for width in np.linspace(30,1200,num=3): means.append(makePointsAndRects(height,width)) #long boxes for width in np.linspace(1600,4000,num=3): #for height in np.linspace(30,100,num=3): # means.append(makePointsAndRects(height,width)) means.append(makePointsAndRects(50,width)) k=len(means) print('K: {}'.format(k)) means = np.stack(means,axis=0) #pointsAndRects [0:p_left_x, 1:p_left_y,2:p_right_x,3:p_right_y,4:p_top_x,5:p_top_y,6:p_bot_x,7:p_bot_y, 8:xc, 9:yc, 10:rot, 11:h, 12:w cluster_centers=means distsFromMean=None prevDistsFromMean=None for iteration in range(100000): #intended to break out print('attempt:{}, bestDistsFromMean:{}, iteration:{}, bestDistsFromMean:{}'.format(attempt,bestDistsFromMean,iteration,prevDistsFromMean), end='\r') #means_points = means[:,0:8] #means_heights = means[:,11] #means_widths = means[:,12] # = groups = assignGroups(means,pointsAndRects) expanded_all_points = pointsAndRects[:,None,0:8] expanded_all_heights = pointsAndRects[:,None,11] expanded_all_widths = pointsAndRects[:,None,12] expanded_means_points = means[None,:,0:8] expanded_means_heights = means[None,:,11] expanded_means_widths = means[None,:,12] #expanded_all_points = expanded_all_points.expand(all_points.shape[0], all_points.shape[1], means_points.shape[1], all_points.shape[2]) expanded_all_points = np.tile(expanded_all_points,(1,means.shape[0],1)) expanded_all_heights = np.tile(expanded_all_heights,(1,means.shape[0])) expanded_all_widths = np.tile(expanded_all_widths,(1,means.shape[0])) #expanded_means_points = expanded_means_points.expand(means_points.shape[0], all_points.shape[0], means_points.shape[0], means_points.shape[2]) expanded_means_points = np.tile(expanded_means_points,(pointsAndRects.shape[0],1,1)) expanded_means_heights = np.tile(expanded_means_heights,(pointsAndRects.shape[0],1)) expanded_means_widths = np.tile(expanded_means_widths,(pointsAndRects.shape[0],1)) point_deltas = (expanded_all_points - expanded_means_points) #avg_heights = ((expanded_means_heights+expanded_all_heights)/2) #avg_widths = ((expanded_means_widths+expanded_all_widths)/2) avg_heights=avg_widths = (expanded_means_heights+expanded_all_heights+expanded_means_widths+expanded_all_widths)/4 #print point_deltas normed_difference = ( np.linalg.norm(point_deltas[:,:,0:2],2,2)/avg_widths + np.linalg.norm(point_deltas[:,:,2:4],2,2)/avg_widths + np.linalg.norm(point_deltas[:,:,4:6],2,2)/avg_heights + np.linalg.norm(point_deltas[:,:,6:8],2,2)/avg_heights )**2 #print normed_difference #import pdb; pdb.set_trace() groups = normed_difference.argmin(1) #this should list the mean (index) for each element of all distsFromMean = normed_difference.min(1).mean() if prevDistsFromMean is not None and distsFromMean >= prevDistsFromMean: break prevDistsFromMean = distsFromMean #means = computeMeans(groups,pointsAndRects) #means = np.zeros(k,13) for ki in range(k): selected = (groups==ki)[:,None] numSel = float(selected.sum()) if (numSel==0): break means[ki,:] = (pointsAndRects*np.tile(selected,(1,13))).sum(0)/numSel if bestDistsFromMean is None or distsFromMean<bestDistsFromMean: bestDistsFromMean = distsFromMean cluster_centers=means #cluster_centers=means dH=600 dW=3000 draw = np.zeros([dH,dW,3],dtype=np.float) toWrite = [] final_k=k for ki in range(k): pop = (groups==ki).sum().item() if pop>2: color = np.random.uniform(0.2,1,3).tolist() #d=math.sqrt(mean[ki,11]**2 + mean[ki,12]**2) #theta = math.atan2(mean[ki,11],mean[ki,12]) + mean[ki,10] h=cluster_centers[ki,11] w=cluster_centers[ki,12] rot=cluster_centers[ki,10] toWrite.append({'height':h.item(),'width':w.item(),'rot':rot.item(),'popularity':pop}) tr = ( int(math.cos(rot)*w-math.sin(rot)*h)+dW//2, int(math.sin(rot)*w+math.cos(rot)*h)+dH//2 ) tl = ( int(math.cos(rot)*-w-math.sin(rot)*h)+dW//2, int(math.sin(rot)*-w+math.cos(rot)*h)+dH//2 ) br = ( int(math.cos(rot)*w-math.sin(rot)*-h)+dW//2, int(math.sin(rot)*w+math.cos(rot)*-h)+dH//2 ) bl = ( int(math.cos(rot)*-w-math.sin(rot)*-h)+dW//2, int(math.sin(rot)*-w+math.cos(rot)*-h)+dH//2 ) cv2.line(draw,tl,tr,color) cv2.line(draw,tr,br,color) cv2.line(draw,br,bl,color) cv2.line(draw,bl,tl,color,2) else: final_k-=1 #print(toWrite) with open(outPath.format(final_k),'w') as out: out.write(json.dumps(toWrite)) print('saved '+outPath.format(final_k)) cv2.imshow('clusters',draw) cv2.waitKey()
def FormsGraphPair_printer(config,instance, model, gpu, metrics, outDir=None, startIndex=None, lossFunc=None): def __eval_metrics(data,target): acc_metrics = np.zeros((output.shape[0],len(metrics))) for ind in range(output.shape[0]): for i, metric in enumerate(metrics): acc_metrics[ind,i] += metric(output[ind:ind+1], target[ind:ind+1]) return acc_metrics def __to_tensor(instance,gpu): image = instance['img'] bbs = instance['bb_gt'] adjaceny = instance['adj'] num_neighbors = instance['num_neighbors'] if gpu is not None: image = image.to(gpu) if bbs is not None: bbs = bbs.to(gpu) if num_neighbors is not None: num_neighbors = num_neighbors.to(gpu) #adjacenyMatrix = adjacenyMatrix.to(self.gpu) return image, bbs, adjaceny, num_neighbors rel_thresholds = [config['THRESH']] if 'THRESH' in config else [0.5] if ('sweep_threshold' in config and config['sweep_threshold']) or ('sweep_thresholds' in config and config['sweep_thresholds']): rel_thresholds = np.arange(0.1,1.0,0.05) if ('sweep_threshold_big' in config and config['sweep_threshold_big']) or ('sweep_thresholds_big' in config and config['sweep_thresholds_big']): rel_thresholds = np.arange(0,20.0,1) if ('sweep_threshold_small' in config and config['sweep_threshold_small']) or ('sweep_thresholds_small' in config and config['sweep_thresholds_small']): rel_thresholds = np.arange(0,0.1,0.01) draw_rel_thresh = config['draw_thresh'] if 'draw_thresh' in config else rel_thresholds[0] #print(type(instance['pixel_gt'])) #if type(instance['pixel_gt']) == list: # print(instance) # print(startIndex) #data, targetBB, targetBBSizes = instance lossWeights = config['loss_weights'] if 'loss_weights' in config else {"box": 1, "rel":1} if lossFunc is None: yolo_loss = YoloLoss(model.numBBTypes,model.rotation,model.scale,model.anchors,**config['loss_params']['box']) else: yolo_loss = lossFunc data = instance['img'] batchSize = data.shape[0] assert(batchSize==1) targetBoxes = instance['bb_gt'] adjacency = instance['adj'] adjacency = list(adjacency) imageName = instance['imgName'] scale = instance['scale'] target_num_neighbors = instance['num_neighbors'] if not model.detector.predNumNeighbors: instance['num_neighbors']=None dataT, targetBoxesT, adjT, target_num_neighborsT = __to_tensor(instance,gpu) pretty = config['pretty'] if 'pretty' in config else False useDetections = config['useDetections'] if 'useDetections' in config else False if 'useDetect' in config: useDetections = config['useDetect'] confThresh = config['conf_thresh'] if 'conf_thresh' in config else None numClasses=2 #TODO no hard code resultsDirName='results' #if outDir is not None and resultsDirName is not None: #rPath = os.path.join(outDir,resultsDirName) #if not os.path.exists(rPath): # os.mkdir(rPath) #for name in targetBoxes: # nPath = os.path.join(rPath,name) # if not os.path.exists(nPath): # os.mkdir(nPath) #dataT = __to_tensor(data,gpu) #print('{}: {} x {}'.format(imageName,data.shape[2],data.shape[3])) if useDetections=='gt': outputBoxes, outputOffsets, relPred, relIndexes, bbPred = model(dataT,targetBoxesT,target_num_neighborsT,True, otherThresh=confThresh, otherThreshIntur=1 if confThresh is not None else None, hard_detect_limit=600) outputBoxes=torch.cat((torch.ones(targetBoxes.size(1),1),targetBoxes[0,:,0:5],targetBoxes[0,:,-numClasses:]),dim=1) #add score elif type(useDetections) is str: dataset=config['DATASET'] jsonPath = os.path.join(useDetections,imageName+'.json') with open(os.path.join(jsonPath)) as f: annotations = json.loads(f.read()) fixAnnotations(dataset,annotations) savedBoxes = torch.FloatTensor(len(annotations['byId']),6+model.detector.predNumNeighbors+numClasses) for i,(id,bb) in enumerate(annotations['byId'].items()): qX, qY, qH, qW, qR, qIsText, qIsField, qIsBlank, qNN = getBBInfo(bb,dataset.rotate,useBlankClass=not dataset.no_blanks) savedBoxes[i,0]=1 #conf savedBoxes[i,1]=qX*scale #x-center, already scaled savedBoxes[i,2]=qY*scale #y-center savedBoxes[i,3]=qR #rotation savedBoxes[i,4]=qH*scale/2 savedBoxes[i,5]=qW*scale/2 if model.detector.predNumNeighbors: extra=1 savedBoxes[i,6]=qNN else: extra=0 savedBoxes[i,6+extra]=qIsText savedBoxes[i,7+extra]=qIsField if gpu is not None: savedBoxes=savedBoxes.to(gpu) outputBoxes, outputOffsets, relPred, relIndexes, bbPred = model(dataT,savedBoxes,None,"saved", otherThresh=confThresh, otherThreshIntur=1 if confThresh is not None else None, hard_detect_limit=600) outputBoxes=savedBoxes.cpu() elif useDetections: print('Unknown detection flag: '+useDetections) exit() else: outputBoxes, outputOffsets, relPred, relIndexes, bbPred = model(dataT, otherThresh=confThresh, otherThreshIntur=1 if confThresh is not None else None, hard_detect_limit=600) if model.predNN and bbPred is not None: predNN = bbPred[:,0] else: predNN=None if model.detector.predNumNeighbors and not useDetections: #useOutputBBs=torch.cat((outputBoxes[:,0:6],outputBoxes[:,7:]),dim=1) #throw away NN pred extraPreds=1 if not model.predNN: predNN = outputBoxes[:,6] else: extraPreds=0 if not model.predNN: predNN = None #useOutputBBs=outputBoxes if targetBoxesT is not None: targetSize=targetBoxesT.size(1) else: targetSize=0 lossThis, position_loss, conf_loss, class_loss, nn_loss, recall, precision = yolo_loss(outputOffsets,targetBoxesT,[targetSize], target_num_neighborsT) if 'rule' in config: if config['rule']=='closest': dists = torch.FloatTensor(relPred.size()) differentClass = torch.FloatTensor(relPred.size()) predClasses = torch.argmax(outputBoxes[:,extraPreds+6:extraPreds+6+numClasses],dim=1) for i,(bb1,bb2) in enumerate(relIndexes): dists[i] = math.sqrt((outputBoxes[bb1,1]-outputBoxes[bb2,1])**2 + (outputBoxes[bb1,2]-outputBoxes[bb2,2])**2) differentClass[i] = predClasses[bb1]!=predClasses[bb2] maxDist = torch.max(dists) minDist = torch.min(dists) relPred = 1-(dists-minDist)/(maxDist-minDist) relPred *= differentClass elif config['rule']=='icdar': height = torch.FloatTensor(relPred.size()) dists = torch.FloatTensor(relPred.size()) right = torch.FloatTensor(relPred.size()) sameClass = torch.FloatTensor(relPred.size()) predClasses = torch.argmax(outputBoxes[:,extraPreds+6:extraPreds+6+numClasses],dim=1) for i,(bb1,bb2) in enumerate(relIndexes): sameClass[i] = predClasses[bb1]==predClasses[bb2] #g4 of the paper height[i] = max(outputBoxes[bb1,4],outputBoxes[bb2,4])/min(outputBoxes[bb1,4],outputBoxes[bb2,4]) #g5 of the paper if predClasses[bb1]==0: widthLabel = outputBoxes[bb1,5]*2 #we predict half width widthValue = outputBoxes[bb2,5]*2 dists[i] = math.sqrt(((outputBoxes[bb1,1]+widthLabel)-(outputBoxes[bb2,1]-widthValue))**2 + (outputBoxes[bb1,2]-outputBoxes[bb2,2])**2) else: widthLabel = outputBoxes[bb2,5]*2 #we predict half width widthValue = outputBoxes[bb1,5]*2 dists[i] = math.sqrt(((outputBoxes[bb1,1]-widthValue)-(outputBoxes[bb2,1]+widthLabel))**2 + (outputBoxes[bb1,2]-outputBoxes[bb2,2])**2) if dists[i]>2*widthLabel: dists[i]/=widthLabel else: #undefined dists[i] = min(1,dists[i]/widthLabel) #g6 of the paper if predClasses[bb1]==0: widthValue = outputBoxes[bb2,5]*2 hDist = outputBoxes[bb1,1]-outputBoxes[bb2,1] else: widthValue = outputBoxes[bb1,5]*2 hDist = outputBoxes[bb2,1]-outputBoxes[bb1,1] right[i] = hDist/widthValue relPred = 1-(height+dists+right + 10000*sameClass) else: print('ERROR, unknown rule {}'.format(config['rule'])) exit() elif relPred is not None: relPred = torch.sigmoid(relPred)[:,0] relCand = relIndexes if relCand is None: relCand=[] if model.rotation: bbAlignment, bbFullHit = getTargIndexForPreds_dist(targetBoxes[0],outputBoxes,0.9,numClasses,extraPreds,hard_thresh=False) else: bbAlignment, bbFullHit = getTargIndexForPreds_iou(targetBoxes[0],outputBoxes,0.5,numClasses,extraPreds,hard_thresh=False) if targetBoxes is not None: target_for_b = targetBoxes[0,:,:] else: target_for_b = torch.empty(0) if outputBoxes.size(0)>0: maxConf = outputBoxes[:,0].max().item() minConf = outputBoxes[:,0].min().item() if useDetections: minConf=0 #threshConf = max(maxConf*THRESH,0.5) #if model.rotation: # outputBoxes = non_max_sup_dist(outputBoxes.cpu(),threshConf,3) #else: # outputBoxes = non_max_sup_iou(outputBoxes.cpu(),threshConf,0.4) if model.rotation: ap_5, prec_5, recall_5 =AP_dist(target_for_b,outputBoxes,0.9,model.numBBTypes,beforeCls=extraPreds) else: ap_5, prec_5, recall_5 =AP_iou(target_for_b,outputBoxes,0.5,model.numBBTypes,beforeCls=extraPreds) #precisionHistory={} #precision=-1 #minStepSize=0.025 #targetPrecisions=[None] #for targetPrecision in targetPrecisions: # if len(precisionHistory)>0: # closestPrec=9999 # for prec in precisionHistory: # if abs(targetPrecision-prec)<abs(closestPrec-targetPrecision): # closestPrec=prec # precision=prec # stepSize=precisionHistory[prec][0] # else: # stepSize=0.1 # # while True: #abs(precision-targetPrecision)>0.001: toRet={} for rel_threshold in rel_thresholds: if 'optimize' in config and config['optimize']: if 'penalty' in config: penalty = config['penalty'] else: penalty = 0.25 print('optimizing with penalty {}'.format(penalty)) thresh=0.15 while thresh<0.45: keep = relPred>thresh newRelPred = relPred[keep] if newRelPred.size(0)<700: break if newRelPred.size(0)>0: #newRelCand = [ cand for i,cand in enumerate(relCand) if keep[i] ] usePredNN= predNN is not None and config['optimize']!='gt' idMap={} newId=0 newRelCand=[] numNeighbors=[] for index,(id1,id2) in enumerate(relCand): if keep[index]: if id1 not in idMap: idMap[id1]=newId if not usePredNN: numNeighbors.append(target_num_neighbors[0,bbAlignment[id1]]) else: numNeighbors.append(predNN[id1]) newId+=1 if id2 not in idMap: idMap[id2]=newId if not usePredNN: numNeighbors.append(target_num_neighbors[0,bbAlignment[id2]]) else: numNeighbors.append(predNN[id2]) newId+=1 newRelCand.append( [idMap[id1],idMap[id2]] ) #if not usePredNN: # decision = optimizeRelationships(newRelPred,newRelCand,numNeighbors,penalty) #else: decision= optimizeRelationshipsSoft(newRelPred,newRelCand,numNeighbors,penalty, rel_threshold) decision= torch.from_numpy( np.round_(decision).astype(int) ) decision=decision.to(relPred.device) relPred[keep] = torch.where(0==decision,relPred[keep]-1,relPred[keep]) relPred[1-keep] -=1 rel_threshold_use=0#-0.5 else: rel_threshold_use=rel_threshold else: rel_threshold_use=rel_threshold #threshed in model #if len(precisionHistory)==0: if len(toRet)==0: #align bb predictions (final) with GT if bbPred is not None and bbPred.size(0)>0: #create aligned GT #this was wrong... #first, remove unmatched predicitons that didn't overlap (weren't close) to any targets #toKeep = 1-((bbNoIntersections==1) * (bbAlignment==-1)) #remove predictions that overlapped with GT, but not enough if model.predNN: start=1 toKeep = 1-((bbFullHit==0) * (bbAlignment!=-1)) #toKeep = not (incomplete_overlap and did_overlap) if toKeep.any(): bbPredNN_use = bbPred[toKeep][:,0] bbAlignment_use = bbAlignment[toKeep] #becuase we used -1 to indicate no match (in bbAlignment), we add 0 as the last position in the GT, as unmatched if target_num_neighborsT is not None: target_num_neighbors_use = torch.cat((target_num_neighborsT[0].float(),torch.zeros(1).to(target_num_neighborsT.device)),dim=0) else: target_num_neighbors_use = torch.zeros(1).to(bbPred.device) alignedNN_use = target_num_neighbors_use[bbAlignment_use] else: bbPredNN_use=None alignedNN_use=None else: start=0 if model.predClass: #We really don't care about the class of non-overlapping instances if targetBoxes is not None: toKeep = bbFullHit==1 if toKeep.any(): bbPredClass_use = bbPred[toKeep][:,start:start+model.numBBTypes] bbAlignment_use = bbAlignment[toKeep] alignedClass_use = targetBoxesT[0][bbAlignment_use][:,13:13+model.numBBTypes] #There should be no -1 indexes in hereS else: bbPredClass_use=None alignedClass_use=None else: alignedClass_use = None else: bbPredNN_use = None bbPredClass_use = None if model.predNN and bbPredNN_use is not None and bbPredNN_use.size(0)>0: nn_loss_final = F.mse_loss(bbPredNN_use,alignedNN_use) #nn_loss_final *= self.lossWeights['nn'] #loss += nn_loss_final nn_loss_final = nn_loss_final.item() else: nn_loss_final=0 if model.predNN and predNN is not None: predNN_p=bbPred[:,0] diffs=torch.abs(predNN_p-target_num_neighborsT[0][bbAlignment].float()) nn_acc = (diffs<0.5).sum().item() nn_acc /= predNN.size(0) elif model.predNN: nn_acc = 0 if model.detector.predNumNeighbors and not useDetections: predNN_d = outputBoxes[:,6] diffs=torch.abs(predNN_d-target_num_neighbors[0][bbAlignment].float()) nn_acc_d = (diffs<0.5).sum().item() nn_acc_d /= predNN.size(0) if model.predClass and bbPredClass_use is not None and bbPredClass_use.size(0)>0: class_loss_final = F.binary_cross_entropy_with_logits(bbPredClass_use,alignedClass_use) #class_loss_final *= self.lossWeights['class'] #loss += class_loss_final class_loss_final = class_loss_final.item() else: class_loss_final = 0 #class_acc=0 useOutputBBs=None truePred=falsePred=badPred=0 scores=[] matches=0 i=0 numMissedByHeur=0 targGotHit=set() for i,(n0,n1) in enumerate(relCand): t0 = bbAlignment[n0].item() t1 = bbAlignment[n1].item() if t0>=0 and bbFullHit[n0]: targGotHit.add(t0) if t1>=0 and bbFullHit[n1]: targGotHit.add(t1) if t0>=0 and t1>=0 and bbFullHit[n0] and bbFullHit[n1]: if (min(t0,t1),max(t0,t1)) in adjacency: matches+=1 scores.append( (relPred[i],True) ) if relPred[i]>rel_threshold_use: truePred+=1 else: scores.append( (relPred[i],False) ) if relPred[i]>rel_threshold_use: falsePred+=1 else: scores.append( (relPred[i],False) ) if relPred[i]>rel_threshold_use: badPred+=1 for i in range(len(adjacency)-matches): numMissedByHeur+=1 scores.append( (float('nan'),True) ) rel_ap=computeAP(scores) numMissedByDetect=0 for t0,t1 in adjacency: if t0 not in targGotHit or t1 not in targGotHit: numMissedByHeur-=1 numMissedByDetect+=1 heurRecall = (len(adjacency)-numMissedByHeur)/len(adjacency) detectRecall = (len(adjacency)-numMissedByDetect)/len(adjacency) if len(adjacency)>0: relRecall = truePred/len(adjacency) else: relRecall = 1 #if falsePred>0: # relPrec = truePred/(truePred+falsePred) #else: # relPrec = 1 if falsePred+badPred>0: precision = truePred/(truePred+falsePred+badPred) else: precision = 1 toRet['prec@{}'.format(rel_threshold)]=precision toRet['recall@{}'.format(rel_threshold)]=relRecall if relRecall+precision>0: toRet['F-M@{}'.format(rel_threshold)]=2*relRecall*precision/(relRecall+precision) else: toRet['F-M@{}'.format(rel_threshold)]=0 toRet['rel_AP@{}'.format(rel_threshold)]=rel_ap #precisionHistory[precision]=(draw_rel_thresh,stepSize) #if targetPrecision is not None: # if abs(precision-targetPrecision)<0.001: # break # elif stepSize<minStepSize: # if precision<targetPrecision: # draw_rel_thresh += stepSize*2 # continue # else: # break # elif precision<targetPrecision: # draw_rel_thresh += stepSize # if not wasTooSmall: # reverse=True # wasTooSmall=True # else: # reverse=False # else: # draw_rel_thresh -= stepSize # if wasTooSmall: # reverse=True # wasTooSmall=False # else: # reverse=False # if reverse: # stepSize *= 0.5 #else: # break #import pdb;pdb.set_trace() #for b in range(len(outputBoxes)): dists=defaultdict(list) dists_x=defaultdict(list) dists_y=defaultdict(list) scaleDiffs=defaultdict(list) rotDiffs=defaultdict(list) b=0 #print('image {} has {} {}'.format(startIndex+b,targetBoxesSizes[name][b],name)) #bbImage = np.ones_like(image):w if outDir is not None: outputBoxes = outputBoxes.data.numpy() data = data.numpy() image = (1-((1+np.transpose(data[b][:,:,:],(1,2,0)))/2.0)).copy() if image.shape[2]==1: image = cv2.cvtColor(image,cv2.COLOR_GRAY2RGB) #if name=='text_start_gt': #Draw GT bbs if not pretty: for j in range(targetSize): plotRect(image,(1,0.5,0),targetBoxes[0,j,0:5]) #x=int(targetBoxes[b,j,0]) #y=int(targetBoxes[b,j,1]+targetBoxes[b,j,3]) #cv2.putText(image,'{:.2f}'.format(target_num_neighbors[b,j]),(x,y), cv2.FONT_HERSHEY_SIMPLEX, 0.5,(0.6,0.3,0),2,cv2.LINE_AA) #if alignmentBBs[b] is not None: # aj=alignmentBBs[b][j] # xc_gt = targetBoxes[b,j,0] # yc_gt = targetBoxes[b,j,1] # xc=outputBoxes[b,aj,1] # yc=outputBoxes[b,aj,2] # cv2.line(image,(xc,yc),(xc_gt,yc_gt),(0,1,0),1) # shade = 0.0+(outputBoxes[b,aj,0]-threshConf)/(maxConf-threshConf) # shade = max(0,shade) # if outputBoxes[b,aj,6] > outputBoxes[b,aj,7]: # color=(0,shade,shade) #text # else: # color=(shade,shade,0) #field # plotRect(image,color,outputBoxes[b,aj,1:6]) #bbs=[] #pred_points=[] #maxConf = outputBoxes[b,:,0].max() #threshConf = 0.5 #threshConf = max(maxConf*0.9,0.5) #print("threshConf:{}".format(threshConf)) #for j in range(outputBoxes.shape[1]): # conf = outputBoxes[b,j,0] # if conf>threshConf: # bbs.append((conf,j)) # #pred_points.append( #bbs.sort(key=lambda a: a[0]) #so most confident bbs are draw last (on top) #import pdb; pdb.set_trace() #Draw pred bbs bbs = outputBoxes for j in range(bbs.shape[0]): #circle aligned predictions conf = bbs[j,0] if outDir is not None: shade = 0.0+(conf-minConf)/(maxConf-minConf) #print(shade) #if name=='text_start_gt' or name=='field_end_gt': # cv2.bb(bbImage[:,:,1],p1,p2,shade,2) #if name=='text_end_gt': # cv2.bb(bbImage[:,:,2],p1,p2,shade,2) #elif name=='field_end_gt' or name=='field_start_gt': # cv2.bb(bbImage[:,:,0],p1,p2,shade,2) if bbs[j,6+extraPreds] > bbs[j,7+extraPreds]: color=(0,0,shade) #text else: color=(0,shade,shade) #field if pretty=='light': lineWidth=2 else: lineWidth=1 plotRect(image,color,bbs[j,1:6],lineWidth) if predNN is not None and not pretty: #model.detector.predNumNeighbors: x=int(bbs[j,1]) y=int(bbs[j,2])#-bbs[j,4]) targ_j = bbAlignment[j].item() if targ_j>=0: gtNN = target_num_neighbors[0,targ_j].item() else: gtNN = 0 pred_nn = predNN[j].item() color = min(abs(pred_nn-gtNN),1)#*0.5 cv2.putText(image,'{:.2}/{}'.format(pred_nn,gtNN),(x,y), cv2.FONT_HERSHEY_SIMPLEX, 0.5,(color,0,0),2,cv2.LINE_AA) #for j in alignmentBBsTarg[name][b]: # p1 = (targetBoxes[name][b,j,0], targetBoxes[name][b,j,1]) # p2 = (targetBoxes[name][b,j,0], targetBoxes[name][b,j,1]) # mid = ( int(round((p1[0]+p2[0])/2.0)), int(round((p1[1]+p2[1])/2.0)) ) # rad = round(math.sqrt((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2)/2.0) # #print(mid) # #print(rad) # cv2.circle(image,mid,rad,(1,0,1),1) draw_rel_thresh = relPred.max() * draw_rel_thresh #Draw pred pairings numrelpred=0 hits = [False]*len(adjacency) for i in range(len(relCand)): #print('{},{} : {}'.format(relCand[i][0],relCand[i][1],relPred[i])) if pretty: if relPred[i]>0 or pretty=='light': score = relPred[i] pruned=False lineWidth=2 else: score = relPred[i]+1 pruned=True lineWidth=1 #else: # score = (relPred[i]+1)/2 # pruned=False # lineWidth=2 #if pretty=='light': # lineWidth=3 else: lineWidth=1 if relPred[i]>draw_rel_thresh or (pretty and score>draw_rel_thresh): ind1 = relCand[i][0] ind2 = relCand[i][1] x1 = round(bbs[ind1,1]) y1 = round(bbs[ind1,2]) x2 = round(bbs[ind2,1]) y2 = round(bbs[ind2,2]) if pretty: targ1 = bbAlignment[ind1].item() targ2 = bbAlignment[ind2].item() aId=None if bbFullHit[ind1] and bbFullHit[ind2]: if (targ1,targ2) in adjacency: aId = adjacency.index((targ1,targ2)) elif (targ2,targ1) in adjacency: aId = adjacency.index((targ2,targ1)) if aId is None: if pretty=='clean' and pruned: color=np.array([1,1,0]) else: color=np.array([1,0,0]) else: if pretty=='clean' and pruned: color=np.array([1,0,1]) else: color=np.array([0,1,0]) hits[aId]=True #if pruned: # color = color*0.7 cv2.line(image,(x1,y1),(x2,y2),color.tolist(),lineWidth) #color=color/3 #x = int((x1+x2)/2) #y = int((y1+y2)/2) #if pruned: # cv2.putText(image,'[{:.2}]'.format(score),(x,y), cv2.FONT_HERSHEY_PLAIN, 0.6,color.tolist(),1) #else: # cv2.putText(image,'{:.2}'.format(score),(x,y), cv2.FONT_HERSHEY_PLAIN,1.1,color.tolist(),1) else: shade = (relPred[i].item()-draw_rel_thresh)/(1-draw_rel_thresh) #print('draw {} {} {} {} '.format(x1,y1,x2,y2)) cv2.line(image,(x1,y1),(x2,y2),(0,shade,0),lineWidth) numrelpred+=1 if pretty and pretty!="light" and pretty!="clean": for i in range(len(relCand)): #print('{},{} : {}'.format(relCand[i][0],relCand[i][1],relPred[i])) if relPred[i]>-1: score = (relPred[i]+1)/2 pruned=False else: score = (relPred[i]+2+1)/2 pruned=True if relPred[i]>draw_rel_thresh or (pretty and score>draw_rel_thresh): ind1 = relCand[i][0] ind2 = relCand[i][1] x1 = round(bbs[ind1,1]) y1 = round(bbs[ind1,2]) x2 = round(bbs[ind2,1]) y2 = round(bbs[ind2,2]) targ1 = bbAlignment[ind1].item() targ2 = bbAlignment[ind2].item() aId=None if bbFullHit[ind1] and bbFullHit[ind2]: if (targ1,targ2) in adjacency: aId = adjacency.index((targ1,targ2)) elif (targ2,targ1) in adjacency: aId = adjacency.index((targ2,targ1)) if aId is None: color=np.array([1,0,0]) else: color=np.array([0,1,0]) color=color/2 x = int((x1+x2)/2) y = int((y1+y2)/2) if pruned: cv2.putText(image,'[{:.2}]'.format(score),(x,y), cv2.FONT_HERSHEY_PLAIN, 0.6,color.tolist(),1) else: cv2.putText(image,'{:.2}'.format(score),(x,y), cv2.FONT_HERSHEY_PLAIN,1.1,color.tolist(),1) #print('number of pred rels: {}'.format(numrelpred)) #Draw GT pairings if not pretty: gtcolor=(0.25,0,0.25) wth=3 else: #gtcolor=(1,0,0.6) gtcolor=(1,0.6,0) wth=2 for aId,(i,j) in enumerate(adjacency): if not pretty or not hits[aId]: x1 = round(targetBoxes[0,i,0].item()) y1 = round(targetBoxes[0,i,1].item()) x2 = round(targetBoxes[0,j,0].item()) y2 = round(targetBoxes[0,j,1].item()) cv2.line(image,(x1,y1),(x2,y2),gtcolor,wth) #Draw alginment between gt and pred bbs if not pretty: for predI in range(bbs.shape[0]): targI=bbAlignment[predI].item() x1 = int(round(bbs[predI,1])) y1 = int(round(bbs[predI,2])) if targI>0: x2 = round(targetBoxes[0,targI,0].item()) y2 = round(targetBoxes[0,targI,1].item()) cv2.line(image,(x1,y1),(x2,y2),(1,0,1),1) else: #draw 'x', indicating not match cv2.line(image,(x1-5,y1-5),(x1+5,y1+5),(.1,0,.1),1) cv2.line(image,(x1+5,y1-5),(x1-5,y1+5),(.1,0,.1),1) saveName = '{}_boxes_prec:{:.2f},{:.2f}_recall:{:.2f},{:.2f}_rels_AP:{:.3f}'.format(imageName,prec_5[0],prec_5[1],recall_5[0],recall_5[1],rel_ap) #for j in range(metricsOut.shape[1]): # saveName+='_m:{0:.3f}'.format(metricsOut[i,j]) saveName+='.png' io.imsave(os.path.join(outDir,saveName),image) #print('saved: '+os.path.join(outDir,saveName)) print('\n{} ap:{}\tnumMissedByDetect:{}\tmissedByHuer:{}'.format(imageName,rel_ap,numMissedByDetect,numMissedByHeur)) retData= { 'bb_ap':[ap_5], 'bb_recall':[recall_5], 'bb_prec':[prec_5], 'bb_Fm': -1,#(recall_5[0]+recall_5[1]+prec_5[0]+prec_5[1])/4, 'nn_loss': nn_loss, 'rel_recall':relRecall, 'rel_precision':precision, 'rel_Fm':2*relRecall*precision/(relRecall+precision) if relRecall+precision>0 else 0, 'relMissedByHeur':numMissedByHeur, 'relMissedByDetect':numMissedByDetect, 'heurRecall': heurRecall, 'detectRecall': detectRecall, **toRet } if rel_ap is not None: #none ap if no relationships retData['rel_AP']=rel_ap retData['no_targs']=0 else: retData['no_targs']=1 if model.predNN: retData['nn_loss_final']=nn_loss_final retData['nn_loss_diff']=nn_loss_final-nn_loss retData['nn_acc_final'] = nn_acc if model.detector.predNumNeighbors and not useDetections: retData['nn_acc_detector'] = nn_acc_d if model.predClass: retData['class_loss_final']=class_loss_final retData['class_loss_diff']=class_loss_final-class_loss return ( retData, (lossThis, position_loss, conf_loss, class_loss, recall, precision) )
def __init__(self, dirPath=None, split=None, config=None, instances=None, test=False): if split=='valid': valid=True amountPer=0.25 else: valid=False self.cache_resized=False if 'augmentation_params' in config: self.augmentation_params=config['augmentation_params'] else: self.augmentation_params=None if 'no_blanks' in config: self.no_blanks = config['no_blanks'] else: self.no_blanks = False if 'no_print_fields' in config: self.no_print_fields = config['no_print_fields'] else: self.no_print_fields = False numFeats=10 self.use_corners = config['corners'] if 'corners' in config else False self.no_graphics = config['no_graphics'] if 'no_graphics' in config else False if self.use_corners=='xy': numFeats=18 elif self.use_corners: numFeats=14 self.swapCircle = config['swap_circle'] if 'swap_circle' in config else True self.onlyFormStuff = config['only_form_stuff'] if 'only_form_stuff' in config else False self.only_opposite_pairs = config['only_opposite_pairs'] if 'only_opposite_pairs' in config else False self.color = config['color'] if 'color' in config else True self.rotate = config['rotation'] if 'rotation' in config else True #self.simple_dataset = config['simple_dataset'] if 'simple_dataset' in config else False self.special_dataset = config['special_dataset'] if 'special_dataset' in config else None if 'simple_dataset' in config and config['simple_dataset']: self.special_dataset='simple' self.balance = config['balance'] if 'balance' in config else False self.eval = config['eval'] if 'eval' in config else False self.altJSONDir = config['alternate_json_dir'] if 'alternate_json_dir' in config else None #width_mean=400.006887263 #height_mean=47.9102279201 xScale=400 yScale=50 xyScale=(xScale+yScale)/2 if instances is not None: self.instances=instances else: if self.special_dataset is not None: splitFile = self.special_dataset+'_train_valid_test_split.json' else: splitFile = 'train_valid_test_split.json' with open(os.path.join(dirPath,splitFile)) as f: readFile = json.loads(f.read()) if type(split) is str: groupsToUse = readFile[split] elif type(split) is list: groupsToUse = {} for spstr in split: newGroups = readFile[spstr] groupsToUse.update(newGroups) else: print("Error, unknown split {}".format(split)) exit() groupNames = list(groupsToUse.keys()) groupNames.sort() pair_instances=[] notpair_instances=[] for groupName in groupNames: imageNames=groupsToUse[groupName] if groupName in SKIP: print('Skipped group {}'.format(groupName)) continue for imageName in imageNames: org_path = os.path.join(dirPath,'groups',groupName,imageName) #print(org_path) if self.cache_resized: path = os.path.join(self.cache_path,imageName) else: path = org_path jsonPaths = [org_path[:org_path.rfind('.')]+'.json'] if self.altJSONDir is not None: jsonPaths = [os.path.join(self.altJSONDir,imageName[:imageName.rfind('.')]+'.json')] for jsonPath in jsonPaths: annotations=None if os.path.exists(jsonPath): if annotations is None: with open(os.path.join(jsonPath)) as f: annotations = json.loads(f.read()) #print(os.path.join(jsonPath)) #fix assumptions made in GTing missedCount=fixAnnotations(self,annotations) #print(path) numNeighbors=defaultdict(lambda:0) for id,bb in annotations['byId'].items(): if not self.onlyFormStuff or ('paired' in bb and bb['paired']): responseBBList = self.__getResponseBBList(id,annotations) responseIds = [bb['id'] for bb in responseBBList] for id2,bb2 in annotations['byId'].items(): if id!=id2: pair = id2 in responseIds if pair: numNeighbors[id]+=1 #well catch id2 on it's own pass for id,bb in annotations['byId'].items(): if not self.onlyFormStuff or ('paired' in bb and bb['paired']): numN1 = numNeighbors[id]-1 qX, qY, qH, qW, qR, qIsText, qIsField, qIsBlank, qNN = getBBInfo(bb,self.rotate,useBlankClass=not self.no_blanks) tlX = bb['poly_points'][0][0] tlY = bb['poly_points'][0][1] trX = bb['poly_points'][1][0] trY = bb['poly_points'][1][1] brX = bb['poly_points'][2][0] brY = bb['poly_points'][2][1] blX = bb['poly_points'][3][0] blY = bb['poly_points'][3][1] qH /= yScale #math.log( (qH+0.375*height_mean)/height_mean ) #rescaling so 0 height is -1, big height is 1+ qW /= xScale #math.log( (qW+0.375*width_mean)/width_mean ) #rescaling so 0 width is -1, big width is 1+ qR = qR/math.pi responseBBList = self.__getResponseBBList(id,annotations) responseIds = [bb['id'] for bb in responseBBList] for id2,bb2 in annotations['byId'].items(): if id!=id2: numN2 = numNeighbors[id2]-1 iX, iY, iH, iW, iR, iIsText, iIsField, iIsBlank, iNN = getBBInfo(bb2,self.rotate,useBlankClass=not self.no_blanks) tlX2 = bb2['poly_points'][0][0] tlY2 = bb2['poly_points'][0][1] trX2 = bb2['poly_points'][1][0] trY2 = bb2['poly_points'][1][1] brX2 = bb2['poly_points'][2][0] brY2 = bb2['poly_points'][2][1] blX2 = bb2['poly_points'][3][0] blY2 = bb2['poly_points'][3][1] iH /=yScale #math.log( (iH+0.375*height_mean)/height_mean ) iW /=xScale #math.log( (iW+0.375*width_mean)/width_mean ) iR = iR/math.pi xDiff=iX-qX yDiff=iY-qY yDiff /= yScale #math.log( (yDiff+0.375*yDiffScale)/yDiffScale ) xDiff /= xScale #math.log( (xDiff+0.375*xDiffScale)/xDiffScale ) tlDiff = math.sqrt( (tlX-tlX2)**2 + (tlY-tlY2)**2 )/xyScale trDiff = math.sqrt( (trX-trX2)**2 + (trY-trY2)**2 )/xyScale brDiff = math.sqrt( (brX-brX2)**2 + (brY-brY2)**2 )/xyScale blDiff = math.sqrt( (blX-blX2)**2 + (blY-blY2)**2 )/xyScale tlXDiff = (tlX2-tlX)/xScale trXDiff = (trX2-trX)/xScale brXDiff = (brX2-brX)/xScale blXDiff = (blX2-blX)/xScale tlYDiff = (tlY2-tlY)/yScale trYDiff = (trY2-trY)/yScale brYDiff = (brY2-brY)/yScale blYDiff = (blY2-blY)/yScale pair = id2 in responseIds if pair or self.eval: instances = pair_instances else: instances = notpair_instances if self.altJSONDir is None: data=[qH,qW,qR,qIsText, iH,iW,iR,iIsText, xDiff, yDiff] else: data=[qH,qW,qR,qIsText,qIsField, iH,iW,iR,iIsText,iIsField, xDiff, yDiff] if self.use_corners=='xy': data+=[tlXDiff,trXDiff,brXDiff,blXDiff,tlYDiff,trYDiff,brYDiff,blYDiff] elif self.use_corners: data+=[tlDiff, trDiff, brDiff, blDiff] if qIsBlank is not None: data+=[qIsBlank,iIsBlank] if qNN is not None: data+=[qNN,iNN] instances.append( { 'data': torch.tensor([ data ]), 'label': pair, 'imgName': imageName, 'qXY' : (qX,qY), 'iXY' : (iX,iY), 'qHW' : (qH,qW), 'iHW' : (iH,iW), 'ids' : (id,id2), 'numNeighbors': torch.tensor([ [numN1,numN2] ]) } ) if self.eval: #if evaluating, pack all instances for an image into a batch datas=[] labels=[] qXYs=[] iXYs=[] qHWs=[] iHWs=[] nodeIds=[] NNs=[] numTrue=0 for inst in pair_instances: datas.append(inst['data']) labels.append(inst['label']) numTrue += inst['label'] qXYs.append(inst['qXY']) iXYs.append(inst['iXY']) qHWs.append(inst['qHW']) iHWs.append(inst['iHW']) nodeIds.append(inst['ids']) NNs.append(inst['numNeighbors']) if len(datas)>0: data = torch.cat(datas,dim=0), else: data = torch.FloatTensor((0,numFeats)) if len(NNs)>0: NNs = torch.cat(NNs,dim=0) else: NNs = torch.FloatTensor((0,2)) #missedCount=0 #for id1,id2 in annotations['pairs']: # if id1 not in annotations['byId'] or id2 not in annotations['byId']: # missedCount+=1 notpair_instances.append( { 'data': data, 'label': torch.ByteTensor(labels), 'imgName': imageName, 'imgPath' : path, 'qXY' : qXYs, 'iXY' : iXYs, 'qHW' : qHWs, 'iHW' : iHWs, 'nodeIds' : nodeIds, 'numNeighbors' : NNs, 'missedRels': missedCount } ) pair_instances=[] self.instances = notpair_instances if self.balance and not self.eval: dif = len(notpair_instances)/float(len(pair_instances)) print('not: {}, pair: {}. Adding {}x'.format(len(notpair_instances),len(pair_instances),math.floor(dif))) for i in range(math.floor(dif)): self.instances += pair_instances else: self.instances += pair_instances