Exemplo n.º 1
0
    def forward(self, batch):
        batch.node_feature = self.model(batch.node_feature, batch.edge_index)
        return batch


class GeneralMulAttConv(nn.Module):
    def __init__(self, dim_in, dim_out, bias=False, **kwargs):
        super(GeneralMulAttConv, self).__init__()
        self.model = GeneralMulAttConvLayer(dim_in, dim_out, bias=bias)

    def forward(self, batch):
        batch.node_feature = self.model(batch.node_feature, batch.edge_index)
        return batch


register_layer('gaddconv', GeneralAddAttConv)
register_layer('gmulconv', GeneralMulAttConv)


class GeneralEdgeAttConvv1Layer(MessagePassing):
    r"""Att conv with edge feature"""
    def __init__(self,
                 in_channels,
                 out_channels,
                 task_channels=None,
                 improved=False,
                 cached=False,
                 bias=True,
                 **kwargs):
        super(GeneralEdgeAttConvv1Layer, self).__init__(aggr=cfg.gnn.agg,
                                                        **kwargs)
Exemplo n.º 2
0
import torch.nn as nn
from torch_geometric.nn import HypergraphConv
from graphgym.register import register_layer

# for GraphGym, need to wrap `MessagePassing` objects as such
# s.t. they take and produce `batch`es. This is pointless in
# the case of a single graph but we have to follow the API.


class HypergraphConvGG(nn.Module):
    def __init__(self, dim_in, dim_out, bias=False, **kwargs):
        super(HypergraphConvGG, self).__init__()
        self.model = HypergraphConv(dim_in, dim_out, bias=bias)

    def forward(self, batch):
        batch.node_feature = self.model(batch.node_feature, batch.edge_index)
        return batch


register_layer('hyperconv', HypergraphConvGG)
Exemplo n.º 3
0
    def message(self, x_j):
        return x_j

    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)


# Remember to register your layer!
register_layer('exampleconv1', ExampleConv1)


# Example 2: First define a PyG format Conv layer
# Then wrap it to become GraphGym format
class ExampleConv2Layer(MessagePassing):
    r"""Example GNN layer

    """

    def __init__(self, in_channels, out_channels, bias=True, **kwargs):
        super(ExampleConv2Layer, self).__init__(aggr=cfg.gnn.agg, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
Exemplo n.º 4
0
        elif self.concat and (isinstance(x, tuple) or isinstance(x, list)):
            assert res_n_id is not None
            aggr_out = torch.cat([x[0][res_n_id], aggr_out], dim=-1)

        aggr_out = torch.matmul(aggr_out, self.weight)

        if self.bias is not None:
            aggr_out = aggr_out + self.bias

        if self.normalize:
            aggr_out = F.normalize(aggr_out, p=2, dim=-1)

        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)


class SAGEinitConv(nn.Module):
    def __init__(self, dim_in, dim_out, bias=False, **kwargs):
        super(SAGEinitConv, self).__init__()
        self.model = SAGEConvLayer(dim_in, dim_out, bias=bias, concat=True)

    def forward(self, batch):
        batch.node_feature = self.model(batch.node_feature, batch.edge_index)
        return batch


register_layer('sageinitconv', SAGEinitConv)
Exemplo n.º 5
0
    def __init__(self, dim_in, dim_out, bias=False, **kwargs):
        super(GATIDConv, self).__init__()
        self.model = GATIDConvLayer(dim_in, dim_out, bias=bias)

    def forward(self, batch):
        batch.node_feature = self.model(batch.node_feature, batch.edge_index,
                                        batch.node_id_index)
        return batch


class GINIDConv(nn.Module):
    def __init__(self, dim_in, dim_out, bias=False, **kwargs):
        super(GINIDConv, self).__init__()
        gin_nn = nn.Sequential(nn.Linear(dim_in, dim_out), nn.ReLU(),
                               nn.Linear(dim_out, dim_out))
        gin_nn_id = nn.Sequential(nn.Linear(dim_in, dim_out), nn.ReLU(),
                                  nn.Linear(dim_out, dim_out))
        self.model = GINIDConvLayer(gin_nn, gin_nn_id)

    def forward(self, batch):
        batch.node_feature = self.model(batch.node_feature, batch.edge_index,
                                        batch.node_id_index)
        return batch


register_layer('idconv', GeneralIDConv)
register_layer('gcnidconv', GCNIDConv)
register_layer('sageidconv', SAGEIDConv)
register_layer('gatidconv', GATIDConv)
register_layer('ginidconv', GINIDConv)
Exemplo n.º 6
0
                              x=x,
                              norm=norm,
                              edge_feature=edge_feature)

    def message(self, x_j, norm, edge_feature):
        return norm.view(-1, 1) * (
            x_j + edge_feature) if norm is not None else (x_j + edge_feature)

    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)


class GeneralOGBConv(nn.Module):
    def __init__(self, dim_in, dim_out, bias=False, **kwargs):
        super(GeneralOGBConv, self).__init__()
        self.model = GeneralOGBConvLayer(dim_in, dim_out, bias=bias)

    def forward(self, batch):
        batch.node_feature = self.model(batch.node_feature, batch.edge_index,
                                        batch.edge_feature)
        return batch


register_layer('generalogbconv', GeneralOGBConv)
Exemplo n.º 7
0
        super(Ctrl, self).__init__(**kwargs)
        check_config_compat(cfg)

        self.in_channels = in_channels
        self.out_channels = out_channels

    def forward(self, batch):
        """Set all the nodes to the graph embedding"""
        if not hasattr(batch.G[0], 'cached_ctrl_embedding'):
            device = batch.node_feature.device

            for (i, G) in enumerate(batch.G):
                count = G.number_of_nodes()
                if not nx.is_connected(G):
                    G = add_vertex(G)
                setattr(batch.G[i], 'cached_ctrl_embedding',
                        get_embedding(G, count, device))

        graph_embeddings = [G.cached_ctrl_embedding for G in batch.G]
        batch.node_feature = torch.cat(graph_embeddings)

        return batch

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)


# Remember to register your layer!
register_layer('ctrl', Ctrl)