Exemplo n.º 1
0
 def forward(self, x):
     batch_size = x.shape[0]
     in_channels = x.shape[1]
     if self.num_samples > 1:
         p = x[:, :3]  # XYZ coordinates
         s = F.kernel_density(p, self.bandwidth).unsqueeze(1)
         s = self.scale(torch.reciprocal(s))  # calculate scaling factor
         q = F.farthest_point_sample(p, self.num_samples)
         _, index = F.knn(p, q, self.kernel_size)
         index = index.view(batch_size, -1).unsqueeze(1)
         # Point and density grouping
         x = torch.gather(x, 2, index.expand(-1, in_channels, -1))
         s = torch.gather(s, 2, index)
         x = x.view(batch_size, in_channels, self.kernel_size, -1)
         s = s.view(batch_size, 1, self.kernel_size, -1)
         x[:, :3] -= q.unsqueeze(2).expand(-1, -1, self.kernel_size, -1)
         w = self.weight(x[:, :3])
         x = self.mlp(x * s)
         x = torch.matmul(w.permute(0, 3, 1, 2), x.permute(0, 3, 2, 1))
         x = x.permute(0, 3, 2, 1)
         x = self.lin(x).squeeze(2)
         x = torch.cat([q, x], dim=1)
     else:
         p = x[:, :3]
         s = F.kernel_density(p, self.bandwidth).unsqueeze(1)
         s = self.scale(torch.reciprocal(s)).unsqueeze(3)
         x = x.unsqueeze(3)
         w = self.weight(x[:, :3])
         x = self.mlp(x * s)
         x = torch.matmul(w.permute(0, 3, 1, 2), x.permute(0, 3, 2, 1))
         x = x.permute(0, 3, 2, 1)
         x = self.lin(x).squeeze(2)
     return x
Exemplo n.º 2
0
def test_farthest_point_sample():
    batch_size = 1
    num_samples = 2
    in_channels = 3
    p = torch.tensor([[
        [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
        [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0],
        [0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0],
    ]])
    size = torch.Size([batch_size, in_channels, num_samples])
    q = F.farthest_point_sample(p, num_samples)
    assert q.shape == size
    assert q.tolist() == [[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]]]
Exemplo n.º 3
0
 def forward(self, x):
     batch_size = x.shape[0]
     in_channels = x.shape[1]
     if self.num_samples > 1:
         p = x[:, :3]  # XYZ coordinates
         q = F.farthest_point_sample(p, self.num_samples)
         index = F.ball_point(p, q, self.kernel_size, self.radius)
         index = index.view(batch_size, -1)
         index = index.unsqueeze(1).expand(-1, in_channels, -1)
         x = torch.gather(x, 2, index)
         x = x.view(batch_size, in_channels, self.kernel_size, -1)
         x[:, :3] -= q.unsqueeze(2).expand(-1, -1, self.kernel_size, -1)
         x = super(SetAbstraction, self).forward(x)
         x = x.squeeze(2)
         x = torch.cat([q, x], dim=1)
     else:
         x = x.unsqueeze(3)
         x = super(SetAbstraction, self).forward(x)
         x = x.squeeze(2)
     return x
Exemplo n.º 4
0
 def forward(self, x):
     batch_size = x.shape[0]
     num_points = x.shape[2]
     if self.radius is not None:
         num_samples = num_points // self.stride
         p = x[:, :3]  # XYZ coordinates
         q = F.farthest_point_sample(p, num_samples)
         index = F.ball_point(p, q, self.kernel_size, self.radius)
         index = index.view(batch_size, -1)
         index = index.unsqueeze(1).expand(-1, self.in_channels, -1)
         x = torch.gather(x, 2, index)
         x = x.view(batch_size, self.in_channels, self.kernel_size, -1)
         x[:, :3] -= q.unsqueeze(2).expand(-1, -1, self.kernel_size, -1)
         x = super(SetConv, self).forward(x)
         x = x.squeeze(2)
         x = torch.cat([q, x], dim=1)
     else:
         x = x.unsqueeze(3)
         x = super(SetConv, self).forward(x)
         x = x.squeeze(2)
     return x
Exemplo n.º 5
0
 def forward(self, x):
     batch_size = x.shape[0]
     in_channels = x.shape[1]
     p = x[:, :3]  # XYZ coordinates
     q = F.farthest_point_sample(p, self.num_samples)
     _, index = F.knn(p, q, self.kernel_size * self.dilation)
     index = index[:, ::self.dilation]
     index = index.reshape(batch_size, -1).unsqueeze(1)
     p_hat = torch.gather(p, 2, index.expand(-1, 3, -1))
     p_hat = p_hat.view(batch_size, 3, self.kernel_size, -1)
     p_hat = p_hat - q.unsqueeze(2).expand(-1, -1, self.kernel_size, -1)
     T = self.stn(p_hat).view(batch_size, self.kernel_size,
                              self.kernel_size, -1)
     x = torch.gather(x, 2, index.expand(-1, in_channels, -1))
     x = x.view(batch_size, in_channels, self.kernel_size, -1)
     x_hat = self.mlp(p_hat)
     x_hat = torch.cat([x_hat, x], 1)
     x = torch.matmul(x_hat.permute(0, 3, 1, 2), T.permute(0, 3, 1, 2))
     x = x.permute(0, 2, 3, 1)
     x = self.conv(x).squeeze(2)
     x = torch.cat([q, x], dim=1)
     return x