コード例 #1
0
class SortPool(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden, num_classes):
        super(SortPool, self).__init__()
        self.k = 10
        self.conv1 = SAGEConv(dataset.num_features, hidden)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden, hidden))
        self.lin1 = nn.Linear(self.k * hidden, hidden)
        self.lin2 = nn.Linear(hidden, num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        x = global_sort_pool(x, batch, self.k)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
コード例 #2
0
class GlobalAttentionNet(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden):
        super().__init__()
        self.conv1 = SAGEConv(dataset.num_features, hidden)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden, hidden))
        self.att = GlobalAttention(Linear(hidden, 1))
        self.lin1 = Linear(hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.att.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        x = self.att(x, batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
class SortPool(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, output_dim, num_layers):
        super(SortPool, self).__init__()
        self.k = 30
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.conv1d = Conv1d(hidden_channels, 32, 5)
        self.lin1 = Linear(32 * (self.k - 5 + 1), hidden_channels)
        self.lin2 = Linear(hidden_channels, output_dim)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.conv1d.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, x, edge_index, batch):
        # x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))

        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        x = global_sort_pool(x, batch, self.k)
        x = x.view(len(x), self.k, -1).permute(0, 2, 1)
        x = F.relu(self.conv1d(x))
        x = x.view(len(x), -1)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x
コード例 #4
0
class GraphSAGEWithJK(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden):
        super(GraphSAGEWithJK, self).__init__()
        self.conv1 = SAGEConv(dataset.num_features, hidden)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden, hidden))
        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear(num_layers * hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        xs = [x]
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
            xs += [x]
        x = self.jump(xs)
        x = global_mean_pool(x, batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
コード例 #5
0
ファイル: modules.py プロジェクト: zazyzaya/KDD-Autograph
class GraphSAGE(torch.nn.Module):
    def __init__(self, num_layers=2, hidden=16, features_num=16, num_class=2):
        super().__init__()

        self.sage1 = SAGEConv(features_num, hidden)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden, hidden))
        self.lin2 = Linear(hidden, num_class)

    def reset_parameters(self):
        self.first_lin.reset_parameters()
        self.sage1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight

        x = F.relu(self.sage1(x, edge_index, edge_weight=edge_weight))
        x = F.dropout(x, p=0.5, training=self.training)
        for conv in self.convs:
            x = F.relu(conv(x, edge_index, edge_weight=edge_weight))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
コード例 #6
0
class SAGE(torch.nn.Module):    #已精调
    def __init__(self, num_layers=2, hidden=32, features_num=32, num_class=2):
        super(SAGE, self).__init__()
        self.conv1 = SAGEConv(hidden, hidden)
        self.conv2 = SAGEConv(hidden, hidden)
        self.out = Linear(hidden * 3, num_class)
        self.first_lin = Linear(features_num, hidden)
        self.fuse_weight = torch.nn.Parameter(torch.FloatTensor(num_layers),requires_grad=True)
        self.fuse_weight.data.fill_(float(1) / (num_layers + 1))

    def reset_parameters(self):
        self.first_lin.reset_parameters()
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight
        x = F.relu(self.first_lin(x))
        x = F.dropout(x, p=0.5, training=self.training)
        xx = x
        x = self.conv1(x, edge_index, edge_weight)
        x = F.dropout(x, p=0.2, training=self.training)
        xx = torch.cat([xx, x], dim=1)
        x = self.conv2(x, edge_index, edge_weight)
        x = F.dropout(x, p=0.2, training=self.training)
        xx = torch.cat([xx, x], dim=1)
        x = self.out(xx)
        return F.log_softmax(x, dim=-1)
コード例 #7
0
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        hidden = args.hidden
        num_layers = 5
        self.conv1 = SAGEConv(dataset.num_features, hidden)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden, hidden))
        self.set2set = Set2Set(hidden, processing_steps=4)
        self.lin1 = Linear(2 * hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.set2set.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        x = self.set2set(x, batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)
コード例 #8
0
ファイル: GraphSage.py プロジェクト: v7labs/Gale
class GraphSAGE(torch.nn.Module):
    def __init__(self,
                 num_features,
                 output_channels,
                 num_layers=3,
                 hidden=128,
                 **kwargs):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(num_features, hidden)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden, hidden))
        self.lin1 = Linear(hidden, hidden)
        self.lin2 = Linear(hidden, output_channels)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data, target_size, **kwargs):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        x = global_mean_pool(x, batch, size=target_size)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
