Пример #1
0
def test_ball_point():
    batch_size = 1
    k = 2
    radius = 0.5
    p = torch.tensor([[
        [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
        [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0],
        [0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0],
    ]])
    q = p.clone()
    num_points = p.shape[2]
    size = torch.Size([batch_size, k, num_points])
    index = F.ball_point(p, q, k, radius)
    assert index.shape == size
    assert index.tolist() == [[[0, 1, 2, 3, 4, 5, 6, 7],
                               [0, 1, 2, 3, 4, 5, 6, 7]]]
Пример #2
0
 def forward(self, x):
     batch_size = x.shape[0]
     in_channels = x.shape[1]
     if self.num_samples > 1:
         p = x[:, :3]  # XYZ coordinates
         q = F.farthest_point_sample(p, self.num_samples)
         index = F.ball_point(p, q, self.kernel_size, self.radius)
         index = index.view(batch_size, -1)
         index = index.unsqueeze(1).expand(-1, in_channels, -1)
         x = torch.gather(x, 2, index)
         x = x.view(batch_size, in_channels, self.kernel_size, -1)
         x[:, :3] -= q.unsqueeze(2).expand(-1, -1, self.kernel_size, -1)
         x = super(SetAbstraction, self).forward(x)
         x = x.squeeze(2)
         x = torch.cat([q, x], dim=1)
     else:
         x = x.unsqueeze(3)
         x = super(SetAbstraction, self).forward(x)
         x = x.squeeze(2)
     return x
Пример #3
0
 def forward(self, x):
     batch_size = x.shape[0]
     num_points = x.shape[2]
     if self.radius is not None:
         num_samples = num_points // self.stride
         p = x[:, :3]  # XYZ coordinates
         q = F.farthest_point_sample(p, num_samples)
         index = F.ball_point(p, q, self.kernel_size, self.radius)
         index = index.view(batch_size, -1)
         index = index.unsqueeze(1).expand(-1, self.in_channels, -1)
         x = torch.gather(x, 2, index)
         x = x.view(batch_size, self.in_channels, self.kernel_size, -1)
         x[:, :3] -= q.unsqueeze(2).expand(-1, -1, self.kernel_size, -1)
         x = super(SetConv, self).forward(x)
         x = x.squeeze(2)
         x = torch.cat([q, x], dim=1)
     else:
         x = x.unsqueeze(3)
         x = super(SetConv, self).forward(x)
         x = x.squeeze(2)
     return x