def sample(self, pos=None, x=None, batch=None):
        if len(pos.shape) != 2:
            raise ValueError("This class is for sparse data and expects the pos tensor to be of dimension 2")

        pool = voxel_grid(pos, batch, self._subsampling_param)
        pool, perm = consecutive_cluster(pool)
        batch = pool_batch(perm, batch)
        if x is not None:
            return pool_pos(pool, x), pool_pos(pool, pos), batch
        else:
            return None, pool_pos(pool, pos), batch
Exemplo n.º 2
0
    def __call__(self, data):

        try:
            batch = data.batch
        except AttributeError:
            batch = torch.zeros(data.x.shape[0])

        # points = [data.pos]
        # First downsample

        points = [data.pos]
        list_pool = []
        list_batch = [batch]
        for ind, voxel_size in enumerate(self.list_voxel_size):

            pool = voxel_grid(points[-1], list_batch[-1], voxel_size)
            pool, perm = consecutive_cluster(pool)
            list_batch.append(pool_batch(perm, list_batch[-1]))
            points.append(pool_pos(pool, points[-1]))

        try:
            res = MultiScaleBatch(
                batch=data.batch,
                list_pool=list_pool,
                points=points,
                list_batch=list_batch,
                x=data.x,
                y=data.y,
                pos=data.pos)
        except AttributeError:
            res = MultiScaleData(
                list_pool=list_pool,
                points=points,
                x=data.x,
                y=data.y,
                pos=data.pos)

        return res
Exemplo n.º 3
0
def community_pooling(cluster, data):
    """Pools features and edges of all cluster members 
    
    All cluster members are pooled into a single node that is assigned:
    - the max cluster value for each feature
    - the average cluster nodes position 

    Args:
        cluster ([type]): clusters
        data ([type]): features tensor

    Returns:
        pooled features tensor
    """
    # determine what the batches has as attributes
    has_internal_edges = hasattr(data, 'internal_edge_index')
    has_pos2D = hasattr(data, 'pos2D')
    has_pos = hasattr(data, 'pos')
    has_cluster = hasattr(data, 'cluster0')

    cluster, perm = consecutive_cluster(cluster)
    cluster = cluster.to(data.x.device)

    # pool the node infos
    x, _ = scatter_max(data.x, cluster, dim=0)

    # pool the edges
    edge_index, edge_attr = pool_edge(cluster, data.edge_index, data.edge_attr)

    # pool internal edges if necessary
    if has_internal_edges:
        internal_edge_index, internal_edge_attr = pool_edge(
            cluster, data.internal_edge_index, data.internal_edge_attr)

    # pool the pos
    if has_pos:
        pos = scatter_mean(data.pos, cluster, dim=0)
    if has_pos2D:
        pos2D = scatter_mean(data.pos2D, cluster, dim=0)

    if has_cluster:
        c0, c1 = data.cluster0, data.cluster1

    # pool batch
    if hasattr(data, 'batch'):
        batch = None if data.batch is None else pool_batch(perm, data.batch)
        data = Batch(batch=batch,
                     x=x,
                     edge_index=edge_index,
                     edge_attr=edge_attr,
                     pos=pos)

        if has_internal_edges:
            data.internal_edge_index = internal_edge_index
            data.internal_edge_attr = internal_edge_attr

        if has_cluster:
            data.cluster0 = c0
            data.cluster1 = c1

    else:
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos)

        if has_internal_edges:
            data.internal_edge_index = internal_edge_index
            data.internal_edge_attr = internal_edge_attr

        if has_pos2D:
            data.pos2D = pos2D

        if has_cluster:
            data.cluster0 = c0
            data.cluster1 = c1

    return data
Exemplo n.º 4
0
def community_pooling(cluster, data):

    # determine what the batches has as attributes
    has_internal_edges = hasattr(data, 'internal_edge_index')
    has_pos2D = hasattr(data, 'pos2D')
    has_pos = hasattr(data, 'pos')
    has_cluster = hasattr(data, 'cluster0')

    cluster, perm = consecutive_cluster(cluster)
    cluster = cluster.to(data.x.device)

    # pool the node infos
    x, _ = scatter_max(data.x, cluster, dim=0)

    # pool the edges
    edge_index, edge_attr = pool_edge(cluster, data.edge_index, data.edge_attr)

    # pool internal edges if necessary
    if has_internal_edges:
        internal_edge_index, internal_edge_attr = pool_edge(
            cluster, data.internal_edge_index, data.internal_edge_attr)

    # pool the pos
    if has_pos:
        pos = scatter_mean(data.pos, cluster, dim=0)
    if has_pos2D:
        pos2D = scatter_mean(data.pos2D, cluster, dim=0)

    if has_cluster:
        c0, c1 = data.cluster0, data.cluster1

    # pool batch
    if hasattr(data, 'batch'):
        batch = None if data.batch is None else pool_batch(perm, data.batch)
        data = Batch(batch=batch,
                     x=x,
                     edge_index=edge_index,
                     edge_attr=edge_attr,
                     pos=pos)

        if has_internal_edges:
            data.internal_edge_index = internal_edge_index
            data.internal_edge_attr = internal_edge_attr

        if has_cluster:
            data.cluster0 = c0
            data.cluster1 = c1

    else:
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos)

        if has_internal_edges:
            data.internal_edge_index = internal_edge_index
            data.internal_edge_attr = internal_edge_attr

        if has_pos2D:
            data.pos2D = pos2D

        if has_cluster:
            data.cluster0 = c0
            data.cluster1 = c1

    return data
