Esempio n. 1
0
def softmax(src: Tensor,
            index: Tensor,
            ptr: Optional[Tensor] = None,
            num_nodes: Optional[int] = None) -> Tensor:
    r"""Computes a sparsely evaluated softmax.
    Given a value tensor :attr:`src`, this function first groups the values
    along the first dimension based on the indices specified in :attr:`index`,
    and then proceeds to compute the softmax individually for each group.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements for applying the softmax.
        ptr (LongTensor, optional): If given, computes the softmax based on
            sorted inputs in CSR representation. (default: :obj:`None`)
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)

    :rtype: :class:`Tensor`
    """

    if ptr is None:
        N = maybe_num_nodes(index, num_nodes)
        out = src - scatter(src, index, dim=0, dim_size=N, reduce='max')[index]
        out = out.exp()
        out_sum = scatter(out, index, dim=0, dim_size=N, reduce='sum')[index]
        return out / (out_sum + 1e-16)
    else:
        out = src - gather_csr(segment_csr(src, ptr, reduce='max'), ptr)
        out = out.exp()
        out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr)
        return out / (out_sum + 1e-16)
Esempio n. 2
0
def softmax_csr(x, key):
    x = x - torch_scatter.gather_csr(
        torch_scatter.segment_csr(x, key, reduce="max"), key)
    x = x.exp()
    x_sum = torch_scatter.gather_csr(
        torch_scatter.segment_csr(x, key, reduce="sum"), key)
    return x / (x_sum + 1e-16)
Esempio n. 3
0
def test_segment_out(test, reduce, dtype, device):
    src = tensor(test['src'], dtype, device)
    index = tensor(test['index'], torch.long, device)
    indptr = tensor(test['indptr'], torch.long, device)
    expected = tensor(test[reduce], dtype, device)

    size = list(src.size())
    size[indptr.dim() - 1] = indptr.size(-1) - 1
    out = src.new_full(size, -2)

    segment_csr(src, indptr, out, reduce=reduce)
    assert torch.all(out == expected)

    out.fill_(-2)

    segment_coo(src, index, out, reduce=reduce)

    if reduce == 'sum':
        expected = expected - 2
    elif reduce == 'mean':
        expected = out  # We can not really test this here.
    elif reduce == 'min':
        expected = expected.fill_(-2)
    elif reduce == 'max':
        expected[expected == 0] = -2
    else:
        raise ValueError

    assert torch.all(out == expected)
    def forward(self,
                x,
                point_key=None,
                point_edges=None,
                point_dirs=None,
                **kwargs):
        x = self.cat_dirs(x, point_dirs)

        if self.net is not None:
            x = self.net(x, point_edges)

        if self.avg_mode == "mean":
            x = torch_scatter.segment_csr(x, point_key, reduce="mean")
        elif self.avg_mode == "dirs":
            with torch.no_grad():
                weight = (point_dirs[:, :3] * point_dirs[:, 3:]).sum(dim=1)
                weight = torch.clamp(weight, 0.01, 1)
                weight_sum = torch_scatter.segment_csr(weight,
                                                       point_key,
                                                       reduce="sum")
                weight /= torch_scatter.gather_csr(weight_sum, point_key)
            x = weight.view(-1, 1) * x
            x = torch_scatter.segment_csr(x, point_key, reduce="sum")
        elif self.avg_mode == "softmax":
            weight = self.weight_lin(x)
            weight = softmax_csr(weight, point_key)
            x = self.feat_lin(x)
            x = weight * x
            x = torch_scatter.segment_csr(x, point_key, reduce="sum")
        else:
            raise Exception("invalid avg_mode")
        return x
Esempio n. 5
0
    def dipole_forward(self, h, z, pos, batch):
        # 加入偶极矩
        # Get center of mass.
        mass = self.atomic_mass[z].view(-1, 1)
        c = segment_csr(mass * pos, get_ptr(batch)) / segment_csr(
            mass, get_ptr(batch))
        h = h * (pos - c[batch])

        return h
