Пример #1
0
def softmax_csr(x, key):
    x = x - torch_scatter.gather_csr(
        torch_scatter.segment_csr(x, key, reduce="max"), key)
    x = x.exp()
    x_sum = torch_scatter.gather_csr(
        torch_scatter.segment_csr(x, key, reduce="sum"), key)
    return x / (x_sum + 1e-16)
Пример #2
0
def softmax(src: Tensor,
            index: Tensor,
            ptr: Optional[Tensor] = None,
            num_nodes: Optional[int] = None) -> Tensor:
    r"""Computes a sparsely evaluated softmax.
    Given a value tensor :attr:`src`, this function first groups the values
    along the first dimension based on the indices specified in :attr:`index`,
    and then proceeds to compute the softmax individually for each group.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements for applying the softmax.
        ptr (LongTensor, optional): If given, computes the softmax based on
            sorted inputs in CSR representation. (default: :obj:`None`)
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)

    :rtype: :class:`Tensor`
    """

    if ptr is None:
        N = maybe_num_nodes(index, num_nodes)
        out = src - scatter(src, index, dim=0, dim_size=N, reduce='max')[index]
        out = out.exp()
        out_sum = scatter(out, index, dim=0, dim_size=N, reduce='sum')[index]
        return out / (out_sum + 1e-16)
    else:
        out = src - gather_csr(segment_csr(src, ptr, reduce='max'), ptr)
        out = out.exp()
        out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr)
        return out / (out_sum + 1e-16)
Пример #3
0
    def forward(
        self,
        x,
        ptx,
        bs,
        height,
        width,
        point_key=None,
        point_edges=None,
        point_src_dirs=None,
        point_tgt_dirs=None,
        pixel_tgt_idx=None,
        **kwargs,
    ):
        # eval network
        point_tgt_dirs = point_tgt_dirs[pixel_tgt_idx.long()]
        point_tgt_dirs = torch_scatter.gather_csr(point_tgt_dirs, point_key)
        ptx = torch.cat((ptx, point_src_dirs, point_tgt_dirs), dim=1)
        ptx = self._net_forward(ptx, point_key, point_edges)
        ptx = torch_scatter.segment_csr(ptx, point_key, reduce=self.aggr)

        # map list to map
        ptx, mask = ext.mytorch.list_to_map(ptx, pixel_tgt_idx, bs, height,
                                            width)
        x = torch.where(mask > 0, ptx, x)

        return x, ptx
    def forward(self,
                x,
                point_key=None,
                point_edges=None,
                point_dirs=None,
                **kwargs):
        x = self.cat_dirs(x, point_dirs)

        if self.net is not None:
            x = self.net(x, point_edges)

        if self.avg_mode == "mean":
            x = torch_scatter.segment_csr(x, point_key, reduce="mean")
        elif self.avg_mode == "dirs":
            with torch.no_grad():
                weight = (point_dirs[:, :3] * point_dirs[:, 3:]).sum(dim=1)
                weight = torch.clamp(weight, 0.01, 1)
                weight_sum = torch_scatter.segment_csr(weight,
                                                       point_key,
                                                       reduce="sum")
                weight /= torch_scatter.gather_csr(weight_sum, point_key)
            x = weight.view(-1, 1) * x
            x = torch_scatter.segment_csr(x, point_key, reduce="sum")
        elif self.avg_mode == "softmax":
            weight = self.weight_lin(x)
            weight = softmax_csr(weight, point_key)
            x = self.feat_lin(x)
            x = weight * x
            x = torch_scatter.segment_csr(x, point_key, reduce="sum")
        else:
            raise Exception("invalid avg_mode")
        return x
