def softmax_csr(x, key): x = x - torch_scatter.gather_csr( torch_scatter.segment_csr(x, key, reduce="max"), key) x = x.exp() x_sum = torch_scatter.gather_csr( torch_scatter.segment_csr(x, key, reduce="sum"), key) return x / (x_sum + 1e-16)
def softmax(src: Tensor, index: Tensor, ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None) -> Tensor: r"""Computes a sparsely evaluated softmax. Given a value tensor :attr:`src`, this function first groups the values along the first dimension based on the indices specified in :attr:`index`, and then proceeds to compute the softmax individually for each group. Args: src (Tensor): The source tensor. index (LongTensor): The indices of elements for applying the softmax. ptr (LongTensor, optional): If given, computes the softmax based on sorted inputs in CSR representation. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) :rtype: :class:`Tensor` """ if ptr is None: N = maybe_num_nodes(index, num_nodes) out = src - scatter(src, index, dim=0, dim_size=N, reduce='max')[index] out = out.exp() out_sum = scatter(out, index, dim=0, dim_size=N, reduce='sum')[index] return out / (out_sum + 1e-16) else: out = src - gather_csr(segment_csr(src, ptr, reduce='max'), ptr) out = out.exp() out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr) return out / (out_sum + 1e-16)
def forward( self, x, ptx, bs, height, width, point_key=None, point_edges=None, point_src_dirs=None, point_tgt_dirs=None, pixel_tgt_idx=None, **kwargs, ): # eval network point_tgt_dirs = point_tgt_dirs[pixel_tgt_idx.long()] point_tgt_dirs = torch_scatter.gather_csr(point_tgt_dirs, point_key) ptx = torch.cat((ptx, point_src_dirs, point_tgt_dirs), dim=1) ptx = self._net_forward(ptx, point_key, point_edges) ptx = torch_scatter.segment_csr(ptx, point_key, reduce=self.aggr) # map list to map ptx, mask = ext.mytorch.list_to_map(ptx, pixel_tgt_idx, bs, height, width) x = torch.where(mask > 0, ptx, x) return x, ptx
def forward(self, x, point_key=None, point_edges=None, point_dirs=None, **kwargs): x = self.cat_dirs(x, point_dirs) if self.net is not None: x = self.net(x, point_edges) if self.avg_mode == "mean": x = torch_scatter.segment_csr(x, point_key, reduce="mean") elif self.avg_mode == "dirs": with torch.no_grad(): weight = (point_dirs[:, :3] * point_dirs[:, 3:]).sum(dim=1) weight = torch.clamp(weight, 0.01, 1) weight_sum = torch_scatter.segment_csr(weight, point_key, reduce="sum") weight /= torch_scatter.gather_csr(weight_sum, point_key) x = weight.view(-1, 1) * x x = torch_scatter.segment_csr(x, point_key, reduce="sum") elif self.avg_mode == "softmax": weight = self.weight_lin(x) weight = softmax_csr(weight, point_key) x = self.feat_lin(x) x = weight * x x = torch_scatter.segment_csr(x, point_key, reduce="sum") else: raise Exception("invalid avg_mode") return x
def forward( self, x, ptx, bs, height, width, point_key=None, point_edges=None, pixel_tgt_idx=None, **kwargs, ): # cat ptx, x to feat assert bs == 1 xch = x.shape[1] feat = x.permute(0, 2, 3, 1).view(height * width, xch) feat = feat[pixel_tgt_idx.long()] feat = torch_scatter.gather_csr(feat, point_key) feat = torch.cat((ptx, feat), dim=1) # eval network feat = self._net_forward(feat, point_key, point_edges) weight = softmax_csr(self.sm_out(feat), point_key) if self.feat_out: ptx = self.feat_out(feat) feat = weight * ptx feat = torch_scatter.segment_csr(feat, point_key, reduce="sum") # map feat to x feat, mask = ext.mytorch.list_to_map(feat, pixel_tgt_idx, bs, height, width) x = torch.where(mask > 0, feat, x) return x, ptx
def forward( self, x, ptx, bs, height, width, point_key=None, point_edges=None, pixel_tgt_idx=None, **kwargs, ): # cat ptx, x to feat assert bs == 1 xch = x.shape[1] x_old = x x = x.permute(0, 2, 3, 1).view(height * width, xch) x = x[pixel_tgt_idx.long()] x = torch_scatter.gather_csr(x, point_key) x = torch.cat((ptx, x), dim=1) # eval network x = self._net_forward(x, point_key, point_edges) ptx = x x = torch_scatter.segment_csr(x, point_key, reduce=self.aggr) # map list to map x, mask = ext.mytorch.list_to_map(x, pixel_tgt_idx, bs, height, width) x = torch.where(mask > 0, x, x_old) return x, ptx
def test_zero_elements(reduce, dtype, device): x = torch.randn(0, 0, 0, 16, dtype=dtype, device=device, requires_grad=True) index = tensor([], torch.long, device) indptr = tensor([], torch.long, device) out = scatter(x, index, dim=0, dim_size=0, reduce=reduce) out.backward(torch.randn_like(out)) assert out.size() == (0, 0, 0, 16) out = segment_coo(x, index, dim_size=0, reduce=reduce) out.backward(torch.randn_like(out)) assert out.size() == (0, 0, 0, 16) out = gather_coo(x, index) out.backward(torch.randn_like(out)) assert out.size() == (0, 0, 0, 16) out = segment_csr(x, indptr, reduce=reduce) out.backward(torch.randn_like(out)) assert out.size() == (0, 0, 0, 16) out = gather_csr(x, indptr) out.backward(torch.randn_like(out)) assert out.size() == (0, 0, 0, 16)
def __collect__(self, edge_index, size, mp_type, kwargs): i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0) ij = {'_i': i, '_j': j} out = {} for arg in self.__user_args__: if arg[-2:] not in ij.keys(): out[arg] = kwargs.get(arg, inspect.Parameter.empty) else: idx = ij[arg[-2:]] data = kwargs.get(arg[:-2], inspect.Parameter.empty) if data is inspect.Parameter.empty: out[arg] = data continue if isinstance(data, tuple) or isinstance(data, list): assert len(data) == 2 self.__set_size__(size, 1 - idx, data[1 - idx]) data = data[idx] if not torch.is_tensor(data): out[arg] = data continue self.__set_size__(size, idx, data) if mp_type == 'edge_index': out[arg] = data.index_select(self.node_dim, edge_index[idx]) elif mp_type == 'adj_t' and idx == 1: rowptr = edge_index.storage.rowptr() for _ in range(self.node_dim): rowptr = rowptr.unsqueeze(0) out[arg] = gather_csr(data, rowptr) elif mp_type == 'adj_t' and idx == 0: col = edge_index.storage.col() out[arg] = data.index_select(self.node_dim, col) size[0] = size[1] if size[0] is None else size[0] size[1] = size[0] if size[1] is None else size[1] if mp_type == 'edge_index': out['edge_index_j'] = edge_index[j] out['edge_index_i'] = edge_index[i] out['index'] = out['edge_index_i'] elif mp_type == 'adj_t': out['adj_t'] = edge_index out['edge_index_i'] = edge_index.storage.row() out['edge_index_j'] = edge_index.storage.col() out['index'] = edge_index.storage.row() out['ptr'] = edge_index.storage.rowptr() out['edge_attr'] = edge_index.storage.value() out['size_j'] = size[j] out['size_i'] = size[i] out['dim_size'] = out['size_i'] return out
def test_segment_out(test, dtype, device): src = tensor(test['src'], dtype, device) index = tensor(test['index'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device) expected = tensor(test['expected'], dtype, device) size = list(src.size()) size[index.dim() - 1] = index.size(-1) out = src.new_full(size, -2) gather_coo(src, index, out) assert torch.all(out == expected) out.fill_(-2) gather_csr(src, indptr, out) assert torch.all(out == expected)
def forward(self, ptx, bs, height, width, **kwargs): # point features point_key = kwargs["point_key"] pixel_tgt_idx = kwargs["pixel_tgt_idx"] # average per point if self.point_avg_mode == "avg": x = torch_scatter.segment_csr(ptx, point_key, reduce="mean") elif self.point_avg_mode == "diravg": with torch.no_grad(): point_tgt_dirs = kwargs["point_tgt_dirs"][pixel_tgt_idx.long()] point_tgt_dirs = torch_scatter.gather_csr( point_tgt_dirs, point_key) weight = (kwargs["point_src_dirs"] * point_tgt_dirs).sum(dim=1) weight = torch.clamp(weight, 0.01, 1) weight_sum = torch_scatter.segment_csr(weight, point_key, reduce="sum") weight /= torch_scatter.gather_csr(weight_sum, point_key) x = weight.view(-1, 1) * ptx x = torch_scatter.segment_csr(x, point_key, reduce="sum") else: raise Exception("invalid avg_mode") # project to target x, mask = ext.mytorch.list_to_map(x, pixel_tgt_idx, bs, height, width) # run refinement network for sidx in range(len(self.nets)): # process per 3D point if self.gnns is not None and sidx < len(self.gnns): gnn = self.gnns[sidx] x, ptx = gnn(x, ptx, bs, height, width, **kwargs) unet = self.nets[sidx] if self.nets_residual: x = x + unet(x) else: x = unet(x) if self.out_conv: x = self.out_conv(x) return {"out": x, "mask": mask}
def softmax( src: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None, dim: int = 0, ) -> Tensor: r"""Computes a sparsely evaluated softmax. Given a value tensor :attr:`src`, this function first groups the values along the first dimension based on the indices specified in :attr:`index`, and then proceeds to compute the softmax individually for each group. Args: src (Tensor): The source tensor. index (LongTensor, optional): The indices of elements for applying the softmax. (default: :obj:`None`) ptr (LongTensor, optional): If given, computes the softmax based on sorted inputs in CSR representation. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) dim (int, optional): The dimension in which to normalize. (default: :obj:`0`) :rtype: :class:`Tensor` """ if ptr is not None: dim = dim + src.dim() if dim < 0 else dim size = ([1] * dim) + [-1] ptr = ptr.view(size) src_max = gather_csr(segment_csr(src, ptr, reduce='max'), ptr) out = (src - src_max).exp() out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr) elif index is not None: N = maybe_num_nodes(index, num_nodes) src_max = scatter(src, index, dim, dim_size=N, reduce='max') src_max = src_max.index_select(dim, index) out = (src - src_max).exp() out_sum = scatter(out, index, dim, dim_size=N, reduce='sum') out_sum = out_sum.index_select(dim, index) else: raise NotImplementedError return out / (out_sum + 1e-16)
def test_forward(test, dtype, device): src = tensor(test['src'], dtype, device) index = tensor(test['index'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device) expected = tensor(test['expected'], dtype, device) out = gather_coo(src, index) assert torch.all(out == expected) out = gather_csr(src, indptr) assert torch.all(out == expected)
def __lift__(self, src, edge_index, dim): if isinstance(edge_index, Tensor): index = edge_index[dim] return src.index_select(self.node_dim, index) elif isinstance(edge_index, SparseTensor): if dim == 1: rowptr = edge_index.storage.rowptr() rowptr = expand_left(rowptr, dim=self.node_dim, dims=src.dim()) return gather_csr(src, rowptr) elif dim == 0: col = edge_index.storage.col() return src.index_select(self.node_dim, col) raise ValueError
def lift_jump_index_select(self, src, edge_index, dim): """From MessagePassing.__lift__, jump edge_attr from adj""" if isinstance(edge_index, Tensor): index = edge_index[dim] return src.index_select(self.node_dim, index) elif isinstance(edge_index, SparseTensor): if dim == 1: rowptr = edge_index.storage.rowptr() rowptr = expand_left(rowptr, dim=self.node_dim, dims=src.dim()) return gather_csr(src, rowptr) elif dim == 0: col = edge_index.storage.col() # return src.index_select(self.node_dim, col) return src[col] raise ValueError
def softmax(src: Tensor, index: Optional[Tensor], ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None) -> Tensor: out = src if src.numel() > 0: out = out - src.max() out = out.exp() if ptr is not None: out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr) elif index is not None: N = maybe_num_nodes(index, num_nodes) out_sum = scatter(out, index, dim=0, dim_size=N, reduce='sum')[index] else: raise NotImplementedError return out / (out_sum + 1e-16)
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor: rowptr, col, value = src.csr() if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise. other = gather_csr(other.squeeze(1), rowptr) elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise. other = other.squeeze(0)[col] else: raise ValueError( f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or ' f'(1, {src.size(1)}, ...), but got size {other.size()}.') if value is not None: value = value.add_(other.to(value.dtype)) else: value = other.add_(1) return src.set_value_(value, layout='coo')
def test_non_contiguous_segment(test, dtype, device): src = tensor(test['src'], dtype, device) index = tensor(test['index'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device) expected = tensor(test['expected'], dtype, device) if src.dim() > 1: src = src.transpose(0, 1).contiguous().transpose(0, 1) if index.dim() > 1: index = index.transpose(0, 1).contiguous().transpose(0, 1) if indptr.dim() > 1: indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1) out = gather_coo(src, index) assert torch.all(out == expected) out = gather_csr(src, indptr) assert torch.all(out == expected)
def add(src, other): # noqa: F811 if isinstance(other, Tensor): rowptr, col, value = src.csr() if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise. other = gather_csr(other.squeeze(1), rowptr) elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise. other = other.squeeze(0)[col] else: raise ValueError( f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or ' f'(1, {src.size(1)}, ...), but got size {other.size()}.') if value is not None: value = other.to(value.dtype).add_(value) else: value = other.add_(1) return src.set_value(value, layout='coo') elif isinstance(other, SparseTensor): rowA, colA, valueA = src.coo() rowB, colB, valueB = other.coo() row = torch.cat([rowA, rowB], dim=0) col = torch.cat([colA, colB], dim=0) value: Optional[Tensor] = None if valueA is not None and valueB is not None: value = torch.cat([valueA, valueB], dim=0) M = max(src.size(0), other.size(0)) N = max(src.size(1), other.size(1)) sparse_sizes = (M, N) out = SparseTensor(row=row, col=col, value=value, sparse_sizes=sparse_sizes) out = out.coalesce(reduce='sum') return out else: raise NotImplementedError
def correctness(dataset): group, name = dataset mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long) row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long) dim_size = rowptr.size(0) - 1 for size in sizes[1:]: try: x = torch.randn((dim_size, size), device=args.device) x = x.squeeze(-1) if size == 1 else x out1 = x.index_select(0, row) out2 = gather_coo(x, row) out3 = gather_csr(x, rowptr) assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) except RuntimeError as e: if 'out of memory' not in str(e): raise RuntimeError(e) torch.cuda.empty_cache()
def gat_csr(x): return gather_csr(x, rowptr)
def index_select(src: SparseTensor, dim: int, idx: torch.Tensor) -> SparseTensor: dim = src.dim() + dim if dim < 0 else dim assert idx.dim() == 1 if dim == 0: old_rowptr, col, value = src.csr() rowcount = src.storage.rowcount() rowcount = rowcount[idx] rowptr = col.new_zeros(idx.size(0) + 1) torch.cumsum(rowcount, dim=0, out=rowptr[1:]) row = torch.arange(idx.size(0), device=col.device).repeat_interleave(rowcount) perm = torch.arange(row.size(0), device=row.device) perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr) col = col[perm] if value is not None: value = value[perm] sparse_sizes = (idx.size(0), src.sparse_size(1)) storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=rowcount, colptr=None, colcount=None, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage) elif dim == 1: old_colptr, row, value = src.csc() colcount = src.storage.colcount() colcount = colcount[idx] colptr = row.new_zeros(idx.size(0) + 1) torch.cumsum(colcount, dim=0, out=colptr[1:]) col = torch.arange(idx.size(0), device=row.device).repeat_interleave(colcount) perm = torch.arange(col.size(0), device=col.device) perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr) row = row[perm] csc2csr = (idx.size(0) * row + col).argsort() row, col = row[csc2csr], col[csc2csr] if value is not None: value = value[perm][csc2csr] sparse_sizes = (src.sparse_size(0), idx.size(0)) storage = SparseStorage(row=row, rowptr=None, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=None, colptr=colptr, colcount=colcount, csr2csc=None, csc2csr=csc2csr, is_sorted=True) return src.from_storage(storage) else: value = src.storage.value() if value is not None: return src.set_value(value.index_select(dim - 1, idx), layout='coo') else: raise ValueError