def compute_ucd(partial_ls, output_ls):
    """ 
    input two lists (small lists)
    return a single mean
    """
    if isinstance(partial_ls[0],np.ndarray):
        partial_ls = [torch.from_numpy(itm) for itm in partial_ls]
        output_ls = [torch.from_numpy(itm) for itm in output_ls]
    if len(partial_ls) < 100:
        partial = torch.stack(partial_ls).cuda()
        output = torch.stack(output_ls).cuda()
        dist1, dist2 , _, _ = distChamfer(partial, output)
        cd_loss = dist1.mean()*10000
        cd_ls = (dist1.mean(1)*10000).cpu().numpy().tolist()
        return cd_loss.item(), cd_ls
    else:
        batch_size = 50
        n_samples = len(partial_ls)
        n_batches = int(n_samples/batch_size) + min(1, n_samples%batch_size)
        cd_ls = []
        for i in range(n_batches):
            # if i*batch_size
            # print(n_samples, i, i*batch_size)
            partial = torch.stack(partial_ls[i*batch_size:min(n_samples,i*batch_size+batch_size)]).cuda()
            output = torch.stack(output_ls[i*batch_size:min(n_samples,i*batch_size+batch_size)]).cuda()
            dist1, dist2 , _, _ = distChamfer(partial, output)
            cd_loss = dist1.mean(1)*10000
            cd_ls.append(cd_loss)
        cd = torch.cat(cd_ls).mean().item()
        cd_ls = torch.cat(cd_ls).cpu().numpy().tolist()
        return cd, cd_ls
    def k_mask(self, target, x, stage=-1):
        """
        masking based on CD.
        target: (1, N, 3), partial, can be < 2048, 2048, > 2048
        x: (1, 2048, 3)
        x_map: (1, N', 3), N' < 2048
        x_map: v1: 2048, 0 masked points
        """
        stage = max(0, stage)
        knn = self.args.k_mask_k[stage]
        if knn == 1:
            cd1, cd2, argmin1, argmin2 = distChamfer(target, x)
            idx = torch.unique(argmin1).type(torch.long)
        elif knn > 1:
            # dist_mat shape (B, 2048, 2048), where B = 1
            dist_mat = distChamfer_raw(target, x)
            # indices (B, 2048, k)
            val, indices = torch.topk(dist_mat, k=knn, dim=2,largest=False)
            # union of all the indices
            idx = torch.unique(indices).type(torch.long)

        if self.args.masking_option == 'element_product':   
            mask_tensor = torch.zeros(2048,1)
            mask_tensor[idx] = 1
            mask_tensor = mask_tensor.cuda().unsqueeze(0)
            x_map = torch.mul(x, mask_tensor) 
        elif self.args.masking_option == 'indexing':  
            x_map = x[:, idx]

        return x_map
示例#3
0
def MMD_batch(sample_pcs,
              ref_pcs,
              batch_size=50,
              normalize=True,
              sess=None,
              verbose=False,
              use_sqrt=False,
              use_EMD=False,
              device=None):
    '''
    compute MMD with CD / EMD between two point sets
    same input and output as minimum_mathing_distance() 
    cuda implementation CD and EMD
    input:
        sample and ref_pcs can be np or tensor
        (full data, like 1000 sample_pcs, and 1000 ref_pcs)
    '''
    n_ref, n_pc_points, pc_dim = ref_pcs.shape
    n_sample, n_pc_points_s, pc_dim_s = sample_pcs.shape
    if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s:
        raise ValueError('Incompatible size of point-clouds.')

    dist_mat = torch.zeros(n_ref, n_sample)

    # np to cuda tensor if start from np
    if isinstance(sample_pcs, np.ndarray):
        ref_pcs = torch.from_numpy(ref_pcs).cuda()
        sample_pcs = torch.from_numpy(sample_pcs).cuda()

    for r in range(n_ref):
        for i in range(0, n_sample, batch_size):
            if i + batch_size < n_sample:
                sample_pcd_seg = sample_pcs[i:i + batch_size]
            else:
                sample_pcd_seg = sample_pcs[i:]

            ref_pcd = ref_pcs[r].unsqueeze(0)
            ref_pcd_e = ref_pcd.expand(sample_pcd_seg.shape[0], n_pc_points,
                                       pc_dim)

            if use_EMD:
                # EMD
                # ref: https://github.com/Colin97/MSN-Point-Cloud-Completion/tree/master/emd
                emd = emdModule()
                dists, assigment = emd(sample_pcd_seg, ref_pcd_e, 0.005, 50)
                dist = dists.mean(dim=1)
            else:
                # CD
                dist1, dist2, _, _ = distChamfer(ref_pcd_e, sample_pcd_seg)
                dist = dist1.mean(axis=1) + dist2.mean(axis=1)

            if i + batch_size < n_sample:
                dist_mat[r, i:i + batch_size] = dist
            else:
                dist_mat[r, i:] = dist
    mmd_all, _ = dist_mat.min(dim=1)
    mmd_all = mmd_all.detach().cpu().numpy()
    mmd = np.mean(mmd_all)
    mmd = np.sqrt(mmd)
    return mmd, mmd_all, dist_mat.cpu().numpy()
