def get_shape_chamfer_loss(self, pts, quat1, quat2, center1, center2, part_cnt): num_point = pts.shape[1] part_pcs1 = qrot(quat1.unsqueeze(1).repeat(1, num_point, 1), pts) + center1.unsqueeze(1).repeat(1, num_point, 1) part_pcs2 = qrot(quat2.unsqueeze(1).repeat(1, num_point, 1), pts) + center2.unsqueeze(1).repeat(1, num_point, 1) t = 0 shape_pcs1 = [] shape_pcs2 = [] for cnt in part_cnt: cur_shape_pc1 = part_pcs1[t:t + cnt].view(1, -1, 3) cur_shape_pc2 = part_pcs2[t:t + cnt].view(1, -1, 3) with torch.no_grad(): idx1 = furthest_point_sample(cur_shape_pc1, 2048).long()[0] idx2 = furthest_point_sample(cur_shape_pc2, 2048).long()[0] shape_pcs1.append(cur_shape_pc1[:, idx1]) shape_pcs2.append(cur_shape_pc2[:, idx2]) t += cnt shape_pcs1 = torch.cat(shape_pcs1, dim=0) # numshapes x 2048 x 3 shape_pcs2 = torch.cat(shape_pcs2, dim=0) # numshapes x 2048 x 3 dist1, dist2 = chamfer_distance(shape_pcs1, shape_pcs2, transpose=False) loss_per_data = torch.mean(dist1, dim=1) + torch.mean(dist2, dim=1) return loss_per_data
def get_quat_loss(self, pts, quat1, quat2): num_point = pts.shape[1] pts1 = qrot(quat1.unsqueeze(1).repeat(1, num_point, 1), pts) pts2 = qrot(quat2.unsqueeze(1).repeat(1, num_point, 1), pts) dist1, dist2 = chamfer_distance(pts1, pts2, transpose=False) loss_per_data = torch.mean(dist1, dim=1) + torch.mean(dist2, dim=1) return loss_per_data
def get_anchor_loss(self, box_size, quat1, quat2, center1, center2): box1 = torch.cat([center1, box_size, quat1], dim=-1) box2 = torch.cat([center2, box_size, quat2], dim=-1) anchor1_pc = transform_pc_batch(self.unit_anchor, box1, anchor=True) anchor2_pc = transform_pc_batch(self.unit_anchor, box2, anchor=True) d1, d2 = chamfer_distance(anchor1_pc, anchor2_pc, transpose=False) loss_per_data = (d1.mean(dim=1) + d2.mean(dim=1)) / 2 return loss_per_data
def get_box_loss(self, box_size, quat1, quat2, center1, center2): box1 = torch.cat([center1, box_size, quat1], dim=-1) box2 = torch.cat([center2, box_size, quat2], dim=-1) box1_pc = transform_pc_batch(self.unit_cube, box1) box2_pc = transform_pc_batch(self.unit_cube, box2) with torch.no_grad(): box1_reweight = get_surface_reweighting_batch( box1[:, 3:6], self.unit_cube.size(0)) box2_reweight = get_surface_reweighting_batch( box2[:, 3:6], self.unit_cube.size(0)) d1, d2 = chamfer_distance(box1_pc, box2_pc, transpose=False) loss_per_data = (d1 * box1_reweight).sum(dim=1) / (box1_reweight.sum(dim=1) + 1e-12) + \ (d2 * box2_reweight).sum(dim=1) / (box2_reweight.sum(dim=1) + 1e-12) return loss_per_data
def get_adj_loss(self, pts, quat, center, adjs): num_point = pts.shape[1] part_pcs = qrot(quat.unsqueeze(1).repeat(1, num_point, 1), pts) + center.unsqueeze(1).repeat(1, num_point, 1) loss = [] for cur_shape_adj in adjs: cur_shape_loss = [] for adj in cur_shape_adj: idx1, idx2 = adj dist1, dist2 = chamfer_distance(part_pcs[idx1].unsqueeze(0), part_pcs[idx2].unsqueeze(0), transpose=False) cur_loss = torch.min(dist1, dim=1)[0] + torch.min(dist2, dim=1)[0] cur_shape_loss.append(cur_loss) loss.append(torch.stack(cur_shape_loss).mean()) return loss
fn = input_shape_id[t] + '-' + input_view_id[t] + '.npy' np.save(os.path.join(test_res_matched_dir, fn), to_save) print('saved', fn) t += cnt # get metric stats pred_pts = qrot( matched_pred_quat2_all.unsqueeze(1).repeat(1, num_point, 1), input_pts) + matched_pred_center2_all.unsqueeze(1).repeat( 1, num_point, 1) gt_pts = qrot( matched_gt_quat_all.unsqueeze(1).repeat(1, num_point, 1), input_pts) + matched_gt_center_all.unsqueeze(1).repeat( 1, num_point, 1) dist1, dist2 = chamfer_distance(gt_pts, pred_pts, transpose=False, sqrt=True) dist = torch.mean(dist1, dim=1).cpu().numpy() + torch.mean( dist2, dim=1).cpu().numpy() dist /= 2.0 batch_size = input_box_size.shape[0] correct = 0 cur_batch_accu = [] cur_batch_correct = [] t = 0 for cnt in input_total_part_cnt: cur_shape_correct = 0 for i in range(cnt): print('part dist', dist[i], 'thresh', eval_conf.thresh, 'correct', dist[i] < eval_conf.thresh) if dist[t + i] < eval_conf.thresh:
def linear_assignment(self, mask1, mask2, similar_cnt, pts, centers1, quats1, centers2, quats2): ''' mask1, mask 2: # part_cnt x 224 x 224 similar cnt # shape: max_mask x 2 # first index is the index of parts without similar parts 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 6, .... # second index is the number of similar part count: 1, 1, 1, 2, 2, 2, 2, 4, 4, 4, 4, .... ''' bids = [] ids1 = [] ids2 = [] inds1 = [] inds2 = [] img_size = mask1.shape[-1] t = 0 num_point = pts.shape[1] max_num_part = centers1.shape[1] # part_cnt = [item for sublist in part_cnt for item in sublist] with torch.no_grad(): while t < similar_cnt.shape[0]: cnt = similar_cnt[t].item() bids = [t] * cnt cur_mask1 = mask1[t:t + cnt].unsqueeze(1).repeat( 1, cnt, 1, 1).view(-1, img_size, img_size) cur_mask2 = mask2[t:t + cnt].unsqueeze(0).repeat( cnt, 1, 1, 1).view(-1, img_size, img_size) dist_mat_mask = self.get_mask_loss(cur_mask1, cur_mask2).view(cnt, cnt) dist_mat_mask = torch.clamp(dist_mat_mask, max=-0.1) cur_pts = pts[t:t + cnt] cur_quats1 = quats1[t:t + cnt].unsqueeze(1).repeat( 1, num_point, 1) cur_centers1 = centers1[t:t + cnt].unsqueeze(1).repeat( 1, num_point, 1) cur_pts1 = qrot(cur_quats1, cur_pts) + cur_centers1 cur_quats2 = quats2[t:t + cnt].unsqueeze(1).repeat( 1, num_point, 1) cur_centers2 = centers2[t:t + cnt].unsqueeze(1).repeat( 1, num_point, 1) cur_pts2 = qrot(cur_quats2, cur_pts) + cur_centers2 cur_pts1 = cur_pts1.unsqueeze(1).repeat(1, cnt, 1, 1).view( -1, num_point, 3) cur_pts2 = cur_pts2.unsqueeze(0).repeat(cnt, 1, 1, 1).view( -1, num_point, 3) dist1, dist2 = chamfer_distance(cur_pts1, cur_pts2, transpose=False) dist_mat_pts = (dist1.mean(1) + dist2.mean(1)).view(cnt, cnt) dist_mat_pts = torch.clamp(dist_mat_pts, max=1) * 0.1 dist_mat = torch.add(dist_mat_mask, dist_mat_pts) t += cnt rind, cind = linear_sum_assignment(dist_mat.cpu().numpy()) ids1 = list(rind) ids2 = list(cind) inds1 += [bids[i] + ids1[i] for i in range(len(ids1))] inds2 += [bids[i] + ids2[i] for i in range(len(ids2))] return inds1, inds2
def forward(self, part_pcs, part_feat_old, part_cnt, equiv_edge_indices): after_part_feat_3d = self.pointnet(part_pcs) part_feat_8d = torch.cat([part_feat_old, after_part_feat_3d], dim=1) part_feat_8d = torch.relu(self.mlp1(part_feat_8d)) # perform graph-conv t = 0 i = 0 output_part_feat_8d = [] for cnt in part_cnt: child_feats = part_feat_8d[t:t + cnt] iter_feats = [child_feats] cur_equiv_edge_indices = equiv_edge_indices[i] cur_equiv_edge_feats = cur_equiv_edge_indices.new_zeros( cur_equiv_edge_indices.shape[0], 1) # detect adj topk = min(5, cnt - 1) with torch.no_grad(): cur_part_pcs = part_pcs[t:t + cnt] A = cur_part_pcs.unsqueeze(0).repeat(cnt, 1, 1, 1).view(cnt * cnt, -1, 3) B = cur_part_pcs.unsqueeze(1).repeat(1, cnt, 1, 1).view(cnt * cnt, -1, 3) dist1, dist2 = chamfer_distance(A, B, transpose=False) dist = dist1.min(dim=1)[0].view(cnt, cnt) cur_adj_edge_indices = [] for j in range(cnt): for k in dist[j].argsort()[1:topk + 1]: cur_adj_edge_indices.append([j, k.item()]) cur_adj_edge_indices = torch.Tensor( cur_adj_edge_indices).long().to( cur_equiv_edge_indices.device) cur_adj_edge_feats = cur_adj_edge_indices.new_ones( cur_adj_edge_indices.shape[0], 1) cur_edge_indices = torch.cat( [cur_equiv_edge_indices, cur_adj_edge_indices], dim=0) cur_edge_feats = torch.cat( [cur_equiv_edge_feats, cur_adj_edge_feats], dim=0).float() for j in range(self.num_iteration): node_edge_feats = torch.cat([ child_feats[cur_edge_indices[:, 0], :], child_feats[cur_edge_indices[:, 1], :], cur_edge_feats ], dim=1) node_edge_feats = torch.relu( self.node_edge_op[j](node_edge_feats)) new_child_feats = child_feats.new_zeros(cnt, 256) new_child_feats = torch_scatter.scatter_mean( node_edge_feats, cur_edge_indices[:, 0], dim=0, out=new_child_feats) child_feats = new_child_feats iter_feats.append(child_feats) all_iter_feat = torch.cat(iter_feats, dim=1) output_part_feat_8d.append(all_iter_feat) t += cnt i += 1 feat = torch.cat(output_part_feat_8d, dim=0) center, quat = self.pose_decoder(feat) return center, quat