class CoNet(torch.nn.Module): def __init__(self, in_channels, out_channels, model): super(CoNet, self).__init__() if model == 'AFFN': self.layer1 = SAGEConv( in_channels, out_channels, 'mean') self.layer2 = GraphConv(in_channels, out_channels) self.layer3 = GATConv(in_channels, out_channels, 1) elif model == 'GCN': self.layer1 = GraphConv(in_channels, out_channels) self.layer2 = GraphConv(in_channels, out_channels) self.layer3 = GraphConv(in_channels, out_channels) elif model == 'SAGE': self.layer1 = SAGEConv(in_channels, out_channels, 'mean') self.layer2 = SAGEConv(in_channels, out_channels, 'mean') self.layer3 = SAGEConv(in_channels, out_channels, 'mean') else: self.layer1 = GATConv(in_channels, out_channels, 1) self.layer2 = GATConv(in_channels, out_channels, 1) self.layer3 = GATConv(in_channels, out_channels, 1) # 特征融合权向量 self.w = Parameter(torch.tensor([1, 1, 1], dtype=torch.float)) self.model = model def reset_parameters(self): self.layer1.reset_parameters() self.layer2.reset_parameters() self.layer3.reset_parameters() init.uniform_(self.w) def forward(self, g, x): x1 = self.layer1(g, x) x2 = self.layer2(g, x) x3 = self.layer3(g, x) # 消除DGL内置的GAT使用多头注意力机制输出的多余维度 if self.model == 'AFFN': x3 = x3.squeeze(1) elif self.model == 'GAT': x1 = x1.squeeze(1) x2 = x2.squeeze(1) x3 = x3.squeeze(1) # 权向量标准化 weights = self.w / torch.sum(self.w, 0) return weights[0] * x1 + weights[1] * x2 + weights[2] * x3
class GATLayer(nn.Module): r"""Single GAT layer from `Graph Attention Networks <https://arxiv.org/abs/1710.10903>`__ Parameters ---------- in_feats : int Number of input node features out_feats : int Number of output node features num_heads : int Number of attention heads feat_drop : float Dropout applied to the input features attn_drop : float Dropout applied to attention values of edges alpha : float Hyperparameter in LeakyReLU, which is the slope for negative values. Default to 0.2. residual : bool Whether to perform skip connection, default to True. agg_mode : str The way to aggregate multi-head attention results, can be either 'flatten' for concatenating all-head results or 'mean' for averaging all head results. activation : activation function or None Activation function applied to the aggregated multi-head results, default to None. """ def __init__(self, in_feats, out_feats, num_heads, feat_drop, attn_drop, alpha=0.2, residual=True, agg_mode='flatten', activation=None): super(GATLayer, self).__init__() self.gat_conv = GATConv(in_feats=in_feats, out_feats=out_feats, num_heads=num_heads, feat_drop=feat_drop, attn_drop=attn_drop, negative_slope=alpha, residual=residual) assert agg_mode in ['flatten', 'mean'] self.agg_mode = agg_mode self.activation = activation def reset_parameters(self): """Reinitialize model parameters.""" self.gat_conv.reset_parameters() def forward(self, bg, feats): """Update node representations Parameters ---------- bg : DGLGraph DGLGraph for a batch of graphs. feats : FloatTensor of shape (N, M1) * N is the total number of nodes in the batch of graphs * M1 is the input node feature size, which equals in_feats in initialization Returns ------- feats : FloatTensor of shape (N, M2) * N is the total number of nodes in the batch of graphs * M2 is the output node representation size, which equals out_feats in initialization if self.agg_mode == 'mean' and out_feats * num_heads in initialization otherwise. """ feats = self.gat_conv(bg, feats) if self.agg_mode == 'flatten': feats = feats.flatten(1) else: feats = feats.mean(1) if self.activation is not None: feats = self.activation(feats) return feats
class ConvLayer(nn.Module): def __init__(self, in_feats, out_feats, conv_type, activation=None, residual=True, batchnorm=True, dropout=0., num_heads=1, negative_slope=0.2): super(ConvLayer, self).__init__() self.activation = activation self.conv_type = conv_type if conv_type == 'gcn': self.graph_conv = GraphConv(in_feats=in_feats, out_feats=out_feats, norm='both', activation=activation) elif conv_type == 'sage': self.graph_conv = SAGEConv(in_feats=in_feats, out_feats=out_feats, aggregator_type='mean', norm=None, activation=activation) elif conv_type == 'gat': assert out_feats % num_heads == 0 self.graph_conv = GATConv(in_feats=in_feats, out_feats=out_feats // num_heads, num_heads=num_heads, feat_drop=dropout, attn_drop=dropout, negative_slope=negative_slope, activation=activation) self.dropout = nn.Dropout(dropout) self.residual = residual if residual: self.res_connection = nn.Linear(in_feats, out_feats) self.bn = batchnorm if batchnorm: self.bn_layer = nn.BatchNorm1d(out_feats) def reset_parameters(self): """Reinitialize model parameters.""" self.graph_conv.reset_parameters() if self.residual: self.res_connection.reset_parameters() if self.bn: self.bn_layer.reset_parameters() def forward(self, g, feats): new_feats = self.graph_conv(g, feats) if self.conv_type == 'gat': new_feats = new_feats.view(new_feats.shape[0], -1) if self.residual: res_feats = self.res_connection(feats) if self.activation is not None: res_feats = self.activation(res_feats) new_feats = new_feats + res_feats new_feats = self.dropout(new_feats) if self.bn: new_feats = self.bn_layer(new_feats) return new_feats