def test_segment_out(test, reduce, dtype, device): src = tensor(test['src'], dtype, device) index = tensor(test['index'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device) expected = tensor(test[reduce], dtype, device) size = list(src.size()) size[indptr.dim() - 1] = indptr.size(-1) - 1 out = src.new_full(size, -2) segment_csr(src, indptr, out, reduce=reduce) assert torch.all(out == expected) out.fill_(-2) segment_coo(src, index, out, reduce=reduce) if reduce == 'sum': expected = expected - 2 elif reduce == 'mean': expected = out # We can not really test this here. elif reduce == 'min': expected = expected.fill_(-2) elif reduce == 'max': expected[expected == 0] = -2 else: raise ValueError assert torch.all(out == expected)
def correctness(dataset): group, name = dataset mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long) row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long) dim_size = rowptr.size(0) - 1 for size in sizes: try: x = torch.randn((row.size(0), size), device=args.device) x = x.squeeze(-1) if size == 1 else x out1 = scatter_add(x, row, dim=0, dim_size=dim_size) out2 = segment_coo(x, row, dim_size=dim_size, reduce='add') out3 = segment_csr(x, rowptr, reduce='add') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) out1 = scatter_mean(x, row, dim=0, dim_size=dim_size) out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean') out3 = segment_csr(x, rowptr, reduce='mean') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) x = x.abs_().mul_(-1) out1, _ = scatter_min(x, row, 0, torch.zeros_like(out1)) out2, _ = segment_coo(x, row, reduce='min') out3, _ = segment_csr(x, rowptr, reduce='min') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) x = x.abs_() out1, _ = scatter_max(x, row, 0, torch.zeros_like(out1)) out2, _ = segment_coo(x, row, reduce='max') out3, _ = segment_csr(x, rowptr, reduce='max') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) except RuntimeError as e: if 'out of memory' not in str(e): raise RuntimeError(e) torch.cuda.empty_cache()
def test_non_contiguous_segment(test, reduce, dtype, device): src = tensor(test['src'], dtype, device) index = tensor(test['index'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device) expected = tensor(test[reduce], dtype, device) if src.dim() > 1: src = src.transpose(0, 1).contiguous().transpose(0, 1) if index.dim() > 1: index = index.transpose(0, 1).contiguous().transpose(0, 1) if indptr.dim() > 1: indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1) out = segment_coo(src, index, reduce=reduce) if isinstance(out, tuple): out, arg_out = out arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) assert torch.all(arg_out == arg_expected) assert torch.all(out == expected) out = segment_csr(src, indptr, reduce=reduce) if isinstance(out, tuple): out, arg_out = out arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) assert torch.all(arg_out == arg_expected) assert torch.all(out == expected)
def test_zero_elements(reduce, dtype, device): x = torch.randn(0, 0, 0, 16, dtype=dtype, device=device, requires_grad=True) index = tensor([], torch.long, device) indptr = tensor([], torch.long, device) out = scatter(x, index, dim=0, dim_size=0, reduce=reduce) out.backward(torch.randn_like(out)) assert out.size() == (0, 0, 0, 16) out = segment_coo(x, index, dim_size=0, reduce=reduce) out.backward(torch.randn_like(out)) assert out.size() == (0, 0, 0, 16) out = gather_coo(x, index) out.backward(torch.randn_like(out)) assert out.size() == (0, 0, 0, 16) out = segment_csr(x, indptr, reduce=reduce) out.backward(torch.randn_like(out)) assert out.size() == (0, 0, 0, 16) out = gather_csr(x, indptr) out.backward(torch.randn_like(out)) assert out.size() == (0, 0, 0, 16)
def forward(self, data): x = self.mpl1(data.edge_index, data.node_attr, data.edge_attr) x = self.act1(self.lin1(x)) x = segment_coo(x, data.batch, reduce="sum") x = self.act2(self.lin2(x)) x = self.lin3(x) return x
def test_forward(test, reduce, dtype): device = torch.device('cuda:1') src = tensor(test['src'], dtype, device) index = tensor(test['index'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device) dim = test['dim'] expected = tensor(test[reduce], dtype, device) out = torch_scatter.scatter(src, index, dim, reduce=reduce) assert torch.all(out == expected) out = torch_scatter.segment_coo(src, index, reduce=reduce) assert torch.all(out == expected) out = torch_scatter.segment_csr(src, indptr, reduce=reduce) assert torch.all(out == expected)
def test_forward(test, reduce, dtype, device): src = tensor(test['src'], dtype, device) index = tensor(test['index'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device) expected = tensor(test[reduce], dtype, device) out = segment_coo(src, index, reduce=reduce) if isinstance(out, tuple): out, arg_out = out arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) assert torch.all(arg_out == arg_expected) assert torch.all(out == expected) out = segment_csr(src, indptr, reduce=reduce) if isinstance(out, tuple): out, arg_out = out arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) assert torch.all(arg_out == arg_expected) assert torch.all(out == expected)
def seg_coo(x): return segment_coo(x, row, reduce=args.reduce)
def get_max_neighbors_mask(natoms, index, atom_distance, max_num_neighbors_threshold): """ Give a mask that filters out edges so that each atom has at most `max_num_neighbors_threshold` neighbors. Assumes that `index` is sorted. """ device = natoms.device num_atoms = natoms.sum() # Get number of neighbors # segment_coo assumes sorted index ones = index.new_ones(1).expand_as(index) num_neighbors = segment_coo(ones, index, dim_size=num_atoms) max_num_neighbors = num_neighbors.max() num_neighbors_thresholded = num_neighbors.clamp( max=max_num_neighbors_threshold) # Get number of (thresholded) neighbors per image image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long) image_indptr[1:] = torch.cumsum(natoms, dim=0) num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr) # If max_num_neighbors is below the threshold, return early if (max_num_neighbors <= max_num_neighbors_threshold or max_num_neighbors_threshold <= 0): mask_num_neighbors = torch.tensor([True], dtype=bool, device=device).expand_as(index) return mask_num_neighbors, num_neighbors_image # Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors. # Fill with infinity so we can easily remove unused distances later. distance_sort = torch.full([num_atoms * max_num_neighbors], np.inf, device=device) # Create an index map to map distances from atom_distance to distance_sort # index_sort_map assumes index to be sorted index_neighbor_offset = torch.cumsum(num_neighbors, dim=0) - num_neighbors index_neighbor_offset_expand = torch.repeat_interleave( index_neighbor_offset, num_neighbors) index_sort_map = (index * max_num_neighbors + torch.arange(len(index), device=device) - index_neighbor_offset_expand) distance_sort.index_copy_(0, index_sort_map, atom_distance) distance_sort = distance_sort.view(num_atoms, max_num_neighbors) # Sort neighboring atoms based on distance distance_sort, index_sort = torch.sort(distance_sort, dim=1) # Select the max_num_neighbors_threshold neighbors that are closest distance_sort = distance_sort[:, :max_num_neighbors_threshold] index_sort = index_sort[:, :max_num_neighbors_threshold] # Offset index_sort so that it indexes into index index_sort = index_sort + index_neighbor_offset.view(-1, 1).expand( -1, max_num_neighbors_threshold) # Remove "unused pairs" with infinite distances mask_finite = torch.isfinite(distance_sort) index_sort = torch.masked_select(index_sort, mask_finite) # At this point index_sort contains the index into index of the # closest max_num_neighbors_threshold neighbors per atom # Create a mask to remove all pairs not in index_sort mask_num_neighbors = torch.zeros(len(index), device=device, dtype=bool) mask_num_neighbors.index_fill_(0, index_sort, True) return mask_num_neighbors, num_neighbors_image