Esempio n. 6
0
def correctness(dataset):
    group, name = dataset
    mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
    rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
    row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
    dim_size = rowptr.size(0) - 1

    for size in sizes:
        try:
            x = torch.randn((row.size(0), size), device=args.device)
            x = x.squeeze(-1) if size == 1 else x

            out1 = scatter_add(x, row, dim=0, dim_size=dim_size)
            out2 = segment_coo(x, row, dim_size=dim_size, reduce='add')
            out3 = segment_csr(x, rowptr, reduce='add')

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)

            out1 = scatter_mean(x, row, dim=0, dim_size=dim_size)
            out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean')
            out3 = segment_csr(x, rowptr, reduce='mean')

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)

            x = x.abs_().mul_(-1)

            out1, _ = scatter_min(x, row, 0, torch.zeros_like(out1))
            out2, _ = segment_coo(x, row, reduce='min')
            out3, _ = segment_csr(x, rowptr, reduce='min')

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)

            x = x.abs_()

            out1, _ = scatter_max(x, row, 0, torch.zeros_like(out1))
            out2, _ = segment_coo(x, row, reduce='max')
            out3, _ = segment_csr(x, rowptr, reduce='max')

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)

        except RuntimeError as e:
            if 'out of memory' not in str(e):
                raise RuntimeError(e)
            torch.cuda.empty_cache()
Esempio n. 7
0
def test_non_contiguous_segment(test, reduce, dtype, device):
    src = tensor(test['src'], dtype, device)
    index = tensor(test['index'], torch.long, device)
    indptr = tensor(test['indptr'], torch.long, device)
    expected = tensor(test[reduce], dtype, device)

    if src.dim() > 1:
        src = src.transpose(0, 1).contiguous().transpose(0, 1)
    if index.dim() > 1:
        index = index.transpose(0, 1).contiguous().transpose(0, 1)
    if indptr.dim() > 1:
        indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1)

    out = segment_coo(src, index, reduce=reduce)
    if isinstance(out, tuple):
        out, arg_out = out
        arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
        assert torch.all(arg_out == arg_expected)
    assert torch.all(out == expected)

    out = segment_csr(src, indptr, reduce=reduce)
    if isinstance(out, tuple):
        out, arg_out = out
        arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
        assert torch.all(arg_out == arg_expected)
    assert torch.all(out == expected)
Esempio n. 8
0
    def aggregate_boundary(self,
                           inputs: Tensor,
                           agg_boundary_index: Tensor,
                           boundary_ptr: Optional[Tensor] = None,
                           boundary_dim_size: Optional[int] = None) -> Tensor:
        r"""Aggregates messages from the boundary cells.

        Takes in the output of message computation as first argument and any
        argument which was initially passed to :meth:`propagate`.

        By default, this function will delegate its call to scatter functions
        that support "add", "mean" and "max" operations as specified in
        :meth:`__init__` by the :obj:`aggr` argument.
        """
        # import pdb; pdb.set_trace()
        if boundary_ptr is not None:
            down_ptr = expand_left(boundary_ptr,
                                   dim=self.node_dim,
                                   dims=inputs.dim())
            return segment_csr(inputs, down_ptr, reduce=self.aggr_boundary)
        else:
            return scatter(inputs,
                           agg_boundary_index,
                           dim=self.node_dim,
                           dim_size=boundary_dim_size,
                           reduce=self.aggr_boundary)
Esempio n. 9
0
    def coalesce(self, reduce: str = "add"):
        idx = self._col.new_full((self._col.numel() + 1, ), -1)
        idx[1:] = self._sparse_sizes[1] * self.row() + self._col
        mask = idx[1:] > idx[:-1]

        if mask.all():  # Skip if indices are already coalesced.
            return self

        row = self.row()[mask]
        col = self._col[mask]

        value = self._value
        if value is not None:
            ptr = mask.nonzero().flatten()
            ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
            value = segment_csr(value, ptr, reduce=reduce)
            value = value[0] if isinstance(value, tuple) else value

        return SparseStorage(row=row,
                             rowptr=None,
                             col=col,
                             value=value,
                             sparse_sizes=self._sparse_sizes,
                             rowcount=None,
                             colptr=None,
                             colcount=None,
                             csr2csc=None,
                             csc2csr=None,
                             is_sorted=True)
Esempio n. 10
0
    def forward(
        self,
        x,
        ptx,
        bs,
        height,
        width,
        point_key=None,
        point_edges=None,
        point_src_dirs=None,
        point_tgt_dirs=None,
        pixel_tgt_idx=None,
        **kwargs,
    ):
        # eval network
        point_tgt_dirs = point_tgt_dirs[pixel_tgt_idx.long()]
        point_tgt_dirs = torch_scatter.gather_csr(point_tgt_dirs, point_key)
        ptx = torch.cat((ptx, point_src_dirs, point_tgt_dirs), dim=1)
        ptx = self._net_forward(ptx, point_key, point_edges)
        ptx = torch_scatter.segment_csr(ptx, point_key, reduce=self.aggr)

        # map list to map
        ptx, mask = ext.mytorch.list_to_map(ptx, pixel_tgt_idx, bs, height,
                                            width)
        x = torch.where(mask > 0, ptx, x)

        return x, ptx
