Пример #1
0
    def forward(self, data):

        cluster = nn_geometric.voxel_grid(
            data.pos,
            data.batch,
            self.pool_rad,
            start=data.pos.min(dim=0)[0] - self.pool_rad * 0.5,
            end=data.pos.max(dim=0)[0] + self.pool_rad * 0.5)

        cluster, perm = consecutive_cluster(cluster)

        new_pos = scatter(data.pos, cluster, dim=0, reduce='mean')
        new_batch = data.batch[perm]

        cluster = nearest(data.pos, new_pos, data.batch, new_batch)

        cluster, perm = consecutive_cluster(cluster)

        data.x = scatter(data.x, cluster, dim=0, reduce=self._aggr)
        data.pos = scatter(data.pos, cluster, dim=0, reduce='mean')

        data.batch = data.batch[perm]

        data.edge_index = None
        data.edge_attr = None

        data = self.graph_reg(data)

        return data
Пример #2
0
    def forward(self, b1, b2):
        pos = torch.cat([b1.pos, b2.pos], 0)
        batch = torch.cat([b1.batch, b2.batch], 0)

        batch, sorted_indx = torch.sort(batch)
        inv_indx = torch.argsort(sorted_indx)
        pos = pos[sorted_indx, :]

        start = pos.min(dim=0)[0] - self.pool_rad * 0.5
        end = pos.max(dim=0)[0] + self.pool_rad * 0.5

        cluster = torch_geometric.nn.voxel_grid(pos,
                                                batch,
                                                self.pool_rad,
                                                start=start,
                                                end=end)
        cluster, perm = consecutive_cluster(cluster)

        superpoint = scatter(pos, cluster, dim=0, reduce='mean')
        new_batch = batch[perm]

        cluster = nearest(pos, superpoint, batch, new_batch)

        cluster, perm = consecutive_cluster(cluster)

        pos = scatter(pos, cluster, dim=0, reduce='mean')
        branch_mask = torch.zeros(batch.size(0)).bool()
        branch_mask[0:b1.batch.size(0)] = 1

        cluster = cluster[inv_indx]

        nVoxels = len(cluster.unique())

        x_b1 = torch.ones(nVoxels, b1.x.shape[1], device=b1.x.device)
        x_b2 = torch.ones(nVoxels, b2.x.shape[1], device=b2.x.device)

        x_b1 = scatter(b1.x,
                       cluster[branch_mask],
                       dim=0,
                       out=x_b1,
                       reduce='mean')
        x_b2 = scatter(b2.x,
                       cluster[~branch_mask],
                       dim=0,
                       out=x_b2,
                       reduce='mean')

        x = torch.cat([x_b1, x_b2], 1)

        batch = batch[perm]

        b1.x = x
        b1.pos = pos
        b1.batch = batch
        b1.edge_attr = None
        b1.edge_index = None

        return b1
Пример #3
0
    def _process(self, data):
        if self._mode == "last":
            data = shuffle_data(data)

        coords = ((data.pos) / self._grid_size).int()
        if "batch" not in data:
            cluster = grid_cluster(coords, torch.tensor([1, 1, 1]))
        else:
            cluster = voxel_grid(coords, data.batch, 1)
        cluster, unique_pos_indices = consecutive_cluster(cluster)

        skip_keys = []
        if self._quantize_coords:
            skip_keys.append("pos")

        data = group_data(data,
                          cluster,
                          unique_pos_indices,
                          mode=self._mode,
                          skip_keys=skip_keys)

        if self._quantize_coords:
            data.pos = coords[unique_pos_indices]

        return data
Пример #4
0
    def __call__(self, data):
        num_nodes = data.num_nodes

        if "batch" not in data:
            batch = data.pos.new_zeros(num_nodes, dtype=torch.long)
        else:
            batch = data.batch

        cluster = voxel_grid(data.pos, batch, self.size, self.start, self.end)
        cluster, perm = consecutive_cluster(cluster)

        for key, item in data:
            if bool(re.search("edge", key)):
                raise ValueError(
                    "GridSampling does not support coarsening of edges")

            if torch.is_tensor(item) and item.size(0) == num_nodes:
                if key == "y":
                    item = F.one_hot(item, num_classes=self.num_classes)
                    item = scatter_add(item, cluster, dim=0)
                    data[key] = item.argmax(dim=-1)
                elif key == "batch":
                    data[key] = item[perm]
                else:
                    data[key] = scatter_mean(item, cluster, dim=0)

        return data