コード例 #9
0
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        hidden = args.hidden
        num_layers = 5
        self.k = 30
        self.conv1 = SAGEConv(dataset.num_features, hidden)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden, hidden))
        self.conv1d = Conv1d(hidden, 32, 5)
        self.lin1 = Linear(32 * (self.k - 5 + 1), hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.conv1d.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        x = global_sort_pool(x, batch, self.k)
        x = x.view(len(x), self.k, -1).permute(0, 2, 1)
        x = F.relu(self.conv1d(x))
        x = x.view(len(x), -1)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)
コード例 #10
0
 def reset_parameters(self):
     for conv in self.down_convs:
         conv.reset_parameters()
     for pool in self.pools:
         pool.reset_parameters()
     for conv in self.up_convs:
         conv.reset_parameters()
コード例 #11
0
class SAGEEncoder(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.conv = SAGEConv(in_dim, out_dim)
        self.sigma = nn.PReLU(out_dim)
        self.reset_parameters()

    def reset_parameters(self):
        self.conv.reset_parameters()

    def forward(self, x, edge_index):
        z = self.sigma(self.conv(x, edge_index))
        return z
コード例 #12
0
ファイル: rev_gnn.py プロジェクト: rusty1s/pytorch_geometric
class GNNBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.norm = LayerNorm(in_channels, elementwise_affine=True)
        self.conv = SAGEConv(in_channels, out_channels)

    def reset_parameters(self):
        self.norm.reset_parameters()
        self.conv.reset_parameters()

    def forward(self, x, edge_index, dropout_mask=None):
        x = self.norm(x).relu()
        if self.training and dropout_mask is not None:
            x = x * dropout_mask
        return self.conv(x, edge_index)
コード例 #13
0
ファイル: modules.py プロジェクト: zazyzaya/KDD-Autograph
class BenSAGE(torch.nn.Module):
    def __init__(self, num_layers=2, hidden=16, features_num=16, num_class=2):
        super().__init__()
        # first layer
        self.conv1 = SAGEConv(features_num, hidden)

        # list of 2nd - num_layers layers
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden, hidden))

        # fully connected layers
        self.lin2 = Linear(hidden, num_class)
        self.first_lin = Linear(features_num, hidden)

    def reset_parameters(self):
        # clear weights
        self.first_lin.reset_parameters()
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight
        # fully connected layer + relu
        x = F.relu(self.first_lin(x))

        # dropout layer
        x = F.dropout(x, p=0.5, training=self.training)

        # GCN layers
        for conv in self.convs:
            x = F.relu(conv(x, edge_index, edge_weight=edge_weight))

        # Another dropout
        x = F.dropout(x, p=0.5, training=self.training)

        # second FC layer
        x = self.lin2(x)

        # Softmax
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
コード例 #14
0
class GraphSAGE(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(dataset.num_features, hidden)
        #self.conv1 = SAGELafConv(dataset.num_features, hidden)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 2):
            #self.convs.append(SAGEConv(hidden, hidden))
            self.convs.append(SAGELafConv(hidden, hidden))
        #self.convn = SAGELafConv(hidden, dataset.num_classes)
        self.convn = SAGEConv(hidden, dataset.num_classes)
        #self.lin1 = Linear(hidden, hidden)
        #self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.convn.reset_parameters()
        #self.lin1.reset_parameters()
        #self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        x = F.relu(self.convn(x, edge_index))
        #x = F.relu(self.lin1(x))
        #x = F.dropout(x, p=0.5, training=self.training)
        #x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
コード例 #15
0
class SAGPooling(torch.nn.Module):
    r"""The self-attention pooling operator from the `"Self-Attention Graph
    Pooling" <https://arxiv.org/abs/1904.08082>`_  paper

    .. math::
        \mathbf{y} &= \textrm{GNN}(\mathbf{X}, \mathbf{A})
        \mathbf{i} &= \mathrm{top}_k(\mathbf{y})
        \mathbf{X}^{\prime} &= (\mathbf{X} \odot
        \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}}
        \mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}},

    where nodes are dropped based on a learnable projection score
    :math:`\mathbf{p}`.
    Projections scores are learned based on a graph neural network layer.

    Args:
        in_channels (int): Size of each input sample.
        ratio (float): Graph pooling ratio, which is used to compute
            :math:`k = \lceil \mathrm{ratio} \cdot N \rceil`.
            (default: :obj:`0.5`)
        gnn (string, optional): Specifies which graph neural network layer to
            use for calculating projection scores (one of
            :obj:`"GCN"`, :obj:`"GAT"` or :obj:`"SAGE"`). (default: :obj:`GCN`)
        **kwargs (optional): Additional parameters for initializing the graph
            neural network layer.
    """

    def __init__(self, in_channels, ratio=0.5, gnn='GCN', **kwargs):
        super(SAGPooling, self).__init__()

        self.in_channels = in_channels
        self.ratio = ratio
        self.gnn_name = gnn

        assert gnn in ['GCN', 'GAT', 'SAGE']
        if gnn == 'GCN':
            self.gnn = GraphConv(self.in_channels, 1, **kwargs)
        elif gnn == 'GAT':
            self.gnn = GATConv(self.in_channels, 1, **kwargs)
        else:
            self.gnn = SAGEConv(self.in_channels, 1, **kwargs)

        self.reset_parameters()

    def reset_parameters(self):
        self.gnn.reset_parameters()

    def forward(self, x, edge_index, edge_attr=None, batch=None):
        """"""
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        x = x.unsqueeze(-1) if x.dim() == 1 else x

        score = torch.tanh(self.gnn(x, edge_index).view(-1))
        perm = topk(score, self.ratio, batch)
        x = x[perm] * score[perm].view(-1, 1)
        batch = batch[perm]
        edge_index, edge_attr = filter_adj(
            edge_index, edge_attr, perm, num_nodes=score.size(0))

        return x, edge_index, edge_attr, batch, perm

    def __repr__(self):
        return '{}({}, {}, ratio={})'.format(self.__class__.__name__,
                                             self.gnn_name, self.in_channels,
                                             self.ratio)
コード例 #16
0
class GraphSAGE(torch.nn.Module):
    def __init__(self, args):
        super(GraphSAGE, self).__init__()
        self.args = set_default(args, {
                    'num_layers': 2,
                    'hidden': 64,
                    'hidden2': 32,
                    'dropout': 0.5,
                    'lr': 0.005,
                    'epoches': 300,
                    'weight_decay': 5e-4,
                    'act': 'leaky_relu',
                    'withbn': True,
                        })
        self.timer = self.args['timer']
        self.dropout = self.args['dropout']
        self.agg = self.args['agg']
        self.withbn = self.args['withbn']
        self.conv1 = SAGEConv(self.args['hidden'], self.args['hidden'])
        self.convs = torch.nn.ModuleList()
        if self.withbn:
            self.bn1 = BatchNorm1d(self.args['hidden'])
            self.bns = torch.nn.ModuleList()
        hd = [self.args['hidden'], self.args['hidden']]
        for i in range(self.args['num_layers'] - 1):
            hd.append(self.args['hidden2'])
            self.convs.append(SAGEConv(self.args['hidden'], self.args['hidden2']))
            self.bns.append(BatchNorm1d(self.args['hidden2']))
        if self.args['agg'] == 'concat':
            outdim = sum(hd)
        elif self.args['agg'] == 'self':
            outdim = hd[-1]
        if self.args['act'] == 'leaky_relu':
            self.act = F.leaky_relu
        elif self.args['act'] == 'tanh':
            self.act = torch.tanh
        else:
            self.act = lambda x: x
        self.lin2 = Linear(outdim, self.args['num_class'])
        self.first_lin = Linear(self.args['features_num'], self.args['hidden'])

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight
        x = self.act(self.first_lin(x))
        xs = [x]
        x = self.act(self.conv1(x, edge_index, edge_weight=edge_weight))
        if self.withbn:
            x = self.bn1(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        xs.append(x)
        for conv, bn in zip(self.convs, self.bns):
            x = self.act(conv(x, edge_index, edge_weight=edge_weight))
            if self.withbn:
                x = bn(x)
            xs.append(x)
        #x = F.dropout(x, p=self.dropout, training=self.training)
        if self.agg == 'concat':
            x = torch.cat(xs, dim=1)
        elif self.agg == 'self':
            x = xs[-1]
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def train_predict(self, data, train_mask=None, val_mask=None, return_out=True):
        if train_mask is None:
            train_mask = data.train_mask
        optimizer = torch.optim.Adam(self.parameters(), lr=self.args['lr'], weight_decay=self.args['weight_decay'])
        flag_end = False
        st = time.time()
        for epoch in range(1, self.args['epoches']):
            self.train()
            optimizer.zero_grad()
            res = self.forward(data)
            loss = F.nll_loss(res[train_mask], data.y[train_mask])
            loss.backward()
            optimizer.step()
            if epoch%50 == 0:
                cost = (time.time()-st)/epoch*50
                if max(cost*10, 5) > self.timer.remain_time():
                    flag_end = True
                    break

        test_mask = data.test_mask
        self.eval()
        with torch.no_grad():
            res = self.forward(data)
            if return_out:
                pred = res
            else:
                pred = res[test_mask]
            if val_mask is not None:
                return pred, res[val_mask], flag_end
        return pred, flag_end
 
    def __repr__(self):
        return self.__class__.__name__
コード例 #17
0
class SAGE(nn.Module):
    """
        SAGE model class

            Attributes:
                NN_1 : torch.nn.Linear
                    Input layer
                NN_2 : torch.nn.Linear
                    Hidden layer
                NN_3 : torch.nn.Linear
                    Output layer

                
    """
    def __init__(self, num_in, num_hid, num_out, mathcal_Z, dropout, device):
        """
            Initialization

                Parameters:
                    num_in : int
                        Number of neurons in the input layer
                    num_hid : int
                        Number of neurons in the hidden layer
                    num_out : int
                        Number of neurons in the output layer
                    mathcal_Z : torch.distributions.multivariate_normal
                        AWGN channel
                    dropout : float
                        Dropout probability
                    device : torch.device
                        Torch device

            Returns:

        """
        super(SAGE, self).__init__()

        # Linear layers
        self.NN_1 = SAGEConv(num_in, num_hid)
        self.NN_2 = SAGEConv(num_hid, num_hid)
        self.NN_3 = SAGEConv(num_hid, num_out)

        # Batch Normalization layers
        self.BN_1 = nn.BatchNorm1d(num_hid)
        self.BN_2 = nn.BatchNorm1d(num_hid)

        self.mathcal_Z = mathcal_Z
        self.dropout = dropout
        self.device = device

    def reset_parameters(self):
        """
            Reset the parameters
        """
        self.NN_1.reset_parameters()
        self.NN_2.reset_parameters()
        self.NN_3.reset_parameters()

        self.BN_1.reset_parameters()
        self.BN_2.reset_parameters()

    def forward(self, x, edge_index):
        """
            Forward module of MLP

                Parameters:
                    x : torch.tensor of shape (num_examples, num_dims)
                        Input tensor
                    edge_index : torch.tensor of shape (2, num_edges)
                        Input edge index
                    
                Returns:
                    y_pred : torch.tensor of shape (num_examples)
                        Predicted labels
                    [S_1, S_2]: list of length 2
                        Outputs of the hidden layers before passing through the AWGN channels

        """
        x = self.NN_1(x, edge_index)
        x = self.BN_1(x)
        S_1 = F.relu(x)

        if self.mathcal_Z is None:
            x = F.dropout(S_1, p=self.dropout, training=self.training)
        else:
            n_1 = torch.zeros(S_1.size(0))
            mathcal_z = self.mathcal_Z.sample(n_1.size()).to(self.device)
            T_1 = S_1 + mathcal_z if self.training else S_1

        x = self.NN_2(T_1, edge_index)
        x = self.BN_2(x)
        S_2 = F.relu(x)

        if self.mathcal_Z is None:
            x = F.dropout(S_2, p=self.dropout, training=self.training)
        else:
            n_2 = torch.zeros(S_2.size(0))
            mathcal_z = self.mathcal_Z.sample(n_2.size()).to(self.device)
            T_2 = S_2 + mathcal_z if self.training else S_2

        z = self.NN_3(T_2, edge_index)
        y_pred = torch.log_softmax(z, dim=-1)

        return z.detach(), y_pred, [S_1.detach(), S_2.detach()]
コード例 #18
0
ファイル: node_importance.py プロジェクト: youngflyasd/GSSNN
class NodeImportance(torch.nn.Module):
    def __init__(self,
                 in_channels,
                 ratio=0.5,
                 layer=1,
                 gnn='GCN',
                 bias=True,
                 **kwargs):
        super(NodeImportance, self).__init__()

        self.in_channels = in_channels
        self.ratio = ratio
        self.layer = layer

        assert gnn in ['GCN', 'GAT', 'SAGE']
        if gnn == 'GCN':
            if layer == 1:
                self.gnn = GraphConv(self.in_channels, 1, **kwargs)
            elif layer == 2:
                self.gnn1 = GraphConv(self.in_channels, self.in_channels,
                                      **kwargs)
                self.gnn2 = GraphConv(self.in_channels, 1, **kwargs)
            elif layer == 3:
                self.gnn1 = GraphConv(self.in_channels, self.in_channels,
                                      **kwargs)
                self.gnn2 = GraphConv(self.in_channels, self.in_channels,
                                      **kwargs)
                self.gnn3 = GraphConv(self.in_channels, 1, **kwargs)
        elif gnn == 'GAT':
            self.gnn = GATConv(self.in_channels, 1, **kwargs)
        else:
            self.gnn = SAGEConv(self.in_channels, 1, **kwargs)

        self.weight_closeness = Parameter(torch.Tensor(1))
        self.weight_degree = Parameter(torch.Tensor(1))
        self.weight_score = Parameter(torch.Tensor(1))

        if bias:
            self.bias = Parameter(torch.Tensor(1))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        if self.layer == 1:
            self.gnn.reset_parameters()
        elif self.layer == 2:
            self.gnn1.reset_parameters()
            self.gnn2.reset_parameters()
        elif self.layer == 3:
            self.gnn1.reset_parameters()
            self.gnn2.reset_parameters()
            self.gnn3.reset_parameters()

        uniform_(self.weight_closeness, a=0, b=1)
        uniform_(self.weight_degree, a=0, b=1)
        uniform_(self.bias, a=0, b=1)
        uniform_(self.weight_score, a=0, b=1)

    def forward(self,
                x,
                edge_index,
                closeness,
                degree,
                edge_attr=None,
                batch=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        x = x.unsqueeze(-1) if x.dim() == 1 else x

        if self.layer == 1:
            score = torch.relu(self.gnn(x, edge_index).view(-1))
        elif self.layer == 2:
            score = torch.relu(self.gnn1(x, edge_index))
            score = torch.relu(self.gnn2(score, edge_index).view(-1))
        elif self.layer == 3:
            score = torch.relu(self.gnn1(x, edge_index))
            score = torch.relu(self.gnn2(score, edge_index))
            score = torch.relu(self.gnn3(score, edge_index).view(-1))
        '''centrality adjust'''
        closeness = closeness * self.weight_closeness
        degree = degree * self.weight_degree
        centrality = closeness + degree
        if self.bias is not None:
            centrality += self.bias

        score = score * self.weight_score
        score = score + centrality
        score = F.relu(score)

        perm = topk(score, self.ratio, batch)
        tmp1 = x[perm]
        tmp2 = score[perm]
        x = tmp1 * tmp2.view(-1, 1)
        batch = batch[perm]

        return x, perm, batch

    def __repr__(self):
        return '{}({}, {}, ratio={})'.format(self.__class__.__name__,
                                             self.gnn_name, self.in_channels,
                                             self.ratio)
コード例 #19
0
ファイル: model.py プロジェクト: waljan/GNNpT1
class GraphSAGE(torch.nn.Module):
    def __init__(self, num_input_features, num_layers, hidden):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(num_input_features, hidden)  # SAGEConv layer
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden, hidden))  # SAGEConv layers
        self.lin1 = Linear(3 * hidden, hidden)  # linear layer
        self.lin2 = Linear(hidden, 2)  # linear layer

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        # data: Batch(batch=[num_nodes_in_batch],
        #               edge_attr=[2*num_nodes_in_batch,num_edge_features_per_edge],
        #               edge_index=[2,2*num_nodes_in_batch],
        #               pos=[num_nodes_in_batch,2],
        #               x=[num_nodes_in_batch, num_input_features_per_node],
        #               y=[num_graphs_in_batch, num_classes]
        # example: Batch(batch=[2490], edge_attr=[4980,1], edge_index=[2,4980], pos=[2490,2], x=[2490,33], y=[32,2]

        x, edge_index, batch = data.x, data.edge_index, data.batch
        # x.shape: torch.Size([num_nodes_in_batch, num_input_features_per_node])
        # edge_index.shape: torch.Size([2, 2*num_nodes_in_batch])
        # batch.shape: torch.Size([num_nodes_in_batch])
        # example:  x.shape = troch.Size([2490,33])
        #           edge_index.shape = torch.Size([2,4980])
        #           batch.shape = torch.Size([2490])

        x = F.relu(self.conv1(x, edge_index))
        # x.shape: torch.Size([num_nodes_in_batch, hidden])
        # example:  x.shape = troch.Size([2490,66])

        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
            # x.shape: torch.Size([num_nodes_in_batch, hidden])
            # example:  x.shape = troch.Size([2490,66])

        x = torch.cat([
            global_add_pool(x, batch),
            global_mean_pool(x, batch),
            global_max_pool(x, batch)
        ],
                      dim=1)
        # x.shape:  torch.Size([num_graphs_in_batch, hidden)
        # example:  x.shape = torch.Size([32, 66])

        x = F.relu(self.lin1(x))
        # x.shape:  torch.Size([num_graphs_in_batch, hidden)
        # example:  x.shape = torch.Size([32, 66])
        x = F.dropout(x, p=0.5, training=self.training)

        x = self.lin2(x)
        # x.shape:  torch.Size([num_graphs_in_batch, num_classes)
        # example:  x.shape = torch.Size([32, 2])
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__