Esempio n. 11
0
def reduction(src: SparseTensor, dim: Optional[int] = None,
              reduce: str = 'sum') -> torch.Tensor:
    value = src.storage.value()

    if dim is None:
        if value is not None:
            if reduce == 'sum' or reduce == 'add':
                return value.sum()
            elif reduce == 'mean':
                return value.mean()
            elif reduce == 'min':
                return value.min()
            elif reduce == 'max':
                return value.max()
            else:
                raise ValueError
        else:
            if reduce == 'sum' or reduce == 'add':
                return torch.tensor(src.nnz(), dtype=src.dtype(),
                                    device=src.device())
            elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
                return torch.tensor(1, dtype=src.dtype(), device=src.device())
            else:
                raise ValueError
    else:
        if dim < 0:
            dim = src.dim() + dim

        if dim == 0 and value is not None:
            col = src.storage.col()
            return scatter(value, col, dim=0, dim_size=src.size(0))
        elif dim == 0 and value is None:
            if reduce == 'sum' or reduce == 'add':
                return src.storage.colcount().to(src.dtype())
            elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
                return torch.ones(src.size(1), dtype=src.dtype())
            else:
                raise ValueError
        elif dim == 1 and value is not None:
            return segment_csr(value, src.storage.rowptr(), None, reduce)
        elif dim == 1 and value is None:
            if reduce == 'sum' or reduce == 'add':
                return src.storage.rowcount().to(src.dtype())
            elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
                return torch.ones(src.size(0), dtype=src.dtype())
            else:
                raise ValueError
        elif dim > 1 and value is not None:
            if reduce == 'sum' or reduce == 'add':
                return value.sum(dim=dim - 1)
            elif reduce == 'mean':
                return value.mean(dim=dim - 1)
            elif reduce == 'min':
                return value.min(dim=dim - 1)[0]
            elif reduce == 'max':
                return value.max(dim=dim - 1)[0]
            else:
                raise ValueError
        else:
            raise ValueError
Esempio n. 12
0
    def forward(
        self,
        x,
        ptx,
        bs,
        height,
        width,
        point_key=None,
        point_edges=None,
        pixel_tgt_idx=None,
        **kwargs,
    ):
        # cat ptx, x to feat
        assert bs == 1
        xch = x.shape[1]
        x_old = x
        x = x.permute(0, 2, 3, 1).view(height * width, xch)
        x = x[pixel_tgt_idx.long()]
        x = torch_scatter.gather_csr(x, point_key)
        x = torch.cat((ptx, x), dim=1)

        # eval network
        x = self._net_forward(x, point_key, point_edges)
        ptx = x
        x = torch_scatter.segment_csr(x, point_key, reduce=self.aggr)

        # map list to map
        x, mask = ext.mytorch.list_to_map(x, pixel_tgt_idx, bs, height, width)
        x = torch.where(mask > 0, x, x_old)

        return x, ptx
Esempio n. 13
0
def test_zero_elements(reduce, dtype, device):
    x = torch.randn(0,
                    0,
                    0,
                    16,
                    dtype=dtype,
                    device=device,
                    requires_grad=True)
    index = tensor([], torch.long, device)
    indptr = tensor([], torch.long, device)

    out = scatter(x, index, dim=0, dim_size=0, reduce=reduce)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)

    out = segment_coo(x, index, dim_size=0, reduce=reduce)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)

    out = gather_coo(x, index)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)

    out = segment_csr(x, indptr, reduce=reduce)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)

    out = gather_csr(x, indptr)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)
Esempio n. 14
0
    def to_symmetric(self, reduce: str = "sum"):
        N = max(self.size(0), self.size(1))

        row, col, value = self.coo()
        idx = col.new_full((2 * col.numel() + 1, ), -1)
        idx[1:row.numel() + 1] = row
        idx[row.numel() + 1:] = col
        idx[1:] *= N
        idx[1:row.numel() + 1] += col
        idx[row.numel() + 1:] += row

        idx, perm = idx.sort()
        mask = idx[1:] > idx[:-1]
        perm = perm[1:].sub_(1)
        idx = perm[mask]

        if value is not None:
            ptr = mask.nonzero().flatten()
            ptr = torch.cat([ptr, ptr.new_full((1, ), perm.size(0))])
            value = torch.cat([value, value])[perm]
            value = segment_csr(value, ptr, reduce=reduce)

        new_row = torch.cat([row, col], dim=0, out=perm)[idx]
        new_col = torch.cat([col, row], dim=0, out=perm)[idx]

        out = SparseTensor(row=new_row, rowptr=None, col=new_col, value=value,
                           sparse_sizes=(N, N), is_sorted=True)
        return out