Пример #5
0
    def forward(
        self,
        x,
        ptx,
        bs,
        height,
        width,
        point_key=None,
        point_edges=None,
        pixel_tgt_idx=None,
        **kwargs,
    ):
        # cat ptx, x to feat
        assert bs == 1
        xch = x.shape[1]
        feat = x.permute(0, 2, 3, 1).view(height * width, xch)
        feat = feat[pixel_tgt_idx.long()]
        feat = torch_scatter.gather_csr(feat, point_key)
        feat = torch.cat((ptx, feat), dim=1)

        # eval network
        feat = self._net_forward(feat, point_key, point_edges)

        weight = softmax_csr(self.sm_out(feat), point_key)
        if self.feat_out:
            ptx = self.feat_out(feat)
        feat = weight * ptx
        feat = torch_scatter.segment_csr(feat, point_key, reduce="sum")

        # map feat to x
        feat, mask = ext.mytorch.list_to_map(feat, pixel_tgt_idx, bs, height,
                                             width)
        x = torch.where(mask > 0, feat, x)

        return x, ptx
Пример #6
0
    def forward(
        self,
        x,
        ptx,
        bs,
        height,
        width,
        point_key=None,
        point_edges=None,
        pixel_tgt_idx=None,
        **kwargs,
    ):
        # cat ptx, x to feat
        assert bs == 1
        xch = x.shape[1]
        x_old = x
        x = x.permute(0, 2, 3, 1).view(height * width, xch)
        x = x[pixel_tgt_idx.long()]
        x = torch_scatter.gather_csr(x, point_key)
        x = torch.cat((ptx, x), dim=1)

        # eval network
        x = self._net_forward(x, point_key, point_edges)
        ptx = x
        x = torch_scatter.segment_csr(x, point_key, reduce=self.aggr)

        # map list to map
        x, mask = ext.mytorch.list_to_map(x, pixel_tgt_idx, bs, height, width)
        x = torch.where(mask > 0, x, x_old)

        return x, ptx
Пример #7
0
def test_zero_elements(reduce, dtype, device):
    x = torch.randn(0,
                    0,
                    0,
                    16,
                    dtype=dtype,
                    device=device,
                    requires_grad=True)
    index = tensor([], torch.long, device)
    indptr = tensor([], torch.long, device)

    out = scatter(x, index, dim=0, dim_size=0, reduce=reduce)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)

    out = segment_coo(x, index, dim_size=0, reduce=reduce)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)

    out = gather_coo(x, index)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)

    out = segment_csr(x, indptr, reduce=reduce)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)

    out = gather_csr(x, indptr)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)
Пример #8
0
    def __collect__(self, edge_index, size, mp_type, kwargs):
        i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0)
        ij = {'_i': i, '_j': j}

        out = {}
        for arg in self.__user_args__:
            if arg[-2:] not in ij.keys():
                out[arg] = kwargs.get(arg, inspect.Parameter.empty)
            else:
                idx = ij[arg[-2:]]
                data = kwargs.get(arg[:-2], inspect.Parameter.empty)

                if data is inspect.Parameter.empty:
                    out[arg] = data
                    continue

                if isinstance(data, tuple) or isinstance(data, list):
                    assert len(data) == 2
                    self.__set_size__(size, 1 - idx, data[1 - idx])
                    data = data[idx]

                if not torch.is_tensor(data):
                    out[arg] = data
                    continue

                self.__set_size__(size, idx, data)

                if mp_type == 'edge_index':
                    out[arg] = data.index_select(self.node_dim,
                                                 edge_index[idx])
                elif mp_type == 'adj_t' and idx == 1:
                    rowptr = edge_index.storage.rowptr()
                    for _ in range(self.node_dim):
                        rowptr = rowptr.unsqueeze(0)
                    out[arg] = gather_csr(data, rowptr)
                elif mp_type == 'adj_t' and idx == 0:
                    col = edge_index.storage.col()
                    out[arg] = data.index_select(self.node_dim, col)

        size[0] = size[1] if size[0] is None else size[0]
        size[1] = size[0] if size[1] is None else size[1]

        if mp_type == 'edge_index':
            out['edge_index_j'] = edge_index[j]
            out['edge_index_i'] = edge_index[i]
            out['index'] = out['edge_index_i']
        elif mp_type == 'adj_t':
            out['adj_t'] = edge_index
            out['edge_index_i'] = edge_index.storage.row()
            out['edge_index_j'] = edge_index.storage.col()
            out['index'] = edge_index.storage.row()
            out['ptr'] = edge_index.storage.rowptr()
            out['edge_attr'] = edge_index.storage.value()

        out['size_j'] = size[j]
        out['size_i'] = size[i]
        out['dim_size'] = out['size_i']

        return out