示例#4
0
def mutual_distance(pcd_ls):

    if isinstance(pcd_ls[0], np.ndarray):
        pcd_ls = [torch.from_numpy(itm).unsqueeze(0) for itm in pcd_ls]
    sum_dist = 0
    for i in range(len(pcd_ls)):
        for j in range(i + 1, len(pcd_ls)):
            dist1, dist2, _, _ = distChamfer(pcd_ls[i], pcd_ls[j])
            dist = dist1.mean(axis=1) + dist2.mean(axis=1)
            sum_dist += dist
    mean_dist = sum_dist * 2 / (len(pcd_ls) - 1)
    return mean_dist.item() * 10000
def compute_cd_small_batch(gt, output,batch_size=50):
    """
    compute cd in case n_pcd is large
    """
    n_pcd = gt.shape[0]
    dist = []
    for i in range(0, n_pcd, batch_size):
        last_idx = min(i+batch_size,n_pcd)
        dist1, dist2 , _, _ = distChamfer(gt[i:last_idx], output[i:last_idx])
        cd_loss = dist1.mean(1) + dist2.mean(1)
        dist.append(cd_loss)
    dist_tensor = torch.cat(dist)
    cd_ls = (dist_tensor*10000).cpu().numpy().tolist()
    return cd_ls
def eval_completion_with_gt(input_dir, cd_verbose=False):
    
    ours_gt, ours_output, _ = retrieve_ours_pcs(input_dir)
    cd_ls, acc_ls, comp_ls, f1_ls = compute_4_metrics(ours_gt, ours_output)
    if cd_verbose:
        dist1, dist2 , _, _ = distChamfer(ours_gt, ours_output)
        cd_loss = dist1.mean() + dist2.mean()
        if cd_verbose:
            cd_pcds = dist1.mean(1)+ dist2.mean(1)
            cd_pcds*=10000
            for i in range(150):
                print(i,int(cd_pcds[i].item()))
    print('cd.mean:',np.mean(cd_ls)) 
    print('acc :',np.mean(acc_ls))
    print('comp:',np.mean(comp_ls))
    print('f1  :',np.mean(f1_ls))
