def forward(self, global_feat, point_input): batch_size = global_feat.size()[0] coarse = F.relu(self.fc1(global_feat)) coarse = F.relu(self.fc2(coarse)) coarse = self.fc3(coarse).view(batch_size, 3, self.num_coarse) if self.downsample_im: if self.mirror_im: org_points_input = symmetric_sample( point_input.transpose(1, 2).contiguous(), int((2048 - self.num_coarse) / 2)) org_points_input = org_points_input.transpose(1, 2).contiguous() else: org_points_input = pn2.gather_operation( point_input, pn2.furthest_point_sample( point_input.transpose(1, 2).contiguous(), int(2048 - self.num_coarse))) else: org_points_input = point_input if self.points_label: id0 = torch.zeros(coarse.shape[0], 1, coarse.shape[2]).cuda().contiguous() coarse_input = torch.cat((coarse, id0), 1) id1 = torch.ones(org_points_input.shape[0], 1, org_points_input.shape[2]).cuda().contiguous() org_points_input = torch.cat((org_points_input, id1), 1) points = torch.cat((coarse_input, org_points_input), 2) else: points = torch.cat((coarse, org_points_input), 2) dense_feat = self.encoder(points) if self.scale >= 2: dense_feat = self.expansion(dense_feat) point_feat = F.relu(self.conv1(dense_feat)) fine = self.conv2(point_feat) num_out = fine.size()[2] if num_out > self.num_fine: fine = pn2.gather_operation( fine, pn2.furthest_point_sample( fine.transpose(1, 2).contiguous(), self.num_fine)) return coarse, fine
def edge_preserve_sampling(feature_input, point_input, num_samples, k=10): batch_size = feature_input.size()[0] feature_size = feature_input.size()[1] num_points = feature_input.size()[2] p_idx = pn2.furthest_point_sample(point_input, num_samples) point_output = pn2.gather_operation(point_input.transpose(1, 2).contiguous(), p_idx).transpose(1, 2).contiguous() # B M 3 pk = int(min(k, num_points)) _, pn_idx = knn_point(pk, point_input, point_output) pn_idx = pn_idx.detach().int() # B M pk # print(pn_idx.size()) # neighbor_feature = pn2.grouping_operation(feature_input, pn_idx) # neighbor_feature = index_points(feature_input.transpose(1,2).contiguous(), pn_idx).permute(0, 3, 1, 2) neighbor_feature = pn2.gather_operation(feature_input, pn_idx.view(batch_size, num_samples * pk)).view(batch_size, feature_size, num_samples, pk) neighbor_feature, _ = torch.max(neighbor_feature, 3) center_feature = pn2.grouping_operation(feature_input, p_idx.unsqueeze(2)).view(batch_size, -1, num_samples) net = torch.cat((center_feature, neighbor_feature), 1) return net, p_idx, pn_idx, point_output
def get_uniform_loss(pcd, percentages=[0.004, 0.006, 0.008, 0.010, 0.012], radius=1.0): B, N, C = pcd.size() npoint = int(N * 0.05) loss = 0 for p in percentages: nsample = int(N*p) r = math.sqrt(p*radius) disk_area = math.pi * (radius ** 2) * p/nsample new_xyz = pn2.gather_operation(pcd.transpose(1, 2).contiguous(), pn2.furthest_point_sample(pcd, npoint)).transpose(1, 2).contiguous() idx = pn2.ball_query(r, nsample, pcd, new_xyz) expect_len = math.sqrt(disk_area) grouped_pcd = pn2.grouping_operation(pcd.transpose(1,2).contiguous(), idx) grouped_pcd = grouped_pcd.permute(0, 2, 3, 1).contiguous().view(-1, nsample, 3) var, _ = knn_point(2, grouped_pcd, grouped_pcd) uniform_dis = -var[:, :, 1:] uniform_dis = torch.sqrt(torch.abs(uniform_dis+1e-8)) uniform_dis = torch.mean(uniform_dis, dim=-1) uniform_dis = ((uniform_dis - expect_len)**2 / (expect_len + 1e-8)) mean = torch.mean(uniform_dis) mean = mean*math.pow(p*100,2) loss += mean return loss/len(percentages)
def symmetric_sample(points, num=512): p1_idx = pn2.furthest_point_sample(points, num) input_fps = pn2.gather_operation(points.transpose(1, 2).contiguous(), p1_idx).transpose(1, 2).contiguous() x = torch.unsqueeze(input_fps[:, :, 0], dim=2) y = torch.unsqueeze(input_fps[:, :, 1], dim=2) z = torch.unsqueeze(-input_fps[:, :, 2], dim=2) input_fps_flip = torch.cat([x, y, z], dim=2) input_fps = torch.cat([input_fps, input_fps_flip], dim=1) return input_fps
def forward(self, xyz, features=None): # type: (_PointnetSAModuleBase, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] r""" Parameters ---------- xyz : torch.Tensor (B, N, 3) tensor of the xyz coordinates of the features features : torch.Tensor (B, N, C) tensor of the descriptors of the the features Returns ------- new_xyz : torch.Tensor (B, npoint, 3) tensor of the new features' xyz new_features : torch.Tensor (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors """ new_features_list = [] xyz_flipped = xyz.transpose(1, 2).contiguous() if self.npoint is not None: # check if self.npoint == xyz.size(1): fps_idx = torch.arange(self.npoint, device=xyz.device, dtype=torch.int).unsqueeze(0).expand( xyz.size(0), self.npoint).contiguous() new_xyz = xyz else: fps_idx = pointnet2_utils.furthest_point_sample( xyz, self.npoint) # (B, npoint) new_xyz = pointnet2_utils.gather_operation( xyz_flipped, fps_idx).transpose(1, 2).contiguous() fps_idx = fps_idx.data else: new_xyz = None fps_idx = None for i in range(len(self.groupers)): new_features = self.groupers[i](xyz, new_xyz, features, fps_idx) if self.npoint is not None else \ self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample) new_features = self.mlps[i]( new_features) # (B, mlp[-1], npoint, nsample) new_features = F.max_pool2d(new_features, kernel_size=[ 1, new_features.size(3) ]) # (B, mlp[-1], npoint, 1) new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) new_features = self.out_mlps[i](new_features) new_features_list.append(new_features) return new_xyz, torch.cat(new_features_list, dim=1)
def forward(self, xyz: torch.Tensor, features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): r""" Parameters ---------- xyz : torch.Tensor (B, N, 3) tensor of the xyz coordinates of the points features : torch.Tensor (B, N, C) tensor of the descriptors of the the points Returns ------- new_xyz : torch.Tensor (B, npoint, 3) tensor of the new points' xyz new_features : torch.Tensor (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_points descriptors """ new_features_list = [] xyz_flipped = xyz.transpose(1, 2).contiguous() if self.npoint is not None: # check if self.npoint == xyz.size(1): fps_idx = torch.arange(self.npoint, device=xyz.device, dtype=torch.int).unsqueeze(0).expand( xyz.size(0), self.npoint).contiguous() new_xyz = xyz else: fps_idx = pointnet2_utils.furthest_point_sample( xyz, self.npoint) # (B, npoint) new_xyz = pointnet2_utils.gather_operation( xyz_flipped, fps_idx).transpose(1, 2).contiguous() fps_idx = fps_idx.data else: new_xyz = None fps_idx = None for i in range(len(self.groupers)): new_features = self.groupers[i]( xyz, new_xyz, features, fps_idx) if self.npoint is not None else self.groupers[i]( xyz, new_xyz, features) # (B, C, npoint, nsample) new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint) new_features_list.append(new_features) if len(new_features_list) > 1: new_features_list = [f.unsqueeze(0) for f in new_features_list] final_feature, _ = torch.max(torch.cat(new_features_list, dim=0), dim=0) return new_xyz, final_feature else: return new_xyz, torch.cat(new_features_list, dim=1)
def forward(self, xyz, features, end_points): """ Args: xyz: (B,K,3) features: (B,C,K) Returns: scores: (B,num_proposal,2+3+NH*2+NS*4) """ if self.sampling == 'vote_fps': # Farthest point sampling (FPS) on votes xyz, features, fps_inds = self.vote_aggregation(xyz, features) sample_inds = fps_inds elif self.sampling == 'seed_fps': # FPS on seed and choose the votes corresponding to the seeds # This gets us a slightly better coverage of *object* votes than vote_fps (which tends to get more cluster votes) sample_inds = pointnet2_utils.furthest_point_sample( end_points['seed_xyz'], self.num_proposal) xyz, features, _ = self.vote_aggregation(xyz, features, sample_inds) elif self.sampling == 'random': # Random sampling from the votes num_seed = end_points['seed_xyz'].shape[1] batch_size = end_points['seed_xyz'].shape[0] sample_inds = torch.randint(0, num_seed, (batch_size, self.num_proposal), dtype=torch.int).cuda() xyz, features, _ = self.vote_aggregation(xyz, features, sample_inds) else: log_string('Unknown sampling strategy: %s. Exiting!' % (self.sampling)) exit() end_points[ 'aggregated_vote_xyz'] = xyz # (batch_size, num_proposal, 3) end_points[ 'aggregated_vote_inds'] = sample_inds # (batch_size, num_proposal,) # should be 0,1,2,...,num_proposal # --------- CONTEXT LEARNING --------- feature_dim = features.shape[1] batch_size = features.shape[0] features = features.contiguous().view(batch_size, feature_dim, 16, 16) net = self.sa(features) net = net.contiguous().view(batch_size, feature_dim, self.num_proposal) # --------- GRASP/PROPOSAL GENERATION --------- net = F.relu(self.bn1(self.conv1(net))) net = F.relu(self.bn2(self.conv2(net))) net = self.conv3( net ) # (batch_size, 2+3+1+1+num_angle_bin*2+num_viewpoint+self.num_class, num_proposal) end_points = decode_scores(net, end_points, self.num_class, self.num_angle_bin, self.num_viewpoint) return end_points
def fps_sample(numpy_x, num_point): points = torch.from_numpy(numpy_x).float().view(1, -1, 6).cuda() xyz = points[:, :, :3].contiguous() fps_idx = pointnet2_utils.furthest_point_sample(xyz, num_point) # (B, npoint) points = pointnet2_utils.gather_operation( points.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous() # (B, N, 3) points = points.data.cpu().numpy() return points[0]
def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, inds: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): r""" Parameters ---------- xyz : torch.Tensor (B, N, 3) tensor of the xyz coordinates of the features features : torch.Tensor (B, C, N) tensor of the descriptors of the the features inds : torch.Tensor (B, npoint) tensor that stores index to the xyz points (values in 0-N-1) Returns ------- new_xyz : torch.Tensor (B, npoint, 3) tensor of the new features' xyz new_features : torch.Tensor (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors inds: torch.Tensor (B, npoint) tensor of the inds """ if not self.same_idx: xyz_flipped = xyz.transpose(1, 2).contiguous() if inds is None: inds = pointnet2_utils.furthest_point_sample(xyz, self.npoint) else: assert(inds.shape[1] == self.npoint) new_xyz = pointnet2_utils.gather_operation( xyz_flipped, inds ).transpose(1, 2).contiguous() if self.npoint is not None else None else: new_xyz = xyz if not self.ret_unique_cnt: #grouped_features, grouped_xyz = self.grouper( # xyz, new_xyz, features #) # (B, C, npoint, nsample) grouped_features, grouped_xyz, idx = self.grouper( xyz, new_xyz, features ) # (B, C, npoint, nsample) else: grouped_features, grouped_xyz, unique_cnt = self.grouper( xyz, new_xyz, features ) # (B, C, npoint, nsample), (B,3,npoint,nsample), (B,npoint) new_features = [] for i in range(self.split): new_features.append(self.mlp_module[i]( grouped_features )) # (B, mlp[-1], npoint, nsample) new_features = torch.stack(new_features, 1) return new_xyz, new_features, idx, grouped_features
def forward(self, xyz, features, end_points): """ Args: xyz: (B,K,3) features: (B,C,K) Returns: scores: (B,num_proposal,2+3+NH*2+NS*4) """ if self.sampling == 'vote_fps': # Farthest point sampling (FPS) on votes xyz, features, fps_inds, group_idx = self.vote_aggregation( xyz, features) sample_inds = fps_inds # groupd_idx (B, num_proposal, 16) , 16 is nb of samples per prop. elif self.sampling == 'seed_fps': # FPS on seed and choose the votes corresponding to the seeds # This gets us a slightly better coverage of *object* votes than vote_fps (which tends to get more cluster votes) sample_inds = pointnet2_utils.furthest_point_sample( end_points['seed_xyz'], self.num_proposal) xyz, features, _, group_idx = self.vote_aggregation( xyz, features, sample_inds) elif self.sampling == 'random': # Random sampling from the votes num_seed = end_points['seed_xyz'].shape[1] batch_size = end_points['seed_xyz'].shape[0] sample_inds = torch.randint(0, num_seed, (batch_size, self.num_proposal), dtype=torch.int).cuda() xyz, features, _, group_idx = self.vote_aggregation( xyz, features, sample_inds) else: log_string('Unknown sampling strategy: %s. Exiting!' % (self.sampling)) exit() end_points[ 'seeds_in_prop_indices'] = group_idx #(B, num_proposal, 16) values are in 0- num_seeds end_points[ 'aggregated_vote_xyz'] = xyz # (batch_size, num_proposal, 3) end_points[ 'aggregated_vote_inds'] = sample_inds # (batch_size, num_proposal,) # should be 0,1,2,...,num_proposal # --------- PROPOSAL GENERATION --------- net_conv1 = F.relu(self.bn1(self.conv1(features))) net_conv2 = self.bn2(self.conv2(net_conv1)) net = F.relu(net_conv2) net = self.conv3( net ) # (batch_size, 2+3+num_heading_bin*2+num_size_cluster*4, num_proposal) end_points = decode_scores(net, end_points, self.num_class, self.num_heading_bin, self.num_size_cluster, self.mean_size_arr) end_points['proposal_lastlayer_features'] = net_conv2 return end_points
def forward(self, xyz, features, end_points, mode=''): """ Args: xyz: (B,K,3) features: (B,C,K) Returns: scores: (B,num_proposal,2+3+NH*2+NS*4) """ if self.sampling == 'vote_fps': # Farthest point sampling (FPS) on votes original_feature = features xyz, features, fps_inds = self.vote_aggregation(xyz, features) #original_feature = torch.gather(original_features, 2, fps_inds.unsqueeze(1).repeat(1,256,1).detach().long()).contiguous() sample_inds = fps_inds elif self.sampling == 'seed_fps': # FPS on seed and choose the votes corresponding to the seeds # This gets us a slightly better coverage of *object* votes than vote_fps (which tends to get more cluster votes) sample_inds = pointnet2_utils.furthest_point_sample( end_points['seed_xyz'], self.num_proposal) xyz, features, _ = self.vote_aggregation(xyz, features, sample_inds) elif self.sampling == 'random': # Random sampling from the votes num_seed = end_points['seed_xyz'].shape[1] sample_inds = torch.randint(0, num_seed, (batch_size, self.num_proposal), dtype=torch.int).cuda() xyz, features, _ = self.vote_aggregation(xyz, features, sample_inds) else: log_string('Unknown sampling strategy: %s. Exiting!' % (self.sampling)) exit() end_points['aggregated_vote_xyz' + mode] = xyz # (batch_size, num_proposal, 3) end_points[ 'aggregated_vote_inds' + mode] = sample_inds # (batch_size, num_proposal,) # should be 0,1,2,...,num_proposal end_points['aggregated_feature' + mode] = features # (batch_size, 128, num_proposal) # --------- PROPOSAL GENERATION --------- net = F.relu(self.bn1(self.conv1(features))) last_net = F.relu(self.bn2(self.conv2(net))) net = self.conv3( last_net ) # (batch_size, 2+3+num_heading_bin*2+num_size_cluster*4, num_proposal) newcenter, end_points = decode_scores(net, end_points, self.num_class, mode=mode) return newcenter.contiguous(), features.contiguous(), end_points
def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor): """ :param xyz: (B, N, 3) tensor of the xyz coordinates :param features: (B, C, N) tensor of point features :param new_xyz: :return: new_xyz: (B, npoint, 3) tensor of the new xyz coordinates new_features: (B, \sum_k(mlps[k][-1]), npoint) tensor of the new point features, note: the channels = a sum of the last number of the mlps, e.g., mlps[[16,32],[32,64]] the result is 32+64 """ new_features_list = [] # sample npoint points from N points for each batch using FPS algorithm asssuming (npoint<N) xyz_flipped = xyz.transpose(1, 2).contiguous() # (B,3,N) if new_xyz is None: new_xyz = pointnet2_utils.gather_operation( features=xyz_flipped, idx=pointnet2_utils.furthest_point_sample( xyz, self.npoint) # (B,npoint) indexes ).transpose(1, 2).contiguous( ) if self.npoint is not None else None # (B,npoint,3) for i in range(len(self.groupers)): new_features = self.groupers[i]( xyz, new_xyz, features) # (B, C, npoint, nsample) new_features = self.mlps[i]( new_features) # (B, mlp[i][-1], npoint, nsample) if self.pool_method == 'max_pool': new_features = F.max_pool2d(new_features, kernel_size=[ 1, new_features.size(3) ]) # (B, mlp[i][-1], npoint, 1) elif self.pool_method == 'avg_pool': new_features = F.avg_pool2d(new_features, kernel_size=[ 1, new_features.size(3) ]) # (B, mlp[i][-1], npoint, 1) else: raise NotImplementedError new_features = new_features.squeeze(-1) # (B, mlp[i][-1], npoint) new_features_list.append(new_features) return new_xyz, torch.cat( new_features_list, dim=1) # (B,npoint,3) (B,\sum_k(mlp[k][-1]), npoint)
def forward(self, xyz, features, end_points): """ Args: xyz: (B,K,3) features: (B,C,K) Returns: scores: (B,num_proposal,2+3+NH*2+NS*4) """ if self.sampling == 'vote_fps': # Farthest point sampling (FPS) on votes xyz, features, fps_inds = self.vote_aggregation(xyz, features) sample_inds = fps_inds elif self.sampling == 'seed_fps': # FPS on seed and choose the votes corresponding to the seeds # This gets us a slightly better coverage of *object* votes than vote_fps (which tends to get more cluster votes) sample_inds = pointnet2_utils.furthest_point_sample(end_points['seed_xyz'], self.num_proposal) xyz, features, _ = self.vote_aggregation(xyz, features, sample_inds) elif self.sampling == 'random': # Random sampling from the votes num_seed = end_points['seed_xyz'].shape[1] batch_size = end_points['seed_xyz'].shape[0] sample_inds = torch.randint(0, num_seed, (batch_size, self.num_proposal), dtype=torch.int).cuda() xyz, features, _ = self.vote_aggregation(xyz, features, sample_inds) else: log_string('Unknown sampling strategy: %s. Exiting!'%(self.sampling)) exit() end_points['aggregated_vote_xyz'] = xyz # (batch_size, num_proposal, 3) end_points['aggregated_vote_inds'] = sample_inds # (batch_size, num_proposal,) # should be 0,1,2,...,num_proposal feature_dim = features.shape[1] batch_size = features.shape[0] features = features.contiguous().view(batch_size, feature_dim, 16, 16) net = self.sa1(features) net = self.sa2(net) net = net.contiguous().view(batch_size, feature_dim, self.num_proposal) features = features.contiguous().view(batch_size, feature_dim, self.num_proposal) seed_features = end_points['seed_features'] global_features_2 = F.max_pool1d(features, kernel_size=features.size(2)) # (B, 128, 1) global_features_1 = F.max_pool1d(seed_features, kernel_size=seed_features.size(2)) # (B, 256, 1) global_features = torch.cat((global_features_1, global_features_2), 1) # (B, 256+128, 1) global_features = torch.cat((global_features.expand(features.shape[0], 256+128, 256), net),1) global_features = self.gs_conv1(global_features) global_features = torch.sigmoid(torch.log(torch.abs(global_features))) net = net * global_features net = F.relu(self.bn1(self.conv1(net))) net = F.relu(self.bn2(self.conv2(net))) net = self.conv3(net) # (batch_size, 2+3+num_heading_bin*2+num_size_cluster*4, num_proposal) end_points = decode_scores(net, end_points, self.num_class, self.num_heading_bin, self.num_size_cluster, self.mean_size_arr) return end_points
def forward( self, xyz: torch.Tensor, features: torch.Tensor = None, inds: torch.Tensor = None, ) -> (torch.Tensor, torch.Tensor): r""" Parameters ---------- xyz : torch.Tensor (B, N, 3) tensor of the xyz coordinates of the features features : torch.Tensor (B, C, C) tensor of the descriptors of the the features inds : torch.Tensor (B, npoint) tensor that stores index to the xyz points (values in 0-N-1) Returns ------- new_xyz : torch.Tensor (B, npoint, 3) tensor of the new features' xyz new_features : torch.Tensor (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors inds: torch.Tensor (B, npoint) tensor of the inds """ new_features_list = [] xyz_flipped = xyz.transpose(1, 2).contiguous() if inds is None: inds = pointnet2_utils.furthest_point_sample(xyz, self.npoint) new_xyz = (pointnet2_utils.gather_operation( xyz_flipped, inds).transpose(1, 2).contiguous() if self.npoint is not None else None) for i in range(len(self.groupers)): new_features = self.groupers[i]( xyz, new_xyz, features) # (B, C, npoint, nsample) new_features = self.mlps[i]( new_features) # (B, mlp[-1], npoint, nsample) new_features = F.max_pool2d(new_features, kernel_size=[ 1, new_features.size(3) ]) # (B, mlp[-1], npoint, 1) new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) new_features_list.append(new_features) return new_xyz, torch.cat(new_features_list, dim=1), inds
def forward(self, xyz, features, end_points): """ Args: xyz: (B,K,3) features: (B,C,K) Returns: scores: (B,num_proposal,2+3+NH*2+NS*4) """ if self.sampling == 'vote_fps': # Farthest point sampling (FPS) on votes xyz, features, fps_inds = self.vote_aggregation(xyz, features) sample_inds = fps_inds elif self.sampling == 'seed_fps': # FPS on seed and choose the votes corresponding to the seeds # This gets us a slightly better coverage of *object* votes than vote_fps (which tends to get more cluster votes) sample_inds = pointnet2_utils.furthest_point_sample(end_points['seed_xyz'], self.num_proposal) xyz, features, _ = self.vote_aggregation(xyz, features, sample_inds) elif self.sampling == 'random': # Random sampling from the votes num_seed = end_points['seed_xyz'].shape[1] batch_size = end_points['seed_xyz'].shape[0] sample_inds = torch.randint(0, num_seed, (batch_size, self.num_proposal), dtype=torch.int).cuda() xyz, features, _ = self.vote_aggregation(xyz, features, sample_inds) else: log_string('Unknown sampling strategy: %s. Exiting!'%(self.sampling)) exit() end_points['aggregated_vote_xyz'] = xyz # (batch_size, num_proposal, 3) end_points['aggregated_vote_inds'] = sample_inds # (batch_size, num_proposal,) # should be 0,1,2,...,num_proposal # --------- PROPOSAL GENERATION --------- # Add rn feature end_points = self.rnet(features, end_points) rn_feature = end_points['rn_feature'] print("net: {}, {}".format(np.min(features.cpu().detach().numpy()), np.max(features.cpu().detach().numpy()))) print("rn: {}, {}".format(np.min(rn_feature.cpu().detach().numpy()), np.max(rn_feature.cpu().detach().numpy()))) features = torch.cat((features, rn_feature), 1) # features = features + rn_feature print("cat(fetures,rn_feature): {}, {}".format(np.min(features.cpu().detach().numpy()), np.max(features.cpu().detach().numpy()))) net = F.relu(self.bn1(self.conv1(features))) net = F.relu(self.bn2(self.conv2(net))) net = self.conv3(net) # (batch_size, 2+3+num_heading_bin*2+num_size_cluster*4, num_proposal) end_points = decode_scores(net, end_points, self.num_class, self.num_heading_bin, self.num_size_cluster, self.mean_size_arr) return end_points
def forward(self, xyz, features): """ Args: xyz: (B,K,3) features: (B,C,K) """ # Farthest point sampling (FPS) sample_inds = pointnet2_utils.furthest_point_sample( xyz, self.num_proposal) xyz_flipped = xyz.transpose(1, 2).contiguous() new_xyz = pointnet2_utils.gather_operation(xyz_flipped, sample_inds).transpose( 1, 2).contiguous() new_features = pointnet2_utils.gather_operation( features, sample_inds).contiguous() return new_xyz, new_features, sample_inds
def forward(self, code, inputs, step_ratio, num_extract=512, mean_feature=None): ''' :param code: B * C :param inputs: B * C * N :param step_ratio: int :param num_extract: int :param mean_feature: B * C :return: coarse(B * N * C), fine(B, N, C) ''' coarse = torch.tanh(self.coarse_mlp(code)) # (32, 1536) coarse = coarse.view(-1, 512, 3) # (32, 512, 3) coarse = coarse.transpose(2, 1).contiguous() # (32, 3, 512) inputs_new = inputs.transpose(2, 1).contiguous() # (32, 2048, 3) input_fps = symmetric_sample(inputs_new, int(num_extract/2)) # [32, 512, 3] input_fps = input_fps.transpose(2, 1).contiguous() # [32, 3, 512] level0 = torch.cat([input_fps, coarse], 2) # (32, 3, 1024) if num_extract > 512: level0_flipped = level0.transpose(2, 1).contiguous() level0 = pn2.gather_operation(level0, pn2.furthest_point_sample(level0_flipped, 1024)) for i in range(int(math.log2(step_ratio))): num_fine = 2 ** (i + 1) * 1024 grid = gen_grid_up(2 ** (i + 1)).cuda().contiguous() grid = torch.unsqueeze(grid, 0) # (1, 2, 2) grid_feat = grid.repeat(level0.shape[0], 1, 1024) # (32, 2, 2048) point_feat = torch.unsqueeze(level0, 3).repeat(1, 1, 1, 2) # (32, 3, 1024, 2) point_feat = point_feat.view(-1, 3, num_fine) # (32, 3, 2048) global_feat = torch.unsqueeze(code, 2).repeat(1, 1, num_fine) # (32, 1024, 2048) if mean_feature is not None: mean_feature_use = F.relu(self.mean_fc(mean_feature)) #(32, 128) mean_feature_use = torch.unsqueeze(mean_feature_use, 2).repeat(1, 1, num_fine) #(32, 128, 2048) feat = torch.cat([grid_feat, point_feat, global_feat, mean_feature_use], dim=1) # (32, 1157, 2048) feat1 = F.relu(self.up_branch_mlp_conv_mf(feat)) # (32, 64, 2048) else: feat = torch.cat([grid_feat, point_feat, global_feat], dim=1) feat1 = F.relu(self.up_branch_mlp_conv_nomf(feat)) # (32, 64, 2048) feat2 = self.contract_expand(feat1) # (32, 64, 2048) feat = feat1 + feat2 # (32, 64, 2048) fine = self.fine_mlp_conv(feat) + point_feat # (32, 3, 2048) level0 = fine return coarse.transpose(1, 2).contiguous(), fine.transpose(1, 2).contiguous()
def forward(self, xyz, normal, features=None): """ Parameters ---------- xyz : (B, N, 3) tensor of the xyz coordinates of the points xyz : (B, N, 3) tensor of the normal vectors of the points features : (B, N, C) tensor of the descriptors of the the points Returns ------- new_xyz : (B, npoint, 3) tensor of the new points' xyz new_normal : (B, npoint, 3) tensor of the new points' normal new_features : (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_points descriptors """ new_features_list = [] xyz_flipped = xyz.transpose(1, 2).contiguous() if self.npoint is not None: normal_flipped = normal.transpose(1, 2).contiguous() fps_idx = pointnet2_utils.furthest_point_sample(xyz, self.npoint) new_xyz = pointnet2_utils.gather_operation(xyz_flipped, fps_idx).transpose( 1, 2).contiguous() new_normal = pointnet2_utils.gather_operation( normal_flipped, fps_idx).transpose(1, 2).contiguous() fps_idx = fps_idx.data else: # for global convolution new_xyz = torch.FloatTensor([ 0.0 ]).cuda().unsqueeze(-1).unsqueeze(-1).expand(xyz.shape[0], 1, 3).contiguous() new_normal = None fps_idx = None for i in range(len(self.groupers)): if self.npoint is not None: new_features = self.groupers[i](xyz, new_xyz, normal, features, fps_idx) else: new_features = self.groupers[i](xyz, new_xyz, features) new_features = self.mlps[i]((new_features, new_normal)) new_features_list.append(new_features) return new_xyz, new_normal, torch.cat(new_features_list, dim=1)
def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor): """ :param xyz: (B, N, 3) tensor of the xyz coordinates of the features :param features: (B, N, C) tensor of the descriptors of the the features :param new_xyz: :return: new_xyz: (B, npoint, 3) tensor of the new features' xyz new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors """ new_features_list = [] xyz_flipped = xyz.transpose(1, 2).contiguous() if new_xyz is None: new_xyz = pointnet2_utils.gather_operation( xyz_flipped, pointnet2_utils.furthest_point_sample( xyz, self.npoint)).transpose( 1, 2).contiguous() if self.npoint is not None else None for i in range(len(self.groupers)): new_features = self.groupers[i]( xyz, new_xyz, features) # (B, C, npoint, nsample) new_features = self.mlps[i]( new_features) # (B, mlp[-1], npoint, nsample) if self.pool_method == 'max_pool': new_features = F.max_pool2d(new_features, kernel_size=[ 1, new_features.size(3) ]) # (B, mlp[-1], npoint, 1) elif self.pool_method == 'avg_pool': new_features = F.avg_pool2d(new_features, kernel_size=[ 1, new_features.size(3) ]) # (B, mlp[-1], npoint, 1) else: raise NotImplementedError new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) new_features_list.append(new_features) return new_xyz, torch.cat(new_features_list, dim=1)
def forward(self, xyz: torch.Tensor, features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): r""" Parameters ---------- xyz : torch.Tensor (B, N, 3) tensor of the xyz coordinates of the points features : torch.Tensor (B, N, C) tensor of the descriptors of the the points Returns ------- new_xyz : torch.Tensor (B, npoint, 3) tensor of the new points' xyz new_features : torch.Tensor (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_points descriptors """ new_features_list = [] xyz_flipped = xyz.transpose(1, 2).contiguous() if self.npoint is not None: fps_idx = pointnet2_utils.furthest_point_sample( xyz, self.npoint) # (B, npoint) new_xyz = pointnet2_utils.gather_operation(xyz_flipped, fps_idx).transpose( 1, 2).contiguous() fps_idx = fps_idx.data else: new_xyz = None fps_idx = None for i in range(len(self.groupers)): new_features = self.groupers[i]( xyz, new_xyz, features, fps_idx) if self.npoint is not None else self.groupers[i]( xyz, new_xyz, features) # (B, C, npoint, nsample) new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint) new_features_list.append(new_features) return new_xyz, torch.cat(new_features_list, dim=1)
def forward(self, xyz: torch.Tensor, points: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): r""" Parameters ---------- xyz : torch.Tensor (B, N, 3) tensor of the xyz coordinates of the points point : torch.Tensor (B, N, C) tensor of the descriptors of the the points Returns ------- new_xyz : torch.Tensor (B, npoint, 3) tensor of the new points' xyz new_points : torch.Tensor (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_points descriptors """ new_points_list = [] xyz_flipped = xyz.transpose(1, 2).contiguous() new_xyz = pointnet2_utils.gather_points( xyz_flipped, pointnet2_utils.furthest_point_sample( xyz, self.npoint)).transpose( 1, 2).contiguous() if self.npoint is not None else None for i in range(len(self.groupers)): new_points = self.groupers[i](xyz, new_xyz, points) # (B, C, npoint, nsample) new_points = self.mlps[i]( new_points) # (B, mlp[-1], npoint, nsample) new_points = F.max_pool2d( new_points, kernel_size=[1, new_points.size(3)]) # (B, mlp[-1], npoint, 1) new_points = new_points.squeeze(-1) # (B, mlp[-1], npoint) new_points_list.append(new_points) return new_xyz, torch.cat(new_points_list, dim=1)
def forward(self, xyz: torch.Tensor, features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): r""" Parameters ---------- xyz : torch.Tensor (B, N, 3) tensor of the xyz coordinates of the points features : torch.Tensor (B, N, C) tensor of the descriptors of the the points Returns ------- new_xyz : torch.Tensor (B, npoint, 3) tensor of the new points' xyz new_features : torch.Tensor (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_points descriptors """ all_features = 0 xyz_flipped = xyz.transpose(1, 2).contiguous() if self.npoint is not None: fps_idx = pointnet2_utils.furthest_point_sample(xyz, self.npoint) \ if self.pool else torch.from_numpy(np.arange(xyz.size(1))).int().cuda().repeat(xyz.size(0), 1) new_xyz = pointnet2_utils.gather_operation(xyz_flipped, fps_idx).transpose( 1, 2).contiguous() else: new_xyz = None for i in range(len(self.groupers)): new_features = self.groupers[i]( xyz, new_xyz, features) # (B, C, npoint, nsample) if not self.pool and self.npoint is not None: new_features = [new_features, features] new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint) all_features += new_features return new_xyz, all_features
def forward(self, global_feat, point_input): batch_size = global_feat.size()[0] coarse = F.relu(self.fc1(global_feat)) coarse = F.relu(self.fc2(coarse)) coarse = self.fc3(coarse).view(batch_size, 3, self.num_coarse) org_points_input = point_input points = torch.cat((coarse, org_points_input), 2) dense_feat = self.encoder(points) if self.scale >= 2: dense_feat = self.expansion(dense_feat) point_feat = F.relu(self.conv1(dense_feat)) fine = self.conv2(point_feat) num_out = fine.size()[2] if num_out > self.num_fine: fine = pn2.gather_operation(fine, pn2.furthest_point_sample(fine.transpose(1, 2).contiguous(), self.num_fine)) return coarse, fine
def forward(self, inputs, gt, eps, iters, EMD=True, CD=True): cur_bs = inputs.size()[0] output1, output2 = self.model(inputs) gt = gt[:, :, :3] emd1 = emd2 = cd_p1 = cd_p2 = cd_t1 = cd_t2 = torch.tensor( [0], dtype=torch.float32).cuda() if EMD: num_coarse = self.model.num_coarse gt_fps = pn2.gather_operation( gt.transpose(1, 2).contiguous(), pn2.furthest_point_sample(gt, num_coarse)).transpose( 1, 2).contiguous() dist1, _ = self.EMD(output1, gt_fps, eps, iters) emd1 = torch.sqrt(dist1).mean(1) dist2, _ = self.EMD(output2, gt, eps, iters) emd2 = torch.sqrt(dist2).mean(1) # CD loss if CD: dist11, dist12, _, _ = chamLoss(gt, output1) cd_p1 = (torch.sqrt(dist11).mean(1) + torch.sqrt(dist12).mean(1)) / 2 cd_t1 = (dist11.mean(1) + dist12.mean(1)) dist21, dist22, _, _ = chamLoss(gt, output2) cd_p2 = (torch.sqrt(dist21).mean(1) + torch.sqrt(dist22).mean(1)) / 2 cd_t2 = (dist21.mean(1) + dist22.mean(1)) u1 = get_uniform_loss(output1) u2 = get_uniform_loss(output2) return output1, output2, emd1, emd2, cd_p1, cd_p2, cd_t1, cd_t2, u1, u2
def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, inds: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): r""" Parameters ---------- xyz : torch.Tensor (B, N, 3) tensor of the xyz coordinates of the features features : torch.Tensor (B, C, N) tensor of the descriptors of the the features inds : torch.Tensor (B, npoint) tensor that stores index to the xyz points (values in 0-N-1) Returns ------- new_xyz : torch.Tensor (B, npoint, 3) tensor of the new features' xyz new_features : torch.Tensor (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors inds: torch.Tensor (B, npoint) tensor of the inds """ xyz_flipped = xyz.transpose(1, 2).contiguous() if inds is None: inds = pointnet2_utils.furthest_point_sample(xyz, self.npoint) else: assert (inds.shape[1] == self.npoint) new_xyz = pointnet2_utils.gather_operation( xyz_flipped, inds).transpose( 1, 2).contiguous() if self.npoint is not None else None if not self.ret_unique_cnt: grouped_features, grouped_xyz = self.grouper( xyz, new_xyz, features) # (B, C, npoint, nsample) else: grouped_features, grouped_xyz, unique_cnt = self.grouper( xyz, new_xyz, features ) # (B, C, npoint, nsample), (B,3,npoint,nsample), (B,npoint) new_features = self.mlp_module( grouped_features) # (B, mlp[-1], npoint, nsample) if self.pooling == 'max': new_features = F.max_pool2d(new_features, kernel_size=[ 1, new_features.size(3) ]) # (B, mlp[-1], npoint, 1) elif self.pooling == 'avg': new_features = F.avg_pool2d(new_features, kernel_size=[ 1, new_features.size(3) ]) # (B, mlp[-1], npoint, 1) elif self.pooling == 'rbf': # Use radial basis function kernel for weighted sum of features (normalized by nsample and sigma) # Ref: https://en.wikipedia.org/wiki/Radial_basis_function_kernel rbf = torch.exp(-1 * grouped_xyz.pow(2).sum(1, keepdim=False) / (self.sigma**2) / 2) # (B, npoint, nsample) new_features = torch.sum( new_features * rbf.unsqueeze(1), -1, keepdim=True) / float( self.nsample) # (B, mlp[-1], npoint, 1) new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) if not self.ret_unique_cnt: return new_xyz, new_features, inds else: return new_xyz, new_features, inds, unique_cnt
def forward(self, partial_cloud, gt, is_training=True, mean_feature=None, alpha=None): pt_features_64_l = self.gridding(partial_cloud).view(-1, 1, 64, 64, 64) pt_features_32_l = self.conv1(pt_features_64_l) pt_features_16_l = self.conv2(pt_features_32_l) pt_features_8_l = self.conv3(pt_features_16_l) pt_features_4_l = self.conv4(pt_features_8_l) features = self.fc5(pt_features_4_l.view(-1, 16384)) pt_features_4_r = self.fc6(features).view(-1, 256, 4, 4, 4) + pt_features_4_l pt_features_8_r = self.dconv7(pt_features_4_r) + pt_features_8_l pt_features_16_r = self.dconv8(pt_features_8_r) + pt_features_16_l pt_features_32_r = self.dconv9(pt_features_16_r) + pt_features_32_l pt_features_64_r = self.dconv10(pt_features_32_r) + pt_features_64_l sparse_cloud = self.gridding_rev(pt_features_64_r.squeeze(dim=1)) sparse_cloud = self.point_sampling(sparse_cloud, partial_cloud) point_features_32 = self.feature_sampling(sparse_cloud, pt_features_32_r).view( -1, 2048, 256) point_features_16 = self.feature_sampling(sparse_cloud, pt_features_16_r).view( -1, 2048, 512) point_features_8 = self.feature_sampling(sparse_cloud, pt_features_8_r).view( -1, 2048, 1024) point_features = torch.cat( [point_features_32, point_features_16, point_features_8], dim=2) point_features = self.fc11(point_features) point_features = self.fc12(point_features) point_features = self.fc13(point_features) point_offset = self.fc14(point_features).view(-1, 16384, 3) dense_cloud = sparse_cloud.unsqueeze(dim=2).repeat(1, 1, 8, 1).view( -1, 16384, 3) + point_offset if self.num_points < 16384: idx_fps = pn2.furthest_point_sample(dense_cloud, self.num_points) dense_cloud = pn2.gather_operation(dense_cloud, idx_fps) if is_training: if self.train_loss == 'emd': loss1 = calc_emd(sparse_cloud, gt) loss2 = calc_emd(dense_cloud, gt) elif self.train_loss == 'cd': _, loss1 = calc_cd(sparse_cloud, gt) # cd_t _, loss2 = calc_cd(dense_cloud, gt) # cd_t else: raise NotImplementedError('Train loss is either CD or EMD!') total_train_loss = loss1.mean() + loss2.mean() return dense_cloud, loss2, total_train_loss else: emd = calc_emd(dense_cloud, gt, eps=0.004, iterations=3000) cd_p, cd_t, f1 = calc_cd(dense_cloud, gt, calc_f1=True) return { 'out1': dense_cloud, 'out2': dense_cloud, 'emd': emd, 'cd_p': cd_p, 'cd_t': cd_t, 'f1': f1 }
def forward(self, xyz_object, features_object, xyz_part, features_part, end_points): if self.sampling == 'vote_fps': # Farthest point sampling (FPS) on votes xyz_object, features_object, fps_inds = self.vote_aggregation( xyz_object, features_object) sample_inds = fps_inds xyz_part, features_part, fps_inds_part = self.vote_aggregation( xyz_part, features_part) sample_inds_part = fps_inds_part elif self.sampling == 'seed_fps': # FPS on seed and choose the votes corresponding to the seeds sample_inds = pointnet2_utils.furthest_point_sample( end_points['seed_xyz'], self.num_proposal) xyz_object, features_object, _ = self.vote_aggregation( xyz_object, features_object, sample_inds) sample_inds_part = pointnet2_utils.furthest_point_sample( end_points['seed_xyz'], self.num_proposal) xyz_part, features_part, _ = self.vote_aggregation( xyz_part, features_part, sample_inds_part) elif self.sampling == 'random': # Random sampling from the votes num_seed = end_points['seed_xyz'].shape[1] batch_size = end_points['seed_xyz'].shape[0] sample_inds = torch.randint(0, num_seed, (batch_size, self.num_proposal), dtype=torch.int).cuda() xyz_object, features_object, _ = self.vote_aggregation( xyz_object, features_object, sample_inds) sample_inds_part = torch.randint(0, num_seed, (batch_size, self.num_proposal), dtype=torch.int).cuda() xyz_part, features_part, _ = self.vote_aggregation( xyz_part, features_part, sample_inds_part) else: log_string('Unknown sampling strategy: %s. Exiting!' % (self.sampling)) exit() end_points['aggregated_vote_object_xyz'] = xyz_object end_points['aggregated_vote_inds'] = sample_inds end_points['aggregated_vote_part_xyz'] = xyz_part end_points['aggregated_vote_part_inds'] = sample_inds_part # --------- Learning object-to-object correlation with self-attention --------- feature_dim = features_object.shape[1] batch_size = features_object.shape[0] features_object = features_object.contiguous().view( batch_size, feature_dim, 16, 16) net = self.sa_object(features_object) net = net.contiguous().view(batch_size, feature_dim, self.num_proposal) # --------- Learning part-to-part correlation with self-attention --------- feature_part_dim = features_part.shape[1] features_part = features_part.contiguous().view( batch_size, feature_part_dim, 16, 16) net_part = self.sa_part(features_part) net_part = net.contiguous().view(batch_size, feature_part_dim, self.num_proposal) # --------- OBJECT POSE ESTIMATION --------- net = torch.cat((net, net_part), 1) net = F.relu(self.bn1(self.conv1(net))) net = F.relu(self.bn2(self.conv2(net))) net = self.conv3(net) end_points = decode_scores(net, end_points, self.num_class) return end_points
def forward(self, xyzs: torch.Tensor, features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): """ Args: xyzs: torch.Tensor (B, L, N, 3) tensor of sequence of the xyz coordinates features: torch.Tensor (B, L, C, N) tensor of sequence of the features """ device = xyzs.get_device() nframes = xyzs.size(1) # L npoints = xyzs.size(2) # N if self.temporal_kernel_size > 1 and self.temporal_stride > 1: assert ((nframes + sum(self.temporal_padding) - self.temporal_kernel_size) % self.temporal_stride == 0), "PSTConv: Temporal parameter error!" xyzs = torch.split(tensor=xyzs, split_size_or_sections=1, dim=1) xyzs = [torch.squeeze(input=xyz, dim=1).contiguous() for xyz in xyzs] if self.in_planes != 0: features = torch.split(tensor=features, split_size_or_sections=1, dim=1) features = [torch.squeeze(input=feature, dim=1).contiguous() for feature in features] if self.padding_mode == "zeros": xyz_padding = torch.zeros(xyzs[0].size(), dtype=torch.float32, device=device) for i in range(self.temporal_padding[0]): xyzs = [xyz_padding] + xyzs for i in range(self.temporal_padding[1]): xyzs = xyzs + [xyz_padding] if self.in_planes != 0: feature_padding = torch.zeros(features[0].size(), dtype=torch.float32, device=device) for i in range(self.temporal_padding[0]): features = [feature_padding] + features for i in range(self.temporal_padding[1]): features = features + [feature_padding] else: # "replicate" for i in range(self.temporal_padding[0]): xyzs = [xyzs[0]] + xyzs for i in range(self.temporal_padding[1]): xyzs = xyzs + [xyzs[-1]] if self.in_planes != 0: for i in range(self.temporal_padding[0]): features = [features[0]] + features for i in range(self.temporal_padding[1]): features = features + [features[-1]] new_xyzs = [] new_features = [] for t in range(self.temporal_radius, len(xyzs)-self.temporal_radius, self.temporal_stride): # temporal anchor frames # spatial anchor point subsampling by FPS anchor_idx = pointnet2_utils.furthest_point_sample(xyzs[t], npoints//self.spatial_stride) # (B, N//self.spatial_stride) anchor_xyz_flipped = pointnet2_utils.gather_operation(xyzs[t].transpose(1, 2).contiguous(), anchor_idx) # (B, 3, N//self.spatial_stride) anchor_xyz_expanded = torch.unsqueeze(anchor_xyz_flipped, 3) # (B, 3, N//spatial_stride, 1) anchor_xyz = anchor_xyz_flipped.transpose(1, 2).contiguous() # (B, N//spatial_stride, 3) # spatial convolution spatial_features = [] for i in range(t-self.temporal_radius, t+self.temporal_radius+1): neighbor_xyz = xyzs[i] idx = pointnet2_utils.ball_query(self.r, self.k, neighbor_xyz, anchor_xyz) neighbor_xyz_flipped = neighbor_xyz.transpose(1, 2).contiguous() # (B, 3, N) neighbor_xyz_grouped = pointnet2_utils.grouping_operation(neighbor_xyz_flipped, idx) # (B, 3, N//spatial_stride, k) displacement = neighbor_xyz_grouped - anchor_xyz_expanded # (B, 3, N//spatial_stride, k) displacement = self.spatial_conv_d(displacement) # (B, mid_planes, N//spatial_stride, k) if self.in_planes != 0: neighbor_feature_grouped = pointnet2_utils.grouping_operation(features[i], idx) # (B, in_planes, N//spatial_stride, k) feature = self.spatial_conv_f(neighbor_feature_grouped) # (B, mid_planes, N//spatial_stride, k) if self.spatial_aggregation == "addition": spatial_feature = feature + displacement else: spatial_feature = feature * displacement else: spatial_feature = displacement if self.spatial_pooling == 'max': spatial_feature, _ = torch.max(input=spatial_feature, dim=-1, keepdim=False) # (B, mid_planes, N//spatial_stride) elif self.spatial_pooling == 'sum': spatial_feature = torch.sum(input=spatial_feature, dim=-1, keepdim=False) # (B, mid_planes, N//spatial_stride) else: spatial_feature = torch.mean(input=spatial_feature, dim=-1, keepdim=False) # (B, mid_planes, N//spatial_stride) spatial_features.append(spatial_feature) spatial_features = torch.cat(tensors=spatial_features, dim=1, out=None) # (B, temporal_kernel_size*mid_planes, N//spatial_stride) # batch norm and relu if self.batch_norm: spatial_features = self.batch_norm(spatial_features) spatial_features = self.relu(spatial_features) # temporal convolution spatio_temporal_feature = self.temporal(spatial_features) new_xyzs.append(anchor_xyz) new_features.append(spatio_temporal_feature) new_xyzs = torch.stack(tensors=new_xyzs, dim=1) new_features = torch.stack(tensors=new_features, dim=1) return new_xyzs, new_features
def forward(self, global_feat, point_input): batch_size = global_feat.size()[0] coarse_raw = self.fc3(self.af(self.fc2(self.af( self.fc1(global_feat))))).view(batch_size, 3, self.num_coarse_raw) input_points_num = point_input.size()[2] org_points_input = point_input if self.points_label: id0 = torch.zeros(coarse_raw.shape[0], 1, coarse_raw.shape[2]).cuda().contiguous() coarse_input = torch.cat((coarse_raw, id0), 1) id1 = torch.ones(org_points_input.shape[0], 1, org_points_input.shape[2]).cuda().contiguous() org_points_input = torch.cat((org_points_input, id1), 1) else: coarse_input = coarse_raw points = torch.cat((coarse_input, org_points_input), 2) dense_feat = self.encoder(points) if self.up_scale >= 2: dense_feat = self.expansion1(dense_feat) coarse_features = self.af(self.conv_cup1(dense_feat)) coarse_high = self.conv_cup2(coarse_features) if coarse_high.size()[2] > self.num_fps: idx_fps = pn2.furthest_point_sample( coarse_high.transpose(1, 2).contiguous(), self.num_fps) coarse_fps = pn2.gather_operation(coarse_high, idx_fps) coarse_features = pn2.gather_operation(coarse_features, idx_fps) else: coarse_fps = coarse_high if coarse_fps.size()[2] > self.num_coarse: scores = F.softplus( self.conv_s3( self.af( self.conv_s2(self.af(self.conv_s1(coarse_features)))))) idx_scores = scores.topk(k=self.num_coarse, dim=2)[1].view(batch_size, -1).int() coarse = pn2.gather_operation(coarse_fps, idx_scores) coarse_features = pn2.gather_operation(coarse_features, idx_scores) else: coarse = coarse_fps if coarse.size()[2] < self.num_fine: if self.local_folding: up_features = self.expansion2(coarse_features, global_feat) center = coarse.transpose( 2, 1).contiguous().unsqueeze(2).repeat( 1, 1, self.num_fine // self.num_coarse, 1).view(batch_size, self.num_fine, 3).transpose(2, 1).contiguous() fine = self.conv_f2(self.af( self.conv_f1(up_features))) + center else: up_features = self.expansion2(coarse_features) fine = self.conv_f2(self.af(self.conv_f1(up_features))) else: assert (coarse.size()[2] == self.num_fine) fine = coarse return coarse_raw, coarse_high, coarse, fine
def forward(self, x, gt, is_training=True, mean_feature=None, alpha=None): num_input = x.size()[2] if is_training: y = pn2.gather_operation( gt.transpose(1, 2).contiguous(), pn2.furthest_point_sample(gt, num_input)) gt = torch.cat([gt, gt], dim=0) points = torch.cat([x, y], dim=0) x = torch.cat([x, x], dim=0) else: points = x feat = self.encoder(points) if is_training: feat_x, feat_y = feat.chunk(2) o_x = self.posterior_infer2(self.posterior_infer1(feat_x)) q_mu, q_std = torch.split(o_x, self.size_z, dim=1) o_y = self.prior_infer(feat_y) p_mu, p_std = torch.split(o_y, self.size_z, dim=1) q_std = F.softplus(q_std) p_std = F.softplus(p_std) q_distribution = torch.distributions.Normal(q_mu, q_std) p_distribution = torch.distributions.Normal(p_mu, p_std) p_distribution_fix = torch.distributions.Normal( p_mu.detach(), p_std.detach()) m_distribution = torch.distributions.Normal( torch.zeros_like(p_mu), torch.ones_like(p_std)) z_q = q_distribution.rsample() z_p = p_distribution.rsample() z = torch.cat([z_q, z_p], dim=0) feat = torch.cat([feat_x, feat_x], dim=0) else: o_x = self.posterior_infer2(self.posterior_infer1(feat)) q_mu, q_std = torch.split(o_x, self.size_z, dim=1) q_std = F.softplus(q_std) q_distribution = torch.distributions.Normal(q_mu, q_std) p_distribution = q_distribution p_distribution_fix = p_distribution m_distribution = p_distribution z = q_distribution.rsample() feat += self.generator(z) coarse_raw, coarse_high, coarse, fine = self.decoder(feat, x) coarse_raw = coarse_raw.transpose(1, 2).contiguous() coarse_high = coarse_high.transpose(1, 2).contiguous() coarse = coarse.transpose(1, 2).contiguous() fine = fine.transpose(1, 2).contiguous() if is_training: if self.distribution_loss == 'MMD': z_m = m_distribution.rsample() z_q = q_distribution.rsample() z_p = p_distribution.rsample() z_p_fix = p_distribution_fix.rsample() dl_rec = self.mmd_loss(z_m, z_p) dl_g = self.mmd_loss2(z_q, z_p_fix) elif self.distribution_loss == 'KLD': dl_rec = torch.distributions.kl_divergence( m_distribution, p_distribution) dl_g = torch.distributions.kl_divergence( p_distribution_fix, q_distribution) else: raise NotImplementedError( 'Distribution loss is either MMD or KLD') if self.train_loss == 'cd': loss1, _ = calc_cd(coarse_raw, gt) loss2, _ = calc_cd(coarse_high, gt) loss3, _ = calc_cd(coarse, gt) loss4, _ = calc_cd(fine, gt) else: raise NotImplementedError('Only CD is supported') total_train_loss = loss1.mean() * 10 + loss2.mean( ) * 0.5 + loss3.mean() + loss4.mean() * alpha total_train_loss += (dl_rec.mean() + dl_g.mean()) * 20 return fine, loss4, total_train_loss else: emd = calc_emd(fine, gt, eps=0.004, iterations=3000) cd_p, cd_t, f1 = calc_cd(fine, gt, calc_f1=True) return { 'out1': coarse_raw, 'out2': fine, 'emd': emd, 'cd_p': cd_p, 'cd_t': cd_t, 'f1': f1 }