Пример #9
0
def test_segment_out(test, dtype, device):
    src = tensor(test['src'], dtype, device)
    index = tensor(test['index'], torch.long, device)
    indptr = tensor(test['indptr'], torch.long, device)
    expected = tensor(test['expected'], dtype, device)

    size = list(src.size())
    size[index.dim() - 1] = index.size(-1)
    out = src.new_full(size, -2)

    gather_coo(src, index, out)
    assert torch.all(out == expected)

    out.fill_(-2)

    gather_csr(src, indptr, out)
    assert torch.all(out == expected)
Пример #10
0
    def forward(self, ptx, bs, height, width, **kwargs):  # point features
        point_key = kwargs["point_key"]
        pixel_tgt_idx = kwargs["pixel_tgt_idx"]

        # average per point
        if self.point_avg_mode == "avg":
            x = torch_scatter.segment_csr(ptx, point_key, reduce="mean")
        elif self.point_avg_mode == "diravg":
            with torch.no_grad():
                point_tgt_dirs = kwargs["point_tgt_dirs"][pixel_tgt_idx.long()]
                point_tgt_dirs = torch_scatter.gather_csr(
                    point_tgt_dirs, point_key)
                weight = (kwargs["point_src_dirs"] * point_tgt_dirs).sum(dim=1)
                weight = torch.clamp(weight, 0.01, 1)
                weight_sum = torch_scatter.segment_csr(weight,
                                                       point_key,
                                                       reduce="sum")
                weight /= torch_scatter.gather_csr(weight_sum, point_key)

            x = weight.view(-1, 1) * ptx
            x = torch_scatter.segment_csr(x, point_key, reduce="sum")
        else:
            raise Exception("invalid avg_mode")

        # project to target
        x, mask = ext.mytorch.list_to_map(x, pixel_tgt_idx, bs, height, width)

        # run refinement network
        for sidx in range(len(self.nets)):
            # process per 3D point
            if self.gnns is not None and sidx < len(self.gnns):
                gnn = self.gnns[sidx]
                x, ptx = gnn(x, ptx, bs, height, width, **kwargs)

            unet = self.nets[sidx]
            if self.nets_residual:
                x = x + unet(x)
            else:
                x = unet(x)

        if self.out_conv:
            x = self.out_conv(x)

        return {"out": x, "mask": mask}
Пример #11
0
def softmax(
    src: Tensor,
    index: Optional[Tensor] = None,
    ptr: Optional[Tensor] = None,
    num_nodes: Optional[int] = None,
    dim: int = 0,
) -> Tensor:
    r"""Computes a sparsely evaluated softmax.
    Given a value tensor :attr:`src`, this function first groups the values
    along the first dimension based on the indices specified in :attr:`index`,
    and then proceeds to compute the softmax individually for each group.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor, optional): The indices of elements for applying the
            softmax. (default: :obj:`None`)
        ptr (LongTensor, optional): If given, computes the softmax based on
            sorted inputs in CSR representation. (default: :obj:`None`)
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
        dim (int, optional): The dimension in which to normalize.
            (default: :obj:`0`)

    :rtype: :class:`Tensor`
    """
    if ptr is not None:
        dim = dim + src.dim() if dim < 0 else dim
        size = ([1] * dim) + [-1]
        ptr = ptr.view(size)
        src_max = gather_csr(segment_csr(src, ptr, reduce='max'), ptr)
        out = (src - src_max).exp()
        out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr)
    elif index is not None:
        N = maybe_num_nodes(index, num_nodes)
        src_max = scatter(src, index, dim, dim_size=N, reduce='max')
        src_max = src_max.index_select(dim, index)
        out = (src - src_max).exp()
        out_sum = scatter(out, index, dim, dim_size=N, reduce='sum')
        out_sum = out_sum.index_select(dim, index)
    else:
        raise NotImplementedError

    return out / (out_sum + 1e-16)