Пример #5
0
    def forward(self, data):
        # 下采样,得到新的数据
        cluster = nn_geometric.voxel_grid(
            data.pos,
            data.batch,
            self.pool_rad,  # self.pool_rad=0.1,找到目标池化后数据的范围
            start=data.pos.min(dim=0)[0] - self.pool_rad * 0.5,
            end=data.pos.max(dim=0)[0] + self.pool_rad * 0.5)

        cluster, perm = consecutive_cluster(cluster)

        new_pos = scatter_('mean', data.pos, cluster)
        new_batch = data.batch[perm]

        cluster = nearest(data.pos, new_pos, data.batch, new_batch)
        data.x = scatter_(
            self._aggr, data.x, cluster,
            dim_size=new_pos.size(0))  # 根据self._aggr判断是什么池化,最大还是平均

        data.pos = new_pos
        data.batch = new_batch
        data.edge_attr = None

        data = self.graph_reg(data)  # 下采样完成后,再进行构图操作
        return data
Пример #6
0
def max_pool_x(cluster, x, batch: Optional[torch.Tensor] = None, size: Optional[int] = None):
    r"""Max-Pools node features according to the clustering defined in
    :attr:`cluster`.
    Args:
        cluster (LongTensor): Cluster vector :math:`\mathbf{c} \in \{ 0,
            \ldots, N - 1 \}^N`, which assigns each node to a specific cluster.
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
        batch (LongTensor): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots,
            B-1\}}^N`, which assigns each node to a specific example.
        size (int, optional): The maximum number of clusters in a single
            example. This property is useful to obtain a batch-wise dense
            representation, *e.g.* for applying FC layers, but should only be
            used if the size of the maximum number of clusters per example is
            known in advance. (default: :obj:`None`)
    :rtype: (:class:`Tensor`, :class:`LongTensor`) if :attr:`size` is
        :obj:`None`, else :class:`Tensor`
    """
    # TK: Throws error when trying to compile, if batch is None batch.max() is not defined
    # if size is not None:
    #     batch_size = int(batch.max().item()) + 1
    #     return _max_pool_x(cluster, x, batch_size * size), None

    cluster, perm = consecutive_cluster(cluster)
    x = _max_pool_x(cluster, x)
    # TK: Put the following under an if-statement
    # batch = pool_batch(perm, batch)
    if not batch is None:
        batch = batch[perm]
    return x, batch
Пример #7
0
def max_pool(cluster, x: torch.Tensor, edge_index: torch.Tensor, batch: Optional[torch.Tensor]=None, transform: bool=None):
    r"""Pools and coarsens a graph given by the
    :class:`torch_geometric.data.Data` object according to the clustering
    defined in :attr:`cluster`.
    All nodes within the same cluster will be represented as one node.
    Final node features are defined by the *maximum* features of all nodes
    within the same cluster, node positions are averaged and edge indices are
    defined to be the union of the edge indices of all nodes within the same
    cluster.
    Args:
        cluster (LongTensor): Cluster vector :math:`\mathbf{c} \in \{ 0,
            \ldots, N - 1 \}^N`, which assigns each node to a specific cluster.
        data (Data): Graph data object.
        transform (callable, optional): A function/transform that takes in the
            coarsened and pooled :obj:`torch_geometric.data.Data` object and
            returns a transformed version. (default: :obj:`None`)
    :rtype: :class:`torch_geometric.data.Data`
    """
    cluster, perm = consecutive_cluster(cluster)

    x = _max_pool_x(cluster, x)
    # x = None if data.x is None else torch_geometric.nn.pool.max_pool._max_pool_x(cluster, data.x)
    edge_index, edge_attr = pool_edge(cluster, edge_index, edge_attr=None)


    if not(batch is None):
        # batch = pool_batch(perm, batch)
        batch = batch[perm]

    return x, edge_index, batch, edge_attr
    def sample(self, pos=None, x=None, batch=None):
        if len(pos.shape) != 2:
            raise ValueError("This class is for sparse data and expects the pos tensor to be of dimension 2")

        pool = voxel_grid(pos, batch, self._subsampling_param)
        pool, perm = consecutive_cluster(pool)
        batch = pool_batch(perm, batch)
        if x is not None:
            return pool_pos(pool, x), pool_pos(pool, pos), batch
        else:
            return None, pool_pos(pool, pos), batch
