def forward(self, batch): batch.node_feature = self.model(batch.node_feature, batch.edge_index) return batch class GeneralMulAttConv(nn.Module): def __init__(self, dim_in, dim_out, bias=False, **kwargs): super(GeneralMulAttConv, self).__init__() self.model = GeneralMulAttConvLayer(dim_in, dim_out, bias=bias) def forward(self, batch): batch.node_feature = self.model(batch.node_feature, batch.edge_index) return batch register_layer('gaddconv', GeneralAddAttConv) register_layer('gmulconv', GeneralMulAttConv) class GeneralEdgeAttConvv1Layer(MessagePassing): r"""Att conv with edge feature""" def __init__(self, in_channels, out_channels, task_channels=None, improved=False, cached=False, bias=True, **kwargs): super(GeneralEdgeAttConvv1Layer, self).__init__(aggr=cfg.gnn.agg, **kwargs)
import torch.nn as nn from torch_geometric.nn import HypergraphConv from graphgym.register import register_layer # for GraphGym, need to wrap `MessagePassing` objects as such # s.t. they take and produce `batch`es. This is pointless in # the case of a single graph but we have to follow the API. class HypergraphConvGG(nn.Module): def __init__(self, dim_in, dim_out, bias=False, **kwargs): super(HypergraphConvGG, self).__init__() self.model = HypergraphConv(dim_in, dim_out, bias=bias) def forward(self, batch): batch.node_feature = self.model(batch.node_feature, batch.edge_index) return batch register_layer('hyperconv', HypergraphConvGG)
def message(self, x_j): return x_j def update(self, aggr_out): if self.bias is not None: aggr_out = aggr_out + self.bias return aggr_out def __repr__(self): return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels) # Remember to register your layer! register_layer('exampleconv1', ExampleConv1) # Example 2: First define a PyG format Conv layer # Then wrap it to become GraphGym format class ExampleConv2Layer(MessagePassing): r"""Example GNN layer """ def __init__(self, in_channels, out_channels, bias=True, **kwargs): super(ExampleConv2Layer, self).__init__(aggr=cfg.gnn.agg, **kwargs) self.in_channels = in_channels self.out_channels = out_channels
elif self.concat and (isinstance(x, tuple) or isinstance(x, list)): assert res_n_id is not None aggr_out = torch.cat([x[0][res_n_id], aggr_out], dim=-1) aggr_out = torch.matmul(aggr_out, self.weight) if self.bias is not None: aggr_out = aggr_out + self.bias if self.normalize: aggr_out = F.normalize(aggr_out, p=2, dim=-1) return aggr_out def __repr__(self): return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels) class SAGEinitConv(nn.Module): def __init__(self, dim_in, dim_out, bias=False, **kwargs): super(SAGEinitConv, self).__init__() self.model = SAGEConvLayer(dim_in, dim_out, bias=bias, concat=True) def forward(self, batch): batch.node_feature = self.model(batch.node_feature, batch.edge_index) return batch register_layer('sageinitconv', SAGEinitConv)
def __init__(self, dim_in, dim_out, bias=False, **kwargs): super(GATIDConv, self).__init__() self.model = GATIDConvLayer(dim_in, dim_out, bias=bias) def forward(self, batch): batch.node_feature = self.model(batch.node_feature, batch.edge_index, batch.node_id_index) return batch class GINIDConv(nn.Module): def __init__(self, dim_in, dim_out, bias=False, **kwargs): super(GINIDConv, self).__init__() gin_nn = nn.Sequential(nn.Linear(dim_in, dim_out), nn.ReLU(), nn.Linear(dim_out, dim_out)) gin_nn_id = nn.Sequential(nn.Linear(dim_in, dim_out), nn.ReLU(), nn.Linear(dim_out, dim_out)) self.model = GINIDConvLayer(gin_nn, gin_nn_id) def forward(self, batch): batch.node_feature = self.model(batch.node_feature, batch.edge_index, batch.node_id_index) return batch register_layer('idconv', GeneralIDConv) register_layer('gcnidconv', GCNIDConv) register_layer('sageidconv', SAGEIDConv) register_layer('gatidconv', GATIDConv) register_layer('ginidconv', GINIDConv)
x=x, norm=norm, edge_feature=edge_feature) def message(self, x_j, norm, edge_feature): return norm.view(-1, 1) * ( x_j + edge_feature) if norm is not None else (x_j + edge_feature) def update(self, aggr_out): if self.bias is not None: aggr_out = aggr_out + self.bias return aggr_out def __repr__(self): return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels) class GeneralOGBConv(nn.Module): def __init__(self, dim_in, dim_out, bias=False, **kwargs): super(GeneralOGBConv, self).__init__() self.model = GeneralOGBConvLayer(dim_in, dim_out, bias=bias) def forward(self, batch): batch.node_feature = self.model(batch.node_feature, batch.edge_index, batch.edge_feature) return batch register_layer('generalogbconv', GeneralOGBConv)
super(Ctrl, self).__init__(**kwargs) check_config_compat(cfg) self.in_channels = in_channels self.out_channels = out_channels def forward(self, batch): """Set all the nodes to the graph embedding""" if not hasattr(batch.G[0], 'cached_ctrl_embedding'): device = batch.node_feature.device for (i, G) in enumerate(batch.G): count = G.number_of_nodes() if not nx.is_connected(G): G = add_vertex(G) setattr(batch.G[i], 'cached_ctrl_embedding', get_embedding(G, count, device)) graph_embeddings = [G.cached_ctrl_embedding for G in batch.G] batch.node_feature = torch.cat(graph_embeddings) return batch def __repr__(self): return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels) # Remember to register your layer! register_layer('ctrl', Ctrl)