Exemplo n.º 1
0
 def forward(self, x, y):
     batch_size = x.shape[0]
     p, x = x[:, :3], x[:, 3:]
     q, y = y[:, :3], y[:, 3:]
     x = F.interpolate(p, q, x, self.kernel_size)
     x = torch.cat([x, y], dim=1)
     s = F.kernel_density(q, self.bandwidth).unsqueeze(1)
     s = self.scale(torch.reciprocal(s))  # calculate scaling factor
     _, index = F.knn(q, q, self.kernel_size)
     index = index.view(batch_size, -1).unsqueeze(1)
     # Point and density grouping
     p = torch.gather(q, 2, index.expand(-1, 3, -1))
     x = torch.gather(x, 2, index.expand(-1, self.in_channels, -1))
     s = torch.gather(s, 2, index)
     p = p.view(batch_size, 3, self.kernel_size, -1)
     x = x.view(batch_size, self.in_channels, self.kernel_size, -1)
     s = s.view(batch_size, 1, self.kernel_size, -1)
     p = p - q.unsqueeze(2).expand(-1, -1, self.kernel_size, -1)
     w = self.weight(p)
     x = self.mlp(x * s)
     x = torch.matmul(w.permute(0, 3, 1, 2), x.permute(0, 3, 2, 1))
     x = x.permute(0, 3, 2, 1)
     x = self.lin(x).squeeze(2)
     x = torch.cat([q, x], dim=1)
     return x
Exemplo n.º 2
0
 def forward(self, x, y):
     p, x = x[:, :3], x[:, 3:]
     q, y = y[:, :3], y[:, 3:]
     x = F.interpolate(p, q, x, self.kernel_size)
     x = torch.cat([x, y], dim=1)
     x = super(SetDeconv, self).forward(x)
     x = torch.cat([q, x], dim=1)
     return x
Exemplo n.º 3
0
def test_interpolate():
    k = 2
    p = torch.tensor([[
        [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
        [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0],
        [0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0],
    ]])
    q = p.clone()
    x = torch.tensor([[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]]])
    y = F.interpolate(p, q, x, k)
    assert torch.allclose(x, y)