Пример #9
0
    def _prepare_data(self, data):
        coords = torch.round((data.pos) / self._grid_size).long()
        cluster = voxel_grid(coords, data.batch, 1)
        cluster, unique_pos_indices = consecutive_cluster(cluster)

        coords = coords[unique_pos_indices]
        new_batch = data.batch[unique_pos_indices]
        new_pos = data.pos[unique_pos_indices]
        x = self._aggregate(data.x, cluster, unique_pos_indices)
        sparse_data = Batch(x=x, pos=new_pos, coords=coords, batch=new_batch)
        return sparse_data, cluster
Пример #10
0
    def _process(self, data):
        num_nodes = data.num_nodes

        if "batch" not in data:
            batch = data.pos.new_zeros(num_nodes, dtype=torch.long)
        else:
            batch = data.batch

        cluster = voxel_grid(data.pos, batch, self.size, self.start, self.end)
        cluster, perm = consecutive_cluster(cluster)

        return group_data(data, cluster, perm, mode="mean")
Пример #11
0
    def _process(self, data):
        if self._mode == "last":
            data = shuffle_data(data)

        coords = torch.round((data.pos) / self._grid_size)
        if "batch" not in data:
            cluster = grid_cluster(coords, torch.tensor([1, 1, 1]))
        else:
            cluster = voxel_grid(coords, data.batch, 1)
        cluster, unique_pos_indices = consecutive_cluster(cluster)

        data = group_data(data, cluster, unique_pos_indices, mode=self._mode)
        if self._quantize_coords:
            data.coords = coords[unique_pos_indices].int()

        return data
Пример #12
0
    def __call__(self, data):
        num_nodes = data.num_nodes

        if 'batch' not in data:
            batch = data.pos.new_zeros(num_nodes, dtype=torch.long)
        else:
            batch = data.batch

        cluster = voxel_grid(data.pos, batch, self.size, self.start, self.end)
        cluster, perm = consecutive_cluster(cluster)

        for key, item in data:
            if bool(re.search('edge', key)):
                raise ValueError(
                    'GridSampling does not support coarsening of edges')
            if torch.is_tensor(item) and item.size(0) == num_nodes:
                if key == 'batch':
                    data[key] = item[perm]
                else:
                    data[key] = scatter_mean(item, cluster, dim=0)

        return data
Пример #13
0
    def __call__(self, data):

        try:
            batch = data.batch
        except AttributeError:
            batch = torch.zeros(data.x.shape[0])

        # points = [data.pos]
        # First downsample

        points = [data.pos]
        list_pool = []
        list_batch = [batch]
        for ind, voxel_size in enumerate(self.list_voxel_size):

            pool = voxel_grid(points[-1], list_batch[-1], voxel_size)
            pool, perm = consecutive_cluster(pool)
            list_batch.append(pool_batch(perm, list_batch[-1]))
            points.append(pool_pos(pool, points[-1]))

        try:
            res = MultiScaleBatch(
                batch=data.batch,
                list_pool=list_pool,
                points=points,
                list_batch=list_batch,
                x=data.x,
                y=data.y,
                pos=data.pos)
        except AttributeError:
            res = MultiScaleData(
                list_pool=list_pool,
                points=points,
                x=data.x,
                y=data.y,
                pos=data.pos)

        return res
Пример #14
0
    def __call__(self, data):
        num_nodes = data.num_nodes

        pos = data.pos
        batch = data.batch

        pool = voxel_grid(pos, batch, self._subsampling_param)
        pool, _ = consecutive_cluster(pool)

        for key, item in data:
            if bool(re.search('edge', key)):
                continue

            if torch.is_tensor(item) and item.size(0) == num_nodes:
                if key == 'y':
                    one_hot = torch.zeros((item.shape[0], self._num_classes))\
                        .scatter(1, item.unsqueeze(-1), 1)

                    aggr_labels = scatter_add(one_hot, pool, dim=0)
                    data[key] = torch.argmax(aggr_labels, -1)
                else:
                    data[key] = pool_pos(pool, item).to(item.dtype)
        return data
