Пример #1
0
 def forward(self, x):
     batch_size = x.shape[0]
     in_channels = x.shape[1]
     if self.num_samples > 1:
         p = x[:, :3]  # XYZ coordinates
         s = F.kernel_density(p, self.bandwidth).unsqueeze(1)
         s = self.scale(torch.reciprocal(s))  # calculate scaling factor
         q = F.farthest_point_sample(p, self.num_samples)
         _, index = F.knn(p, q, self.kernel_size)
         index = index.view(batch_size, -1).unsqueeze(1)
         # Point and density grouping
         x = torch.gather(x, 2, index.expand(-1, in_channels, -1))
         s = torch.gather(s, 2, index)
         x = x.view(batch_size, in_channels, self.kernel_size, -1)
         s = s.view(batch_size, 1, self.kernel_size, -1)
         x[:, :3] -= q.unsqueeze(2).expand(-1, -1, self.kernel_size, -1)
         w = self.weight(x[:, :3])
         x = self.mlp(x * s)
         x = torch.matmul(w.permute(0, 3, 1, 2), x.permute(0, 3, 2, 1))
         x = x.permute(0, 3, 2, 1)
         x = self.lin(x).squeeze(2)
         x = torch.cat([q, x], dim=1)
     else:
         p = x[:, :3]
         s = F.kernel_density(p, self.bandwidth).unsqueeze(1)
         s = self.scale(torch.reciprocal(s)).unsqueeze(3)
         x = x.unsqueeze(3)
         w = self.weight(x[:, :3])
         x = self.mlp(x * s)
         x = torch.matmul(w.permute(0, 3, 1, 2), x.permute(0, 3, 2, 1))
         x = x.permute(0, 3, 2, 1)
         x = self.lin(x).squeeze(2)
     return x
Пример #2
0
 def forward(self, x, y):
     batch_size = x.shape[0]
     p, x = x[:, :3], x[:, 3:]
     q, y = y[:, :3], y[:, 3:]
     x = F.interpolate(p, q, x, self.kernel_size)
     x = torch.cat([x, y], dim=1)
     s = F.kernel_density(q, self.bandwidth).unsqueeze(1)
     s = self.scale(torch.reciprocal(s))  # calculate scaling factor
     _, index = F.knn(q, q, self.kernel_size)
     index = index.view(batch_size, -1).unsqueeze(1)
     # Point and density grouping
     p = torch.gather(q, 2, index.expand(-1, 3, -1))
     x = torch.gather(x, 2, index.expand(-1, self.in_channels, -1))
     s = torch.gather(s, 2, index)
     p = p.view(batch_size, 3, self.kernel_size, -1)
     x = x.view(batch_size, self.in_channels, self.kernel_size, -1)
     s = s.view(batch_size, 1, self.kernel_size, -1)
     p = p - q.unsqueeze(2).expand(-1, -1, self.kernel_size, -1)
     w = self.weight(p)
     x = self.mlp(x * s)
     x = torch.matmul(w.permute(0, 3, 1, 2), x.permute(0, 3, 2, 1))
     x = x.permute(0, 3, 2, 1)
     x = self.lin(x).squeeze(2)
     x = torch.cat([q, x], dim=1)
     return x
Пример #3
0
 def forward(self, x):
     batch_size = x.shape[0]
     _, index = F.knn(x, x, self.kernel_size)
     index = index.view(batch_size, -1).unsqueeze(1)
     index = index.expand(-1, self.in_channels, -1)
     x_hat = torch.gather(x, 2, index)
     x_hat = x_hat.view(batch_size, self.in_channels, self.kernel_size, -1)
     x = x.unsqueeze(2).expand(-1, -1, self.kernel_size, -1)
     x_hat = x_hat - x
     x = torch.cat([x, x_hat], dim=1)
     x = super(EdgeConv, self).forward(x)
     x = x.squeeze(2)
     return x
Пример #4
0
 def forward(self, x):
     batch_size = x.shape[0]
     in_channels = x.shape[1]
     p = x[:, :3]  # XYZ coordinates
     q = F.farthest_point_sample(p, self.num_samples)
     _, index = F.knn(p, q, self.kernel_size * self.dilation)
     index = index[:, ::self.dilation]
     index = index.reshape(batch_size, -1).unsqueeze(1)
     p_hat = torch.gather(p, 2, index.expand(-1, 3, -1))
     p_hat = p_hat.view(batch_size, 3, self.kernel_size, -1)
     p_hat = p_hat - q.unsqueeze(2).expand(-1, -1, self.kernel_size, -1)
     T = self.stn(p_hat).view(batch_size, self.kernel_size,
                              self.kernel_size, -1)
     x = torch.gather(x, 2, index.expand(-1, in_channels, -1))
     x = x.view(batch_size, in_channels, self.kernel_size, -1)
     x_hat = self.mlp(p_hat)
     x_hat = torch.cat([x_hat, x], 1)
     x = torch.matmul(x_hat.permute(0, 3, 1, 2), T.permute(0, 3, 1, 2))
     x = x.permute(0, 2, 3, 1)
     x = self.conv(x).squeeze(2)
     x = torch.cat([q, x], dim=1)
     return x