def softmax(src: Tensor, index: Tensor, ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None) -> Tensor: r"""Computes a sparsely evaluated softmax. Given a value tensor :attr:`src`, this function first groups the values along the first dimension based on the indices specified in :attr:`index`, and then proceeds to compute the softmax individually for each group. Args: src (Tensor): The source tensor. index (LongTensor): The indices of elements for applying the softmax. ptr (LongTensor, optional): If given, computes the softmax based on sorted inputs in CSR representation. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) :rtype: :class:`Tensor` """ if ptr is None: N = maybe_num_nodes(index, num_nodes) out = src - scatter(src, index, dim=0, dim_size=N, reduce='max')[index] out = out.exp() out_sum = scatter(out, index, dim=0, dim_size=N, reduce='sum')[index] return out / (out_sum + 1e-16) else: out = src - gather_csr(segment_csr(src, ptr, reduce='max'), ptr) out = out.exp() out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr) return out / (out_sum + 1e-16)
def softmax_csr(x, key): x = x - torch_scatter.gather_csr( torch_scatter.segment_csr(x, key, reduce="max"), key) x = x.exp() x_sum = torch_scatter.gather_csr( torch_scatter.segment_csr(x, key, reduce="sum"), key) return x / (x_sum + 1e-16)
def test_segment_out(test, reduce, dtype, device): src = tensor(test['src'], dtype, device) index = tensor(test['index'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device) expected = tensor(test[reduce], dtype, device) size = list(src.size()) size[indptr.dim() - 1] = indptr.size(-1) - 1 out = src.new_full(size, -2) segment_csr(src, indptr, out, reduce=reduce) assert torch.all(out == expected) out.fill_(-2) segment_coo(src, index, out, reduce=reduce) if reduce == 'sum': expected = expected - 2 elif reduce == 'mean': expected = out # We can not really test this here. elif reduce == 'min': expected = expected.fill_(-2) elif reduce == 'max': expected[expected == 0] = -2 else: raise ValueError assert torch.all(out == expected)
def forward(self, x, point_key=None, point_edges=None, point_dirs=None, **kwargs): x = self.cat_dirs(x, point_dirs) if self.net is not None: x = self.net(x, point_edges) if self.avg_mode == "mean": x = torch_scatter.segment_csr(x, point_key, reduce="mean") elif self.avg_mode == "dirs": with torch.no_grad(): weight = (point_dirs[:, :3] * point_dirs[:, 3:]).sum(dim=1) weight = torch.clamp(weight, 0.01, 1) weight_sum = torch_scatter.segment_csr(weight, point_key, reduce="sum") weight /= torch_scatter.gather_csr(weight_sum, point_key) x = weight.view(-1, 1) * x x = torch_scatter.segment_csr(x, point_key, reduce="sum") elif self.avg_mode == "softmax": weight = self.weight_lin(x) weight = softmax_csr(weight, point_key) x = self.feat_lin(x) x = weight * x x = torch_scatter.segment_csr(x, point_key, reduce="sum") else: raise Exception("invalid avg_mode") return x
def dipole_forward(self, h, z, pos, batch): # 加入偶极矩 # Get center of mass. mass = self.atomic_mass[z].view(-1, 1) c = segment_csr(mass * pos, get_ptr(batch)) / segment_csr( mass, get_ptr(batch)) h = h * (pos - c[batch]) return h
def correctness(dataset): group, name = dataset mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long) row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long) dim_size = rowptr.size(0) - 1 for size in sizes: try: x = torch.randn((row.size(0), size), device=args.device) x = x.squeeze(-1) if size == 1 else x out1 = scatter_add(x, row, dim=0, dim_size=dim_size) out2 = segment_coo(x, row, dim_size=dim_size, reduce='add') out3 = segment_csr(x, rowptr, reduce='add') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) out1 = scatter_mean(x, row, dim=0, dim_size=dim_size) out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean') out3 = segment_csr(x, rowptr, reduce='mean') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) x = x.abs_().mul_(-1) out1, _ = scatter_min(x, row, 0, torch.zeros_like(out1)) out2, _ = segment_coo(x, row, reduce='min') out3, _ = segment_csr(x, rowptr, reduce='min') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) x = x.abs_() out1, _ = scatter_max(x, row, 0, torch.zeros_like(out1)) out2, _ = segment_coo(x, row, reduce='max') out3, _ = segment_csr(x, rowptr, reduce='max') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) except RuntimeError as e: if 'out of memory' not in str(e): raise RuntimeError(e) torch.cuda.empty_cache()
def test_non_contiguous_segment(test, reduce, dtype, device): src = tensor(test['src'], dtype, device) index = tensor(test['index'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device) expected = tensor(test[reduce], dtype, device) if src.dim() > 1: src = src.transpose(0, 1).contiguous().transpose(0, 1) if index.dim() > 1: index = index.transpose(0, 1).contiguous().transpose(0, 1) if indptr.dim() > 1: indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1) out = segment_coo(src, index, reduce=reduce) if isinstance(out, tuple): out, arg_out = out arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) assert torch.all(arg_out == arg_expected) assert torch.all(out == expected) out = segment_csr(src, indptr, reduce=reduce) if isinstance(out, tuple): out, arg_out = out arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) assert torch.all(arg_out == arg_expected) assert torch.all(out == expected)
def aggregate_boundary(self, inputs: Tensor, agg_boundary_index: Tensor, boundary_ptr: Optional[Tensor] = None, boundary_dim_size: Optional[int] = None) -> Tensor: r"""Aggregates messages from the boundary cells. Takes in the output of message computation as first argument and any argument which was initially passed to :meth:`propagate`. By default, this function will delegate its call to scatter functions that support "add", "mean" and "max" operations as specified in :meth:`__init__` by the :obj:`aggr` argument. """ # import pdb; pdb.set_trace() if boundary_ptr is not None: down_ptr = expand_left(boundary_ptr, dim=self.node_dim, dims=inputs.dim()) return segment_csr(inputs, down_ptr, reduce=self.aggr_boundary) else: return scatter(inputs, agg_boundary_index, dim=self.node_dim, dim_size=boundary_dim_size, reduce=self.aggr_boundary)
def coalesce(self, reduce: str = "add"): idx = self._col.new_full((self._col.numel() + 1, ), -1) idx[1:] = self._sparse_sizes[1] * self.row() + self._col mask = idx[1:] > idx[:-1] if mask.all(): # Skip if indices are already coalesced. return self row = self.row()[mask] col = self._col[mask] value = self._value if value is not None: ptr = mask.nonzero().flatten() ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))]) value = segment_csr(value, ptr, reduce=reduce) value = value[0] if isinstance(value, tuple) else value return SparseStorage(row=row, rowptr=None, col=col, value=value, sparse_sizes=self._sparse_sizes, rowcount=None, colptr=None, colcount=None, csr2csc=None, csc2csr=None, is_sorted=True)
def forward( self, x, ptx, bs, height, width, point_key=None, point_edges=None, point_src_dirs=None, point_tgt_dirs=None, pixel_tgt_idx=None, **kwargs, ): # eval network point_tgt_dirs = point_tgt_dirs[pixel_tgt_idx.long()] point_tgt_dirs = torch_scatter.gather_csr(point_tgt_dirs, point_key) ptx = torch.cat((ptx, point_src_dirs, point_tgt_dirs), dim=1) ptx = self._net_forward(ptx, point_key, point_edges) ptx = torch_scatter.segment_csr(ptx, point_key, reduce=self.aggr) # map list to map ptx, mask = ext.mytorch.list_to_map(ptx, pixel_tgt_idx, bs, height, width) x = torch.where(mask > 0, ptx, x) return x, ptx
def reduction(src: SparseTensor, dim: Optional[int] = None, reduce: str = 'sum') -> torch.Tensor: value = src.storage.value() if dim is None: if value is not None: if reduce == 'sum' or reduce == 'add': return value.sum() elif reduce == 'mean': return value.mean() elif reduce == 'min': return value.min() elif reduce == 'max': return value.max() else: raise ValueError else: if reduce == 'sum' or reduce == 'add': return torch.tensor(src.nnz(), dtype=src.dtype(), device=src.device()) elif reduce == 'mean' or reduce == 'min' or reduce == 'max': return torch.tensor(1, dtype=src.dtype(), device=src.device()) else: raise ValueError else: if dim < 0: dim = src.dim() + dim if dim == 0 and value is not None: col = src.storage.col() return scatter(value, col, dim=0, dim_size=src.size(0)) elif dim == 0 and value is None: if reduce == 'sum' or reduce == 'add': return src.storage.colcount().to(src.dtype()) elif reduce == 'mean' or reduce == 'min' or reduce == 'max': return torch.ones(src.size(1), dtype=src.dtype()) else: raise ValueError elif dim == 1 and value is not None: return segment_csr(value, src.storage.rowptr(), None, reduce) elif dim == 1 and value is None: if reduce == 'sum' or reduce == 'add': return src.storage.rowcount().to(src.dtype()) elif reduce == 'mean' or reduce == 'min' or reduce == 'max': return torch.ones(src.size(0), dtype=src.dtype()) else: raise ValueError elif dim > 1 and value is not None: if reduce == 'sum' or reduce == 'add': return value.sum(dim=dim - 1) elif reduce == 'mean': return value.mean(dim=dim - 1) elif reduce == 'min': return value.min(dim=dim - 1)[0] elif reduce == 'max': return value.max(dim=dim - 1)[0] else: raise ValueError else: raise ValueError
def forward( self, x, ptx, bs, height, width, point_key=None, point_edges=None, pixel_tgt_idx=None, **kwargs, ): # cat ptx, x to feat assert bs == 1 xch = x.shape[1] x_old = x x = x.permute(0, 2, 3, 1).view(height * width, xch) x = x[pixel_tgt_idx.long()] x = torch_scatter.gather_csr(x, point_key) x = torch.cat((ptx, x), dim=1) # eval network x = self._net_forward(x, point_key, point_edges) ptx = x x = torch_scatter.segment_csr(x, point_key, reduce=self.aggr) # map list to map x, mask = ext.mytorch.list_to_map(x, pixel_tgt_idx, bs, height, width) x = torch.where(mask > 0, x, x_old) return x, ptx
def test_zero_elements(reduce, dtype, device): x = torch.randn(0, 0, 0, 16, dtype=dtype, device=device, requires_grad=True) index = tensor([], torch.long, device) indptr = tensor([], torch.long, device) out = scatter(x, index, dim=0, dim_size=0, reduce=reduce) out.backward(torch.randn_like(out)) assert out.size() == (0, 0, 0, 16) out = segment_coo(x, index, dim_size=0, reduce=reduce) out.backward(torch.randn_like(out)) assert out.size() == (0, 0, 0, 16) out = gather_coo(x, index) out.backward(torch.randn_like(out)) assert out.size() == (0, 0, 0, 16) out = segment_csr(x, indptr, reduce=reduce) out.backward(torch.randn_like(out)) assert out.size() == (0, 0, 0, 16) out = gather_csr(x, indptr) out.backward(torch.randn_like(out)) assert out.size() == (0, 0, 0, 16)
def to_symmetric(self, reduce: str = "sum"): N = max(self.size(0), self.size(1)) row, col, value = self.coo() idx = col.new_full((2 * col.numel() + 1, ), -1) idx[1:row.numel() + 1] = row idx[row.numel() + 1:] = col idx[1:] *= N idx[1:row.numel() + 1] += col idx[row.numel() + 1:] += row idx, perm = idx.sort() mask = idx[1:] > idx[:-1] perm = perm[1:].sub_(1) idx = perm[mask] if value is not None: ptr = mask.nonzero().flatten() ptr = torch.cat([ptr, ptr.new_full((1, ), perm.size(0))]) value = torch.cat([value, value])[perm] value = segment_csr(value, ptr, reduce=reduce) new_row = torch.cat([row, col], dim=0, out=perm)[idx] new_col = torch.cat([col, row], dim=0, out=perm)[idx] out = SparseTensor(row=new_row, rowptr=None, col=new_col, value=value, sparse_sizes=(N, N), is_sorted=True) return out
def forward( self, x, ptx, bs, height, width, point_key=None, point_edges=None, pixel_tgt_idx=None, **kwargs, ): # cat ptx, x to feat assert bs == 1 xch = x.shape[1] feat = x.permute(0, 2, 3, 1).view(height * width, xch) feat = feat[pixel_tgt_idx.long()] feat = torch_scatter.gather_csr(feat, point_key) feat = torch.cat((ptx, feat), dim=1) # eval network feat = self._net_forward(feat, point_key, point_edges) weight = softmax_csr(self.sm_out(feat), point_key) if self.feat_out: ptx = self.feat_out(feat) feat = weight * ptx feat = torch_scatter.segment_csr(feat, point_key, reduce="sum") # map feat to x feat, mask = ext.mytorch.list_to_map(feat, pixel_tgt_idx, bs, height, width) x = torch.where(mask > 0, feat, x) return x, ptx
def forward(self, h, batch): h = self.lin1(h) h = self.s1(h) h = segment_csr(h, get_ptr(batch), reduce=self.readout) h = self.lin2(h) h = self.s2(h) h = self.lin3(h) return h
def forward(self, h, batch): if len(self.modules_list_1) > 0: h = self.modules_list_1(h) h = segment_csr(h, get_ptr(batch), reduce=self.readout) if len(self.modules_list_2) > 0: h = self.modules_list_2(h) if self.last_layer is not None: h = self.last_layer(h) return h
def forward(self, ptx, bs, height, width, **kwargs): # point features point_key = kwargs["point_key"] pixel_tgt_idx = kwargs["pixel_tgt_idx"] # average per point if self.point_avg_mode == "avg": x = torch_scatter.segment_csr(ptx, point_key, reduce="mean") elif self.point_avg_mode == "diravg": with torch.no_grad(): point_tgt_dirs = kwargs["point_tgt_dirs"][pixel_tgt_idx.long()] point_tgt_dirs = torch_scatter.gather_csr( point_tgt_dirs, point_key) weight = (kwargs["point_src_dirs"] * point_tgt_dirs).sum(dim=1) weight = torch.clamp(weight, 0.01, 1) weight_sum = torch_scatter.segment_csr(weight, point_key, reduce="sum") weight /= torch_scatter.gather_csr(weight_sum, point_key) x = weight.view(-1, 1) * ptx x = torch_scatter.segment_csr(x, point_key, reduce="sum") else: raise Exception("invalid avg_mode") # project to target x, mask = ext.mytorch.list_to_map(x, pixel_tgt_idx, bs, height, width) # run refinement network for sidx in range(len(self.nets)): # process per 3D point if self.gnns is not None and sidx < len(self.gnns): gnn = self.gnns[sidx] x, ptx = gnn(x, ptx, bs, height, width, **kwargs) unet = self.nets[sidx] if self.nets_residual: x = x + unet(x) else: x = unet(x) if self.out_conv: x = self.out_conv(x) return {"out": x, "mask": mask}
def aggregate(self, inputs, index, ptr=None, dim_size=None): if ptr is not None: for _ in range(self.node_dim): ptr = ptr.unsqueeze(0) aggr_mean = segment_csr(inputs, ptr, reduce='mean') aggr_max = segment_csr(inputs, ptr, reduce='max') else: aggr_mean = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce='mean') aggr_max = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce='max') return torch.cat([aggr_mean, aggr_max], dim=-1)
def softmax( src: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None, dim: int = 0, ) -> Tensor: r"""Computes a sparsely evaluated softmax. Given a value tensor :attr:`src`, this function first groups the values along the first dimension based on the indices specified in :attr:`index`, and then proceeds to compute the softmax individually for each group. Args: src (Tensor): The source tensor. index (LongTensor, optional): The indices of elements for applying the softmax. (default: :obj:`None`) ptr (LongTensor, optional): If given, computes the softmax based on sorted inputs in CSR representation. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) dim (int, optional): The dimension in which to normalize. (default: :obj:`0`) :rtype: :class:`Tensor` """ if ptr is not None: dim = dim + src.dim() if dim < 0 else dim size = ([1] * dim) + [-1] ptr = ptr.view(size) src_max = gather_csr(segment_csr(src, ptr, reduce='max'), ptr) out = (src - src_max).exp() out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr) elif index is not None: N = maybe_num_nodes(index, num_nodes) src_max = scatter(src, index, dim, dim_size=N, reduce='max') src_max = src_max.index_select(dim, index) out = (src - src_max).exp() out_sum = scatter(out, index, dim, dim_size=N, reduce='sum') out_sum = out_sum.index_select(dim, index) else: raise NotImplementedError return out / (out_sum + 1e-16)
def reduce(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, reduce: str = 'add') -> Tensor: if ptr is not None: ptr = expand_left(ptr, dim, dims=x.dim()) return segment_csr(x, ptr, reduce=reduce) assert index is not None return scatter(x, index, dim=dim, dim_size=dim_size, reduce=reduce)
def object_selection(self, features, n_obj): index = torch.LongTensor( sum([[i] * n for i, n in enumerate(n_obj)], [])) selector = features[:, :1] * self.scale_selector selector = scatter_softmax(selector, index.to(self.device), dim=0) indptr = torch.LongTensor([0] + np.cumsum(n_obj).tolist()).to( self.device) selected_features = segment_csr(selector * features, indptr, reduce="sum") image = selected_features.narrow(1, 1, self.nc_out) depth = None if self.depth: depth = selected_features.narrow(1, 0, 1) return image, depth
def test_forward(test, reduce, dtype): device = torch.device('cuda:1') src = tensor(test['src'], dtype, device) index = tensor(test['index'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device) dim = test['dim'] expected = tensor(test[reduce], dtype, device) out = torch_scatter.scatter(src, index, dim, reduce=reduce) assert torch.all(out == expected) out = torch_scatter.segment_coo(src, index, reduce=reduce) assert torch.all(out == expected) out = torch_scatter.segment_csr(src, indptr, reduce=reduce) assert torch.all(out == expected)
def softmax(src: Tensor, index: Optional[Tensor], ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None) -> Tensor: out = src if src.numel() > 0: out = out - src.max() out = out.exp() if ptr is not None: out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr) elif index is not None: N = maybe_num_nodes(index, num_nodes) out_sum = scatter(out, index, dim=0, dim_size=N, reduce='sum')[index] else: raise NotImplementedError return out / (out_sum + 1e-16)
def forward(self, h, edge_index, edge_weight, edge_attr, data=None): state_attr = data.state_attr for lin1, interaction in zip(self.lin_list1, self.interactions): state_attr = state_attr[data.batch] hs = torch.cat((state_attr, h), dim=1) h = lin1(hs) h = h + interaction( h, edge_index, edge_weight, edge_attr, data=data) r = segment_csr(h, get_ptr(data.batch), reduce="mean") state_attr = torch.sum(r, dim=1, keepdim=True) state_attr = state_attr.expand(data.state_attr.shape) h = F.relu(self.out(h)) return h
def forward(self, x, batch): """""" batch_size = batch.max().item() + 1 h = (x.new_zeros((self.num_layers, batch_size, self.in_channels)), x.new_zeros((self.num_layers, batch_size, self.in_channels))) q_star = x.new_zeros(batch_size, self.out_channels) for i in range(self.processing_steps): q, h = self.lstm(q_star.unsqueeze(0), h) q = q.view(batch_size, self.in_channels) e = (x * q[batch]).sum(dim=-1, keepdim=True) a = softmax(e, batch, num_nodes=batch_size) r = segment_csr((a * x), get_ptr(batch)) q_star = torch.cat([q, r], dim=-1) return q_star
def test_forward(test, reduce, dtype, device): src = tensor(test['src'], dtype, device) index = tensor(test['index'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device) expected = tensor(test[reduce], dtype, device) out = segment_coo(src, index, reduce=reduce) if isinstance(out, tuple): out, arg_out = out arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) assert torch.all(arg_out == arg_expected) assert torch.all(out == expected) out = segment_csr(src, indptr, reduce=reduce) if isinstance(out, tuple): out, arg_out = out arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) assert torch.all(arg_out == arg_expected) assert torch.all(out == expected)
def aggregate(self, inputs: torch.Tensor, index: torch.Tensor, ptr: Optional[torch.Tensor] = None, dim_size: Optional[int] = None) -> torch.Tensor: r"""Aggregates messages from neighbors as :math:`\square_{j \in \mathcal{N}(i)}`. Takes in the output of message computation as first argument and any argument which was initially passed to :meth:`propagate`. By default, this function will delegate its call to scatter functions that support "add", "mean" and "max" operations as specified in :meth:`__init__` by the :obj:`aggr` argument. """ if ptr is not None: ptr = unsqueeze(ptr, dim=0, length=self.node_dim) return segment_csr(inputs, ptr, reduce=self.aggr) else: return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)
def aggregate(self, inputs: Tensor, index: Tensor, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, aggr: Optional[str] = None) -> Tensor: r"""Aggregates messages from neighbors as :math:`\square_{j \in \mathcal{N}(i)}`. Takes in the output of message computation as first argument and any argument which was initially passed to :meth:`propagate`. By default, this function will delegate its call to scatter functions that support "add", "mean", "min", "max" and "mul" operations as specified in :meth:`__init__` by the :obj:`aggr` argument. """ aggr = self.aggr if aggr is None else aggr assert aggr is not None if ptr is not None: ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim()) return segment_csr(inputs, ptr, reduce=aggr) else: return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce=aggr)
def aggregate_up(self, inputs: Tensor, agg_up_index: Tensor, up_ptr: Optional[Tensor] = None, up_dim_size: Optional[int] = None) -> Tensor: r"""Aggregates messages from upper adjacent cells. Takes in the output of message computation as first argument and any argument which was initially passed to :meth:`propagate`. By default, this function will delegate its call to scatter functions that support "add", "mean" and "max" operations as specified in :meth:`__init__` by the :obj:`aggr` argument. """ if up_ptr is not None: up_ptr = expand_left(up_ptr, dim=self.node_dim, dims=inputs.dim()) return segment_csr(inputs, up_ptr, reduce=self.aggr_up) else: return scatter(inputs, agg_up_index, dim=self.node_dim, dim_size=up_dim_size, reduce=self.aggr_up)