Пример #15
0
def mgpool(x, pos, edge_index, batch, mask=None):
    adj_values = torch.ones(edge_index.shape[1]).cuda()
    cluster = graclus(edge_index)
    cluster, perm = consecutive_cluster(cluster)

    index = torch.stack([cluster, torch.arange(0, x.shape[0]).cuda()], dim=0)
    values = torch.ones(cluster.shape[0], dtype=torch.float).cuda()
    uniq, inv, counts = torch.unique(cluster,
                                     return_inverse=True,
                                     return_counts=True)
    newsize = uniq.shape[0]

    origsize = x.shape[0]

    new_batch = pool_batch(perm, batch)
    # Compute random walk graph laplacian:
    laplacian_index, laplacian_weights = get_laplacian(edge_index,
                                                       normalization='rw')
    laplacian_index, laplacian_weights = torch_sparse.coalesce(
        laplacian_index, laplacian_weights, m=origsize, n=origsize)
    index, values = torch_sparse.coalesce(index, values, m=newsize,
                                          n=origsize)  # P^T matrix
    new_feat = torch_sparse.spmm(index,
                                 values,
                                 m=newsize,
                                 n=origsize,
                                 matrix=x)  # P^T X
    new_pos = torch_sparse.spmm(index,
                                values,
                                m=newsize,
                                n=origsize,
                                matrix=pos)  # P^T POS

    new_adj, new_adj_val = torch_sparse.spspmm(index,
                                               values,
                                               edge_index,
                                               adj_values,
                                               m=newsize,
                                               k=origsize,
                                               n=origsize,
                                               coalesced=True)  # P^T A
    index, values = torch_sparse.transpose(index,
                                           values,
                                           m=newsize,
                                           n=origsize,
                                           coalesced=True)  # P
    new_adj, new_adj_val = torch_sparse.spspmm(new_adj,
                                               new_adj_val,
                                               index,
                                               values,
                                               m=newsize,
                                               k=origsize,
                                               n=newsize,
                                               coalesced=True)  # (P^T A) P
    # Precompute QP :
    values = torch.ones(cluster.shape[0], dtype=torch.float).cuda()
    index, values = torch_sparse.spspmm(laplacian_index,
                                        laplacian_weights,
                                        index,
                                        values,
                                        m=origsize,
                                        k=origsize,
                                        n=newsize,
                                        coalesced=True)
    return new_adj, new_feat, new_pos, new_batch, index, values, origsize, newsize
Пример #16
0
def community_pooling(cluster, data):
    """Pools features and edges of all cluster members 
    
    All cluster members are pooled into a single node that is assigned:
    - the max cluster value for each feature
    - the average cluster nodes position 

    Args:
        cluster ([type]): clusters
        data ([type]): features tensor

    Returns:
        pooled features tensor
    """
    # determine what the batches has as attributes
    has_internal_edges = hasattr(data, 'internal_edge_index')
    has_pos2D = hasattr(data, 'pos2D')
    has_pos = hasattr(data, 'pos')
    has_cluster = hasattr(data, 'cluster0')

    cluster, perm = consecutive_cluster(cluster)
    cluster = cluster.to(data.x.device)

    # pool the node infos
    x, _ = scatter_max(data.x, cluster, dim=0)

    # pool the edges
    edge_index, edge_attr = pool_edge(cluster, data.edge_index, data.edge_attr)

    # pool internal edges if necessary
    if has_internal_edges:
        internal_edge_index, internal_edge_attr = pool_edge(
            cluster, data.internal_edge_index, data.internal_edge_attr)

    # pool the pos
    if has_pos:
        pos = scatter_mean(data.pos, cluster, dim=0)
    if has_pos2D:
        pos2D = scatter_mean(data.pos2D, cluster, dim=0)

    if has_cluster:
        c0, c1 = data.cluster0, data.cluster1

    # pool batch
    if hasattr(data, 'batch'):
        batch = None if data.batch is None else pool_batch(perm, data.batch)
        data = Batch(batch=batch,
                     x=x,
                     edge_index=edge_index,
                     edge_attr=edge_attr,
                     pos=pos)

        if has_internal_edges:
            data.internal_edge_index = internal_edge_index
            data.internal_edge_attr = internal_edge_attr

        if has_cluster:
            data.cluster0 = c0
            data.cluster1 = c1

    else:
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos)

        if has_internal_edges:
            data.internal_edge_index = internal_edge_index
            data.internal_edge_attr = internal_edge_attr

        if has_pos2D:
            data.pos2D = pos2D

        if has_cluster:
            data.cluster0 = c0
            data.cluster1 = c1

    return data