Esempio n. 15
0
    def forward(
        self,
        x,
        ptx,
        bs,
        height,
        width,
        point_key=None,
        point_edges=None,
        pixel_tgt_idx=None,
        **kwargs,
    ):
        # cat ptx, x to feat
        assert bs == 1
        xch = x.shape[1]
        feat = x.permute(0, 2, 3, 1).view(height * width, xch)
        feat = feat[pixel_tgt_idx.long()]
        feat = torch_scatter.gather_csr(feat, point_key)
        feat = torch.cat((ptx, feat), dim=1)

        # eval network
        feat = self._net_forward(feat, point_key, point_edges)

        weight = softmax_csr(self.sm_out(feat), point_key)
        if self.feat_out:
            ptx = self.feat_out(feat)
        feat = weight * ptx
        feat = torch_scatter.segment_csr(feat, point_key, reduce="sum")

        # map feat to x
        feat, mask = ext.mytorch.list_to_map(feat, pixel_tgt_idx, bs, height,
                                             width)
        x = torch.where(mask > 0, feat, x)

        return x, ptx
Esempio n. 16
0
 def forward(self, h, batch):
     h = self.lin1(h)
     h = self.s1(h)
     h = segment_csr(h, get_ptr(batch), reduce=self.readout)
     h = self.lin2(h)
     h = self.s2(h)
     h = self.lin3(h)
     return h
Esempio n. 17
0
 def forward(self, h, batch):
     if len(self.modules_list_1) > 0:
         h = self.modules_list_1(h)
     h = segment_csr(h, get_ptr(batch), reduce=self.readout)
     if len(self.modules_list_2) > 0:
         h = self.modules_list_2(h)
     if self.last_layer is not None:
         h = self.last_layer(h)
     return h
Esempio n. 18
0
    def forward(self, ptx, bs, height, width, **kwargs):  # point features
        point_key = kwargs["point_key"]
        pixel_tgt_idx = kwargs["pixel_tgt_idx"]

        # average per point
        if self.point_avg_mode == "avg":
            x = torch_scatter.segment_csr(ptx, point_key, reduce="mean")
        elif self.point_avg_mode == "diravg":
            with torch.no_grad():
                point_tgt_dirs = kwargs["point_tgt_dirs"][pixel_tgt_idx.long()]
                point_tgt_dirs = torch_scatter.gather_csr(
                    point_tgt_dirs, point_key)
                weight = (kwargs["point_src_dirs"] * point_tgt_dirs).sum(dim=1)
                weight = torch.clamp(weight, 0.01, 1)
                weight_sum = torch_scatter.segment_csr(weight,
                                                       point_key,
                                                       reduce="sum")
                weight /= torch_scatter.gather_csr(weight_sum, point_key)

            x = weight.view(-1, 1) * ptx
            x = torch_scatter.segment_csr(x, point_key, reduce="sum")
        else:
            raise Exception("invalid avg_mode")

        # project to target
        x, mask = ext.mytorch.list_to_map(x, pixel_tgt_idx, bs, height, width)

        # run refinement network
        for sidx in range(len(self.nets)):
            # process per 3D point
            if self.gnns is not None and sidx < len(self.gnns):
                gnn = self.gnns[sidx]
                x, ptx = gnn(x, ptx, bs, height, width, **kwargs)

            unet = self.nets[sidx]
            if self.nets_residual:
                x = x + unet(x)
            else:
                x = unet(x)

        if self.out_conv:
            x = self.out_conv(x)

        return {"out": x, "mask": mask}
Esempio n. 19
0
    def aggregate(self, inputs, index, ptr=None, dim_size=None):
        if ptr is not None:
            for _ in range(self.node_dim):
                ptr = ptr.unsqueeze(0)
            aggr_mean = segment_csr(inputs, ptr, reduce='mean')
            aggr_max = segment_csr(inputs, ptr, reduce='max')
        else:
            aggr_mean = scatter(inputs,
                                index,
                                dim=self.node_dim,
                                dim_size=dim_size,
                                reduce='mean')
            aggr_max = scatter(inputs,
                               index,
                               dim=self.node_dim,
                               dim_size=dim_size,
                               reduce='max')

        return torch.cat([aggr_mean, aggr_max], dim=-1)
