def forward(self, coords, features): r""" Forward pass Parameters ---------- coords: torch.Tensor, shape (B, N, 3) coordinates of the point cloud features: torch.Tensor, shape (B, d_in, N, 1) features of the point cloud Returns ------- torch.Tensor, shape (B, 2*d_out, N, 1) """ knn_output = knn(coords.cpu().contiguous(), coords.cpu().contiguous(), self.num_neighbors) x = self.mlp1(features) x = self.lse1(coords, x, knn_output) x = self.pool1(x) x = self.lse2(coords, x, knn_output) x = self.pool2(x) return self.lrelu(self.mlp2(x) + self.shortcut(features))
def forward(self, inputs): # finding neighboring points coords = inputs['points'][:,:3] knn_output = knn(coords.cpu().contiguous(), coords.cpu().contiguous(), self.num_neighbors) idx, dist = knn_output B, N, K = idx.size() # idx(B, N, K), coords(B, N, 3) # neighbors[b, i, n, k] = coords[b, idx[b, n, k], i] = extended_coords[b, i, extended_idx[b, i, n, k], k] extended_idx = idx.unsqueeze(1).expand(B, 3, N, K) extended_coords = coords.transpose(-2, -1).unsqueeze(-1).expand(B, 3, N, K) neighbors = torch.gather(extended_coords, 2, extended_idx) # shape (B, 3, N, K) # if USE_CUDA: # neighbors = neighbors.cuda() # relative point position encoding concat = torch.cat(( extended_coords, neighbors, extended_coords - neighbors, dist.unsqueeze(-3) ), dim=-3).to(self.device) ## [B,3+3+3+1,N,K] if concat.shape[0] > self.part: # nn.Linear performs randomly when batch size is too large num_parts = concat.shape[0] // self.part part_linear_out = [self.linear(concat[num_part*self.part:(num_part+1)*self.part]) for num_part in range(num_parts+1)] x = torch.cat(part_linear_out, dim=0) else: x = self.linear(concat) torch.backends.cudnn.enabled = False x = self.norm(x.permute(0, 2, 1)).permute(0, 2, 1) if self.use_norm else x torch.backends.cudnn.enabled = True x = F.relu(x) ## maxpooling x_max = torch.max(x, dim=1, keepdim=True)[0] if self.last_vfe: return x_max else: x_repeat = x_max.repeat(1, inputs.shape[1], 1) x_concatenated = torch.cat([x, x_repeat], dim=2) return x_concatenated
def forward(self, input): r""" Forward pass Parameters ---------- input: torch.Tensor, shape (B, N, d_in) input points Returns ------- torch.Tensor, shape (B, num_classes, N) segmentation scores for each point """ N = input.size(1) d = self.decimation coords = input[...,:3].clone().cpu() x = self.fc_start(input).transpose(-2,-1).unsqueeze(-1) x = self.bn_start(x) # shape (B, d, N, 1) decimation_ratio = 1 # <<<<<<<<<< ENCODER x_stack = [] permutation = torch.randperm(N) coords = coords[:,permutation] x = x[:,:,permutation] for lfa in self.encoder: # at iteration i, x.shape = (B, N//(d**i), d_in) x = lfa(coords[:,:N//decimation_ratio], x) x_stack.append(x.clone()) decimation_ratio *= d x = x[:,:,:N//decimation_ratio] # # >>>>>>>>>> ENCODER x = self.mlp(x) # <<<<<<<<<< DECODER for mlp in self.decoder: neighbors, _ = knn( coords[:,:N//decimation_ratio].cpu().contiguous(), # original set coords[:,:d*N//decimation_ratio].cpu().contiguous(), # upsampled set 1 ) # shape (B, N, 1) neighbors = neighbors.to(self.device) extended_neighbors = neighbors.unsqueeze(1).expand(-1, x.size(1), -1, 1) x_neighbors = torch.gather(x, -2, extended_neighbors) x = torch.cat((x_neighbors, x_stack.pop()), dim=1) x = mlp(x) decimation_ratio //= d # >>>>>>>>>> DECODER # inverse permutation x = x[:,:,torch.argsort(permutation)] scores = self.fc_end(x) return scores.squeeze(-1)
def forward(self, input): r""" Forward pass Parameters ---------- input: torch.Tensor, shape (B, N, d_in) input points Returns ------- torch.Tensor, shape (B, num_classes, N) segmentation scores for each point """ N = input.size(1) # (B,N,d_in) d = self.decimation # sample multiplier, e.g. 4 coords = input[..., :3].clone().cpu() # (B,N,3) x = self.fc_start(input).transpose(-2, -1).unsqueeze(-1) x = self.bn_start(x) # shape (B, d, N, 1) decimation_ratio = 1 # Note: at first it is 1 # <<<<<<<<<< ENCODER x_stack = [] # store the encoder results for decoder permutation = torch.randperm(N) coords = coords[:, permutation] # permute points x = x[:, :, permutation] # permute points for lfa in self.encoder: # at iteration i, x.shape = (B, N//(d**i), d_in) x = lfa( coords[:, :N // decimation_ratio], x ) # shape (B,(i+1)*4*8,N//(d*(i+1)),1), i start is current index, from 0 x_stack.append(x.clone()) decimation_ratio *= d # random sampling operation x = x[:, :, :N // decimation_ratio] # # >>>>>>>>>> ENCODER x = self.mlp(x) # (B,512,N/256=256,1) # <<<<<<<<<< DECODER for mlp in self.decoder: # the below shape only list 1st decoder layer neighbors, _ = knn( coords[:, :N // decimation_ratio].cpu().contiguous( ), # original set (B,256,1) coords[:, :d * N // decimation_ratio].cpu().contiguous( ), # upsampled set for query points (B,1024,1) 1) # shape (B, d*N//decimation_ratio, 1) # here means in the nearest nb_idx wrt query points (i.e. orginal set) neighbors = neighbors.to(self.device) # (B,1024,1) extended_neighbors = neighbors.unsqueeze(1).expand( -1, x.size(1), -1, 1) # (B,512,1024,1) x_neighbors = torch.gather( x, -2, extended_neighbors) # find x's nearest nbs (B,512,?,1) x = torch.cat((x_neighbors, x_stack.pop()), dim=1) # skip connection (global + local features) x = mlp(x) decimation_ratio //= d # >>>>>>>>>> DECODER # inverse permutation x = x[:, :, torch.argsort(permutation)] # (B,C,N,1) scores = self.fc_end(x) #(B,C,N,1) return scores.squeeze(-1) #(B,C,N)