def forward(self, x): batch_size = x.shape[0] in_channels = x.shape[1] if self.num_samples > 1: p = x[:, :3] # XYZ coordinates s = F.kernel_density(p, self.bandwidth).unsqueeze(1) s = self.scale(torch.reciprocal(s)) # calculate scaling factor q = F.farthest_point_sample(p, self.num_samples) _, index = F.knn(p, q, self.kernel_size) index = index.view(batch_size, -1).unsqueeze(1) # Point and density grouping x = torch.gather(x, 2, index.expand(-1, in_channels, -1)) s = torch.gather(s, 2, index) x = x.view(batch_size, in_channels, self.kernel_size, -1) s = s.view(batch_size, 1, self.kernel_size, -1) x[:, :3] -= q.unsqueeze(2).expand(-1, -1, self.kernel_size, -1) w = self.weight(x[:, :3]) x = self.mlp(x * s) x = torch.matmul(w.permute(0, 3, 1, 2), x.permute(0, 3, 2, 1)) x = x.permute(0, 3, 2, 1) x = self.lin(x).squeeze(2) x = torch.cat([q, x], dim=1) else: p = x[:, :3] s = F.kernel_density(p, self.bandwidth).unsqueeze(1) s = self.scale(torch.reciprocal(s)).unsqueeze(3) x = x.unsqueeze(3) w = self.weight(x[:, :3]) x = self.mlp(x * s) x = torch.matmul(w.permute(0, 3, 1, 2), x.permute(0, 3, 2, 1)) x = x.permute(0, 3, 2, 1) x = self.lin(x).squeeze(2) return x
def forward(self, x, y): batch_size = x.shape[0] p, x = x[:, :3], x[:, 3:] q, y = y[:, :3], y[:, 3:] x = F.interpolate(p, q, x, self.kernel_size) x = torch.cat([x, y], dim=1) s = F.kernel_density(q, self.bandwidth).unsqueeze(1) s = self.scale(torch.reciprocal(s)) # calculate scaling factor _, index = F.knn(q, q, self.kernel_size) index = index.view(batch_size, -1).unsqueeze(1) # Point and density grouping p = torch.gather(q, 2, index.expand(-1, 3, -1)) x = torch.gather(x, 2, index.expand(-1, self.in_channels, -1)) s = torch.gather(s, 2, index) p = p.view(batch_size, 3, self.kernel_size, -1) x = x.view(batch_size, self.in_channels, self.kernel_size, -1) s = s.view(batch_size, 1, self.kernel_size, -1) p = p - q.unsqueeze(2).expand(-1, -1, self.kernel_size, -1) w = self.weight(p) x = self.mlp(x * s) x = torch.matmul(w.permute(0, 3, 1, 2), x.permute(0, 3, 2, 1)) x = x.permute(0, 3, 2, 1) x = self.lin(x).squeeze(2) x = torch.cat([q, x], dim=1) return x
def forward(self, x): batch_size = x.shape[0] _, index = F.knn(x, x, self.kernel_size) index = index.view(batch_size, -1).unsqueeze(1) index = index.expand(-1, self.in_channels, -1) x_hat = torch.gather(x, 2, index) x_hat = x_hat.view(batch_size, self.in_channels, self.kernel_size, -1) x = x.unsqueeze(2).expand(-1, -1, self.kernel_size, -1) x_hat = x_hat - x x = torch.cat([x, x_hat], dim=1) x = super(EdgeConv, self).forward(x) x = x.squeeze(2) return x
def forward(self, x): batch_size = x.shape[0] in_channels = x.shape[1] p = x[:, :3] # XYZ coordinates q = F.farthest_point_sample(p, self.num_samples) _, index = F.knn(p, q, self.kernel_size * self.dilation) index = index[:, ::self.dilation] index = index.reshape(batch_size, -1).unsqueeze(1) p_hat = torch.gather(p, 2, index.expand(-1, 3, -1)) p_hat = p_hat.view(batch_size, 3, self.kernel_size, -1) p_hat = p_hat - q.unsqueeze(2).expand(-1, -1, self.kernel_size, -1) T = self.stn(p_hat).view(batch_size, self.kernel_size, self.kernel_size, -1) x = torch.gather(x, 2, index.expand(-1, in_channels, -1)) x = x.view(batch_size, in_channels, self.kernel_size, -1) x_hat = self.mlp(p_hat) x_hat = torch.cat([x_hat, x], 1) x = torch.matmul(x_hat.permute(0, 3, 1, 2), T.permute(0, 3, 1, 2)) x = x.permute(0, 2, 3, 1) x = self.conv(x).squeeze(2) x = torch.cat([q, x], dim=1) return x