Exemplo n.º 5
0
def community_pooling(cluster, data):
    """Pools features and edges of all cluster members

    All cluster members are pooled into a single node that is assigned:
    - the max cluster value for each feature
    - the average cluster nodes position

    Args:
        cluster ([type]): clusters
        data ([type]): features tensor

    Returns:
        pooled features tensor


    Example:
        >>> import torch
        >>> from torch_geometric.data import Data, Batch
        >>> edge_index = torch.tensor([[0, 1, 1, 2, 3, 4, 4, 5],
        >>>                            [1, 0, 2, 1, 4, 3, 5, 4]], dtype=torch.long)
        >>> x = torch.tensor([[0], [1], [2], [3], [4], [5]],
        >>>                  dtype=torch.float)
        >>> data = Data(x=x, edge_index=edge_index)
        >>> data.pos = torch.tensor(np.random.rand(data.num_nodes, 3))
        >>> c = community_detection(data.edge_index, data.num_nodes)
        >>> batch = Batch().from_data_list([data, data])
        >>> cluster = community_detection(batch.edge_index, batch.num_nodes)
        >>> new_batch = community_pooling(cluster, batch)
    """

    # determine what the batches has as attributes
    has_internal_edges = hasattr(data, 'internal_edge_index')
    has_pos2D = hasattr(data, 'pos2D')
    has_pos = hasattr(data, 'pos')
    has_cluster = hasattr(data, 'cluster0')

    cluster, perm = consecutive_cluster(cluster)
    cluster = cluster.to(data.x.device)

    # pool the node infos
    x, _ = scatter_max(data.x, cluster, dim=0)

    # pool the edges
    edge_index, edge_attr = pool_edge(cluster, data.edge_index, data.edge_attr)

    # pool internal edges if necessary
    if has_internal_edges:
        internal_edge_index, internal_edge_attr = pool_edge(
            cluster, data.internal_edge_index, data.internal_edge_attr)

    # pool the pos
    if has_pos:
        pos = scatter_mean(data.pos, cluster, dim=0)
    if has_pos2D:
        pos2D = scatter_mean(data.pos2D, cluster, dim=0)

    if has_cluster:
        c0, c1 = data.cluster0, data.cluster1

    # pool batch
    if hasattr(data, 'batch'):
        batch = None if data.batch is None else pool_batch(perm, data.batch)
        data = Batch(batch=batch,
                     x=x,
                     edge_index=edge_index,
                     edge_attr=edge_attr,
                     pos=pos)

        if has_internal_edges:
            data.internal_edge_index = internal_edge_index
            data.internal_edge_attr = internal_edge_attr

        if has_cluster:
            data.cluster0 = c0
            data.cluster1 = c1

    else:
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos)

        if has_internal_edges:
            data.internal_edge_index = internal_edge_index
            data.internal_edge_attr = internal_edge_attr

        if has_pos2D:
            data.pos2D = pos2D

        if has_cluster:
            data.cluster0 = c0
            data.cluster1 = c1

    return data
Exemplo n.º 6
0
def mgpool(x, pos, edge_index, batch, mask=None):
    adj_values = torch.ones(edge_index.shape[1]).cuda()
    cluster = graclus(edge_index)
    cluster, perm = consecutive_cluster(cluster)

    index = torch.stack([cluster, torch.arange(0, x.shape[0]).cuda()], dim=0)
    values = torch.ones(cluster.shape[0], dtype=torch.float).cuda()
    uniq, inv, counts = torch.unique(cluster,
                                     return_inverse=True,
                                     return_counts=True)
    newsize = uniq.shape[0]

    origsize = x.shape[0]

    new_batch = pool_batch(perm, batch)
    # Compute random walk graph laplacian:
    laplacian_index, laplacian_weights = get_laplacian(edge_index,
                                                       normalization='rw')
    laplacian_index, laplacian_weights = torch_sparse.coalesce(
        laplacian_index, laplacian_weights, m=origsize, n=origsize)
    index, values = torch_sparse.coalesce(index, values, m=newsize,
                                          n=origsize)  # P^T matrix
    new_feat = torch_sparse.spmm(index,
                                 values,
                                 m=newsize,
                                 n=origsize,
                                 matrix=x)  # P^T X
    new_pos = torch_sparse.spmm(index,
                                values,
                                m=newsize,
                                n=origsize,
                                matrix=pos)  # P^T POS

    new_adj, new_adj_val = torch_sparse.spspmm(index,
                                               values,
                                               edge_index,
                                               adj_values,
                                               m=newsize,
                                               k=origsize,
                                               n=origsize,
                                               coalesced=True)  # P^T A
    index, values = torch_sparse.transpose(index,
                                           values,
                                           m=newsize,
                                           n=origsize,
                                           coalesced=True)  # P
    new_adj, new_adj_val = torch_sparse.spspmm(new_adj,
                                               new_adj_val,
                                               index,
                                               values,
                                               m=newsize,
                                               k=origsize,
                                               n=newsize,
                                               coalesced=True)  # (P^T A) P
    # Precompute QP :
    values = torch.ones(cluster.shape[0], dtype=torch.float).cuda()
    index, values = torch_sparse.spspmm(laplacian_index,
                                        laplacian_weights,
                                        index,
                                        values,
                                        m=origsize,
                                        k=origsize,
                                        n=newsize,
                                        coalesced=True)
    return new_adj, new_feat, new_pos, new_batch, index, values, origsize, newsize