def forward(self, x): batch_size = x.shape[0] in_channels = x.shape[1] if self.num_samples > 1: p = x[:, :3] # XYZ coordinates s = F.kernel_density(p, self.bandwidth).unsqueeze(1) s = self.scale(torch.reciprocal(s)) # calculate scaling factor q = F.farthest_point_sample(p, self.num_samples) _, index = F.knn(p, q, self.kernel_size) index = index.view(batch_size, -1).unsqueeze(1) # Point and density grouping x = torch.gather(x, 2, index.expand(-1, in_channels, -1)) s = torch.gather(s, 2, index) x = x.view(batch_size, in_channels, self.kernel_size, -1) s = s.view(batch_size, 1, self.kernel_size, -1) x[:, :3] -= q.unsqueeze(2).expand(-1, -1, self.kernel_size, -1) w = self.weight(x[:, :3]) x = self.mlp(x * s) x = torch.matmul(w.permute(0, 3, 1, 2), x.permute(0, 3, 2, 1)) x = x.permute(0, 3, 2, 1) x = self.lin(x).squeeze(2) x = torch.cat([q, x], dim=1) else: p = x[:, :3] s = F.kernel_density(p, self.bandwidth).unsqueeze(1) s = self.scale(torch.reciprocal(s)).unsqueeze(3) x = x.unsqueeze(3) w = self.weight(x[:, :3]) x = self.mlp(x * s) x = torch.matmul(w.permute(0, 3, 1, 2), x.permute(0, 3, 2, 1)) x = x.permute(0, 3, 2, 1) x = self.lin(x).squeeze(2) return x
def test_farthest_point_sample(): batch_size = 1 num_samples = 2 in_channels = 3 p = torch.tensor([[ [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0], [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0], ]]) size = torch.Size([batch_size, in_channels, num_samples]) q = F.farthest_point_sample(p, num_samples) assert q.shape == size assert q.tolist() == [[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]]]
def forward(self, x): batch_size = x.shape[0] in_channels = x.shape[1] if self.num_samples > 1: p = x[:, :3] # XYZ coordinates q = F.farthest_point_sample(p, self.num_samples) index = F.ball_point(p, q, self.kernel_size, self.radius) index = index.view(batch_size, -1) index = index.unsqueeze(1).expand(-1, in_channels, -1) x = torch.gather(x, 2, index) x = x.view(batch_size, in_channels, self.kernel_size, -1) x[:, :3] -= q.unsqueeze(2).expand(-1, -1, self.kernel_size, -1) x = super(SetAbstraction, self).forward(x) x = x.squeeze(2) x = torch.cat([q, x], dim=1) else: x = x.unsqueeze(3) x = super(SetAbstraction, self).forward(x) x = x.squeeze(2) return x
def forward(self, x): batch_size = x.shape[0] num_points = x.shape[2] if self.radius is not None: num_samples = num_points // self.stride p = x[:, :3] # XYZ coordinates q = F.farthest_point_sample(p, num_samples) index = F.ball_point(p, q, self.kernel_size, self.radius) index = index.view(batch_size, -1) index = index.unsqueeze(1).expand(-1, self.in_channels, -1) x = torch.gather(x, 2, index) x = x.view(batch_size, self.in_channels, self.kernel_size, -1) x[:, :3] -= q.unsqueeze(2).expand(-1, -1, self.kernel_size, -1) x = super(SetConv, self).forward(x) x = x.squeeze(2) x = torch.cat([q, x], dim=1) else: x = x.unsqueeze(3) x = super(SetConv, self).forward(x) x = x.squeeze(2) return x
def forward(self, x): batch_size = x.shape[0] in_channels = x.shape[1] p = x[:, :3] # XYZ coordinates q = F.farthest_point_sample(p, self.num_samples) _, index = F.knn(p, q, self.kernel_size * self.dilation) index = index[:, ::self.dilation] index = index.reshape(batch_size, -1).unsqueeze(1) p_hat = torch.gather(p, 2, index.expand(-1, 3, -1)) p_hat = p_hat.view(batch_size, 3, self.kernel_size, -1) p_hat = p_hat - q.unsqueeze(2).expand(-1, -1, self.kernel_size, -1) T = self.stn(p_hat).view(batch_size, self.kernel_size, self.kernel_size, -1) x = torch.gather(x, 2, index.expand(-1, in_channels, -1)) x = x.view(batch_size, in_channels, self.kernel_size, -1) x_hat = self.mlp(p_hat) x_hat = torch.cat([x_hat, x], 1) x = torch.matmul(x_hat.permute(0, 3, 1, 2), T.permute(0, 3, 1, 2)) x = x.permute(0, 2, 3, 1) x = self.conv(x).squeeze(2) x = torch.cat([q, x], dim=1) return x