def compute_ucd(partial_ls, output_ls): """ input two lists (small lists) return a single mean """ if isinstance(partial_ls[0],np.ndarray): partial_ls = [torch.from_numpy(itm) for itm in partial_ls] output_ls = [torch.from_numpy(itm) for itm in output_ls] if len(partial_ls) < 100: partial = torch.stack(partial_ls).cuda() output = torch.stack(output_ls).cuda() dist1, dist2 , _, _ = distChamfer(partial, output) cd_loss = dist1.mean()*10000 cd_ls = (dist1.mean(1)*10000).cpu().numpy().tolist() return cd_loss.item(), cd_ls else: batch_size = 50 n_samples = len(partial_ls) n_batches = int(n_samples/batch_size) + min(1, n_samples%batch_size) cd_ls = [] for i in range(n_batches): # if i*batch_size # print(n_samples, i, i*batch_size) partial = torch.stack(partial_ls[i*batch_size:min(n_samples,i*batch_size+batch_size)]).cuda() output = torch.stack(output_ls[i*batch_size:min(n_samples,i*batch_size+batch_size)]).cuda() dist1, dist2 , _, _ = distChamfer(partial, output) cd_loss = dist1.mean(1)*10000 cd_ls.append(cd_loss) cd = torch.cat(cd_ls).mean().item() cd_ls = torch.cat(cd_ls).cpu().numpy().tolist() return cd, cd_ls
def k_mask(self, target, x, stage=-1): """ masking based on CD. target: (1, N, 3), partial, can be < 2048, 2048, > 2048 x: (1, 2048, 3) x_map: (1, N', 3), N' < 2048 x_map: v1: 2048, 0 masked points """ stage = max(0, stage) knn = self.args.k_mask_k[stage] if knn == 1: cd1, cd2, argmin1, argmin2 = distChamfer(target, x) idx = torch.unique(argmin1).type(torch.long) elif knn > 1: # dist_mat shape (B, 2048, 2048), where B = 1 dist_mat = distChamfer_raw(target, x) # indices (B, 2048, k) val, indices = torch.topk(dist_mat, k=knn, dim=2,largest=False) # union of all the indices idx = torch.unique(indices).type(torch.long) if self.args.masking_option == 'element_product': mask_tensor = torch.zeros(2048,1) mask_tensor[idx] = 1 mask_tensor = mask_tensor.cuda().unsqueeze(0) x_map = torch.mul(x, mask_tensor) elif self.args.masking_option == 'indexing': x_map = x[:, idx] return x_map
def MMD_batch(sample_pcs, ref_pcs, batch_size=50, normalize=True, sess=None, verbose=False, use_sqrt=False, use_EMD=False, device=None): ''' compute MMD with CD / EMD between two point sets same input and output as minimum_mathing_distance() cuda implementation CD and EMD input: sample and ref_pcs can be np or tensor (full data, like 1000 sample_pcs, and 1000 ref_pcs) ''' n_ref, n_pc_points, pc_dim = ref_pcs.shape n_sample, n_pc_points_s, pc_dim_s = sample_pcs.shape if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s: raise ValueError('Incompatible size of point-clouds.') dist_mat = torch.zeros(n_ref, n_sample) # np to cuda tensor if start from np if isinstance(sample_pcs, np.ndarray): ref_pcs = torch.from_numpy(ref_pcs).cuda() sample_pcs = torch.from_numpy(sample_pcs).cuda() for r in range(n_ref): for i in range(0, n_sample, batch_size): if i + batch_size < n_sample: sample_pcd_seg = sample_pcs[i:i + batch_size] else: sample_pcd_seg = sample_pcs[i:] ref_pcd = ref_pcs[r].unsqueeze(0) ref_pcd_e = ref_pcd.expand(sample_pcd_seg.shape[0], n_pc_points, pc_dim) if use_EMD: # EMD # ref: https://github.com/Colin97/MSN-Point-Cloud-Completion/tree/master/emd emd = emdModule() dists, assigment = emd(sample_pcd_seg, ref_pcd_e, 0.005, 50) dist = dists.mean(dim=1) else: # CD dist1, dist2, _, _ = distChamfer(ref_pcd_e, sample_pcd_seg) dist = dist1.mean(axis=1) + dist2.mean(axis=1) if i + batch_size < n_sample: dist_mat[r, i:i + batch_size] = dist else: dist_mat[r, i:] = dist mmd_all, _ = dist_mat.min(dim=1) mmd_all = mmd_all.detach().cpu().numpy() mmd = np.mean(mmd_all) mmd = np.sqrt(mmd) return mmd, mmd_all, dist_mat.cpu().numpy()
def mutual_distance(pcd_ls): if isinstance(pcd_ls[0], np.ndarray): pcd_ls = [torch.from_numpy(itm).unsqueeze(0) for itm in pcd_ls] sum_dist = 0 for i in range(len(pcd_ls)): for j in range(i + 1, len(pcd_ls)): dist1, dist2, _, _ = distChamfer(pcd_ls[i], pcd_ls[j]) dist = dist1.mean(axis=1) + dist2.mean(axis=1) sum_dist += dist mean_dist = sum_dist * 2 / (len(pcd_ls) - 1) return mean_dist.item() * 10000
def compute_cd_small_batch(gt, output,batch_size=50): """ compute cd in case n_pcd is large """ n_pcd = gt.shape[0] dist = [] for i in range(0, n_pcd, batch_size): last_idx = min(i+batch_size,n_pcd) dist1, dist2 , _, _ = distChamfer(gt[i:last_idx], output[i:last_idx]) cd_loss = dist1.mean(1) + dist2.mean(1) dist.append(cd_loss) dist_tensor = torch.cat(dist) cd_ls = (dist_tensor*10000).cpu().numpy().tolist() return cd_ls
def eval_completion_with_gt(input_dir, cd_verbose=False): ours_gt, ours_output, _ = retrieve_ours_pcs(input_dir) cd_ls, acc_ls, comp_ls, f1_ls = compute_4_metrics(ours_gt, ours_output) if cd_verbose: dist1, dist2 , _, _ = distChamfer(ours_gt, ours_output) cd_loss = dist1.mean() + dist2.mean() if cd_verbose: cd_pcds = dist1.mean(1)+ dist2.mean(1) cd_pcds*=10000 for i in range(150): print(i,int(cd_pcds[i].item())) print('cd.mean:',np.mean(cd_ls)) print('acc :',np.mean(acc_ls)) print('comp:',np.mean(comp_ls)) print('f1 :',np.mean(f1_ls))
def compute_4_metrics(pcn_gt, pcn_output): """ compute cd, acc, comp, f1 for a batch """ if pcn_gt.shape[0] <= 150: dist1, dist2 , _, _ = distChamfer(pcn_gt, pcn_output) cd = dist1.mean(1)+ dist2.mean(1) cd_ls = (cd*10000).cpu().numpy().tolist() else: ### compute with small batches: batch_size = 50 gt = pcn_gt cd_ls = compute_cd_small_batch(pcn_gt, pcn_output, batch_size=batch_size) acc_ls, comp_ls, f1_ls = compute_3_metrics(pcn_gt,pcn_output,thre=0.03) return cd_ls, acc_ls, comp_ls, f1_ls
def select_z(self, select_y=False): tic = time.time() with torch.no_grad(): if self.select_num == 0: self.z.zero_() return elif self.select_num == 1: self.z.normal_() return z_all, y_all, loss_all = [], [], [] for i in range(self.select_num): z = torch.randn(1, 1, 96).cuda() tree = [z] with torch.no_grad(): x = self.G(tree) ftr_loss = self.criterion(self.ftr_net, x, self.target) z_all.append(z) loss_all.append(ftr_loss.detach().cpu().numpy()) toc = time.time() loss_all = np.array(loss_all) idx = np.argmin(loss_all) self.z.copy_(z_all[idx]) if select_y: self.y.copy_(y_all[idx]) x = self.G([self.z]) # visualization if self.gt is not None: x_map = self.pre_process(x, stage=-1) dist1, dist2 , _, _ = distChamfer(x,self.gt) cd_loss = dist1.mean() + dist2.mean() with open(self.args.log_pathname, "a") as file_object: msg = str(self.pcd_id) + ',' + 'init' + ',' + 'cd' +',' + '{:6.5f}'.format(cd_loss.item()) # print(msg) file_object.write(msg+'\n') self.checkpoint_flags.append('init x') self.checkpoint_pcd.append(x) self.checkpoint_flags.append('init x_map') self.checkpoint_pcd.append(x_map) return z_all[idx]
def diversity_search(self, select_y=False): """ produce batch by batch search by 2pf and partial but constrainted to z dimension are large """ batch_size = 50 num_batch = int(self.select_num/batch_size) x_ls = [] z_ls = [] cd_ls = [] tic = time.time() with torch.no_grad(): for i in range(num_batch): z = torch.randn(batch_size, 1, 96).cuda() tree = [z] x = self.G(tree) dist1, dist2 , _, _ = distChamfer(self.target.repeat(batch_size,1,1),x) cd = dist1.mean(1) # single directional CD x_ls.append(x) z_ls.append(z) cd_ls.append(cd) x_full = torch.cat(x_ls) cd_full = torch.cat(cd_ls) z_full = torch.cat(z_ls) toc = time.time() cd_candidates, idx = torch.topk(cd_full,self.args.n_z_candidates,largest=False) z_t = z_full[idx].transpose(0,1) seeds = farthest_point_sample(z_t, self.args.n_outputs).squeeze(0) z_ten = z_full[idx][seeds] self.zs = [itm.unsqueeze(0) for itm in z_ten] self.xs = []
def run(self, ith=-1): loss_dict = {} curr_step = 0 count = 0 for stage, iteration in enumerate(self.iterations): for i in range(iteration): curr_step += 1 # setup learning rate self.G_scheduler.update(curr_step, self.args.G_lrs[stage]) self.z_scheduler.update(curr_step, self.args.z_lrs[stage]) # forward self.z_optim.zero_grad() if self.update_G_stages[stage]: self.G.optim.zero_grad() tree = [self.z] x = self.G(tree) # masking x_map = self.pre_process(x,stage=stage) ### compute losses ftr_loss = self.criterion(self.ftr_net, x_map, self.target) dist1, dist2 , _, _ = distChamfer(x_map, self.target) cd_loss = dist1.mean() + dist2.mean() # optional early stopping if self.args.early_stopping: if cd_loss.item() < self.args.stop_cd: break # nll corresponds to a negative log-likelihood loss nll = self.z**2 / 2 nll = nll.mean() ### loss loss = ftr_loss * self.w_D_loss[stage] + nll * self.args.w_nll \ + cd_loss * 1 # optional to use directed_hausdorff if self.args.directed_hausdorff: directed_hausdorff_loss = self.directed_hausdorff(self.target, x) loss += directed_hausdorff_loss*self.args.w_directed_hausdorff_loss # backward loss.backward() self.z_optim.step() if self.update_G_stages[stage]: self.G.optim.step() # save checkpoint for each stage self.checkpoint_flags.append('s_'+str(stage)+' x') self.checkpoint_pcd.append(x) self.checkpoint_flags.append('s_'+str(stage)+' x_map') self.checkpoint_pcd.append(x_map) # test only for each stage if self.gt is not None: dist1, dist2 , _, _ = distChamfer(x,self.gt) test_cd = dist1.mean() + dist2.mean() with open(self.args.log_pathname, "a") as file_object: msg = str(self.pcd_id) + ',' + 'stage'+str(stage) + ',' + 'cd' +',' + '{:6.5f}'.format(test_cd.item()) file_object.write(msg+'\n') if self.gt is not None: loss_dict = { 'ftr_loss': np.asscalar(ftr_loss.detach().cpu().numpy()), 'nll': np.asscalar(nll.detach().cpu().numpy()), 'cd': np.asscalar(test_cd.detach().cpu().numpy()), } self.loss_log.append(loss_dict) ### save point clouds self.x = x if not osp.isdir(self.args.save_inversion_path): os.mkdir(self.args.save_inversion_path) x_np = x[0].detach().cpu().numpy() x_map_np = x_map[0].detach().cpu().numpy() target_np = self.target[0].detach().cpu().numpy() if ith == -1: basename = str(self.pcd_id) else: basename = str(self.pcd_id)+'_'+str(ith) if self.gt is not None: gt_np = self.gt[0].detach().cpu().numpy() np.savetxt(osp.join(self.args.save_inversion_path,basename+'_gt.txt'), gt_np, fmt = "%f;%f;%f") np.savetxt(osp.join(self.args.save_inversion_path,basename+'_x.txt'), x_np, fmt = "%f;%f;%f") np.savetxt(osp.join(self.args.save_inversion_path,basename+'_xmap.txt'), x_map_np, fmt = "%f;%f;%f") np.savetxt(osp.join(self.args.save_inversion_path,basename+'_target.txt'), target_np, fmt = "%f;%f;%f") # jittering mode if self.args.inversion_mode == 'jittering': self.jitter(self.target)