Esempio n. 20
0
def softmax(
    src: Tensor,
    index: Optional[Tensor] = None,
    ptr: Optional[Tensor] = None,
    num_nodes: Optional[int] = None,
    dim: int = 0,
) -> Tensor:
    r"""Computes a sparsely evaluated softmax.
    Given a value tensor :attr:`src`, this function first groups the values
    along the first dimension based on the indices specified in :attr:`index`,
    and then proceeds to compute the softmax individually for each group.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor, optional): The indices of elements for applying the
            softmax. (default: :obj:`None`)
        ptr (LongTensor, optional): If given, computes the softmax based on
            sorted inputs in CSR representation. (default: :obj:`None`)
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
        dim (int, optional): The dimension in which to normalize.
            (default: :obj:`0`)

    :rtype: :class:`Tensor`
    """
    if ptr is not None:
        dim = dim + src.dim() if dim < 0 else dim
        size = ([1] * dim) + [-1]
        ptr = ptr.view(size)
        src_max = gather_csr(segment_csr(src, ptr, reduce='max'), ptr)
        out = (src - src_max).exp()
        out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr)
    elif index is not None:
        N = maybe_num_nodes(index, num_nodes)
        src_max = scatter(src, index, dim, dim_size=N, reduce='max')
        src_max = src_max.index_select(dim, index)
        out = (src - src_max).exp()
        out_sum = scatter(out, index, dim, dim_size=N, reduce='sum')
        out_sum = out_sum.index_select(dim, index)
    else:
        raise NotImplementedError

    return out / (out_sum + 1e-16)
Esempio n. 21
0
    def reduce(self,
               x: Tensor,
               index: Optional[Tensor] = None,
               ptr: Optional[Tensor] = None,
               dim_size: Optional[int] = None,
               dim: int = -2,
               reduce: str = 'add') -> Tensor:

        if ptr is not None:
            ptr = expand_left(ptr, dim, dims=x.dim())
            return segment_csr(x, ptr, reduce=reduce)

        assert index is not None
        return scatter(x, index, dim=dim, dim_size=dim_size, reduce=reduce)
 def object_selection(self, features, n_obj):
     index = torch.LongTensor(
         sum([[i] * n for i, n in enumerate(n_obj)], []))
     selector = features[:, :1] * self.scale_selector
     selector = scatter_softmax(selector, index.to(self.device), dim=0)
     indptr = torch.LongTensor([0] + np.cumsum(n_obj).tolist()).to(
         self.device)
     selected_features = segment_csr(selector * features,
                                     indptr,
                                     reduce="sum")
     image = selected_features.narrow(1, 1, self.nc_out)
     depth = None
     if self.depth:
         depth = selected_features.narrow(1, 0, 1)
     return image, depth
Esempio n. 23
0
def test_forward(test, reduce, dtype):
    device = torch.device('cuda:1')
    src = tensor(test['src'], dtype, device)
    index = tensor(test['index'], torch.long, device)
    indptr = tensor(test['indptr'], torch.long, device)
    dim = test['dim']
    expected = tensor(test[reduce], dtype, device)

    out = torch_scatter.scatter(src, index, dim, reduce=reduce)
    assert torch.all(out == expected)

    out = torch_scatter.segment_coo(src, index, reduce=reduce)
    assert torch.all(out == expected)

    out = torch_scatter.segment_csr(src, indptr, reduce=reduce)
    assert torch.all(out == expected)
Esempio n. 24
0
def softmax(src: Tensor, index: Optional[Tensor], ptr: Optional[Tensor] = None,
            num_nodes: Optional[int] = None) -> Tensor:
    out = src
    if src.numel() > 0:
        out = out - src.max()
    out = out.exp()

    if ptr is not None:
        out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr)
    elif index is not None:
        N = maybe_num_nodes(index, num_nodes)
        out_sum = scatter(out, index, dim=0, dim_size=N, reduce='sum')[index]
    else:
        raise NotImplementedError

    return out / (out_sum + 1e-16)