Пример #17
0
def community_pooling(cluster, data):

    # determine what the batches has as attributes
    has_internal_edges = hasattr(data, 'internal_edge_index')
    has_pos2D = hasattr(data, 'pos2D')
    has_pos = hasattr(data, 'pos')
    has_cluster = hasattr(data, 'cluster0')

    cluster, perm = consecutive_cluster(cluster)
    cluster = cluster.to(data.x.device)

    # pool the node infos
    x, _ = scatter_max(data.x, cluster, dim=0)

    # pool the edges
    edge_index, edge_attr = pool_edge(cluster, data.edge_index, data.edge_attr)

    # pool internal edges if necessary
    if has_internal_edges:
        internal_edge_index, internal_edge_attr = pool_edge(
            cluster, data.internal_edge_index, data.internal_edge_attr)

    # pool the pos
    if has_pos:
        pos = scatter_mean(data.pos, cluster, dim=0)
    if has_pos2D:
        pos2D = scatter_mean(data.pos2D, cluster, dim=0)

    if has_cluster:
        c0, c1 = data.cluster0, data.cluster1

    # pool batch
    if hasattr(data, 'batch'):
        batch = None if data.batch is None else pool_batch(perm, data.batch)
        data = Batch(batch=batch,
                     x=x,
                     edge_index=edge_index,
                     edge_attr=edge_attr,
                     pos=pos)

        if has_internal_edges:
            data.internal_edge_index = internal_edge_index
            data.internal_edge_attr = internal_edge_attr

        if has_cluster:
            data.cluster0 = c0
            data.cluster1 = c1

    else:
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos)

        if has_internal_edges:
            data.internal_edge_index = internal_edge_index
            data.internal_edge_attr = internal_edge_attr

        if has_pos2D:
            data.pos2D = pos2D

        if has_cluster:
            data.cluster0 = c0
            data.cluster1 = c1

    return data
Пример #18
0
def community_pooling(cluster, data):
    """Pools features and edges of all cluster members

    All cluster members are pooled into a single node that is assigned:
    - the max cluster value for each feature
    - the average cluster nodes position

    Args:
        cluster ([type]): clusters
        data ([type]): features tensor

    Returns:
        pooled features tensor


    Example:
        >>> import torch
        >>> from torch_geometric.data import Data, Batch
        >>> edge_index = torch.tensor([[0, 1, 1, 2, 3, 4, 4, 5],
        >>>                            [1, 0, 2, 1, 4, 3, 5, 4]], dtype=torch.long)
        >>> x = torch.tensor([[0], [1], [2], [3], [4], [5]],
        >>>                  dtype=torch.float)
        >>> data = Data(x=x, edge_index=edge_index)
        >>> data.pos = torch.tensor(np.random.rand(data.num_nodes, 3))
        >>> c = community_detection(data.edge_index, data.num_nodes)
        >>> batch = Batch().from_data_list([data, data])
        >>> cluster = community_detection(batch.edge_index, batch.num_nodes)
        >>> new_batch = community_pooling(cluster, batch)
    """

    # determine what the batches has as attributes
    has_internal_edges = hasattr(data, 'internal_edge_index')
    has_pos2D = hasattr(data, 'pos2D')
    has_pos = hasattr(data, 'pos')
    has_cluster = hasattr(data, 'cluster0')

    cluster, perm = consecutive_cluster(cluster)
    cluster = cluster.to(data.x.device)

    # pool the node infos
    x, _ = scatter_max(data.x, cluster, dim=0)

    # pool the edges
    edge_index, edge_attr = pool_edge(cluster, data.edge_index, data.edge_attr)

    # pool internal edges if necessary
    if has_internal_edges:
        internal_edge_index, internal_edge_attr = pool_edge(
            cluster, data.internal_edge_index, data.internal_edge_attr)

    # pool the pos
    if has_pos:
        pos = scatter_mean(data.pos, cluster, dim=0)
    if has_pos2D:
        pos2D = scatter_mean(data.pos2D, cluster, dim=0)

    if has_cluster:
        c0, c1 = data.cluster0, data.cluster1

    # pool batch
    if hasattr(data, 'batch'):
        batch = None if data.batch is None else pool_batch(perm, data.batch)
        data = Batch(batch=batch,
                     x=x,
                     edge_index=edge_index,
                     edge_attr=edge_attr,
                     pos=pos)

        if has_internal_edges:
            data.internal_edge_index = internal_edge_index
            data.internal_edge_attr = internal_edge_attr

        if has_cluster:
            data.cluster0 = c0
            data.cluster1 = c1

    else:
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos)

        if has_internal_edges:
            data.internal_edge_index = internal_edge_index
            data.internal_edge_attr = internal_edge_attr

        if has_pos2D:
            data.pos2D = pos2D

        if has_cluster:
            data.cluster0 = c0
            data.cluster1 = c1

    return data
Пример #19
0
def test_consecutive_cluster():
    src = torch.tensor([8, 2, 10, 15, 100, 1, 100])

    out, perm = consecutive_cluster(src)
    assert out.tolist() == [2, 1, 3, 4, 5, 0, 5]
    assert perm.tolist() == [5, 1, 0, 2, 3, 6]