def test_spline_weighting_forward(test, dtype, device): x = tensor(test['x'], dtype, device) weight = tensor(test['weight'], dtype, device) basis = tensor(test['basis'], dtype, device) weight_index = tensor(test['weight_index'], torch.long, device) out = spline_weighting(x, weight, basis, weight_index) assert out.tolist() == test['expected']
def message(self, x_j, pseudo): data = spline_basis(pseudo, self._buffers['kernel_size'], self._buffers['is_open_spline'], self.degree) return spline_weighting(x_j, self.weight, *data)
def message(self, x_j: Tensor, edge_attr: OptTensor) -> Tensor: assert edge_attr is not None data = spline_basis(edge_attr, self.kernel_size, self.is_open_spline, self.degree) return spline_weighting(x_j, self.weight, *data)
def message(self, x_j, pseudo): data = spline_basis(pseudo, self.kernel_size, self.is_open_spline, self.degree) return spline_weighting(x_j, self.weight, *data)