Esempio n. 25
0
    def forward(self, h, edge_index, edge_weight, edge_attr, data=None):
        state_attr = data.state_attr

        for lin1, interaction in zip(self.lin_list1, self.interactions):
            state_attr = state_attr[data.batch]
            hs = torch.cat((state_attr, h), dim=1)
            h = lin1(hs)
            h = h + interaction(
                h, edge_index, edge_weight, edge_attr, data=data)

            r = segment_csr(h, get_ptr(data.batch), reduce="mean")

            state_attr = torch.sum(r, dim=1, keepdim=True)

            state_attr = state_attr.expand(data.state_attr.shape)

        h = F.relu(self.out(h))
        return h
Esempio n. 26
0
    def forward(self, x, batch):
        """"""
        batch_size = batch.max().item() + 1

        h = (x.new_zeros((self.num_layers, batch_size, self.in_channels)),
             x.new_zeros((self.num_layers, batch_size, self.in_channels)))
        q_star = x.new_zeros(batch_size, self.out_channels)

        for i in range(self.processing_steps):
            q, h = self.lstm(q_star.unsqueeze(0), h)
            q = q.view(batch_size, self.in_channels)
            e = (x * q[batch]).sum(dim=-1, keepdim=True)
            a = softmax(e, batch, num_nodes=batch_size)

            r = segment_csr((a * x), get_ptr(batch))
            q_star = torch.cat([q, r], dim=-1)

        return q_star
Esempio n. 27
0
def test_forward(test, reduce, dtype, device):
    src = tensor(test['src'], dtype, device)
    index = tensor(test['index'], torch.long, device)
    indptr = tensor(test['indptr'], torch.long, device)
    expected = tensor(test[reduce], dtype, device)

    out = segment_coo(src, index, reduce=reduce)
    if isinstance(out, tuple):
        out, arg_out = out
        arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
        assert torch.all(arg_out == arg_expected)
    assert torch.all(out == expected)

    out = segment_csr(src, indptr, reduce=reduce)
    if isinstance(out, tuple):
        out, arg_out = out
        arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
        assert torch.all(arg_out == arg_expected)
    assert torch.all(out == expected)
    def aggregate(self, inputs: torch.Tensor, index: torch.Tensor,
                  ptr: Optional[torch.Tensor] = None,
                  dim_size: Optional[int] = None) -> torch.Tensor:
        r"""Aggregates messages from neighbors as
        :math:`\square_{j \in \mathcal{N}(i)}`.

        Takes in the output of message computation as first argument and any
        argument which was initially passed to :meth:`propagate`.

        By default, this function will delegate its call to scatter functions
        that support "add", "mean" and "max" operations as specified in
        :meth:`__init__` by the :obj:`aggr` argument.
        """
        if ptr is not None:
            ptr = unsqueeze(ptr, dim=0, length=self.node_dim)
            return segment_csr(inputs, ptr, reduce=self.aggr)
        else:
            return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
                           reduce=self.aggr)
Esempio n. 29
0
    def aggregate(self, inputs: Tensor, index: Tensor,
                  ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
                  aggr: Optional[str] = None) -> Tensor:
        r"""Aggregates messages from neighbors as
        :math:`\square_{j \in \mathcal{N}(i)}`.

        Takes in the output of message computation as first argument and any
        argument which was initially passed to :meth:`propagate`.

        By default, this function will delegate its call to scatter functions
        that support "add", "mean", "min", "max" and "mul" operations as
        specified in :meth:`__init__` by the :obj:`aggr` argument.
        """
        aggr = self.aggr if aggr is None else aggr
        assert aggr is not None
        if ptr is not None:
            ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
            return segment_csr(inputs, ptr, reduce=aggr)
        else:
            return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
                           reduce=aggr)
Esempio n. 30
0
    def aggregate_up(self,
                     inputs: Tensor,
                     agg_up_index: Tensor,
                     up_ptr: Optional[Tensor] = None,
                     up_dim_size: Optional[int] = None) -> Tensor:
        r"""Aggregates messages from upper adjacent cells.

        Takes in the output of message computation as first argument and any
        argument which was initially passed to :meth:`propagate`.

        By default, this function will delegate its call to scatter functions
        that support "add", "mean" and "max" operations as specified in
        :meth:`__init__` by the :obj:`aggr` argument.
        """
        if up_ptr is not None:
            up_ptr = expand_left(up_ptr, dim=self.node_dim, dims=inputs.dim())
            return segment_csr(inputs, up_ptr, reduce=self.aggr_up)
        else:
            return scatter(inputs,
                           agg_up_index,
                           dim=self.node_dim,
                           dim_size=up_dim_size,
                           reduce=self.aggr_up)