def getMatchingPrimitive(dataS, dataT, dataset, representation, doCompletion): """ - detect keypoint - get keypoint 3d position/normal/feature """ # compute keypoint if 'suncg' in dataset or 'matterport' in dataset: pts,ptsNorm,ptsW,ptt,pttNorm,pttW = getKeypoint(dataS['rgb'],dataT['rgb'],dataS['feat'],dataT['feat']) elif 'scannet' in dataset: pts,ptsNorm,ptsW,ptt,pttNorm,pttW = getKeypoint_kinect(dataS['rgb'],dataT['rgb'],dataS['feat'],dataT['feat'],dataS['rgb_full'],dataT['rgb_full']) # early return if too few keypoint detected if pts is None or ptt is None or pts.shape[1]<2 or ptt.shape[1]<2: return None,None,None,None,None,None,None,None # get the 3d location of matches pts3d,ptsns = getPixel(dataS['depth'],dataS['normal'],pts,dataset=dataset,representation=representation) ptt3d,ptsnt = getPixel(dataT['depth'],dataT['normal'],ptt,dataset=dataset,representation=representation) # interpolate the nn feature map to get feature vectors dess = torch_op.npy(interpolate(dataS['feat'],torch_op.v(ptsNorm))).T dest = torch_op.npy(interpolate(dataT['feat'],torch_op.v(pttNorm))).T if not doCompletion: # filter out those keypoint from unobserved region pts3d,ptsns,dess,ptsW = pts3d[:,ptsW==1],ptsns[ptsW==1],dess[ptsW==1],ptsW[ptsW==1] ptt3d,ptsnt,dest,pttW = ptt3d[:,pttW==1],ptsnt[pttW==1],dest[pttW==1],pttW[pttW==1] return pts3d,ptt3d,ptsns,ptsnt,dess,dest,ptsW,pttW
def contrast_loss(self,featMaps,featMapt,denseCorres): validCorres=torch.nonzero(denseCorres['valid']==1).view(-1).long() n = featMaps.shape[0] if not len(validCorres): loss_fl_pos=torch_op.v(np.array([0]))[0] loss_fl_neg=torch_op.v(np.array([0]))[0] loss_fl=torch_op.v(np.array([0]))[0][0] loss_fc=torch_op.v(np.array([0]))[0] else: # consistency of keypoint proposal across different view idxInst=torch.arange(n)[validCorres].view(-1,1).repeat(1,denseCorres['idxSrc'].shape[1]).view(-1).long() featS=featMaps[idxInst,:,denseCorres['idxSrc'][validCorres,:,1].view(-1).long(),denseCorres['idxSrc'][validCorres,:,0].view(-1).long()] featT=featMapt[idxInst,:,denseCorres['idxTgt'][validCorres,:,1].view(-1).long(),denseCorres['idxTgt'][validCorres,:,0].view(-1).long()] # positive example, loss_fl_pos=(featS-featT).pow(2).sum(1).mean() # negative example, make sure does not contain positive Kn = denseCorres['idxSrc'].shape[1] C = featMaps.shape[1] negIdy=torch.from_numpy(np.random.choice(range(featMaps.shape[2]),Kn*100*len(validCorres))) negIdx=torch.from_numpy(np.random.choice(range(featMaps.shape[3]),Kn*100*len(validCorres))) idx=torch.arange(n)[validCorres].view(-1,1).repeat(1,Kn*100).view(-1).long() loss_fl_neg=F.relu(self.args.D-(featS.unsqueeze(1).repeat(1,100,1).view(-1,C)-featMapt[idx,:,negIdy,negIdx]).pow(2).sum(1)).mean() loss_fl=loss_fl_pos+loss_fl_neg return loss_fl, loss_fl_pos, loss_fl_neg
def apply_mask(x,maskMethod,*arg): # input: [n,c,h,w] h=x.shape[2] w=x.shape[3] tp = np.zeros([x.shape[0],1,x.shape[2],x.shape[3]]) geow=np.zeros([x.shape[0],1,x.shape[2],x.shape[3]]) if maskMethod == 'second': tp[:,:,:h,h:2*h]=1 ys,xs=np.meshgrid(range(h),range(w),indexing='ij') dist=np.stack((np.abs(xs-h),np.abs(xs-(2*h)),np.abs(xs-w-h),np.abs(xs-w-(2*h))),0) dist=dist.min(0)/h sigmaGeom=0.7 dist=np.exp(-dist/(2*sigmaGeom**2)) dist[:,h:2*h]=0 geow = torch_op.v(np.tile(np.reshape(dist,[1,1,dist.shape[0],dist.shape[1]]),[geow.shape[0],1,1,1])) elif maskMethod == 'kinect': assert(w==640 and h==160) dw = int(89.67//2) dh = int(67.25//2) tp[:,:,80-dh:80+dh,160+80-dw:160+80+dw]=1 geow = tp.copy()*20 geow[tp==0]=1 geow = torch_op.v(geow) tp=torch_op.v(tp) x=x*tp return x,tp,geow
def pnlayer(depth,normal,plane,dataList,representation): # dp: [n,1,h,w] # n: [n,3,h,w] if 'suncg' in dataList or 'matterport' in dataList: n,h,w = depth.shape[0],depth.shape[2],depth.shape[3] assert(h==w//4) Rs = np.zeros([4,4,4]) Rs[0] = np.eye(4) Rs[1] = np.array([[0,0,-1,0],[0,1,0,0],[1,0,0,0],[0,0,0,1]]) Rs[2] = np.array([[-1,0,0,0],[0,1,0,0],[0,0,-1,0],[0,0,0,1]]) Rs[3] = np.array([[0,0,1,0],[0,1,0,0],[-1,0,0,0],[0,0,0,1]]) Rs=torch_op.v(Rs) loss_pn=0 for i in range(4): plane_this=plane[:,0,:,i*h:(i+1)*h].contiguous() depth_this=depth[:,0,:,i*h:(i+1)*h].contiguous() ys, xs = np.meshgrid(range(h),range(h),indexing='ij') ys, xs = (0.5-ys / h)*2, (xs / h-0.5)*2 xs = xs.flatten() ys = ys.flatten() zs = plane_this.view(-1) mask = (zs!=0) masknpy = torch_op.npy(mask) normal_this=normal[:,:,:,i*h:(i+1)*h].permute(0,2,3,1).contiguous().view(-1,3) if 'suncg' in dataList: normal_this=torch.matmul(Rs[i][:3,:3].t(),normal_this.t()).t() elif 'matterport' in dataList: normal_this=torch.matmul(Rs[(i-1)%4][:3,:3].t(),normal_this.t()).t() ray = np.tile(np.stack((-xs[masknpy],-ys[masknpy],np.ones(len(xs))),1),[n,1]) ray = torch_op.v(ray) pcPn=(zs/(ray*normal_this+1e-6).sum(1)).unsqueeze(1)*ray xs=torch_op.v(np.tile(xs,n)) ys=torch_op.v(np.tile(ys,n)) zs=depth_this.view(-1) xs=xs*zs ys=ys*zs pcD = torch.stack((xs,ys,-zs),1) loss_pn+=(pcD-pcPn).clamp(-5,5).abs().mean() elif 'scannet' in dataList: raise Exception("not implemented: scannet/skybox representation") return loss_pn
def forward(self, x, pair, img, imgPCid): """ x -> features [B, 3, N] -> [B, K] """ # x = points.transpose(1, 2) # [B, 3, N] x_in = x b, _, num_points = x.shape imgFeat = self.imgBackbone(img) pred_semantic_img = self.conv_semantic(imgFeat) imgFeat_s = imgFeat[:b] imgFeat_t = imgFeat[b:] bindex = v( np.tile(np.arange(b)[:, None], [1, imgPCid.shape[2]]).reshape(-1)).long() features_s = imgFeat_s[bindex, :, imgPCid[:, 0, :, 1].contiguous().view(-1).long(), imgPCid[:, 0, :, 0].contiguous().view(-1).long()] features_t = imgFeat_t[bindex, :, imgPCid[:, 1, :, 1].contiguous().view(-1).long(), imgPCid[:, 1, :, 0].contiguous().view(-1).long()] x = self.h1(x) self.t_out_h1 = x # local features x = self.h2(x) #x = flatten(torch.nn.functional.max_pool1d(x, x.size(-1))) x = flatten(self.sy(x)) l0 = self.t_out_h1 # [B, 64, N] g0 = x # [B, K] x = torch.cat((l0, g0.unsqueeze(2).repeat(1, 1, num_points)), dim=1) pred_semantic = self.regressor_s( x.permute(0, 2, 1).contiguous().view(-1, 1088)).view(b, num_points, -1) pred_semantic = pred_semantic.permute(0, 2, 1) x1 = x[:, :, :num_points // 2] x2 = x[:, :, num_points // 2:] bindex = torch.arange(b)[:, None].repeat(1, pair.shape[1]).view(-1).long() feat1 = x1[bindex, :, pair[:, :, 0].view(-1).long()] feat2 = x2[bindex, :, pair[:, :, 1].view(-1).long()] feat = torch.cat((feat1, features_s, feat2, features_t), -1) pred = self.regressor(feat) pred = pred.view(b, -1, 4) return pred, pred_semantic, pred_semantic_img
def apply_mask(x, maskMethod, *arg): # input: [n,c,h,w] h = x.shape[2] w = x.shape[3] tp = np.zeros([x.shape[0], 1, x.shape[2], x.shape[3]]) geow = np.zeros([x.shape[0], 1, x.shape[2], x.shape[3]]) if maskMethod == 'second': tp[:, :, :h, h:2 * h] = 1 ys, xs = np.meshgrid(range(h), range(w), indexing='ij') dist = np.stack((np.abs(xs - h), np.abs(xs - (2 * h)), np.abs(xs - w - h), np.abs(xs - w - (2 * h))), 0) dist = dist.min(0) / h sigmaGeom = 0.7 dist = np.exp(-dist / (2 * sigmaGeom**2)) dist[:, h:2 * h] = 0 geow = torch_op.v( np.tile(np.reshape(dist, [1, 1, dist.shape[0], dist.shape[1]]), [geow.shape[0], 1, 1, 1])) tp = torch_op.v(tp) x = x * tp return x, tp, geow
def getMatchingPrimitive(dataS, dataT, dataset, representation, doCompletion): """ - detect keypoint - get keypoint 3d position/normal/feature """ # compute keypoint # ptsW contains two values, 1: in the observed region, 0.99: not in the observed region if 'suncg' in dataset or 'matterport' in dataset: pts,ptsNorm,ptsW,ptt,pttNorm,pttW = getKeypoint(dataS['rgb'],dataT['rgb'],dataS['feat'],dataT['feat']) elif 'scannet' in dataset: if(dataS['rgb'].shape[0] == 200): pts,ptsNorm,ptsW,ptt,pttNorm,pttW = getKeypoint_kinect_120fov(dataS['rgb'],dataT['rgb'],dataS['feat'],dataT['feat'],dataS['rgb_full'],dataT['rgb_full']) else: pts,ptsNorm,ptsW,ptt,pttNorm,pttW = getKeypoint_kinect(dataS['rgb'],dataT['rgb'],dataS['feat'],dataT['feat'],dataS['rgb_full'],dataT['rgb_full']) # early return if too few keypoint detected if pts is None or ptt is None or pts.shape[1]<2 or ptt.shape[1]<2: return None,None,None,None,None,None,None,None # get the 3d location of matches #pts=pts[(pts[:,0]>200-64) & (pts[:,0]<200+64)] #pts=pts[(pts[:,1]>100-48) & (pts[:,1]<100+48)] #ptt=ptt[(ptt[:,0]>200-64) & (ptt[:,0]<200+64)] #ptt=ptt[(ptt[:,1]>100-48) & (ptt[:,1]<100+48)] pts3d,ptsns = getPixel(dataS['depth'],dataS['normal'],pts,dataset=dataset,representation=representation) ptt3d,ptsnt = getPixel(dataT['depth'],dataT['normal'],ptt,dataset=dataset,representation=representation) # interpolate the nn feature map to get feature vectors dess = torch_op.npy(interpolate(dataS['feat'],torch_op.v(ptsNorm))).T dest = torch_op.npy(interpolate(dataT['feat'],torch_op.v(pttNorm))).T if not doCompletion: # filter out those keypoint from unobserved region pts3d,ptsns,dess,ptsW = pts3d[:,ptsW==1],ptsns[ptsW==1],dess[ptsW==1],ptsW[ptsW==1] ptt3d,ptsnt,dest,pttW = ptt3d[:,pttW==1],ptsnt[pttW==1],dest[pttW==1],pttW[pttW==1] return pts3d,ptt3d,ptsns,ptsnt,dess,dest,ptsW,pttW
def step(self,data,mode='train'): torch.cuda.empty_cache() if self.speed_benchmark: step_start=time.time() with torch.set_grad_enabled(mode == 'train'): np.random.seed() self.optimizerG.zero_grad() MSEcriterion = torch.nn.MSELoss() BCEcriterion = torch.nn.BCELoss() CEcriterion = nn.CrossEntropyLoss(weight=self.class_balance_weights,reduce=False) rgb,norm,depth,dataMask,Q = v(data['rgb']),v(data['norm']),v(data['depth']),v(data['dataMask']),v(data['Q']) proj_rgb_p,proj_n_p,proj_d_p,proj_mask_p = v(data['proj_rgb_p']),v(data['proj_n_p']),v(data['proj_d_p']),v(data['proj_mask_p']) proj_flow = v(data['proj_flow']) if 's' in self.args.outputType: segm = v(data['segm']) if self.args.dynamicWeighting: dynamicW = v(data['proj_box_p']) dynamicW[dynamicW==0] = 0.2 dynamicW = torch.cat((dynamicW[:,0,:,:,:],dynamicW[:,1,:,:,:])) else: dynamicW = 1 errG_rgb,errG_d,errG_n,errG_k,errG_s = torch.FloatTensor([0]),torch.FloatTensor([0]),torch.FloatTensor([0]),torch.FloatTensor([0]),torch.FloatTensor([0]) n = Q.shape[0] complete_s=torch.cat((rgb[:,0,:,:,:],norm[:,0,:,:,:],depth[:,0:1,:,:]),1) complete_t=torch.cat((rgb[:,1,:,:,:],norm[:,1,:,:,:],depth[:,1:2,:,:]),1) view_s,mask_s,geow_s = apply_mask(complete_s.clone(),self.args.maskMethod,self.args.ObserveRatio) view_s = torch.cat((view_s,mask_s),1) view_t,mask_t,geow_t = apply_mask(complete_t.clone(),self.args.maskMethod,self.args.ObserveRatio) view_t = torch.cat((view_t,mask_t),1) view_t2s=torch.cat((proj_rgb_p[:,0,:,:,:],proj_n_p[:,0,:,:,:],proj_d_p[:,0,:,:,:],proj_mask_p[:,0,:,:,:]),1) view_s2t=torch.cat((proj_rgb_p[:,1,:,:,:],proj_n_p[:,1,:,:,:],proj_d_p[:,1,:,:,:],proj_mask_p[:,1,:,:,:]),1) # netG need to tolerate three type of input: # 0.correct s + blank t # 1.correct s + wrong t # 2.correct s + correct t view_s_type0 = torch.cat((view_s,torch.zeros(view_s.shape).cuda()),1) view_s_type1 = torch.cat((view_s,view_t2s),1) view_t_type0 = torch.cat((view_t,torch.zeros(view_t.shape).cuda()),1) view_t_type1 = torch.cat((view_t,view_s2t),1) if 's' in self.args.outputType: segm = torch.cat((segm[:,0,:,:,:],segm[:,1,:,:,:])).repeat(2,1,1,1) # mask the pano view=torch.cat((view_s_type0,view_t_type0,view_s_type1,view_t_type1)) mask=torch.cat((mask_s,mask_t)).repeat(2,1,1,1) geow=torch.cat((geow_s,geow_t)).repeat(2,1,1,1) complete =torch.cat((complete_s,complete_t)).repeat(2,1,1,1) dataMask = torch.cat((dataMask[:,0,:,:,:],dataMask[:,1,:,:,:])).repeat(2,1,1,1) fake = self.netG(view) with torch.set_grad_enabled(False): fakec = self.netF(complete) if 'f' in self.args.outputType: featMapsc = fakec[:n] featMaptc = fakec[n:n*2] if np.random.rand()>0.5: featMaps = fake[:n,self.args.idx_f_start:self.args.idx_f_end,:,:] featMapt = fake[n:n*2,self.args.idx_f_start:self.args.idx_f_end,:,:] else: featMaps = fake[n*2:n*3,self.args.idx_f_start:self.args.idx_f_end,:,:] featMapt = fake[n*3:n*4,self.args.idx_f_start:self.args.idx_f_end,:,:] if self.args.featurelearning: denseCorres = data['denseCorres'] validCorres=torch.nonzero(denseCorres['valid']==1).view(-1).long() loss_fl, loss_fl_pos, loss_fl_neg = self.contrast_loss(featMaps,featMapt,data['denseCorres']) # categorize each correspondence by whether it contain unobserved point allCorres = torch.cat((denseCorres['idxSrc'],denseCorres['idxTgt'])) corresShape = allCorres.shape allCorres = allCorres.view(-1,2).long() typeIdx = torch.arange(corresShape[0]).view(-1,1).repeat(1,corresShape[1]).view(-1).long() typeIcorresP = mask[typeIdx,0,allCorres[:,1],allCorres[:,0]] typeIcorresP=typeIcorresP.view(2,-1,corresShape[1]).sum(0) denseCorres['observe'] = typeIcorresP loss_fc=torch.pow((fake[:,self.args.idx_f_start:self.args.idx_f_end,:,:]-fakec.detach())*dataMask*geow,2).sum(1).mean() errG_recon = 0 if self.args.GeometricWeight: total_weight = geow[:,0:1,:,:]*dynamicW*dataMask else: total_weight = dynamicW*dataMask if 'rgb' in self.args.outputType: errG_rgb = ((fake[:,self.args.idx_rgb_start:self.args.idx_rgb_end,:,:]-complete[:,0:3,:,:])*total_weight).abs().mean() errG_recon += errG_rgb if 'n' in self.args.outputType: errG_n = ((fake[:,self.args.idx_n_start:self.args.idx_n_end,:,:]-complete[:,3:6,:,:])*total_weight).abs().mean() errG_recon += errG_n if 'd' in self.args.outputType: errG_d = ((fake[:,self.args.idx_d_start:self.args.idx_d_end,:,:]-complete[:,6:7,:,:])*total_weight).abs().mean() errG_recon += errG_d if 'k' in self.args.outputType: errG_k = ((fake[:,self.args.idx_k_start:self.args.idx_k_end,:,:]-complete[:,7:8,:,:])*total_weight).abs().mean() errG_recon += errG_k if 's' in self.args.outputType: errG_s = (CEcriterion(fake[:,self.args.idx_s_start:self.args.idx_s_end,:,:],segm.squeeze(1).long())*total_weight).mean() * 0.1 errG_recon += errG_s errG = errG_recon if self.args.pnloss: loss_pn = util.pnlayer(torch.cat((depth[:,0:1,:,:],depth[:,1:2,:,:])),fake[:,3:6,:,:],fake[:,6:7,:,:]*4,self.args.dataList,self.args.representation)*1e-1 #loss_pn = util.pnlayer(torch.cat((depth[:,0:1,:,:],depth[:,1:2,:,:])),complete[:,3:6,:,:],complete[:,6:7,:,:]*4,self.args.dataList,self.args.representation)*1e-1 errG += loss_pn if self.args.featurelearning: errG += loss_fl+loss_fc #if errG.item()>100: # import ipdb;ipdb.set_trace() if mode == 'train': errG.backward() self.optimizerG.step() self.logger_errG.update(errG.data, Q.size(0)) self.logger_errG_rgb.update(errG_rgb.data, Q.size(0)) self.logger_errG_n.update(errG_n.data, Q.size(0)) self.logger_errG_d.update(errG_d.data, Q.size(0)) self.logger_errG_s.update(errG_s.data, Q.size(0)) self.logger_errG_k.update(errG_k.data, Q.size(0)) if self.args.pnloss: self.logger_errG_pn.update(loss_pn.data, Q.size(0)) if self.args.featurelearning: self.logger_errG_fl.update(loss_fl.data, Q.size(0)) self.logger_errG_fl_pos.update(loss_fl_pos.data, Q.size(0)) self.logger_errG_fl_neg.update(loss_fl_neg.data, Q.size(0)) self.logger_errG_fc.update(loss_fc.data, Q.size(0)) if self.args.objectFreqLoss: self.logger_errG_freq.update(loss_freq.data, Q.size(0)) suffix = f"| errG {self.logger_errG.avg:.6f}| | errG_fl {self.logger_errG_fl.avg:.6f}\ | errG_fl_pos {self.logger_errG_fl_pos.avg:.6f} | errG_fl_neg {self.logger_errG_fl_neg.avg:.6f} | errG_fc {self.logger_errG_fc.avg:.6f} | errG_pn {self.logger_errG_pn.avg:.6f} | errG_freq {self.logger_errG_freq.avg:.6f}" if self.global_step % getattr(self.learnerParam,f"{mode}_step_vis") == 0: print(f"total image trasversed:{len(self.sancheck)}\n") # do logging and visualizing if 'n' in self.args.outputType: # normalized normal faken = fake[:,self.args.idx_n_start:self.args.idx_n_end,:,:] faken = faken/torch.norm(faken,dim=1,keepdim=True) vis = [] if 'rgb' in self.args.outputType: # draw rgb visrgb = complete[:,0:3,:,:] visrgbm = view[:,0:3,:,:] visrgbm2 = view[:,8+0:8+3,:,:] visrgbf = fake[:,self.args.idx_rgb_start:self.args.idx_rgb_end,:,:] visrgbf = visNorm(visrgbf) visrgbc = (fake[:,self.args.idx_rgb_start:self.args.idx_rgb_end,:,:]*(1-mask)+visrgb*mask) visrgbc = visNorm(visrgbc) visrgb = torch.cat((visrgbm,visrgbm2,visrgbf,visrgbc,visrgb),2) visrgb = visNorm(visrgb) vis.append(visrgb) if 'n' in self.args.outputType: # draw normal visn = complete[:,3:6,:,:] visnm = view[:,3:6,:,:] visnm2 = view[:,8+3:8+6,:,:] visnf = faken visnc = (faken*(1-mask)+visn*mask) visn = torch.cat((visnm,visnm2,visnf,visnc,visn),2) visn = visNorm(visn) vis.append(visn) if 'd' in self.args.outputType: # draw depth visd = complete[:,6:7,:,:] visdm = view[:,6:7,:,:] visdm2 = view[:,8+6:8+7,:,:] visdf = fake[:,self.args.idx_d_start:self.args.idx_d_end,:,:] visdc = (fake[:,self.args.idx_d_start:self.args.idx_d_end,:,:]*(1-mask)+visd*mask) visd = torch.cat((visdm,visdm2,visdf,visdc,visd),2) visd = visNorm(visd) visd = visd.repeat(1,3,1,1) vis.append(visd) if 'k' in self.args.outputType: # draw keypoint visk = complete[:,7:8,:,:] viskm = view[:,7:8,:,:] viskm2 = view[:,8+7:8+8,:,:] viskf = fake[:,self.args.idx_k_start:self.args.idx_k_end:,:] viskc = fake[:,self.args.idx_k_start:self.args.idx_k_end:,:].clone() viskc = viskc*(1-mask)+(viskc.view(viskc.shape[0],-1).min(1)[0].view(-1,1,1,1))*mask viskc = visNorm(viskc) viskc = util.extractKeypoint(viskc) viskc = (viskc*(1-mask)+visk*mask) visk = torch.cat((viskm,viskf,viskc,visk),2) visk = visk.repeat(1,3,1,1) vis.append(visk) if 's' in self.args.outputType: # draw semantic viss = segm vissm = viss*mask[:,0:1,:,:] vissf = fake[:,self.args.idx_s_start:self.args.idx_s_end,:,:] vissf = torch.argmax(vissf,1,keepdim=True).float() vissc = (vissf*(1-mask)+viss*mask) viss = torch.cat((vissm,vissf,vissc,viss),2) visstp= torch_op.npy(viss) visstp= np.expand_dims(np.squeeze(visstp,1),3) visstp= self.colors[visstp.flatten().astype('int'),:].reshape(visstp.shape[0],visstp.shape[1],visstp.shape[2],3) viss = torch_op.v(visstp.transpose(0,3,1,2))/255. vis.append(viss) if self.args.dynamicWeighting: visdw = dynamicW.repeat(1,3,1,1) vis.append(visdw) if 'f' in self.args.outputType: # draw feature error map visf = fake[:,self.args.idx_f_start:self.args.idx_f_end,:,:] visf = (visf - fakec).pow(2).sum(1,keepdim=True) visf = visNorm(visf) visf = visf.repeat(1,3,1,1) vis.append(visf) visw = total_weight.repeat(1,3,1,1) vis.append(visw) # concate all vis vis = torch.cat(vis, 2)[::2] permute = [2, 1, 0] # bgr to rgb vis = vis[:,permute,:,:] if mode != 'train': with torch.set_grad_enabled(False): if 'n' and 'd' in self.args.outputType: # evaluate strcuture prediction ## 1. normal angle mask_n=(1-mask[:,0:1,:,:]).cpu() mask_n = mask_n * dataMask.cpu() evalErrN=(torch.acos(((faken.cpu()*complete[:,3:6,:,:].cpu()).sum(1,keepdim=True)[mask_n!=0]).clamp(-1,1))/np.pi*180) self.evalErrN.extend(npy(evalErrN)) ## 2. plane distance evalErrD=((fake[:,6:7,:,:].cpu()-complete[:,6:7,:,:].cpu())[mask_n!=0]).abs() self.evalErrD.extend(npy(evalErrD)) # evaluate the learned feature ## 1. descriptive power of learned feature if self.args.featurelearning: if len(validCorres): self.evalFeatRatioSift.extend(self.evalSiftDescriptor(rgb,denseCorres)) obs,unobs,_=self.evalDLDescriptor(featMapsc,featMaptc,denseCorres,complete_s[:,0:3,:,:],complete_t[:,0:3,:,:],mask[0:1,0:1,:,:]) self.evalFeatRatioDLc_obs.extend(obs) self.evalFeatRatioDLc_unobs.extend(unobs) obs,unobs,_=self.evalDLDescriptor(featMaps,featMapt,denseCorres,complete_s[:,0:3,:,:],complete_t[:,0:3,:,:],mask[0:1,0:1,:,:]) self.evalFeatRatioDL_obs.extend(obs) self.evalFeatRatioDL_unobs.extend(unobs) if self.args.objectFreqLoss: freq_pred = freq_pred/freq_pred.sum(1,keepdim=True) freq_gt = freq_gt/freq_gt.sum(1,keepdim=True) self.evalSemantic.append(torch_op.npy(freq_pred)) self.evalSemantic_gt.append(torch_op.npy(freq_gt)) train_op.tboard_add_img(self.tensorboardX,vis,f"{mode}/loss",self.global_step) if self.global_step % getattr(self.learnerParam,f"{mode}_step_log") == 0: self.tensorboardX.add_scalars('data/errG_recon', {f"{mode}":errG_recon}, self.global_step) self.tensorboardX.add_scalars('data/errG_rgb', {f"{mode}":errG_rgb}, self.global_step) self.tensorboardX.add_scalars('data/errG_n', {f"{mode}":errG_n}, self.global_step) self.tensorboardX.add_scalars('data/errG_d', {f"{mode}":errG_d}, self.global_step) self.tensorboardX.add_scalars('data/errG_s', {f"{mode}":errG_s}, self.global_step) self.tensorboardX.add_scalars('data/errG_k', {f"{mode}":errG_k}, self.global_step) if self.args.pnloss: self.tensorboardX.add_scalars('data/errG_pnloss', {f"{mode}":loss_pn}, self.global_step) if self.args.featurelearning: self.tensorboardX.add_scalars('data/errG_fl', {f"{mode}_complete":loss_fl}, self.global_step) self.tensorboardX.add_scalars('data/errG_fl_pos', {f"{mode}_complete":loss_fl_pos}, self.global_step) self.tensorboardX.add_scalars('data/errG_fl_neg', {f"{mode}_complete":loss_fl_neg}, self.global_step) self.tensorboardX.add_scalars('data/errG_fc', {f"{mode}":loss_fc}, self.global_step) if self.args.objectFreqLoss: self.tensorboardX.add_scalars('data/errG_freq', {f"{mode}":loss_freq}, self.global_step) summary = {'suffix':suffix} self.global_step+=1 if self.speed_benchmark: self.time_per_step.update(time.time()-step_start,1) print(f"time elapse per step: {self.time_per_step.avg}") return dotdict(summary)
args.idx_f_start += 3 if 'n' in args.outputType: args.idx_f_start += 3 if 'd' in args.outputType: args.idx_f_start += 1 if 's' in args.outputType: args.idx_f_start += args.snumclass if 'f' in args.outputType: args.idx_f_end = args.idx_f_start + args.featureDim with torch.set_grad_enabled(False): R_hat = np.eye(4) # get the complete scans complete_s = torch.cat( (torch_op.v(data_s['rgb']), torch_op.v( data_s['norm']), torch_op.v(data_s['depth']).unsqueeze(2)), 2).permute(2, 0, 1).unsqueeze(0) complete_t = torch.cat( (torch_op.v(data_t['rgb']), torch_op.v( data_t['norm']), torch_op.v(data_t['depth']).unsqueeze(2)), 2).permute(2, 0, 1).unsqueeze(0) # apply the observation mask view_s, mask_s, _ = util.apply_mask(complete_s.clone(), args.maskMethod) view_t, mask_t, _ = util.apply_mask(complete_t.clone(), args.maskMethod) mask_s = torch_op.npy(mask_s[0, :, :, :]).transpose(1, 2, 0) mask_t = torch_op.npy(mask_t[0, :, :, :]).transpose(1, 2, 0)
def userConfig(self): """ include the task specific setup here """ if self.args.featurelearning: assert('f' in self.args.outputType) pointer = 0 if 'rgb' in self.args.outputType: self.args.idx_rgb_start = pointer self.args.idx_rgb_end = pointer + 3 pointer += 3 if 'n' in self.args.outputType: self.args.idx_n_start = pointer self.args.idx_n_end = pointer + 3 pointer += 3 if 'd' in self.args.outputType: self.args.idx_d_start = pointer self.args.idx_d_end = pointer + 1 pointer += 1 if 'k' in self.args.outputType: self.args.idx_k_start = pointer self.args.idx_k_end = pointer + 1 pointer += 1 if 's' in self.args.outputType: self.args.idx_s_start = pointer self.args.idx_s_end = pointer + self.args.snumclass # 21 class pointer += self.args.snumclass if 'f' in self.args.outputType: self.args.idx_f_start = pointer self.args.idx_f_end = pointer + self.args.featureDim pointer += self.args.featureDim self.args.num_output = pointer self.args.num_input = 8*2 self.args.ngpu = int(1) self.args.nz = int(100) self.args.ngf = int(64) self.args.ndf = int(64) self.args.nef = int(64) self.args.nBottleneck = int(4000) self.args.wt_recon = float(0.998) self.args.wtlD = float(0.002) self.args.overlapL2Weight = 10 # setup logger self.tensorboardX = SummaryWriter(log_dir=os.path.join(self.args.EXP_DIR, 'tensorboard')) self.logger = log.logging(self.args.EXP_DIR_LOG) self.logger_errG = AverageMeter() self.logger_errG_recon = AverageMeter() self.logger_errG_rgb = AverageMeter() self.logger_errG_d = AverageMeter() self.logger_errG_n = AverageMeter() self.logger_errG_s = AverageMeter() self.logger_errG_k = AverageMeter() self.logger_errD_fake = AverageMeter() self.logger_errD_real = AverageMeter() self.logger_errG_fl = AverageMeter() self.logger_errG_fl_pos = AverageMeter() self.logger_errG_fl_neg = AverageMeter() self.logger_errG_fl_f = AverageMeter() self.logger_errG_fc = AverageMeter() self.logger_errG_pn = AverageMeter() self.logger_errG_freq = AverageMeter() self.global_step=0 self.speed_benchmark=True if self.speed_benchmark: self.time_per_step=AverageMeter() self.sift = cv2.xfeatures2d.SIFT_create() self.evalFeatRatioDL_obs,self.evalFeatRatioDL_unobs=[],[] self.evalFeatRatioDLc_obs,self.evalFeatRatioDLc_unobs=[],[] self.evalFeatRatioSift=[] self.evalErrN=[] self.evalErrD=[] self.evalSemantic = [] self.evalSemantic_gt = [] self.sancheck={} # semantic encoding if 'scannet' in self.args.dataList: self.colors = config.scannet_color_palette elif 'matterport' in self.args.dataList: self.colors = config.matterport_color_palette elif 'suncg' in self.args.dataList: self.colors = config.suncg_color_palette self.class_balance_weights = torch_op.v(np.ones([self.args.snumclass]))
def RelativePoseEstimationViaCompletion(net, data_s, data_t, args): """ The main algorithm: Given two set of scans, alternate between scan completion and pairwise matching args need to contain: snumclass: number of semantic class featureDim: feature dimension outputType: ['rgb':color,'d':depth,'n':normal,'s':semantic,'f':feature] maskMethod: ['second'] alterStep: dataset: para: """ EPS = 1e-12 args.idx_f_start = 0 if 'rgb' in args.outputType: args.idx_f_start += 3 if 'n' in args.outputType: args.idx_f_start += 3 if 'd' in args.outputType: args.idx_f_start += 1 if 's' in args.outputType: args.idx_f_start += args.snumclass if 'f' in args.outputType: args.idx_f_end = args.idx_f_start + args.featureDim with torch.set_grad_enabled(False): R_hat=np.eye(4) # get the complete scans complete_s=torch.cat((torch_op.v(data_s['rgb']),torch_op.v(data_s['norm']),torch_op.v(data_s['depth']).unsqueeze(2)),2).permute(2,0,1).unsqueeze(0) complete_t=torch.cat((torch_op.v(data_t['rgb']),torch_op.v(data_t['norm']),torch_op.v(data_t['depth']).unsqueeze(2)),2).permute(2,0,1).unsqueeze(0) # apply the observation mask view_s,mask_s,_ = util.apply_mask(complete_s.clone(),args.maskMethod) view_t,mask_t,_ = util.apply_mask(complete_t.clone(),args.maskMethod) mask_s=torch_op.npy(mask_s[0,:,:,:]).transpose(1,2,0) mask_t=torch_op.npy(mask_t[0,:,:,:]).transpose(1,2,0) # append mask for valid data tpmask = (view_s[:,6:7,:,:]!=0).float().cuda() view_s=torch.cat((view_s,tpmask),1) tpmask = (view_t[:,6:7,:,:]!=0).float().cuda() view_t=torch.cat((view_t,tpmask),1) for alter_ in range(args.alterStep): # warp the second scan using current transformation estimation view_t2s=torch_op.v(util.warping(torch_op.npy(view_t),np.linalg.inv(R_hat),args.dataset)) view_s2t=torch_op.v(util.warping(torch_op.npy(view_s),R_hat,args.dataset)) # append the warped scans view0 = torch.cat((view_s,view_t2s),1) view1 = torch.cat((view_t,view_s2t),1) # generate complete scans f=net(torch.cat((view0,view1))) f0=f[0:1,:,:,:] f1=f[1:2,:,:,:] data_sc,data_tc={},{} # replace the observed region with gt depth/normal data_sc['normal'] = (1-mask_s)*torch_op.npy(f0[0,3:6,:,:]).transpose(1,2,0)+mask_s*data_s['norm'] data_tc['normal'] = (1-mask_t)*torch_op.npy(f1[0,3:6,:,:]).transpose(1,2,0)+mask_t*data_t['norm'] data_sc['normal']/= (np.linalg.norm(data_sc['normal'],axis=2,keepdims=True)+EPS) data_tc['normal']/= (np.linalg.norm(data_tc['normal'],axis=2,keepdims=True)+EPS) data_sc['depth'] = (1-mask_s[:,:,0])*torch_op.npy(f0[0,6,:,:])+mask_s[:,:,0]*data_s['depth'] data_tc['depth'] = (1-mask_t[:,:,0])*torch_op.npy(f1[0,6,:,:])+mask_t[:,:,0]*data_t['depth'] data_sc['obs_mask'] = mask_s.copy() data_tc['obs_mask'] = mask_t.copy() data_sc['rgb'] = (mask_s*data_s['rgb']*255).astype('uint8') data_tc['rgb'] = (mask_t*data_t['rgb']*255).astype('uint8') # for scannet, we use the original size rgb image(480x640) to extract sift keypoint if 'scannet' in args.dataset: data_sc['rgb_full'] = (data_s['rgb_full']*255).astype('uint8') data_tc['rgb_full'] = (data_t['rgb_full']*255).astype('uint8') data_sc['depth_full'] = data_s['depth_full'] data_tc['depth_full'] = data_t['depth_full'] # extract feature maps f0_feat=f0[:,args.idx_f_start:args.idx_f_end,:,:] f1_feat=f1[:,args.idx_f_start:args.idx_f_end,:,:] data_sc['feat']=f0_feat.squeeze(0) data_tc['feat']=f1_feat.squeeze(0) para_this = copy.copy(args.para) para_this.sigmaAngle1 = para_this.sigmaAngle1[alter_] para_this.sigmaAngle2 = para_this.sigmaAngle2[alter_] para_this.sigmaDist = para_this.sigmaDist[alter_] para_this.sigmaFeat = para_this.sigmaFeat[alter_] # run relative pose module to get next estimate R_hat = RelativePoseEstimation(data_sc,data_tc,para_this,args.dataset,args.representation,doCompletion=args.completion,maskMethod=args.maskMethod,index=None) return R_hat
confusion_matrix_data = {'pred': [], 'gt': []} fp = './tmp/%s_local_result_%.3f_%.3f_%.3f.txt' % ( args.dataset, args.thre_coplane, args.thre_parallel, args.thre_perp) with torch.set_grad_enabled(False): for batch_id, data in enumerate(val_loader): if 'suncg' in args.dataList: THRESH = THRESH = np.array( [args.thre_coplane, args.thre_parallel, args.thre_perp]) elif 'scannet' in args.dataList: THRESH = np.array( [args.thre_coplane, args.thre_parallel, args.thre_perp]) elif 'matterport' in args.dataList: THRESH = np.array( [args.thre_coplane, args.thre_parallel, args.thre_perp]) rgb, depth, dataMask, R, overlap = v(data['rgb']), v( data['depth']), v(data['dataMask']), v(data['R']), npy( data['overlap']) pointcloud = v(data['pointcloud']) igt = v(data['igt']) if args.local_method == 'patch': pair = v(data['pair']) rel_dst = v(data['rel_dst']) rel_cls = v(data['rel_cls']) rel_ndot = v(data['rel_ndot']) rel_valid = v(data['rel_valid']) plane_idx = v(data['plane_idx']) plane_center = v(data['plane_center']) elif args.local_method == 'point':
def pnlayer(depth, normal, plane, dataList, representation): # dp: [n,1,h,w] # n: [n,3,h,w] if 'suncg' in dataList or 'matterport' in dataList: n, h, w = depth.shape[0], depth.shape[2], depth.shape[3] assert (h == w // 4) Rs = np.zeros([4, 4, 4]) Rs[0] = np.eye(4) Rs[1] = np.array([[0, 0, -1, 0], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]]) Rs[2] = np.array([[-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) Rs[3] = np.array([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]) Rs = torch_op.v(Rs) loss_pn = 0 for i in range(4): import ipdb ipdb.set_trace() plane_this = plane[:, 0, :, i * h:(i + 1) * h].contiguous() depth_this = depth[:, 0, :, i * h:(i + 1) * h].contiguous() ys, xs = np.meshgrid(range(h), range(h), indexing='ij') ys, xs = (0.5 - ys / h) * 2, (xs / h - 0.5) * 2 xs = xs.flatten() ys = ys.flatten() zs = plane_this.view(-1) mask = (zs != 0) masknpy = torch_op.npy(mask) normal_this = normal[:, :, :, i * h:(i + 1) * h].permute( 0, 2, 3, 1).contiguous().view(-1, 3) if 'suncg' in dataList: normal_this = torch.matmul(Rs[i][:3, :3].t(), normal_this.t()).t() elif 'matterport' in dataList: normal_this = torch.matmul(Rs[(i - 1) % 4][:3, :3].t(), normal_this.t()).t() ray = np.tile( np.stack((-xs[masknpy], -ys[masknpy], np.ones(len(xs))), 1), [n, 1]) ray = torch_op.v(ray) pcPn = (zs / (ray * normal_this + 1e-6).sum(1)).unsqueeze(1) * ray xs = torch_op.v(np.tile(xs, n)) ys = torch_op.v(np.tile(ys, n)) zs = depth_this.view(-1) xs = xs * zs ys = ys * zs pcD = torch.stack((xs, ys, -zs), 1) loss_pn += (pcD - pcPn).clamp(-5, 5).abs().mean() elif 'scannet' in dataList: if representation == 'expand': n, h, w = depth.shape[0], depth.shape[2], depth.shape[3] loss_pn = 0 for ii in range(n): plane_this = plane[ii, 0, :, :].contiguous() depth_this = depth[ii, 0, :, :].contiguous() mask = (depth_this.view(-1) != 0) masknpy = torch_op.npy(mask).astype('bool') ys, xs = np.meshgrid(range(h), range(w), indexing='ij') ys, xs = (0.5 - ys / h) * 2, (xs / w - 0.5) * 2 xs = xs.flatten() ys = ys.flatten() zs = plane_this.view(-1) zs = zs[mask] xs = xs[masknpy] ys = ys[masknpy] normal_this = normal[ii].permute(1, 2, 0).contiguous().view(-1, 3) ray = np.stack( (-xs / 0.89218745, -ys / 1.18958327, np.ones(len(xs))), 1) ray = torch_op.v(ray) pcPn = ( zs / (ray * normal_this[mask] + 1e-6).sum(1)).unsqueeze(1) * ray zs = depth_this.view(-1) zs = zs[mask] xs = torch_op.v(xs) ys = torch_op.v(ys) xs = xs * zs / 0.89218745 ys = ys * zs / 1.18958327 pcD = torch.stack((xs, ys, -zs), 1) loss_pn += (pcD - pcPn).clamp(-5, 5).abs().mean() return loss_pn
def getKeypoint_kinect(rs, rt, feats, featt, rs_full, rt_full): H, W = 160, 640 KINECT_W = 640 KINECT_H = 480 KINECT_FOV_W = 88 KINECT_FOV_H = 66 N_SIFT = 300 N_SIFT_MATCH = 30 N_RANDOM = 100 MARKER = 0.99 SIFT_THRE = 0.02 TOPK = 2 grays = cv2.cvtColor(rs, cv2.COLOR_BGR2GRAY) grayt = cv2.cvtColor(rt, cv2.COLOR_BGR2GRAY) sift = cv2.xfeatures2d.SIFT_create(contrastThreshold=SIFT_THRE) grays = cv2.cvtColor(rs_full, cv2.COLOR_BGR2GRAY) (kps, _) = sift.detectAndCompute(grays, None) if not len(kps): return None, None, None, None, None, None pts = np.zeros([len(kps), 2]) for j, m in enumerate(kps): pts[j, :] = m.pt pts[:, 0] = pts[:, 0] / KINECT_W * KINECT_FOV_W # the observed region size of kinect camera is [88x66] pts[:, 1] = pts[:, 1] / KINECT_H * KINECT_FOV_H pts[:, 0] += H + H // 2 - KINECT_FOV_W // 2 pts[:, 1] += H // 2 - KINECT_FOV_H // 2 grayt = cv2.cvtColor(rt_full, cv2.COLOR_BGR2GRAY) (kpt, _) = sift.detectAndCompute(grayt, None) if not len(kpt): return None, None, None, None, None, None ptt = np.zeros([len(kpt), 2]) for j, m in enumerate(kpt): ptt[j, :] = m.pt ptt[:, 0] = ptt[:, 0] / KINECT_W * KINECT_FOV_W ptt[:, 1] = ptt[:, 1] / KINECT_H * KINECT_FOV_H ptt[:, 0] += H + H // 2 - KINECT_FOV_W // 2 ptt[:, 1] += H // 2 - KINECT_FOV_H // 2 pts = pts[np.random.choice(range(len(pts)), N_SIFT), :] ptt = ptt[np.random.choice(range(len(ptt)), N_SIFT), :] ptsNorm = pts.copy().astype('float') ptsNorm[:, 0] /= W ptsNorm[:, 1] /= H pttNorm = ptt.copy().astype('float') pttNorm[:, 0] /= W pttNorm[:, 1] /= H fs0 = interpolate(feats, torch_op.v(ptsNorm)) ft0 = interpolate(featt, torch_op.v(pttNorm)) # find the most probable correspondence using feature map C = feats.shape[0] fsselect = np.random.choice(range(pts.shape[0]), min(N_SIFT_MATCH, pts.shape[0])) ftselect = np.random.choice(range(ptt.shape[0]), min(N_SIFT_MATCH, ptt.shape[0])) dist = (fs0[:, fsselect].unsqueeze(2) - featt.view(C, 1, -1)).pow(2).sum(0).view(len(fsselect), H, W) pttAug = Sampling(torch_op.npy(dist), TOPK) dist = (ft0[:, ftselect].unsqueeze(2) - feats.view(C, 1, -1)).pow(2).sum(0).view(len(ftselect), H, W) ptsAug = Sampling(torch_op.npy(dist), TOPK) pttAug = pttAug.reshape(-1, 2) ptsAug = ptsAug.reshape(-1, 2) valid = (pttAug[:, 0] < W - 1) * (pttAug[:, 1] < H - 1) pttAug = pttAug[valid] valid = (ptsAug[:, 0] < W - 1) * (ptsAug[:, 1] < H - 1) ptsAug = ptsAug[valid] pts = np.concatenate((pts, ptsAug)) ptt = np.concatenate((ptt, pttAug)) N = 120 xs = (np.random.rand(N) * W).astype('int').clip(0, W - 2) ys = (np.random.rand(N) * H).astype('int').clip(0, H - 2) ptsrnd = np.stack((xs, ys), 1) # filter out observed region valid = ((ptsrnd[:, 0] >= H + H // 2 - KINECT_FOV_W // 2) * (ptsrnd[:, 0] <= H + H // 2 + KINECT_FOV_W // 2) * (ptsrnd[:, 1] >= H // 2 - KINECT_FOV_H // 2) * (ptsrnd[:, 1] <= H // 2 + KINECT_FOV_H // 2)) ptsrnd = ptsrnd[~valid] ptsrndNorm = ptsrnd.copy().astype('float') ptsrndNorm[:, 0] /= W ptsrndNorm[:, 1] /= H fs0 = interpolate(feats, torch_op.v(ptsrndNorm)) fsselect = np.random.choice(range(ptsrnd.shape[0]), min(N_RANDOM, ptsrnd.shape[0])) dist = (fs0[:, fsselect].unsqueeze(2) - featt.view(C, 1, -1)).pow(2).sum(0).view(len(fsselect), H, W) pttAug = Sampling(torch_op.npy(dist), TOPK) pttAug = pttAug.reshape(-1, 2) valid = (pttAug[:, 0] < W - 1) * (pttAug[:, 1] < H - 1) pttAug = pttAug[valid] pts = np.concatenate((pts, ptsrnd[fsselect])) ptt = np.concatenate((ptt, pttAug)) ptsNorm = pts.copy().astype('float') ptsNorm[:, 0] /= W ptsNorm[:, 1] /= H pttNorm = ptt.copy().astype('float') pttNorm[:, 0] /= W pttNorm[:, 1] /= H # hacks to get the points belongs to kinect observed region. valid = ((pts[:, 0] >= H + H // 2 - KINECT_FOV_W // 2) * (pts[:, 0] <= H + H // 2 + KINECT_FOV_W // 2) * (pts[:, 1] >= H // 2 - KINECT_FOV_H // 2) * (pts[:, 1] <= H // 2 + KINECT_FOV_H // 2)) ptsW = np.ones(len(valid)) ptsW[~valid] *= MARKER valid = ((ptt[:, 0] >= H + H // 2 - KINECT_FOV_W // 2) * (ptt[:, 0] <= H + H // 2 + KINECT_FOV_W // 2) * (ptt[:, 1] >= H // 2 - KINECT_FOV_H // 2) * (ptt[:, 1] <= H // 2 + KINECT_FOV_H // 2)) pttW = np.ones(len(valid)) pttW[~valid] *= MARKER return pts, ptsNorm, ptsW, ptt, pttNorm, pttW
def getKeypoint(rs, rt, feats, featt): H, W = 160, 640 N_SIFT_MATCH = 30 N_RANDOM = 30 MARKER = 0.99 SIFT_THRE = 0.02 TOPK = 2 grays = cv2.cvtColor(rs, cv2.COLOR_BGR2GRAY) grayt = cv2.cvtColor(rt, cv2.COLOR_BGR2GRAY) sift = cv2.xfeatures2d.SIFT_create(contrastThreshold=SIFT_THRE) grays = grays[:, H:H * 2] (kps, _) = sift.detectAndCompute(grays, None) if not len(kps): return None, None, None, None, None, None pts = np.zeros([len(kps), 2]) for j, m in enumerate(kps): pts[j, :] = m.pt pts[:, 0] += H grayt = grayt[:, H:H * 2] (kpt, _) = sift.detectAndCompute(grayt, None) if not len(kpt): return None, None, None, None, None, None ptt = np.zeros([len(kpt), 2]) for j, m in enumerate(kpt): ptt[j, :] = m.pt ptt[:, 0] += H ptsNorm = pts.copy().astype('float') ptsNorm[:, 0] /= W ptsNorm[:, 1] /= H pttNorm = ptt.copy().astype('float') pttNorm[:, 0] /= W pttNorm[:, 1] /= H fs0 = interpolate(feats, torch_op.v(ptsNorm)) ft0 = interpolate(featt, torch_op.v(pttNorm)) # find the most probable correspondence using feature map C = feats.shape[0] fsselect = np.random.choice(range(pts.shape[0]), min(N_SIFT_MATCH, pts.shape[0])) ftselect = np.random.choice(range(ptt.shape[0]), min(N_SIFT_MATCH, ptt.shape[0])) dist = (fs0[:, fsselect].unsqueeze(2) - featt.view(C, 1, -1)).pow(2).sum(0).view(len(fsselect), H, W) pttAug = Sampling(torch_op.npy(dist), TOPK) dist = (ft0[:, ftselect].unsqueeze(2) - feats.view(C, 1, -1)).pow(2).sum(0).view(len(ftselect), H, W) ptsAug = Sampling(torch_op.npy(dist), TOPK) pttAug = pttAug.reshape(-1, 2) ptsAug = ptsAug.reshape(-1, 2) valid = (pttAug[:, 0] < W - 1) * (pttAug[:, 1] < H - 1) pttAug = pttAug[valid] valid = (ptsAug[:, 0] < W - 1) * (ptsAug[:, 1] < H - 1) ptsAug = ptsAug[valid] pts = np.concatenate((pts, ptsAug)) ptt = np.concatenate((ptt, pttAug)) xs = (np.random.rand(N_RANDOM) * W).astype('int').clip(0, W - 2) ys = (np.random.rand(N_RANDOM) * H).astype('int').clip(0, H - 2) ptsrnd = np.stack((xs, ys), 1) valid = ((ptsrnd[:, 0] >= H) * (ptsrnd[:, 0] <= H * 2)) ptsrnd = ptsrnd[~valid] ptsrndNorm = ptsrnd.copy().astype('float') ptsrndNorm[:, 0] /= W ptsrndNorm[:, 1] /= H fs0 = interpolate(feats, torch_op.v(ptsrndNorm)) fsselect = np.random.choice(range(ptsrnd.shape[0]), min(N_RANDOM, ptsrnd.shape[0])) dist = (fs0[:, fsselect].unsqueeze(2) - featt.view(C, 1, -1)).pow(2).sum(0).view(len(fsselect), H, W) pttAug = Sampling(torch_op.npy(dist), TOPK) pttAug = pttAug.reshape(-1, 2) valid = (pttAug[:, 0] < W - 1) * (pttAug[:, 1] < H - 1) pttAug = pttAug[valid] pts = np.concatenate((pts, ptsrnd[fsselect])) ptt = np.concatenate((ptt, pttAug)) ptsNorm = pts.copy().astype('float') ptsNorm[:, 0] /= W ptsNorm[:, 1] /= H pttNorm = ptt.copy().astype('float') pttNorm[:, 0] /= W pttNorm[:, 1] /= H valid = (pts[:, 0] >= H) * (pts[:, 0] <= H * 2) ptsW = np.ones(len(valid)) ptsW[~valid] *= MARKER valid = (ptt[:, 0] >= H) * (ptt[:, 0] <= H * 2) pttW = np.ones(len(valid)) pttW[~valid] *= MARKER return pts, ptsNorm, ptsW, ptt, pttNorm, pttW
def getKeypoint(rs, rt, feats, featt): h, w = 160, 640 grays = cv2.cvtColor(rs, cv2.COLOR_BGR2GRAY) grayt = cv2.cvtColor(rt, cv2.COLOR_BGR2GRAY) sift = cv2.xfeatures2d.SIFT_create( contrastThreshold=0.02) # default is 0.04 grays = grays[:, 160:160 * 2] (kps, _) = sift.detectAndCompute(grays, None) if not len(kps): return None, None, None, None, None, None pts = np.zeros([len(kps), 2]) for j, m in enumerate(kps): pts[j, :] = m.pt pts[:, 0] += 160 grayt = grayt[:, 160:160 * 2] (kpt, _) = sift.detectAndCompute(grayt, None) if not len(kpt): return None, None, None, None, None, None ptt = np.zeros([len(kpt), 2]) for j, m in enumerate(kpt): ptt[j, :] = m.pt ptt[:, 0] += 160 ptsNorm = pts.copy().astype('float') ptsNorm[:, 0] /= 640 ptsNorm[:, 1] /= 160 pttNorm = ptt.copy().astype('float') pttNorm[:, 0] /= 640 pttNorm[:, 1] /= 160 fs0 = interpolate(feats, torch_op.v(ptsNorm)) ft0 = interpolate(featt, torch_op.v(pttNorm)) # find the most probable correspondence using feature map C = feats.shape[0] fsselect = np.random.choice(range(pts.shape[0]), min(30, pts.shape[0])) ftselect = np.random.choice(range(ptt.shape[0]), min(30, ptt.shape[0])) dist = (fs0[:, fsselect].unsqueeze(2) - featt.view(C, 1, -1)).pow(2).sum(0).view(len(fsselect), h, w) pttAug = Sampling(torch_op.npy(dist), 2) dist = (ft0[:, ftselect].unsqueeze(2) - feats.view(C, 1, -1)).pow(2).sum(0).view(len(ftselect), h, w) ptsAug = Sampling(torch_op.npy(dist), 2) pttAug = pttAug.reshape(-1, 2) ptsAug = ptsAug.reshape(-1, 2) valid = (pttAug[:, 0] < w - 1) * (pttAug[:, 1] < h - 1) pttAug = pttAug[valid] valid = (ptsAug[:, 0] < w - 1) * (ptsAug[:, 1] < h - 1) ptsAug = ptsAug[valid] pts = np.concatenate((pts, ptsAug)) ptt = np.concatenate((ptt, pttAug)) N = 300 // 10 xs = (np.random.rand(N) * 640).astype('int').clip(0, 640 - 2) ys = (np.random.rand(N) * 160).astype('int').clip(0, 160 - 2) ptsrnd = np.stack((xs, ys), 1) valid = ((ptsrnd[:, 0] >= 160) * (ptsrnd[:, 0] <= 160 * 2)) ptsrnd = ptsrnd[~valid] ptsrndNorm = ptsrnd.copy().astype('float') ptsrndNorm[:, 0] /= 640 ptsrndNorm[:, 1] /= 160 fs0 = interpolate(feats, torch_op.v(ptsrndNorm)) fsselect = np.random.choice(range(ptsrnd.shape[0]), min(100, ptsrnd.shape[0])) dist = (fs0[:, fsselect].unsqueeze(2) - featt.view(C, 1, -1)).pow(2).sum(0).view(len(fsselect), h, w) pttAug = Sampling(torch_op.npy(dist), 2) pttAug = pttAug.reshape(-1, 2) valid = (pttAug[:, 0] < w - 1) * (pttAug[:, 1] < h - 1) pttAug = pttAug[valid] pts = np.concatenate((pts, ptsrnd[fsselect])) ptt = np.concatenate((ptt, pttAug)) ptsNorm = pts.copy().astype('float') ptsNorm[:, 0] /= 640 ptsNorm[:, 1] /= 160 pttNorm = ptt.copy().astype('float') pttNorm[:, 0] /= 640 pttNorm[:, 1] /= 160 valid = (pts[:, 0] >= 160) * (pts[:, 0] <= 160 * 2) ptsW = np.ones(len(valid)) ptsW[~valid] *= 0.99 valid = (ptt[:, 0] >= 160) * (ptt[:, 0] <= 160 * 2) pttW = np.ones(len(valid)) pttW[~valid] *= 0.99 return pts, ptsNorm, ptsW, ptt, pttNorm, pttW
def forward(self, topdown_gt, rgb, imgPCid, img2ind, partial, pc2ind, use_predicted_plane): imgFeat = self.imgBackbone(rgb) if 1: # topdownfeat_gt = self.imgBackbone_topdown(topdown_gt) topdownfeat_gt = None else: topdownfeat_gt = [] for i in range(topdown_gt.shape[0]): gray = cv2.cvtColor((npy(topdown_gt[i]).transpose(1, 2, 0) * 255).astype('uint8'), cv2.COLOR_BGR2GRAY) sift = cv2.xfeatures2d.SIFT_create() step_size = 5 kp = [ cv2.KeyPoint(x, y, step_size) for y in range(0, gray.shape[0], step_size) for x in range(0, gray.shape[1], step_size) ] dense_feat = sift.compute(gray, kp)[1] dense_feat = dense_feat.reshape(45, 45, 128) dense_feat = cv2.resize(dense_feat, (224, 224)) topdownfeat_gt.append(dense_feat.transpose(2, 0, 1)) topdownfeat_gt = np.stack(topdownfeat_gt) topdownfeat_gt = v(topdownfeat_gt) bindex = v( np.tile(np.arange(rgb.shape[0])[:, None], [1, imgPCid.shape[1]]).reshape(-1)).long() features = imgFeat[bindex, :, imgPCid[:, :, 1].view(-1).long(), imgPCid[:, :, 0].view(-1).long()] #features = rgb[bindex, :, imgPCid[:,:,1].view(-1).long(), imgPCid[:,:,0].view(-1).long()] pointwise_features = features.view(rgb.shape[0], imgPCid.shape[1], -1).permute(0, 2, 1) # pointwise_features = self.drop_layer(pointwise_features) pc2ind = img2ind n = partial.shape[0] features = self.conv1(partial) features = self.conv2(features) features_global = features.max(-1)[0][..., None].repeat( 1, 1, partial.shape[-1]) features = torch.cat([features, features_global], dim=1) features = self.conv3(features) features = self.conv4(features) features = self.conv5(features) features = self.conv6(features) features_plane = features.max(2)[0][:, :, None] plane_pred = self.conv_plane(features_plane).squeeze(-1) plane_pred = torch.cat( (plane_pred[:, :3] / torch.norm(plane_pred[:, :3], dim=1, keepdim=True), plane_pred[:, 3:4]), -1) if 1: pointwise_features_pnet = self.conv7(features) pointwise_features = torch.cat( (pointwise_features, pointwise_features_pnet), 1) # features = partial[:,6:9,:] features_out_pnet = self.conv_semantic_pnet( pointwise_features).squeeze(1) features_out = self.conv_semantic(imgFeat) origins = np.zeros([n, 3]) axis_xs = np.zeros([n, 3]) axis_ys = np.zeros([n, 3]) axis_zs = np.zeros([n, 3]) height = 224 width = 224 if use_predicted_plane: pc2ind = np.zeros([n, partial.shape[-1], 3]) for i in range(n): origin_0 = npy(-plane_pred[i, :3] * plane_pred[i, 3]) # axis [0,0,-1], [] axis_base = np.array([0, 0, -1]) axis_y_0 = axis_base - np.dot(axis_base, npy( plane_pred[i, :3])) * npy(plane_pred[i, :3]) axis_y_0 /= (np.linalg.norm(axis_y_0) + 1e-16) axis_x_0 = np.cross(axis_y_0, npy(plane_pred[i, :3])) axis_x_0 /= (np.linalg.norm(axis_x_0) + 1e-16) axis_z_0 = npy(plane_pred[i, :3]) origins[i] = origin_0 axis_xs[i] = axis_x_0 axis_ys[i] = axis_y_0 axis_zs[i] = axis_z_0 pc0 = npy(partial[i, :3, :]).T colors = np.random.rand(self.nclass, 3) topdown_c_partial_0, _, topdown_ind_0 = util.topdown_projection( pc0, np.ones([pc0.shape[0]]).astype('uint8'), colors, origin_0, axis_x_0, axis_y_0, axis_z_0, height, width, self.resolution) pc2ind[i] = topdown_ind_0 pc2ind = v(pc2ind) mask_u = (pc2ind[:, :, 0] >= 0) & (pc2ind[:, :, 0] < width) mask_v = (pc2ind[:, :, 1] >= 0) & (pc2ind[:, :, 1] < height) pointwise_features = pointwise_features.permute(0, 2, 1) mask = mask_u & mask_v featgrids = [] for i in range(n): feat = pointwise_features[i][mask[i]] # index = (pc2ind[i][mask[i]][:,1]*400 + pc2ind[i][mask[i]][:,0]).long() index = (pc2ind[i][mask[i]][:, 2] * height * width + pc2ind[i][mask[i]][:, 1] * width + pc2ind[i][mask[i]][:, 0]).long() # featgrid = torch_scatter.scatter_mean(feat, index,dim=0,dim_size=400*400).view(400,400,-1) # featgrid = torch_scatter.scatter_mean(feat, index,dim=0,dim_size=400*400*4).view(400,400,-1) featgrid = torch_scatter.scatter_mean( feat, index, dim=0, dim_size=height * width * 4).view(4, height, width, -1) # featgrid = torch_scatter.scatter_max(feat, index,dim=0,dim_size=height*width*4,fill_value=0)[0].view(4,height,width,-1) featgrids.append(featgrid) featgrids = torch.stack(featgrids, 0) # featgrids = featgrids.permute(0,3,1,2) featgrids = featgrids.permute(0, 1, 4, 2, 3).contiguous() featgrids = featgrids.view(featgrids.shape[0], -1, featgrids.shape[3], featgrids.shape[4]) # torch_scatter.scatter_mean(pointwise_features, (pc2ind[:,:,1]*400 + pc2ind[:,:,0]).unsqueeze(1).clamp(0,10).long(),dim=-1,dim_size=400*400) #cv2.imwrite('test.png',npy(1-featgrids[1,3:3*2,:,:]).transpose(1,2,0)*255) #cv2.imwrite('test2.png',(1-topdown_vis[1])*255) #util.write_ply('test.ply',npy(partial[1,:3]).T) #util.write_ply('test1.ply',npy(roompc[0])) topdown_pred, topdown_feat = self.topdownnet(featgrids) return topdownfeat_gt, topdown_pred, topdown_feat, features_out, features_out_pnet, plane_pred, v( origins), v(axis_xs), v(axis_ys), v(axis_zs)
def step(self,data,mode='train'): torch.cuda.empty_cache() if self.speed_benchmark: step_start=time.time() with torch.set_grad_enabled(mode == 'train'): np.random.seed() self.optimizerF.zero_grad() MSEcriterion = torch.nn.MSELoss() BCEcriterion = torch.nn.BCELoss() CEcriterion = nn.CrossEntropyLoss(weight=self.class_balance_weights,reduce=False) rgb,norm,depth,dataMask,Q = v(data['rgb']),v(data['norm']),v(data['depth']),v(data['dataMask']),v(data['Q']) segm = v(data['segm']) segm = torch.cat((segm[:,0,:,:,:],segm[:,1,:,:,:])) errG_rgb,errG_d,errG_n,errG_k,errG_s = torch.FloatTensor([0]),torch.FloatTensor([0]),torch.FloatTensor([0]),torch.FloatTensor([0]),torch.FloatTensor([0]) n = Q.shape[0] # compose the input: [rgb, normal, depth] complete0=torch.cat((rgb[:,0,:,:,:],norm[:,0,:,:,:],depth[:,0:1,:,:]),1) complete1=torch.cat((rgb[:,1,:,:,:],norm[:,1,:,:,:],depth[:,1:2,:,:]),1) _,mask0,_ = apply_mask(complete0.clone(),self.args.maskMethod,self.args.ObserveRatio) _,mask1,_ = apply_mask(complete1.clone(),self.args.maskMethod,self.args.ObserveRatio) mask=torch.cat((mask0,mask1)) # mask the pano complete =torch.cat((complete0,complete1)) dataMask = torch.cat((dataMask[:,0,:,:,:],dataMask[:,1,:,:,:])) fakec = self.netF(complete) segm_pred = self.netSemg(fakec) featMapsc = fakec[:n] featMaptc = fakec[n:] denseCorres = data['denseCorres'] validCorres=torch.nonzero(denseCorres['valid']==1).view(-1).long() if not len(validCorres): loss_fl_pos=torch_op.v(np.array([0]))[0] loss_fl_neg=torch_op.v(np.array([0]))[0] loss_fl=torch_op.v(np.array([0]))[0][0] loss_fc=torch_op.v(np.array([0]))[0] loss_fl_pos_f=torch_op.v(np.array([0]))[0] loss_fl_neg_f=torch_op.v(np.array([0]))[0] loss_fl_f=torch_op.v(np.array([0]))[0] else: # categorize each correspondence by whether it contain unobserved point allCorres = torch.cat((denseCorres['idxSrc'],denseCorres['idxTgt'])) corresShape = allCorres.shape allCorres = allCorres.view(-1,2).long() typeIdx = torch.arange(corresShape[0]).view(-1,1).repeat(1,corresShape[1]).view(-1).long() typeIcorresP = mask[typeIdx,0,allCorres[:,1],allCorres[:,0]] typeIcorresP=typeIcorresP.view(2,-1,corresShape[1]).sum(0) denseCorres['observe'] = typeIcorresP # consistency of keypoint proposal across different view idxInst=torch.arange(n)[validCorres].view(-1,1).repeat(1,denseCorres['idxSrc'].shape[1]).view(-1).long() featS=featMapsc[idxInst,:,denseCorres['idxSrc'][validCorres,:,1].view(-1).long(),denseCorres['idxSrc'][validCorres,:,0].view(-1).long()] featT=featMaptc[idxInst,:,denseCorres['idxTgt'][validCorres,:,1].view(-1).long(),denseCorres['idxTgt'][validCorres,:,0].view(-1).long()] # positive example, loss_fl_pos=(featS-featT).pow(2).sum(1).mean() # negative example, make sure does not contain positive Kn = denseCorres['idxSrc'].shape[1] C = featMapsc.shape[1] negIdy=torch.from_numpy(np.random.choice(range(featMapsc.shape[2]),Kn*100*len(validCorres))) negIdx=torch.from_numpy(np.random.choice(range(featMapsc.shape[3]),Kn*100*len(validCorres))) idx=torch.arange(n)[validCorres].view(-1,1).repeat(1,Kn*100).view(-1).long() loss_fl_neg=F.relu(self.args.D-(featS.unsqueeze(1).repeat(1,100,1).view(-1,C)-featMaptc[idx,:,negIdy,negIdx]).pow(2).sum(1)).mean() loss_fl=loss_fl_pos+loss_fl_neg errG = loss_fl total_weight = dataMask if self.args.featlearnSegm: errG_s = (CEcriterion(segm_pred,segm.squeeze(1).long())*total_weight).mean() * 0.1 errG += errG_s if mode == 'train' and len(validCorres) > 0: errG.backward() self.optimizerF.step() self.logger_errG.update(errG.data, Q.size(0)) self.logger_errG_fl.update(loss_fl.data, Q.size(0)) self.logger_errG_fl_pos.update(loss_fl_pos.data, Q.size(0)) self.logger_errG_fl_neg.update(loss_fl_neg.data, Q.size(0)) suffix = f"| errG {self.logger_errG.avg:.6f}| errG_fl {self.logger_errG_fl.avg:.6f}\ | errG_fl_pos {self.logger_errG_fl_pos.avg:.6f} | errG_fl_neg {self.logger_errG_fl_neg.avg:.6f} | errG_fc {self.logger_errG_fc.avg:.6f} | errG_pn {self.logger_errG_pn.avg:.6f} | errG_freq {self.logger_errG_freq.avg:.6f}" if self.global_step % getattr(self.learnerParam,f"{mode}_step_vis") == 0: if mode != 'train': with torch.set_grad_enabled(False): if len(validCorres): self.evalFeatRatioSift.extend(self.evalSiftDescriptor(rgb,denseCorres)) obs,unobs,visCPc=self.evalDLDescriptor(featMapsc,featMaptc,denseCorres,complete0[:,0:3,:,:],complete1[:,0:3,:,:],mask[0:1,0:1,:,:]) self.evalFeatRatioDLc_obs.extend(obs) self.evalFeatRatioDLc_unobs.extend(unobs) visCPdir=os.path.join(self.args.EXP_DIR_SAMPLES,f"step_{self.global_step}") if not os.path.exists(visCPdir):os.mkdir(visCPdir) for ii in range(len(visCPc)): cv2.imwrite(os.path.join(visCPdir,f"complete_{ii}.png"),visCPc[ii]) vis=[] # draw semantic viss = segm vissf = segm_pred vissf = torch.argmax(vissf,1,keepdim=True).float() viss = torch.cat((vissf,viss),2) visstp= torch_op.npy(viss) visstp= np.expand_dims(np.squeeze(visstp,1),3) visstp= self.colors[visstp.flatten().astype('int'),:].reshape(visstp.shape[0],visstp.shape[1],visstp.shape[2],3) viss = torch_op.v(visstp.transpose(0,3,1,2))/255. vis.append(viss) vis = torch.cat(vis, 2)[::2] permute = [2, 1, 0] # bgr to rgb vis = vis[:,permute,:,:] train_op.tboard_add_img(self.tensorboardX,vis,f"{mode}/loss",self.global_step) if self.global_step % getattr(self.learnerParam,f"{mode}_step_log") == 0: self.tensorboardX.add_scalars('data/errG_fl', {f"{mode}_complete":loss_fl}, self.global_step) self.tensorboardX.add_scalars('data/errG_fl_pos', {f"{mode}_complete":loss_fl_pos}, self.global_step) self.tensorboardX.add_scalars('data/errG_fl_neg', {f"{mode}_complete":loss_fl_neg}, self.global_step) self.tensorboardX.add_scalars('data/errG_s', {f"{mode}":errG_s}, self.global_step) summary = {'suffix':suffix} self.global_step+=1 if self.speed_benchmark: self.time_per_step.update(time.time()-step_start,1) print(f"time elapse per step: {self.time_per_step.avg}") return dotdict(summary)
args.idx_f_start += 3 if 'd' in args.outputType: args.idx_f_start += 1 if 's' in args.outputType: args.idx_f_start += args.snumclass if 'f' in args.outputType: args.idx_f_end = args.idx_f_start + args.featureDim with torch.set_grad_enabled(False): R_hat=np.eye(4) # get the complete scans complete_s=torch.cat((torch_op.v(data['rgb'][:,0,:,:,:]),torch_op.v(data['norm'][:,0,:,:,:]),torch_op.v(data['depth'][:,0:1,:,:])),1) complete_t=torch.cat((torch_op.v(data['rgb'][:,1,:,:,:]),torch_op.v(data['norm'][:,1,:,:,:]),torch_op.v(data['depth'][:,1:2,:,:])),1) # apply the observation mask view_s,mask_s,_ = util.apply_mask(complete_s.clone(),args.maskMethod) view_t,mask_t,_ = util.apply_mask(complete_t.clone(),args.maskMethod) mask_s=torch_op.npy(mask_s[0,:,:,:]).transpose(1,2,0) mask_t=torch_op.npy(mask_t[0,:,:,:]).transpose(1,2,0) # append mask for valid data tpmask = (view_s[:,6:7,:,:]!=0).float().cuda() view_s=torch.cat((view_s,tpmask),1) tpmask = (view_t[:,6:7,:,:]!=0).float().cuda() view_t=torch.cat((view_t,tpmask),1)
def getKeypoint_kinect(rs, rt, feats, featt, rs_full, rt_full): h, w = 160, 640 grays = cv2.cvtColor(rs, cv2.COLOR_BGR2GRAY) grayt = cv2.cvtColor(rt, cv2.COLOR_BGR2GRAY) sift = cv2.xfeatures2d.SIFT_create(contrastThreshold=0.02) #grays=grays[80-33:80+33,160+80-44:160+80+44] grays = cv2.cvtColor(rs_full, cv2.COLOR_BGR2GRAY) (kps, _) = sift.detectAndCompute(grays, None) if not len(kps): return None, None, None, None, None, None pts = np.zeros([len(kps), 2]) for j, m in enumerate(kps): pts[j, :] = m.pt pts[:, 0] = pts[:, 0] / 640 * 88 # the observed region size of kinect camera is [88x66] pts[:, 1] = pts[:, 1] / 480 * 66 pts[:, 0] += 160 + 80 - 44 pts[:, 1] += 80 - 33 #grayt=grayt[80-33:80+33,160+80-44:160+80+44] grayt = cv2.cvtColor(rt_full, cv2.COLOR_BGR2GRAY) (kpt, _) = sift.detectAndCompute(grayt, None) if not len(kpt): return None, None, None, None, None, None ptt = np.zeros([len(kpt), 2]) for j, m in enumerate(kpt): ptt[j, :] = m.pt ptt[:, 0] = ptt[:, 0] / 640 * 88 ptt[:, 1] = ptt[:, 1] / 480 * 66 ptt[:, 0] += 160 + 80 - 44 ptt[:, 1] += 80 - 33 pts = pts[np.random.choice(range(len(pts)), 300), :] ptt = ptt[np.random.choice(range(len(ptt)), 300), :] methods = ['siftBestMatch', 'randomBestMatch'] if 'siftBestMatch' in methods: ptsNorm = pts.copy().astype('float') ptsNorm[:, 0] /= 640 ptsNorm[:, 1] /= 160 pttNorm = ptt.copy().astype('float') pttNorm[:, 0] /= 640 pttNorm[:, 1] /= 160 fs0 = interpolate(feats, torch_op.v(ptsNorm)) ft0 = interpolate(featt, torch_op.v(pttNorm)) # find the most probable correspondence using feature map C = feats.shape[0] fsselect = np.random.choice(range(pts.shape[0]), min(30, pts.shape[0])) ftselect = np.random.choice(range(ptt.shape[0]), min(30, ptt.shape[0])) dist = (fs0[:, fsselect].unsqueeze(2) - featt.view(C, 1, -1)).pow(2).sum(0).view(len(fsselect), h, w) pttAug = Sampling(torch_op.npy(dist), 2) dist = (ft0[:, ftselect].unsqueeze(2) - feats.view(C, 1, -1)).pow(2).sum(0).view(len(ftselect), h, w) ptsAug = Sampling(torch_op.npy(dist), 2) pttAug = pttAug.reshape(-1, 2) ptsAug = ptsAug.reshape(-1, 2) valid = (pttAug[:, 0] < w - 1) * (pttAug[:, 1] < h - 1) pttAug = pttAug[valid] valid = (ptsAug[:, 0] < w - 1) * (ptsAug[:, 1] < h - 1) ptsAug = ptsAug[valid] pts = np.concatenate((pts, ptsAug)) ptt = np.concatenate((ptt, pttAug)) if 'randomBestMatch' in methods: N = 120 xs = (np.random.rand(N) * 640).astype('int').clip(0, 640 - 2) ys = (np.random.rand(N) * 160).astype('int').clip(0, 160 - 2) ptsrnd = np.stack((xs, ys), 1) # filter out observed region valid = ((ptsrnd[:, 0] >= 160 + 80 - 44) * (ptsrnd[:, 0] <= 160 + 80 + 44) * (ptsrnd[:, 1] >= 80 - 33) * (ptsrnd[:, 1] <= 80 + 33)) ptsrnd = ptsrnd[~valid] ptsrndNorm = ptsrnd.copy().astype('float') ptsrndNorm[:, 0] /= 640 ptsrndNorm[:, 1] /= 160 fs0 = interpolate(feats, torch_op.v(ptsrndNorm)) fsselect = np.random.choice(range(ptsrnd.shape[0]), min(100, ptsrnd.shape[0])) dist = (fs0[:, fsselect].unsqueeze(2) - featt.view(C, 1, -1)).pow(2).sum(0).view(len(fsselect), h, w) pttAug = Sampling(torch_op.npy(dist), 2) pttAug = pttAug.reshape(-1, 2) valid = (pttAug[:, 0] < w - 1) * (pttAug[:, 1] < h - 1) pttAug = pttAug[valid] pts = np.concatenate((pts, ptsrnd[fsselect])) ptt = np.concatenate((ptt, pttAug)) ptsNorm = pts.copy().astype('float') ptsNorm[:, 0] /= 640 ptsNorm[:, 1] /= 160 pttNorm = ptt.copy().astype('float') pttNorm[:, 0] /= 640 pttNorm[:, 1] /= 160 # hacks to get the observed region for kinect camera configuration. valid = ((pts[:, 0] >= 160 + 80 - 44) * (pts[:, 0] <= 160 + 80 + 44) * (pts[:, 1] >= 80 - 33) * (pts[:, 1] <= 80 + 33)) ptsW = np.ones(len(valid)) ptsW[~valid] *= 0.99 valid = ((ptt[:, 0] >= 160 + 80 - 44) * (ptt[:, 0] <= 160 + 80 + 44) * (ptt[:, 1] >= 80 - 33) * (ptt[:, 1] <= 80 + 33)) pttW = np.ones(len(valid)) pttW[~valid] *= 0.99 return pts, ptsNorm, ptsW, ptt, pttNorm, pttW