def compute_4_metrics(pcn_gt, pcn_output):
    """
    compute cd, acc, comp, f1 for a batch
    """
    if pcn_gt.shape[0] <= 150:
        dist1, dist2 , _, _ = distChamfer(pcn_gt, pcn_output)
        cd = dist1.mean(1)+ dist2.mean(1)
        cd_ls  = (cd*10000).cpu().numpy().tolist()
    else:
        ### compute with small batches:
        batch_size = 50
        gt = pcn_gt
        cd_ls = compute_cd_small_batch(pcn_gt, pcn_output, batch_size=batch_size)
    
    acc_ls, comp_ls, f1_ls = compute_3_metrics(pcn_gt,pcn_output,thre=0.03)

    return cd_ls, acc_ls, comp_ls, f1_ls
    def select_z(self, select_y=False):
        tic = time.time()
        with torch.no_grad():
            if self.select_num == 0:
                self.z.zero_()
                return
            elif self.select_num == 1:
                self.z.normal_()
                return
            z_all, y_all, loss_all = [], [], []
            for i in range(self.select_num):
                z = torch.randn(1, 1, 96).cuda()
                tree = [z]
                with torch.no_grad():
                    x = self.G(tree)
                ftr_loss = self.criterion(self.ftr_net, x, self.target) 
                z_all.append(z)
                loss_all.append(ftr_loss.detach().cpu().numpy())
            
            toc = time.time()
            loss_all = np.array(loss_all)
            idx = np.argmin(loss_all)
            
            self.z.copy_(z_all[idx])
            if select_y:
                self.y.copy_(y_all[idx])
            
            x = self.G([self.z])

            # visualization
            if self.gt is not None:
                x_map = self.pre_process(x, stage=-1)
                dist1, dist2 , _, _ = distChamfer(x,self.gt)
                cd_loss = dist1.mean() + dist2.mean()
                
                with open(self.args.log_pathname, "a") as file_object:
                    msg = str(self.pcd_id) + ',' + 'init' + ',' + 'cd' +',' + '{:6.5f}'.format(cd_loss.item())
                    # print(msg)
                    file_object.write(msg+'\n')
                self.checkpoint_flags.append('init x')
                self.checkpoint_pcd.append(x)
                self.checkpoint_flags.append('init x_map')
                self.checkpoint_pcd.append(x_map)
            return z_all[idx]
    def diversity_search(self, select_y=False):
        """
        produce batch by batch
        search by 2pf and partial
        but constrainted to z dimension are large
        """
        batch_size = 50

        num_batch = int(self.select_num/batch_size)
        x_ls = []
        z_ls = []
        cd_ls = []
        tic = time.time()
        with torch.no_grad():
            for i in range(num_batch):
                z = torch.randn(batch_size, 1, 96).cuda()
                tree = [z]
                x = self.G(tree)
                dist1, dist2 , _, _ = distChamfer(self.target.repeat(batch_size,1,1),x)
                cd = dist1.mean(1) # single directional CD

                x_ls.append(x)
                z_ls.append(z)
                cd_ls.append(cd)
                
        x_full = torch.cat(x_ls)
        cd_full = torch.cat(cd_ls)
        z_full = torch.cat(z_ls)

        toc = time.time()
        
        cd_candidates, idx = torch.topk(cd_full,self.args.n_z_candidates,largest=False)
        z_t = z_full[idx].transpose(0,1)
        seeds = farthest_point_sample(z_t, self.args.n_outputs).squeeze(0) 
        z_ten = z_full[idx][seeds]

        self.zs = [itm.unsqueeze(0) for itm in z_ten]
        self.xs = []
    def run(self, ith=-1):
        loss_dict = {}
        curr_step = 0
        count = 0
        for stage, iteration in enumerate(self.iterations):

            for i in range(iteration):
                curr_step += 1
                # setup learning rate
                self.G_scheduler.update(curr_step, self.args.G_lrs[stage])
                self.z_scheduler.update(curr_step, self.args.z_lrs[stage])

                # forward
                self.z_optim.zero_grad()
                
                if self.update_G_stages[stage]:
                    self.G.optim.zero_grad()
                             
                tree = [self.z]
                x = self.G(tree)
                
                # masking
                x_map = self.pre_process(x,stage=stage)

                ### compute losses
                ftr_loss = self.criterion(self.ftr_net, x_map, self.target)

                dist1, dist2 , _, _ = distChamfer(x_map, self.target)
                cd_loss = dist1.mean() + dist2.mean()
                # optional early stopping
                if self.args.early_stopping:
                    if cd_loss.item() < self.args.stop_cd:
                        break

                # nll corresponds to a negative log-likelihood loss
                nll = self.z**2 / 2
                nll = nll.mean()
                
                ### loss
                loss = ftr_loss * self.w_D_loss[stage] + nll * self.args.w_nll \
                        + cd_loss * 1
                
                # optional to use directed_hausdorff
                if self.args.directed_hausdorff:
                    directed_hausdorff_loss = self.directed_hausdorff(self.target, x)
                    loss += directed_hausdorff_loss*self.args.w_directed_hausdorff_loss
                
                # backward
                loss.backward()
                self.z_optim.step()
                if self.update_G_stages[stage]:
                    self.G.optim.step()

            # save checkpoint for each stage
            self.checkpoint_flags.append('s_'+str(stage)+' x')
            self.checkpoint_pcd.append(x)
            self.checkpoint_flags.append('s_'+str(stage)+' x_map')
            self.checkpoint_pcd.append(x_map)

            # test only for each stage
            if self.gt is not None:
                dist1, dist2 , _, _ = distChamfer(x,self.gt)
                test_cd = dist1.mean() + dist2.mean()
                with open(self.args.log_pathname, "a") as file_object:
                    msg = str(self.pcd_id) + ',' + 'stage'+str(stage) + ',' + 'cd' +',' + '{:6.5f}'.format(test_cd.item())
                    file_object.write(msg+'\n')
        
        if self.gt is not None:
            loss_dict = {
                'ftr_loss': np.asscalar(ftr_loss.detach().cpu().numpy()),
                'nll': np.asscalar(nll.detach().cpu().numpy()),
                'cd': np.asscalar(test_cd.detach().cpu().numpy()),
            }
            self.loss_log.append(loss_dict)
                
        ### save point clouds
        self.x = x
        if not osp.isdir(self.args.save_inversion_path):
            os.mkdir(self.args.save_inversion_path)
        x_np = x[0].detach().cpu().numpy()
        x_map_np = x_map[0].detach().cpu().numpy()
        target_np = self.target[0].detach().cpu().numpy()
        if ith == -1:
            basename = str(self.pcd_id)
        else:
            basename = str(self.pcd_id)+'_'+str(ith)
        if self.gt is not None:
            gt_np = self.gt[0].detach().cpu().numpy()
            np.savetxt(osp.join(self.args.save_inversion_path,basename+'_gt.txt'), gt_np, fmt = "%f;%f;%f")  
        np.savetxt(osp.join(self.args.save_inversion_path,basename+'_x.txt'), x_np, fmt = "%f;%f;%f")  
        np.savetxt(osp.join(self.args.save_inversion_path,basename+'_xmap.txt'), x_map_np, fmt = "%f;%f;%f")  
        np.savetxt(osp.join(self.args.save_inversion_path,basename+'_target.txt'), target_np, fmt = "%f;%f;%f")  

        # jittering mode
        if self.args.inversion_mode == 'jittering':
            self.jitter(self.target)