コード例 #1
0
def test_spline_basis_forward(test, dtype, device):
    pseudo = tensor(test['pseudo'], dtype, device)
    kernel_size = tensor(test['kernel_size'], torch.long, device)
    is_open_spline = tensor(test['is_open_spline'], torch.uint8, device)
    degree = 1

    basis, weight_index = spline_basis(pseudo, kernel_size, is_open_spline,
                                       degree)
    assert basis.tolist() == test['basis']
    assert weight_index.tolist() == test['weight_index']
コード例 #2
0
def test_spline_weighting_backward(device):
    pseudo = torch.rand((4, 2), dtype=torch.double, device=device)
    kernel_size = tensor([5, 5], torch.long, device)
    is_open_spline = tensor([1, 1], torch.uint8, device)
    degree = 1

    basis, weight_index = spline_basis(pseudo, kernel_size, is_open_spline,
                                       degree)
    basis.requires_grad_()

    x = torch.rand((4, 2), dtype=torch.double, device=device)
    x.requires_grad_()
    weight = torch.rand((25, 2, 4), dtype=torch.double, device=device)
    weight.requires_grad_()

    data = (x, weight, basis, weight_index)
    assert gradcheck(spline_weighting, data, eps=1e-6, atol=1e-4) is True
コード例 #3
0
 def message(self, x_j, pseudo):
     data = spline_basis(pseudo, self._buffers['kernel_size'],
                         self._buffers['is_open_spline'], self.degree)
     return spline_weighting(x_j, self.weight, *data)
コード例 #4
0
 def message(self, x_j: Tensor, edge_attr: OptTensor) -> Tensor:
     assert edge_attr is not None
     data = spline_basis(edge_attr, self.kernel_size, self.is_open_spline,
                         self.degree)
     return spline_weighting(x_j, self.weight, *data)
コード例 #5
0
 def message(self, x_j, pseudo):
     data = spline_basis(pseudo, self.kernel_size, self.is_open_spline,
                         self.degree)
     return spline_weighting(x_j, self.weight, *data)