Пример #12
0
def test_forward(test, dtype, device):
    src = tensor(test['src'], dtype, device)
    index = tensor(test['index'], torch.long, device)
    indptr = tensor(test['indptr'], torch.long, device)
    expected = tensor(test['expected'], dtype, device)

    out = gather_coo(src, index)
    assert torch.all(out == expected)

    out = gather_csr(src, indptr)
    assert torch.all(out == expected)
Пример #13
0
 def __lift__(self, src, edge_index, dim):
     if isinstance(edge_index, Tensor):
         index = edge_index[dim]
         return src.index_select(self.node_dim, index)
     elif isinstance(edge_index, SparseTensor):
         if dim == 1:
             rowptr = edge_index.storage.rowptr()
             rowptr = expand_left(rowptr, dim=self.node_dim, dims=src.dim())
             return gather_csr(src, rowptr)
         elif dim == 0:
             col = edge_index.storage.col()
             return src.index_select(self.node_dim, col)
     raise ValueError
Пример #14
0
def lift_jump_index_select(self, src, edge_index, dim):
    """From MessagePassing.__lift__, jump edge_attr from adj"""
    if isinstance(edge_index, Tensor):
        index = edge_index[dim]
        return src.index_select(self.node_dim, index)
    elif isinstance(edge_index, SparseTensor):
        if dim == 1:
            rowptr = edge_index.storage.rowptr()
            rowptr = expand_left(rowptr, dim=self.node_dim, dims=src.dim())
            return gather_csr(src, rowptr)
        elif dim == 0:
            col = edge_index.storage.col()
            # return src.index_select(self.node_dim, col)
            return src[col]
    raise ValueError
Пример #15
0
def softmax(src: Tensor, index: Optional[Tensor], ptr: Optional[Tensor] = None,
            num_nodes: Optional[int] = None) -> Tensor:
    out = src
    if src.numel() > 0:
        out = out - src.max()
    out = out.exp()

    if ptr is not None:
        out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr)
    elif index is not None:
        N = maybe_num_nodes(index, num_nodes)
        out_sum = scatter(out, index, dim=0, dim_size=N, reduce='sum')[index]
    else:
        raise NotImplementedError

    return out / (out_sum + 1e-16)
Пример #16
0
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
    rowptr, col, value = src.csr()
    if other.size(0) == src.size(0) and other.size(1) == 1:  # Row-wise.
        other = gather_csr(other.squeeze(1), rowptr)
    elif other.size(0) == 1 and other.size(1) == src.size(1):  # Col-wise.
        other = other.squeeze(0)[col]
    else:
        raise ValueError(
            f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
            f'(1, {src.size(1)}, ...), but got size {other.size()}.')

    if value is not None:
        value = value.add_(other.to(value.dtype))
    else:
        value = other.add_(1)
    return src.set_value_(value, layout='coo')
Пример #17
0
def test_non_contiguous_segment(test, dtype, device):
    src = tensor(test['src'], dtype, device)
    index = tensor(test['index'], torch.long, device)
    indptr = tensor(test['indptr'], torch.long, device)
    expected = tensor(test['expected'], dtype, device)

    if src.dim() > 1:
        src = src.transpose(0, 1).contiguous().transpose(0, 1)
    if index.dim() > 1:
        index = index.transpose(0, 1).contiguous().transpose(0, 1)
    if indptr.dim() > 1:
        indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1)

    out = gather_coo(src, index)
    assert torch.all(out == expected)

    out = gather_csr(src, indptr)
    assert torch.all(out == expected)
