def test_spline_basis_forward(test, dtype, device): pseudo = tensor(test['pseudo'], dtype, device) kernel_size = tensor(test['kernel_size'], torch.long, device) is_open_spline = tensor(test['is_open_spline'], torch.uint8, device) degree = 1 basis, weight_index = spline_basis(pseudo, kernel_size, is_open_spline, degree) assert basis.tolist() == test['basis'] assert weight_index.tolist() == test['weight_index']
def test_spline_weighting_backward(device): pseudo = torch.rand((4, 2), dtype=torch.double, device=device) kernel_size = tensor([5, 5], torch.long, device) is_open_spline = tensor([1, 1], torch.uint8, device) degree = 1 basis, weight_index = spline_basis(pseudo, kernel_size, is_open_spline, degree) basis.requires_grad_() x = torch.rand((4, 2), dtype=torch.double, device=device) x.requires_grad_() weight = torch.rand((25, 2, 4), dtype=torch.double, device=device) weight.requires_grad_() data = (x, weight, basis, weight_index) assert gradcheck(spline_weighting, data, eps=1e-6, atol=1e-4) is True
def message(self, x_j, pseudo): data = spline_basis(pseudo, self._buffers['kernel_size'], self._buffers['is_open_spline'], self.degree) return spline_weighting(x_j, self.weight, *data)
def message(self, x_j: Tensor, edge_attr: OptTensor) -> Tensor: assert edge_attr is not None data = spline_basis(edge_attr, self.kernel_size, self.is_open_spline, self.degree) return spline_weighting(x_j, self.weight, *data)
def message(self, x_j, pseudo): data = spline_basis(pseudo, self.kernel_size, self.is_open_spline, self.degree) return spline_weighting(x_j, self.weight, *data)