Ejemplo n.º 1
0
class GGCN(torch.nn.Module):
    def __init__(self,
                 num_layers=2,
                 hidden=200,
                 features_num=16,
                 num_class=2,
                 dropout=0.5):
        super(GGCN, self).__init__()
        self.conv1 = GraphConv(features_num, hidden, aggr='add')
        self.lin2 = Linear(hidden, num_class)
        self.dropout = dropout
        print("hidden=%d, dropout=%f" % (hidden, self.dropout))

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight
        x = F.relu(self.conv1(x, edge_index, edge_weight=edge_weight))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
Ejemplo n.º 2
0
class ASAP(torch.nn.Module):
    def __init__(self,
                 num_classes,
                 num_features,
                 num_layers,
                 hidden,
                 ratio=0.8,
                 dropout=0):
        super(ASAP, self).__init__()
        self.conv1 = GraphConv(num_features, hidden, aggr='mean')
        self.convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.convs.extend([
            GraphConv(hidden, hidden, aggr='mean')
            for i in range(num_layers - 1)
        ])
        self.pools.extend([
            ASAPooling(hidden, ratio, dropout=dropout)
            for i in range((num_layers) // 2)
        ])
        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear(num_layers * hidden, hidden)
        self.lin2 = Linear(hidden, num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        for pool in self.pools:
            pool.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        edge_weight = None
        x = F.relu(self.conv1(x, edge_index))
        xs = [global_mean_pool(x, batch)]
        for i, conv in enumerate(self.convs):
            x = conv(x=x, edge_index=edge_index, edge_weight=edge_weight)
            x = F.relu(x)
            xs += [global_mean_pool(x, batch)]
            if i % 2 == 0 and i < len(self.convs) - 1:
                pool = self.pools[i // 2]
                x, edge_index, edge_weight, batch, _ = pool(
                    x=x,
                    edge_index=edge_index,
                    edge_weight=edge_weight,
                    batch=batch)

        x = self.jump(xs)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
Ejemplo n.º 3
0
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        hidden = args.hidden
        num_layers = 5
        ratio = 0.8
        self.conv1 = GraphConv(dataset.num_features, hidden, aggr='add')
        self.convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.convs.extend([
            GraphConv(hidden, hidden, aggr='add')
            for i in range(num_layers - 1)
        ])
        self.pools.extend(
            [TopKPooling(hidden, ratio) for i in range((num_layers) // 2)])
        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear(num_layers * hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        for pool in self.pools:
            pool.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        xs = [global_add_pool(x, batch)]
        for i, conv in enumerate(self.convs):
            x = F.relu(conv(x, edge_index))
            xs += [global_add_pool(x, batch)]
            if i % 2 == 0 and i < len(self.convs) - 1:
                pool = self.pools[i // 2]
                x, edge_index, _, batch, _, _ = pool(x,
                                                     edge_index,
                                                     batch=batch)
        x = self.jump(xs)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)
Ejemplo n.º 4
0
class Graclus(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden):
        super(Graclus, self).__init__()
        self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean')
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(GraphConv(hidden, hidden, aggr='mean'))
        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear(num_layers * hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        xs = [global_mean_pool(x, batch)]
        for i, conv in enumerate(self.convs):
            x = F.relu(conv(x, edge_index))
            xs += [global_mean_pool(x, batch)]
            if i % 2 == 0 and i < len(self.convs) - 1:
                cluster = graclus(edge_index, num_nodes=x.size(0))
                data = Batch(x=x, edge_index=edge_index, batch=batch)
                data = max_pool(cluster, data)
                x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.jump(xs)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
Ejemplo n.º 5
0
class TopK(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden):
        super(TopK, self).__init__()
        self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean')
        self.convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(GraphConv(hidden, hidden, aggr='mean'))
            self.pools.append(TopKPooling(hidden, ratio=0.8))
        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear(num_layers * hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv, pool in zip(self.convs, self.pools):
            conv.reset_parameters()
            pool.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        xs = [global_mean_pool(x, batch)]
        for i, (conv, pool) in enumerate(zip(self.convs, self.pools)):
            x = F.relu(conv(x, edge_index))
            xs += [global_mean_pool(x, batch)]
            if i % 2 == 0:
                x, edge_index, _, batch, _ = pool(x, edge_index, batch=batch)
        x = self.jump(xs)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
Ejemplo n.º 6
0
class GeoGraph(torch.nn.Module):
    def __init__(self, hidden, geo, training):
        super(GeoGraph, self).__init__()

        self.geo = geo
        self.training = training
        if self.geo == True:
            self.conv1 = GraphConv(64 + 3, hidden, aggr='mean')
        else:
            self.conv1 = GraphConv(64, hidden, aggr='mean')
        self.conv2 = GraphConv(hidden, 32, aggr='mean')
        self.conv3 = GraphConv(32, 16, aggr='mean')
        self.lin1 = Linear(16, 16)
        self.pool1 = EdgePoolingMod(16)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.conv3.reset_parameters()
        self.lin1.reset_parameters()

    def forward(self, data):
        #Data(edge_index=[2, 210], neg_edge_index=[2, 182], pos_edge_index=[2, 28], x=[15, 512], x_bbox=[15, 4], x_heading=[15, 2], x_img_pos=[15, 2], x_pos=[15, 2], y=[210])
        if self.training == True:
            x, geo, reg, edge_index, edge_y = data.x.cuda(), data.geos.cuda(
            ), data.regressions.cuda(), data.edge_index.cuda(), data.y.cuda()
            if self.geo == True:
                # x = torch.cat((x,reg),1)
                x = torch.cat((x, geo), 1)
        else:
            x, reg, edge_index = data.x.cuda(), data.regressions.cuda(
            ), data.edge_index.cuda()

        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = x.view(x.size()[0], -1)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.2, training=self.training)
        x, edge_index, batch, edge_scores = self.pool1(x,
                                                       edge_index,
                                                       batch=None)

        return edge_scores

    def __repr__(self):
        return self.__class__.__name__
Ejemplo n.º 7
0
class ASAP(torch.nn.Module):
    def __init__(self, num_vocab, max_seq_len, node_encoder, emb_dim, num_layers, hidden, ratio=0.8, dropout=0, num_class=0):
        super(ASAP, self).__init__()

        self.num_class = num_class
        self.max_seq_len = max_seq_len
        self.node_encoder = node_encoder

        self.conv1 = GraphConv(emb_dim, hidden, aggr='mean')
        self.convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.convs.extend([
            GraphConv(hidden, hidden, aggr='mean')
            for i in range(num_layers - 1)
        ])
        self.pools.extend([
            ASAPooling(hidden, ratio, dropout=dropout)
            for i in range((num_layers) // 2)
        ])
        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear(num_layers * hidden, hidden)
        # self.lin2 = Linear(hidden, dataset.num_classes)

        if self.num_class > 0:  # classification
            self.graph_pred_linear = torch.nn.Linear(hidden, self.num_class)
        else:
            self.graph_pred_linear_list = torch.nn.ModuleList()
            for i in range(max_seq_len):
                self.graph_pred_linear_list.append(torch.nn.Linear(hidden, num_vocab))

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        for pool in self.pools:
            pool.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, node_depth, batch = data.x, data.edge_index, data.node_depth, data.batch

        x = self.node_encoder(x, node_depth.view(-1, ))

        edge_weight = None
        x = F.relu(self.conv1(x, edge_index))
        xs = [global_mean_pool(x, batch)]
        for i, conv in enumerate(self.convs):
            x = conv(x=x, edge_index=edge_index, edge_weight=edge_weight)
            x = F.relu(x)
            xs += [global_mean_pool(x, batch)]
            if i % 2 == 0 and i < len(self.convs) - 1:
                pool = self.pools[i // 2]
                x, edge_index, edge_weight, batch, _ = pool(
                    x=x, edge_index=edge_index, edge_weight=edge_weight,
                    batch=batch)
        x = self.jump(xs)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        # x = self.lin2(x)
        # return F.log_softmax(x, dim=-1)

        if self.num_class > 0:
            return self.graph_pred_linear(x)

        pred_list = []
        for i in range(self.max_seq_len):
            pred_list.append(self.graph_pred_linear_list[i](x))
        return pred_list

    def __repr__(self):
        return self.__class__.__name__
Ejemplo n.º 8
0
class SAGPooling(torch.nn.Module):
    r"""The self-attention pooling operator from the `"Self-Attention Graph
    Pooling" <https://arxiv.org/abs/1904.08082>`_  paper

    .. math::
        \mathbf{y} &= \textrm{GNN}(\mathbf{X}, \mathbf{A})

        \mathbf{i} &= \mathrm{top}_k(\mathbf{y})

        \mathbf{X}^{\prime} &= (\mathbf{X} \odot
        \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}}

        \mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}},

    where nodes are dropped based on a learnable projection score
    :math:`\mathbf{p}`.
    Projections scores are learned based on a graph neural network layer.

    Args:
        in_channels (int): Size of each input sample.
        ratio (float): Graph pooling ratio, which is used to compute
            :math:`k = \lceil \mathrm{ratio} \cdot N \rceil`.
            (default: :obj:`0.5`)
        gnn (string, optional): Specifies which graph neural network layer to
            use for calculating projection scores (one of
            :obj:`"GCN"`, :obj:`"GAT"` or :obj:`"SAGE"`). (default: :obj:`GCN`)
        **kwargs (optional): Additional parameters for initializing the graph
            neural network layer.
    """
    def __init__(self, in_channels, ratio=0.5, gnn='gConv', **kwargs):
        super(SAGPooling, self).__init__()

        self.in_channels = in_channels
        self.ratio = ratio
        self.gnn_name = gnn

        assert gnn in ['GCN', 'GAT', 'SAGE', 'gConv']
        if gnn == 'GCN':
            self.gnn = GCNConv(self.in_channels, 1, **kwargs)
        elif gnn == 'GAT':
            self.gnn = GATConv(self.in_channels, 1, **kwargs)
        elif gnn == 'SAGE':
            self.gnn = SAGEConv(self.in_channels, 1, **kwargs)
        else:
            self.gnn = GraphConv(self.in_channels, 1, **kwargs)

        self.reset_parameters()

    def reset_parameters(self):
        self.gnn.reset_parameters()

    def forward(self, x, edge_index, edge_attr=None, batch=None):
        """"""
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        x = x.unsqueeze(-1) if x.dim() == 1 else x

        score = torch.tanh(self.gnn(x, edge_index).view(-1))
        perm = topk(score, self.ratio, batch)
        x = x[perm] * score[perm].view(-1, 1)
        batch = batch[perm]
        edge_index, edge_attr = filter_adj(edge_index,
                                           edge_attr,
                                           perm,
                                           num_nodes=score.size(0))

        return x, edge_index, edge_attr, batch, perm

    def __repr__(self):
        return '{}({}, {}, ratio={})'.format(self.__class__.__name__,
                                             self.gnn_name, self.in_channels,
                                             self.ratio)
Ejemplo n.º 9
0
class StarPooling(torch.nn.Module):
    r"""The edge pooling operator from the `"Towards Graph Pooling by Edge
    Contraction" <https://graphreason.github.io/papers/17.pdf>`_ and
    `"Edge Contraction Pooling for Graph Neural Networks"
    <https://arxiv.org/abs/1905.10990>`_ papers.

    In short, a score is computed for each edge.
    Edges are contracted iteratively according to that score unless one of
    their nodes has already been part of a contracted edge.

    To duplicate the configuration from the "Towards Graph Pooling by Edge
    Contraction" paper, use either
    :func:`EdgePooling.compute_edge_score_softmax`
    or :func:`EdgePooling.compute_edge_score_tanh`, and set
    :obj:`add_to_edge_score` to :obj:`0`.

    To duplicate the configuration from the "Edge Contraction Pooling for
    Graph Neural Networks" paper, set :obj:`dropout` to :obj:`0.2`.

    Args:
        in_channels (int): Size of each input sample.
        edge_score_method (function, optional): The function to apply
            to compute the edge score from raw edge scores. By default,
            this is the softmax over all incoming edges for each node.
            This function takes in a :obj:`raw_edge_score` tensor of shape
            :obj:`[num_nodes]`, an :obj:`edge_index` tensor and the number of
            nodes :obj:`num_nodes`, and produces a new tensor of the same size
            as :obj:`raw_edge_score` describing normalized edge scores.
            Included functions are
            :func:`EdgePooling.compute_edge_score_softmax`,
            :func:`EdgePooling.compute_edge_score_tanh`, and
            :func:`EdgePooling.compute_edge_score_sigmoid`.
            (default: :func:`EdgePooling.compute_edge_score_softmax`)
        dropout (float, optional): The probability with
            which to drop edge scores during training. (default: :obj:`0`)
        add_to_edge_score (float, optional): This is added to each
            computed edge score. Adding this greatly helps with unpool
            stability. (default: :obj:`0.5`)
    """

    unpool_description = namedtuple("UnpoolDescription",
                                    ["edge_index", "cluster", "batch"])

    def __init__(self, in_channels, node_score_method=None, dropout=0):
        super(StarPooling, self).__init__()
        self.in_channels = in_channels
        if node_score_method is None:
            node_score_method = self.compute_node_score_tanh

        self.compute_node_score = node_score_method
        self.dropout = dropout

        self.score_func = GraphConv(in_channels, 1)
        self.reset_parameters()

    def reset_parameters(self):
        self.score_func.reset_parameters()

    # @staticmethod
    # def compute_edge_score_softmax(raw_edge_score, edge_index, num_nodes):
    #     return softmax(raw_edge_score, edge_index[1], num_nodes)

    @staticmethod
    def compute_node_score_tanh(raw_edge_score):
        return torch.tanh(raw_edge_score)

    @staticmethod
    def compute_node_score_sigmoid(raw_edge_score):
        return torch.sigmoid(raw_edge_score)

    def forward(self, x, edge_index, edge_attr=None, batch=None):
        r"""Forward computation which computes the raw edge score, normalizes
        it, and merges the edges.

        Args:
            x (Tensor): The node features.
            edge_index (LongTensor): The edge indices.
            batch (LongTensor): Batch vector
                :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
                each node to a specific example.

        Return types:
            * **x** *(Tensor)* - The pooled node features.
            * **edge_index** *(LongTensor)* - The coarsened edge indices.
            * **batch** *(LongTensor)* - The coarsened batch vector.
            * **unpool_info** *(unpool_description)* - Information that is
              consumed by :func:`EdgePooling.unpool` for unpooling.
        """
        # e = torch.cat([x[edge_index[0]], x[edge_index[1]]], dim=-1)
        # e = self.lin(e).view(-1)
        # e = F.dropout(e, p=self.dropout, training=self.training)
        # e = self.compute_edge_score(e)
        # e = e + self.add_to_edge_score

        # TODO: change linear to consider both node and edge features
        # n = self.lin(x).view(-1)

        n = self.score_func(x, edge_index).view(-1)
        n = F.dropout(n, p=self.dropout, training=self.training)
        n = self.compute_node_score(n)

        x, edge_index, edge_attr, batch, unpool_info, perm = self.__merge_stars_with_attr_gpu2__(
            x, edge_index, batch, edge_attr, n)
        # print(perm)
        edge_index, edge_attr = filter_adj(edge_index,
                                           edge_attr,
                                           perm,
                                           num_nodes=n.size(0))

        return x, edge_index, edge_attr, batch, perm

    def __merge_stars_with_attr_gpu2__(self, x, edge_index, batch, edge_attr,
                                       node_score):

        node_argsort = torch.argsort(node_score, descending=True)
        cluster = torch.empty_like(batch, device=torch.device('cpu'))
        nodes_remain = torch.ones_like(batch,
                                       device=torch.device('cpu'),
                                       dtype=torch.bool)
        # Iterate through all edges, selecting it if it is not incident to another already chosen edge.
        edge_index_cpu = edge_index.cpu()
        i = 0

        print("edge index 0", edge_index_cpu[0])

        degrees = degree(edge_index_cpu[0]).long()
        cum_num_nodes = torch.cat(
            [degrees.new_zeros(1),
             degrees.cumsum(dim=0)[:-1]], dim=0).long()

        center_nodes = set()

        print(degrees)
        print(edge_index_cpu.size())
        print(cum_num_nodes)

        for node_idx in node_argsort.tolist():
            if not nodes_remain[node_idx]:
                continue

            dests = edge_index_cpu[1][cum_num_nodes[node_idx].item(
            ):cum_num_nodes[node_idx].item() + degrees[node_idx].item()]

            nodes_remain[dests] = False
            nodes_remain[node_idx] = False

            # add node_idx to center_nodes
            center_nodes.add(node_idx)

            cluster[node_idx] = i
            cluster[dests] = i
            i += 1

        cluster = cluster.to(x.device)
        new_x = scatter_add(x, cluster, dim=0, dim_size=i)
        N = new_x.size(0)

        print(cluster)
        new_edge_index, new_edge_attr = coalesce(cluster[edge_index],
                                                 edge_attr, N, N)

        new_batch = x.new_empty(new_x.size(0), dtype=torch.long)
        new_batch = new_batch.scatter_(0, cluster, batch)

        unpool_info = self.unpool_description(edge_index=edge_index,
                                              cluster=cluster,
                                              batch=batch)

        perm = sorted(center_nodes)
        perm = torch.from_numpy(np.array(perm)).view(-1).to(x.device)

        return new_x, new_edge_index, new_edge_attr, new_batch, unpool_info, perm

    def __merge_stars_with_attr_gpu__(self, x, edge_index, edge_attr, batch,
                                      node_score):

        device = x.device

        nodes_remaining = set(range(x.size(0)))
        node_argsort = torch.argsort(node_score, descending=True)

        cluster = torch.empty_like(batch, device=torch.device('cpu'))

        # Iterate through all edges, selecting it if it is not incident to another already chosen edge.
        edge_index_cpu = edge_index.cpu()
        center_nodes = set()
        i = 0

        for node_idx in node_argsort.tolist():
            if node_idx not in nodes_remaining:
                continue
            dest_bool = edge_index_cpu[0] == node_idx
            # get the connected nodes
            dests = set(edge_index_cpu[1][dest_bool].numpy())
            # remove the previous combined nodes
            dests.difference_update(center_nodes)
            nodes_remaining.difference_update(dests)
            nodes_remaining.remove(node_idx)

            # add node_idx to center_nodes
            center_nodes.add(node_idx)

            cluster[node_idx] = i
            cluster[list(dests)] = i
            i += 1

        # The remaining nodes are simply kept.
        for node_idx in nodes_remaining:
            cluster[node_idx] = i
            i += 1

        cluster = cluster.to(x.device)

        new_x = scatter_add(x, cluster, dim=0, dim_size=i)
        N = new_x.size(0)

        new_edge_index, new_edge_attr = coalesce(cluster[edge_index],
                                                 edge_attr, N, N)

        new_batch = x.new_empty(new_x.size(0), dtype=torch.long)
        new_batch = new_batch.scatter_(0, cluster, batch)

        unpool_info = self.unpool_description(edge_index=edge_index,
                                              cluster=cluster,
                                              batch=batch)

        perm = sorted(center_nodes)
        perm = torch.from_numpy(np.array(perm)).view(-1).to(device)

        return new_x, new_edge_index, new_edge_attr, new_batch, unpool_info, perm

    def unpool(self, x, unpool_info):
        r"""Unpools a previous edge pooling step.

        For unpooling, :obj:`x` should be of same shape as those produced by
        this layer's :func:`forward` function. Then, it will produce an
        unpooled :obj:`x` in addition to :obj:`edge_index` and :obj:`batch`.

        Args:
            x (Tensor): The node features.
            unpool_info (unpool_description): Information that has
                been produced by :func:`EdgePooling.forward`.

        Return types:
            * **x** *(Tensor)* - The unpooled node features.
            * **edge_index** *(LongTensor)* - The new edge indices.
            * **batch** *(LongTensor)* - The new batch vector.
        """

        new_x = x / unpool_info.new_edge_score.view(-1, 1)
        new_x = new_x[unpool_info.cluster]
        return new_x, unpool_info.edge_index, unpool_info.batch

    def __merge_star_nodes__(self, x, edge_index, batch, node_score):
        """
        Copy from Edge Contraction Pooling

        :param x: node feature
        :param edge_index: edge index
        :param batch: batch index
        :param edge_score: edge score tensor
        :return:
        """

        nodes_remaining = set(range(x.size(0)))

        cluster = torch.empty_like(batch, device=torch.device('cpu'))
        # edge_argsort = torch.argsort(edge_score, descending=True)
        node_argsort = torch.argsort(node_score, descending=True)
        edge_index_cpu = edge_index.cpu()

        deg = tg.utils.degree(edge_index_cpu[0], x.size(0))

        # Iterate through all edges, selecting it if it is not incident to
        # another already chosen edge.
        i = 0
        new_edge_indices = []
        edge_index_cpu = edge_index.cpu()

        for node_idx in node_argsort.tolist():
            # check if the node is still in the nodes_remaining
            if node_idx not in nodes_remaining:
                continue

            dest_bool = edge_index_cpu[0] == node_idx
            dests = set(edge_index_cpu[1][dest_bool].numpy())
            dests.difference_update(nodes_remaining)

            if len(dests) == 0:
                continue

            cluster[list(dests)] = i
            nodes_remaining.remove(node_idx)
            nodes_remaining.difference_update(dests)

            i += 1

        cluster = cluster.to(x.device)

        # We compute the new features as an addition of the old ones.
        new_x = scatter_add(x, cluster, dim=0, dim_size=i)
        new_edge_score = edge_score[new_edge_indices]
        if len(nodes_remaining) > 0:
            remaining_score = x.new_ones(
                (new_x.size(0) - len(new_edge_indices), ))
            new_edge_score = torch.cat([new_edge_score, remaining_score])
        new_x = new_x * new_edge_score.view(-1, 1)

        N = new_x.size(0)
        new_edge_index, _ = coalesce(cluster[edge_index], None, N, N)

        new_batch = x.new_empty(new_x.size(0), dtype=torch.long)
        new_batch = new_batch.scatter_(0, cluster, batch)

        unpool_info = self.unpool_description(edge_index=edge_index,
                                              cluster=cluster,
                                              batch=batch,
                                              new_edge_score=new_edge_score)

        return new_x, new_edge_index, new_batch, unpool_info

    def __merge_edges__(self, x, edge_index, batch, edge_score):
        """
        Copy from Edge Contraction Pooling

        :param x: node feature
        :param edge_index: edge index
        :param batch: batch index
        :param edge_score: edge score tensor
        :return:
        """
        nodes_remaining = set(range(x.size(0)))

        cluster = torch.empty_like(batch, device=torch.device('cpu'))
        edge_argsort = torch.argsort(edge_score, descending=True)

        # Iterate through all edges, selecting it if it is not incident to
        # another already chosen edge.
        i = 0
        new_edge_indices = []
        edge_index_cpu = edge_index.cpu()
        for edge_idx in edge_argsort.tolist():
            source = edge_index_cpu[0, edge_idx].item()
            if source not in nodes_remaining:
                continue

            target = edge_index_cpu[1, edge_idx].item()
            if target not in nodes_remaining:
                continue

            new_edge_indices.append(edge_idx)

            cluster[source] = i
            nodes_remaining.remove(source)

            if source != target:
                cluster[target] = i
                nodes_remaining.remove(target)

            i += 1

        # The remaining nodes are simply kept.
        for node_idx in nodes_remaining:
            cluster[node_idx] = i
            i += 1
        cluster = cluster.to(x.device)

        # We compute the new features as an addition of the old ones.
        new_x = scatter_add(x, cluster, dim=0, dim_size=i)
        new_edge_score = edge_score[new_edge_indices]
        if len(nodes_remaining) > 0:
            remaining_score = x.new_ones(
                (new_x.size(0) - len(new_edge_indices), ))
            new_edge_score = torch.cat([new_edge_score, remaining_score])
        new_x = new_x * new_edge_score.view(-1, 1)

        N = new_x.size(0)
        new_edge_index, _ = coalesce(cluster[edge_index], None, N, N)

        new_batch = x.new_empty(new_x.size(0), dtype=torch.long)
        new_batch = new_batch.scatter_(0, cluster, batch)

        unpool_info = self.unpool_description(edge_index=edge_index,
                                              cluster=cluster,
                                              batch=batch,
                                              new_edge_score=new_edge_score)

        return new_x, new_edge_index, new_batch, unpool_info

    def __repr__(self):
        return '{}({})'.format(self.__class__.__name__, self.in_channels)
Ejemplo n.º 10
0
class SAGPooling(torch.nn.Module):
    r"""The self-attention pooling operator from the `"Self-Attention Graph
    Pooling" <https://arxiv.org/abs/1904.08082>`_ and `"Understanding
    Attention and Generalization in Graph Neural Networks"
    <https://arxiv.org/abs/1905.02850>`_ papers

    if min_score :math:`\tilde{\alpha}` is None:

        .. math::
            \mathbf{y} &= \textrm{GNN}(\mathbf{X}, \mathbf{A})

            \mathbf{i} &= \mathrm{top}_k(\mathbf{y})

            \mathbf{X}^{\prime} &= (\mathbf{X} \odot
            \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}}

            \mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}

    if min_score :math:`\tilde{\alpha}` is a value in [0, 1]:

        .. math::
            \mathbf{y} &= \mathrm{softmax}(\textrm{GNN}(\mathbf{X},\mathbf{A}))

            \mathbf{i} &= \mathbf{y}_i > \tilde{\alpha}

            \mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}}

            \mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}},

    where nodes are dropped based on a learnable projection score
    :math:`\mathbf{p}`.
    Projections scores are learned based on a graph neural network layer.

    Args:
        in_channels (int): Size of each input sample.
        ratio (float): Graph pooling ratio, which is used to compute
            :math:`k = \lceil \mathrm{ratio} \cdot N \rceil`.
            This value is ignored if min_score is not None.
            (default: :obj:`0.5`)
        gnn (string, optional): Specifies which graph neural network layer to
            use for calculating projection scores (one of
            :obj:`"GCN"`, :obj:`"GAT"` or :obj:`"SAGE"`). (default: :obj:`GCN`)
        min_score (float, optional): Minimal node score :math:`\tilde{\alpha}`
            which is used to compute indices of pooled nodes
            :math:`\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}`.
            When this value is not :obj:`None`, the :obj:`ratio` argument is
            ignored. (default: :obj:`None`)
        multiplier (float, optional): Coefficient by which features gets
            multiplied after pooling. This can be useful for large graphs and
            when :obj:`min_score` is used. (default: :obj:`1`)
        **kwargs (optional): Additional parameters for initializing the graph
            neural network layer.
    """
    def __init__(self,
                 in_channels,
                 ratio=0.5,
                 gnn='GraphConv',
                 min_score=None,
                 multiplier=1,
                 **kwargs):
        super(SAGPooling, self).__init__()

        self.in_channels = in_channels
        self.ratio = ratio
        self.min_score = min_score
        self.multiplier = multiplier
        self.gnn_name = gnn

        assert gnn in ['GraphConv', 'GCN', 'GAT', 'SAGE']
        if gnn == 'GCN':
            self.gnn = GCNConv(self.in_channels, 1, **kwargs)
        elif gnn == 'GAT':
            self.gnn = GATConv(self.in_channels, 1, **kwargs)
        elif gnn == 'SAGE':
            self.gnn = SAGEConv(self.in_channels, 1, **kwargs)
        else:
            self.gnn = GraphConv(self.in_channels, 1, **kwargs)

        self.reset_parameters()

    def reset_parameters(self):
        self.gnn.reset_parameters()

    def forward(self, x, edge_index, edge_attr=None, batch=None, attn=None):
        """"""
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        attn = x if attn is None else attn
        attn = attn.unsqueeze(-1) if attn.dim() == 1 else attn
        score = self.gnn(attn, edge_index).view(-1)

        if self.min_score is None:
            score = torch.tanh(score)
        else:
            score = softmax(score, batch)

        perm = topk(score, self.ratio, batch, self.min_score)
        x = x[perm] * score[perm].view(-1, 1)
        x = self.multiplier * x if self.multiplier != 1 else x

        batch = batch[perm]
        edge_index, edge_attr = filter_adj(edge_index,
                                           edge_attr,
                                           perm,
                                           num_nodes=score.size(0))

        return x, edge_index, edge_attr, batch, perm, score[perm]

    def __repr__(self):
        return '{}({}, {}, {}={}, multiplier={})'.format(
            self.__class__.__name__, self.gnn_name, self.in_channels,
            'ratio' if self.min_score is None else 'min_score',
            self.ratio if self.min_score is None else self.min_score,
            self.multiplier)
Ejemplo n.º 11
0
class GraphNN(torch.nn.Module):
    def __init__(self, num_layers, num_input_features, hidden):
        super(GraphNN, self).__init__()
        self.conv1 = GraphConv(num_input_features, hidden)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(GraphConv(hidden, hidden))

        self.lin1 = torch.nn.Linear(3 * hidden, hidden)
        self.lin2 = torch.nn.Linear(hidden, 2)

    def reset_parameters(self):  # reset all conv and linear layers
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters(
            )  # .reset_parameters() is method of the torch_geometric.nn.GraphConv class
        self.lin1.reset_parameters(
        )  # .reset_parameters() is method of the torch.nn.Linear class
        self.lin2.reset_parameters()

    def forward(self, data):
        # data: Batch(batch=[num_nodes_in_batch],
        #               edge_attr=[2*num_nodes_in_batch,num_edge_features_per_edge],
        #               edge_index=[2,2*num_nodes_in_batch],
        #               pos=[num_nodes_in_batch,2],
        #               x=[num_nodes_in_batch, num_input_features_per_node],
        #               y=[num_graphs_in_batch, num_classes]
        # example: Batch(batch=[2490], edge_attr=[4980,1], edge_index=[2,4980], pos=[2490,2], x=[2490,33], y=[32,2]

        x, edge_index, batch = data.x, data.edge_index, data.batch
        # x.shape: torch.Size([num_nodes_in_batch, num_input_features_per_node])
        # edge_index.shape: torch.Size([2, 2*num_nodes_in_batch])
        # batch.shape: torch.Size([num_nodes_in_batch])
        # example:  x.shape = troch.Size([2490,33])
        #           edge_index.shape = torch.Size([2,4980])
        #           batch.shape = torch.Size([2490])

        # graph convolutions and relu activation
        x = F.relu(self.conv1(x, edge_index))
        # x.shape:  torch.Size([num_nodes_in_batch, hidden])
        # example:  x.shape = torch.Size([2490, 66])

        for conv in self.convs:
            x = F.relu(conv(x, edge_index))

        # x.shape:  torch.Size([num_nodes_in_batch, hidden])
        # example:  x.shape = torch.Size([2490, 66])

        x = torch.cat([
            global_add_pool(x, batch),
            global_mean_pool(x, batch),
            global_max_pool(x, batch)
        ],
                      dim=1)
        # x.shape:  torch.Size([num_graphs_in_batch, 3*hidden)
        # example:  x.shape = torch.Size([32, 3*66])

        # linear layers, activation function, dropout
        x = F.relu(self.lin1(x))
        # x.shape:  torch.Size([num_graphs_in_batch, hidden)
        # example:  x.shape = torch.Size([32, 66])
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        # x.shape:  torch.Size([num_graphs_in_batch, num_classes)
        # example:  x.shape = torch.Size([32, 2])

        output = F.log_softmax(x, dim=-1)

        return output

    def __repr__(self):
        # for getting a printable representation of an object
        return self.__class__.__name__
Ejemplo n.º 12
0
class NodeImportance(torch.nn.Module):
    def __init__(self,
                 in_channels,
                 ratio=0.5,
                 layer=1,
                 gnn='GCN',
                 bias=True,
                 **kwargs):
        super(NodeImportance, self).__init__()

        self.in_channels = in_channels
        self.ratio = ratio
        self.layer = layer

        assert gnn in ['GCN', 'GAT', 'SAGE']
        if gnn == 'GCN':
            if layer == 1:
                self.gnn = GraphConv(self.in_channels, 1, **kwargs)
            elif layer == 2:
                self.gnn1 = GraphConv(self.in_channels, self.in_channels,
                                      **kwargs)
                self.gnn2 = GraphConv(self.in_channels, 1, **kwargs)
            elif layer == 3:
                self.gnn1 = GraphConv(self.in_channels, self.in_channels,
                                      **kwargs)
                self.gnn2 = GraphConv(self.in_channels, self.in_channels,
                                      **kwargs)
                self.gnn3 = GraphConv(self.in_channels, 1, **kwargs)
        elif gnn == 'GAT':
            self.gnn = GATConv(self.in_channels, 1, **kwargs)
        else:
            self.gnn = SAGEConv(self.in_channels, 1, **kwargs)

        self.weight_closeness = Parameter(torch.Tensor(1))
        self.weight_degree = Parameter(torch.Tensor(1))
        self.weight_score = Parameter(torch.Tensor(1))

        if bias:
            self.bias = Parameter(torch.Tensor(1))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        if self.layer == 1:
            self.gnn.reset_parameters()
        elif self.layer == 2:
            self.gnn1.reset_parameters()
            self.gnn2.reset_parameters()
        elif self.layer == 3:
            self.gnn1.reset_parameters()
            self.gnn2.reset_parameters()
            self.gnn3.reset_parameters()

        uniform_(self.weight_closeness, a=0, b=1)
        uniform_(self.weight_degree, a=0, b=1)
        uniform_(self.bias, a=0, b=1)
        uniform_(self.weight_score, a=0, b=1)

    def forward(self,
                x,
                edge_index,
                closeness,
                degree,
                edge_attr=None,
                batch=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        x = x.unsqueeze(-1) if x.dim() == 1 else x

        if self.layer == 1:
            score = torch.relu(self.gnn(x, edge_index).view(-1))
        elif self.layer == 2:
            score = torch.relu(self.gnn1(x, edge_index))
            score = torch.relu(self.gnn2(score, edge_index).view(-1))
        elif self.layer == 3:
            score = torch.relu(self.gnn1(x, edge_index))
            score = torch.relu(self.gnn2(score, edge_index))
            score = torch.relu(self.gnn3(score, edge_index).view(-1))
        '''centrality adjust'''
        closeness = closeness * self.weight_closeness
        degree = degree * self.weight_degree
        centrality = closeness + degree
        if self.bias is not None:
            centrality += self.bias

        score = score * self.weight_score
        score = score + centrality
        score = F.relu(score)

        perm = topk(score, self.ratio, batch)
        tmp1 = x[perm]
        tmp2 = score[perm]
        x = tmp1 * tmp2.view(-1, 1)
        batch = batch[perm]

        return x, perm, batch

    def __repr__(self):
        return '{}({}, {}, ratio={})'.format(self.__class__.__name__,
                                             self.gnn_name, self.in_channels,
                                             self.ratio)
class Graclus(torch.nn.Module):
    def __init__(self, num_features, num_classes, num_layers, hidden, pooling_type,
                 no_cat=False, encode_edge=False):
        super(Graclus, self).__init__()
        self.encode_edge = encode_edge
        if encode_edge:
            self.conv1 = GCNConv(hidden, aggr='add')
        else:
            self.conv1 = GraphConv(num_features, hidden, aggr='add')

        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(GraphConv(hidden, hidden, aggr='add'))

        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear(num_layers * hidden, hidden)
        if no_cat:
            self.lin1 = Linear(hidden, hidden)
        self.lin2 = Linear(hidden, num_classes)
        self.pooling_type = pooling_type
        self.no_cat = no_cat

        self.atom_encoder = AtomEncoder(emb_dim=hidden)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        if self.encode_edge:
            x = self.atom_encoder(x)
            x = self.conv1(x, edge_index, data.edge_attr)
        else:
            x = self.conv1(x, edge_index)
        x = F.relu(x)
        xs = [global_mean_pool(x, batch)]
        for i, conv in enumerate(self.convs):
            x = F.relu(conv(x, edge_index))
            xs += [global_mean_pool(x, batch)]
            if self.pooling_type != 'none':
                if self.pooling_type == 'complement':
                    complement = batched_negative_edges(edge_index=edge_index, batch=batch, force_undirected=True)
                    cluster = graclus(complement, num_nodes=x.size(0))
                elif self.pooling_type == 'graclus':
                    cluster = graclus(edge_index, num_nodes=x.size(0))
                data = Batch(x=x, edge_index=edge_index, batch=batch)
                data = max_pool(cluster, data)
                x, edge_index, batch = data.x, data.edge_index, data.batch

        if not self.no_cat:
            x = self.jump(xs)
        else:
            x = global_mean_pool(x, batch)
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return x