Пример #18
0
def add(src, other):  # noqa: F811
    if isinstance(other, Tensor):
        rowptr, col, value = src.csr()
        if other.size(0) == src.size(0) and other.size(1) == 1:  # Row-wise.
            other = gather_csr(other.squeeze(1), rowptr)
        elif other.size(0) == 1 and other.size(1) == src.size(1):  # Col-wise.
            other = other.squeeze(0)[col]
        else:
            raise ValueError(
                f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
                f'(1, {src.size(1)}, ...), but got size {other.size()}.')
        if value is not None:
            value = other.to(value.dtype).add_(value)
        else:
            value = other.add_(1)
        return src.set_value(value, layout='coo')

    elif isinstance(other, SparseTensor):
        rowA, colA, valueA = src.coo()
        rowB, colB, valueB = other.coo()

        row = torch.cat([rowA, rowB], dim=0)
        col = torch.cat([colA, colB], dim=0)

        value: Optional[Tensor] = None
        if valueA is not None and valueB is not None:
            value = torch.cat([valueA, valueB], dim=0)

        M = max(src.size(0), other.size(0))
        N = max(src.size(1), other.size(1))
        sparse_sizes = (M, N)

        out = SparseTensor(row=row,
                           col=col,
                           value=value,
                           sparse_sizes=sparse_sizes)
        out = out.coalesce(reduce='sum')
        return out

    else:
        raise NotImplementedError
Пример #19
0
def correctness(dataset):
    group, name = dataset
    mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
    rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
    row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
    dim_size = rowptr.size(0) - 1

    for size in sizes[1:]:
        try:
            x = torch.randn((dim_size, size), device=args.device)
            x = x.squeeze(-1) if size == 1 else x

            out1 = x.index_select(0, row)
            out2 = gather_coo(x, row)
            out3 = gather_csr(x, rowptr)

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)
        except RuntimeError as e:
            if 'out of memory' not in str(e):
                raise RuntimeError(e)
            torch.cuda.empty_cache()
Пример #20
0
 def gat_csr(x):
     return gather_csr(x, rowptr)
Пример #21
0
def index_select(src: SparseTensor, dim: int,
                 idx: torch.Tensor) -> SparseTensor:
    dim = src.dim() + dim if dim < 0 else dim
    assert idx.dim() == 1

    if dim == 0:
        old_rowptr, col, value = src.csr()
        rowcount = src.storage.rowcount()

        rowcount = rowcount[idx]

        rowptr = col.new_zeros(idx.size(0) + 1)
        torch.cumsum(rowcount, dim=0, out=rowptr[1:])

        row = torch.arange(idx.size(0),
                           device=col.device).repeat_interleave(rowcount)

        perm = torch.arange(row.size(0), device=row.device)
        perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr)

        col = col[perm]

        if value is not None:
            value = value[perm]

        sparse_sizes = (idx.size(0), src.sparse_size(1))

        storage = SparseStorage(row=row,
                                rowptr=rowptr,
                                col=col,
                                value=value,
                                sparse_sizes=sparse_sizes,
                                rowcount=rowcount,
                                colptr=None,
                                colcount=None,
                                csr2csc=None,
                                csc2csr=None,
                                is_sorted=True)
        return src.from_storage(storage)

    elif dim == 1:
        old_colptr, row, value = src.csc()
        colcount = src.storage.colcount()

        colcount = colcount[idx]

        colptr = row.new_zeros(idx.size(0) + 1)
        torch.cumsum(colcount, dim=0, out=colptr[1:])

        col = torch.arange(idx.size(0),
                           device=row.device).repeat_interleave(colcount)

        perm = torch.arange(col.size(0), device=col.device)
        perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr)

        row = row[perm]
        csc2csr = (idx.size(0) * row + col).argsort()
        row, col = row[csc2csr], col[csc2csr]

        if value is not None:
            value = value[perm][csc2csr]

        sparse_sizes = (src.sparse_size(0), idx.size(0))

        storage = SparseStorage(row=row,
                                rowptr=None,
                                col=col,
                                value=value,
                                sparse_sizes=sparse_sizes,
                                rowcount=None,
                                colptr=colptr,
                                colcount=colcount,
                                csr2csc=None,
                                csc2csr=csc2csr,
                                is_sorted=True)
        return src.from_storage(storage)

    else:
        value = src.storage.value()
        if value is not None:
            return src.set_value(value.index_select(dim - 1, idx),
                                 layout='coo')
        else:
            raise ValueError