Example #1
0
 def __init__(self, features, hidden, classes):
     super(GCN_Net, self).__init__()
     self.conv1 = GCNConv(features, hidden)
     self.conv2 = GCNConv(hidden, classes)
Example #2
0
    def __init__(self, feature_dim, mid_dim, output_dim, gcn_dropout):
        super(TwoLayerGCN, self).__init__()
        self.gcn_dropout = gcn_dropout

        self.conv1 = GCNConv(feature_dim, mid_dim)
        self.conv2 = GCNConv(mid_dim, output_dim)
Example #3
0
 def __init__(self, dataset):
     super(Discriminator, self).__init__()
     self.conv1 = GCNConv(dataset.num_features, args.hidden)
     self.conv2 = GCNConv(args.hidden, args.hidden)
Example #4
0
 def __init__(self):
     super(Net, self).__init__()
     self.conv1 = GCNConv(data.num_features, 16, improved=False)
     self.conv2 = GCNConv(16, data.num_classes, improved=False)
Example #5
0
 def __init__(self, node_features=Nfeat, node_labels=Kc, hidden_channels=4):
     super(GraphNet, self).__init__()
     self.conv1 = GCNConv(node_features, hidden_channels)
     self.conv2 = GCNConv(hidden_channels, node_labels)
Example #6
0
 def __init__(self):
     super(Net, self).__init__()
     self.conv1 = GCNConv(dataset.num_features, 128)
     self.conv2 = GCNConv(128, 64)
Example #7
0
 def __init__(self):
     super(Net, self).__init__()
     self.conv1 = GCNConv(dataset.num_node_features, 16)
     self.conv2 = GCNConv(16, dataset.num_classes)
Example #8
0
 def __init__(self):
     super(Net, self).__init__()
     self.conv1 = GCNConv(dataset.num_features, 32)
     self.conv2 = GCNConv(32, 16)
     self.linear = torch.nn.Linear(16, dataset.num_features)
Example #9
0
 def __init__(self, in_channels, hidden_channels):
     super(Encoder, self).__init__()
     self.conv = GCNConv(in_channels, hidden_channels, cached=True)
     self.prelu = nn.PReLU(hidden_channels)
Example #10
0
 def __init__(self, in_channels, out_channels):
     super().__init__()
     self.conv1 = GCNConv(in_channels, 16, cached=True)
     self.conv2 = GCNConv(16, out_channels, cached=True)
Example #11
0
 def __init__(self, in_features, num_classes):
     super(Net, self).__init__()
     self.conv1 = GCNConv(in_features, 16, cached=True)
     self.conv2 = GCNConv(16, num_classes, cached=True)
Example #12
0
 def __init__(self):
     super(Net, self).__init__()
     self.conv1 = GCNConv(2, 2)
     self.conv2 = GCNConv(2, 2)
Example #13
0
 def __init__(self, x=64):
     super(Net, self).__init__()
     self.conv1 = GCNConv(10, x)
     self.conv2 = GCNConv(x, x)
     self.conv3 = GCNConv(x, x)
     self.fc = torch.nn.Linear(x, max(y).tolist()+1)
Example #14
0
 def __init__(self):
     super().__init__()
     self.conv1 = GCNConv(
         datasets.num_node_features + datasets.num_edge_features, 16)
     self.conv2 = GCNConv(16, 1)
Example #15
0
 def __init__(self, in_channels, out_channels):
     super(VGAEEncoder, self).__init__()
     self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True)
     self.conv_mu = GCNConv(2 * out_channels, out_channels, cached=True)
     self.conv_logvar = GCNConv(
         2 * out_channels, out_channels, cached=True)
Example #16
0
 def __init__(self, input_dim=89, hidden_dim=128):
     super(GEmbedNet, self).__init__()
     self.conv1 = GCNConv(input_dim, hidden_dim)
     self.conv2 = GCNConv(hidden_dim, hidden_dim)
Example #17
0
 def __init__(self, in_channels, out_channels):
     super(Encoder, self).__init__()
     self.conv1 = GCNConv(in_channels, 2 * out_channels)
     self.conv2 = GCNConv(2 * out_channels, out_channels)
Example #18
0
 def __init__(self, dataset, input_dim=128, hidden_dim=128):
     super(GClassifier, self).__init__()
     self.conv1 = GCNConv(input_dim, hidden_dim)
     self.conv2 = GCNConv(hidden_dim, hidden_dim)
     self.mlp = Linear(hidden_dim, dataset.num_classes)
Example #19
0
 def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.1):
     super().__init__()
     self.Conv1 = GCNConv(in_channels=in_dim, out_channels=hidden_dim)
     self.dropout = nn.Dropout(dropout)
     self.Conv2 = GCNConv(in_channels=hidden_dim, out_channels=out_dim)
     self.relu = nn.ReLU()
Example #20
0
 def __init__(self, node_features):
     super(Encoder, self).__init__()
     self.conv1 = GCNConv(node_features, args.hidden1_dim)
     self.conv2 = GCNConv(args.hidden1_dim, args.hidden2_dim)
Example #21
0
 def __init__(self, nfeat, nhid, nclass, dropout, nlayer=1):
     super(StandGCN1, self).__init__()
     self.conv1 = GCNConv(nfeat, nclass)
     self.dropout_p = dropout
Example #22
0
 def __init__(self, node_features):
     super(VEncoder, self).__init__()
     self.common_conv1 = GCNConv(node_features, args.hidden1_dim)
     self.mean_conv2 = GCNConv(args.hidden1_dim, args.hidden2_dim)
     self.logstd_conv2 = GCNConv(args.hidden1_dim, args.hidden2_dim)
Example #23
0
 def __init__(self):
     super().__init__()
     self.conv1 = GCNConv(dataset.num_features, 16, normalize=False)
     self.conv2 = GCNConv(16, dataset.num_classes, normalize=False)
Example #24
0
    def __init__(self, hparams=None):
        super(GCN, self).__init__()
        """
        Graph Convolutional Network for graph classification
        Parameters to be included in hparams
        ----------
        n_features : int
            Number of features for each node in the graph
            Default: 75 features for each atom in the molecule in the molecule dataset
        num_classes : int
            Number of classes for prediction
            Default : 2 (active or inactive)
        pool_type : str
            Type of pooling to aggregate the features after the graphconv layers
            Check : https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#module-torch_geometric.nn.glob
            Default : mean
            Options : mean, add, max
        dropout : float
            Percentage of dropout
            Default : 0.2
        activation : str
            One of 'relu', 'sigmoid' or 'tanh'.
            Default : 'relu'
            NOTE : prefinal layer has log_softmax
        opt : str
            One of 'adam' or 'adamax' or 'rmsprop'.
            Default : 'adam'
        batch_size: int
            Batch size for training.
            Default : 32
        lr: float
            Learning rate for optimizer.
            Default : 0.01
        weight_decay: float
            Weight decay in optimizer.
            Default : 0
        """
        self.__check_hparams(hparams)
        self.hparams = hparams

        # NOTE choose dataloaders appropriately
        self.dataloaders = MoleculeDataloaders()
        self.lenLoaders()
        self.telegrad_logs = {
        }  # log everything you want to be reported via telegram here
        self.predicted_train = [
        ]  # to store the prediction in each epoch to calculate roc and prc
        self.true_train = [
        ]  # to store the true labels in each epoch to calculate roc and prc
        self.correct_train = 0  # number of correct predictions in each epoch
        self.predicted_val = [
        ]  # to store the prediction in each epoch to calculate roc and prc
        self.true_val = [
        ]  # to store the true labels in each epoch to calculate roc and prc
        self.correct_val = 0  # number of correct predictions in each epoch (validation)
        self.predicted_test = [
        ]  # to store the prediction in each epoch to calculate roc and prc
        self.true_test = [
        ]  # to store the true labels in each epoch to calculate roc and prc
        self.correct_test = 0  # number of correct predictions in each epoch (validation)

        self.conv1 = GCNConv(
            self.n_features, 128, cached=False
        )  # if you defined cache=True, the shape of batch must be same!
        self.bn1 = BatchNorm1d(128)
        self.conv2 = GCNConv(128, 64, cached=False)
        self.bn2 = BatchNorm1d(64)
        self.fc1 = Linear(64, 64)
        self.bn3 = BatchNorm1d(64)
        self.fc2 = Linear(64, 64)
        self.fc3 = Linear(64, self.num_classes)
Example #25
0
    def __init__(
            self,
            in_channels=1,
            hidden_channels=1,
            out_channels=1,
            normalize=False,
            add_loop=False,
            gnn_k=1,
            gnn_type=1,
            jump=None,  #None,max,lstm
            res=False,
            activation='leaky'):
        super(GNN, self).__init__()

        self.add_loop = add_loop

        self.in_channels = in_channels
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels)
        self.bn2 = torch.nn.BatchNorm1d(out_channels)
        self.k = gnn_k  #number of repitiions of gnn
        self.gnn_type = gnn_type

        self.jump = jump
        if not (jump is None):
            if jump != 'lstm':
                self.jk = JumpingKnowledge(jump)
            else:
                self.jk = JumpingKnowledge(jump, out_channels, gnn_k)
        if activation == 'leaky':
            self.activ = F.leaky_relu
        elif activation == 'elu':
            self.activ = F.elu
        elif activation == 'relu':
            self.activ = F.relu
        self.res = res
        if self.gnn_type in [10, 12] and self.res == True:
            raise Exception('res must be false when gnn_type==10 or 12!')
        if self.k == 1 and self.res == True:
            raise Exception('res must be false when gnn_k==1!')
        if self.k == 1 and not (self.jump is None):
            raise Exception(
                'jumping knowledge only serves for the case where k>1!')
        if gnn_type == 0:
            self.conv1 = DenseSAGEConv(in_channels=self.in_channels,
                                       out_channels=out_channels,
                                       normalize=False)
            self.conv2 = DenseSAGEConv(in_channels=hidden_channels,
                                       out_channels=out_channels,
                                       normalize=False)
        if gnn_type == 1:
            self.conv1 = DenseSAGEConv(in_channels=self.in_channels,
                                       out_channels=out_channels,
                                       normalize=True)
            self.conv2 = DenseSAGEConv(in_channels=hidden_channels,
                                       out_channels=out_channels,
                                       normalize=True)

        if gnn_type == 2:
            self.conv1 = GCNConv(in_channels=1,
                                 out_channels=out_channels,
                                 cached=False)
            self.conv2 = GCNConv(in_channels=hidden_channels,
                                 out_channels=out_channels,
                                 cached=False)
        if gnn_type == 3:
            self.conv1 = GCNConv(in_channels=1,
                                 out_channels=out_channels,
                                 improved=True,
                                 cached=False)
            self.conv2 = GCNConv(in_channels=hidden_channels,
                                 out_channels=out_channels,
                                 improved=True,
                                 cached=False)
        if gnn_type == 4:
            self.conv1 = ChebConv(in_channels=1,
                                  out_channels=out_channels,
                                  K=2)
            self.conv2 = ChebConv(in_channels=hidden_channels,
                                  out_channels=out_channels,
                                  K=2)
        if gnn_type == 5:
            self.conv1 = ChebConv(in_channels=1,
                                  out_channels=out_channels,
                                  K=4)
            self.conv2 = ChebConv(in_channels=hidden_channels,
                                  out_channels=out_channels,
                                  K=4)
        if gnn_type == 6:
            self.conv1 = GraphConv(in_channels=1,
                                   out_channels=out_channels,
                                   aggr='add')
            self.conv2 = GraphConv(in_channels=hidden_channels,
                                   out_channels=out_channels,
                                   aggr='add')
        if gnn_type == 7:
            self.conv1 = GatedGraphConv(out_channels=out_channels,
                                        num_layers=3,
                                        aggr='add',
                                        bias=True)
            self.conv2 = GatedGraphConv(out_channels=out_channels,
                                        num_layers=3,
                                        aggr='add',
                                        bias=True)
        if gnn_type == 8:
            self.conv1 = GatedGraphConv(out_channels=out_channels,
                                        num_layers=7,
                                        aggr='add',
                                        bias=True)
            self.conv2 = GatedGraphConv(out_channels=out_channels,
                                        num_layers=7,
                                        aggr='add',
                                        bias=True)
        if gnn_type == 9:
            self.conv1 = GATConv(in_channels=1,
                                 out_channels=out_channels,
                                 heads=1,
                                 concat=True,
                                 negative_slope=0.2,
                                 dropout=0)
            self.conv2 = GATConv(in_channels=hidden_channels,
                                 out_channels=out_channels,
                                 heads=1,
                                 concat=True,
                                 negative_slope=0.2,
                                 dropout=0.6)
        if gnn_type == 10:
            self.conv1 = GATConv(in_channels=1,
                                 out_channels=out_channels,
                                 heads=6,
                                 concat=False,
                                 negative_slope=0.2,
                                 dropout=0.6)
            self.conv2 = GATConv(in_channels=hidden_channels,
                                 out_channels=out_channels,
                                 heads=6,
                                 concat=False,
                                 negative_slope=0.2,
                                 dropout=0.6)

        if gnn_type == 11:
            self.conv1 = GATConv(in_channels=1,
                                 out_channels=out_channels,
                                 heads=4,
                                 concat=True,
                                 negative_slope=0.2,
                                 dropout=0)
            self.conv2 = GATConv(in_channels=hidden_channels,
                                 out_channels=out_channels,
                                 heads=4,
                                 concat=True,
                                 negative_slope=0.2,
                                 dropout=0.6)

        if gnn_type == 12:
            self.conv1 = GATConv(in_channels=1,
                                 out_channels=out_channels,
                                 heads=4,
                                 concat=False,
                                 negative_slope=0.2,
                                 dropout=0.6)
            self.conv2 = GATConv(in_channels=hidden_channels,
                                 out_channels=out_channels,
                                 heads=4,
                                 concat=False,
                                 negative_slope=0.2,
                                 dropout=0.6)

        if gnn_type == 13:
            self.conv1 = AGNNConv(requires_grad=True)
            self.conv2 = AGNNConv(requires_grad=True)
        if gnn_type == 14:
            self.conv1 = ARMAConv(in_channels=1,
                                  out_channels=hidden_channels,
                                  num_stacks=1,
                                  num_layers=1,
                                  shared_weights=False,
                                  act=F.relu,
                                  dropout=0.5,
                                  bias=True)
            self.conv2 = ARMAConv(in_channels=hidden_channels,
                                  out_channels=out_channels,
                                  num_stacks=1,
                                  num_layers=1,
                                  shared_weights=False,
                                  act=F.relu,
                                  dropout=0.5,
                                  bias=True)
        if gnn_type == 15:
            self.conv1 = SGConv(in_channels=1,
                                out_channels=out_channels,
                                K=1,
                                cached=True,
                                bias=True)
            self.conv2 = SGConv(in_channels=hidden_channels,
                                out_channels=out_channels,
                                K=1,
                                cached=True,
                                bias=True)
        if gnn_type == 16:
            self.conv1 = SGConv(in_channels=1,
                                out_channels=out_channels,
                                K=3,
                                cached=True,
                                bias=True)
            self.conv2 = SGConv(in_channels=hidden_channels,
                                out_channels=out_channels,
                                K=3,
                                cached=True,
                                bias=True)
        if gnn_type == 17:
            self.conv1 = APPNP(K=1, alpha=0.2, bias=True)
            self.conv2 = APPNP(K=1, alpha=0.2, bias=True)
        if gnn_type == 18:
            self.conv1 = APPNP(K=3, alpha=0.2, bias=True)
            self.conv2 = APPNP(K=3, alpha=0.2, bias=True)
        if gnn_type == 19:
            self.conv1 = RGCNConv(in_channels=1,
                                  out_channels=out_channels,
                                  num_relations=3,
                                  num_bases=2,
                                  bias=True)
            self.conv2 = RGCNConv(in_channels=hidden_channels,
                                  out_channels=out_channels,
                                  num_relations=3,
                                  num_bases=2,
                                  bias=True)
# =============================================================================
#         if gnn_type==20:
#             self.conv1 = SignedConv(in_channels=1, out_channels=out_channels, first_aggr=True, bias=True)
#             self.conv2 = SignedConv(in_channels=hidden_channels, out_channels=out_channels, first_aggr=True, bias=True)
#         if gnn_type==21:
#             self.conv1 =SignedConv(in_channels=1, out_channels=out_channels, first_aggr=False, bias=True)
#             self.conv2 =SignedConv(in_channels=hidden_channels, out_channels=out_channels, first_aggr=False, bias=True)
#         if gnn_type==22:
#             self.conv1 = GMMConv(in_channels=1, out_channels=out_channels, dim=2, kernel_size=3, bias=True)
#             self.conv2 = GMMConv(in_channels=hidden_channels, out_channels=out_channels, dim=2, kernel_size=3, bias=True)
#         if gnn_type==23:
#             self.conv1 = GMMConv(in_channels=1, out_channels=out_channels, dim=5, kernel_size=3, bias=True)
#             self.conv2 = GMMConv(in_channels=hidden_channels, out_channels=out_channels, dim=5, kernel_size=3, bias=True)
#         if gnn_type==24:
#             self.conv1 = GMMConv(in_channels=1, out_channels=out_channels, dim=2, kernel_size=3, bias=True)
#             self.conv2 = GMMConv(in_channels=hidden_channels, out_channels=out_channels, dim=2, kernel_size=3, bias=True)
# =============================================================================
        if gnn_type == 25:
            self.conv1 = SplineConv(in_channels=1,
                                    out_channels=out_channels,
                                    dim=2,
                                    kernel_size=3,
                                    is_open_spline=True,
                                    degree=1,
                                    norm=True,
                                    root_weight=True,
                                    bias=True)
            self.conv2 = SplineConv(in_channels=hidden_channels,
                                    out_channels=out_channels,
                                    dim=2,
                                    kernel_size=3,
                                    is_open_spline=True,
                                    degree=1,
                                    norm=True,
                                    root_weight=True,
                                    bias=True)
        if gnn_type == 26:
            self.conv1 = SplineConv(in_channels=1,
                                    out_channels=out_channels,
                                    dim=3,
                                    kernel_size=3,
                                    is_open_spline=False,
                                    degree=1,
                                    norm=True,
                                    root_weight=True,
                                    bias=True)
            self.conv2 = SplineConv(in_channels=hidden_channels,
                                    out_channels=out_channels,
                                    dim=3,
                                    kernel_size=3,
                                    is_open_spline=False,
                                    degree=1,
                                    norm=True,
                                    root_weight=True,
                                    bias=True)
        if gnn_type == 27:
            self.conv1 = SplineConv(in_channels=1,
                                    out_channels=out_channels,
                                    dim=3,
                                    kernel_size=6,
                                    is_open_spline=True,
                                    degree=1,
                                    norm=True,
                                    root_weight=True,
                                    bias=True)
            self.conv2 = SplineConv(in_channels=hidden_channels,
                                    out_channels=out_channels,
                                    dim=3,
                                    kernel_size=6,
                                    is_open_spline=True,
                                    degree=1,
                                    norm=True,
                                    root_weight=True,
                                    bias=True)
        if gnn_type == 28:
            self.conv1 = SplineConv(in_channels=1,
                                    out_channels=out_channels,
                                    dim=3,
                                    kernel_size=3,
                                    is_open_spline=True,
                                    degree=3,
                                    norm=True,
                                    root_weight=True,
                                    bias=True)
            self.conv2 = SplineConv(in_channels=hidden_channels,
                                    out_channels=out_channels,
                                    dim=3,
                                    kernel_size=3,
                                    is_open_spline=True,
                                    degree=3,
                                    norm=True,
                                    root_weight=True,
                                    bias=True)
        if gnn_type == 29:
            self.conv1 = SplineConv(in_channels=1,
                                    out_channels=out_channels,
                                    dim=3,
                                    kernel_size=6,
                                    is_open_spline=True,
                                    degree=3,
                                    norm=True,
                                    root_weight=True,
                                    bias=True)
            self.conv2 = SplineConv(in_channels=hidden_channels,
                                    out_channels=out_channels,
                                    dim=3,
                                    kernel_size=6,
                                    is_open_spline=True,
                                    degree=3,
                                    norm=True,
                                    root_weight=True,
                                    bias=True)
Example #26
0
 def __init__(self, in_channels, out_channels):
     super().__init__()
     self.conv = GCNConv(in_channels, out_channels)
 def __init__(self, in_feats, hid_feats, out_feats):
     super(BURumorGCN, self).__init__()
     self.conv1 = GCNConv(in_feats, hid_feats)
     # self.conv2 = GCNConv(hid_feats+in_feats, out_feats)
     self.conv2 = GCNConv(hid_feats, out_feats)
Example #28
0
 def __init__(self, in_channels, out_channels):
     super().__init__()
     self.conv_mu = GCNConv(in_channels, out_channels)
     self.conv_logstd = GCNConv(in_channels, out_channels)
Example #29
0
 def _create_reset_gate_parameters_and_layers(self):
     self.conv_r = GCNConv(in_channels=self.in_channels, out_channels=self.out_channels, improved=self.improved,
                           cached=self.cached, add_self_loops=self.add_self_loops )
     self.linear_r = torch.nn.Linear(2 * self.out_channels, self.out_channels)
Example #30
0
    def __init__(self,
                 num_nodes,
                 num_edges,
                 board_input_dim,
                 global_input_dim,
                 hidden_global_dim=32,
                 num_global_layers=4,
                 hidden_conv_dim=16,
                 num_conv_layers=4,
                 hidden_pick_dim=32,
                 num_pick_layers=4,
                 out_pick_dim=1,
                 hidden_place_dim=32,
                 num_place_layers=4,
                 out_place_dim=1,
                 hidden_attack_dim=32,
                 num_attack_layers=4,
                 out_attack_dim=1,
                 hidden_fortify_dim=32,
                 num_fortify_layers=4,
                 out_fortify_dim=1,
                 hidden_value_dim=32,
                 num_value_layers=4,
                 dropout=0.4):

        super().__init__()

        # TODO: add a node encoder?
        # self.node_encoder = nn.Linear(board_input_dim, board_input_dim)

        # Global

        # self.num_global_layers = num_global_layers
        # self.global_fc = torch.nn.ModuleList([nn.Linear(global_input_dim, hidden_global_dim)] + \
        #                                      [nn.Linear(hidden_global_dim, hidden_global_dim) for i in range(num_global_layers-1)])
        # self.global_bns = nn.ModuleList([torch.nn.BatchNorm1d(hidden_global_dim) for i in range(num_global_layers-1)])

        self.num_nodes = num_nodes
        self.num_edges = num_edges

        # Board
        self.num_conv_layers = num_conv_layers
        self.conv_init = GCNConv(board_input_dim, hidden_conv_dim)
        self.deep_convs = nn.ModuleList([
            ResGCN(GCNConv(hidden_conv_dim, hidden_conv_dim),
                   nn.BatchNorm1d(hidden_conv_dim))
            for i in range(num_conv_layers)
        ])

        self.softmax = nn.LogSoftmax(dim=1)

        # Pick country head
        self.num_pick_layers = num_pick_layers
        self.pick_layers = torch.nn.ModuleList([ResGCN(GCNConv(hidden_conv_dim,hidden_pick_dim),
                                                nn.BatchNorm1d(hidden_pick_dim))]+\
                                                [ResGCN(GCNConv(hidden_pick_dim,hidden_pick_dim),
                                                nn.BatchNorm1d(hidden_pick_dim)) for i in range(num_pick_layers-2)] +\
                                               [GCNConv(hidden_pick_dim, out_pick_dim)]
                                               )
        self.pick_final = torch.nn.ModuleList(
            [nn.Linear(num_nodes, 64),
             nn.Linear(64, num_nodes)])

        # Place armies head
        # Distribution over the nodes
        self.num_place_layers = num_place_layers
        self.placeArmies_layers = torch.nn.ModuleList([ResGCN(GCNConv(hidden_conv_dim,hidden_place_dim),
                                                nn.BatchNorm1d(hidden_place_dim))]+\
                                                [ResGCN(GCNConv(hidden_place_dim,hidden_place_dim),
                                                nn.BatchNorm1d(hidden_place_dim)) for i in range(num_place_layers-2)] +\
                                               [GCNConv(hidden_place_dim, out_place_dim)]
                                               )
        # self.global_to_place = nn.Linear(hidden_global_dim, out_place_dim)
        # self.place_final_1 = nn.Linear(2*out_place_dim, out_place_dim)
        # self.place_final_2 = nn.Linear(out_place_dim, 1)

        self.place_final = torch.nn.ModuleList(
            [nn.Linear(num_nodes, 64),
             nn.Linear(64, num_nodes)])

        # Attack head
        self.num_attack_layers = num_attack_layers
        self.hidden_attack_dim = hidden_attack_dim
        self.attack_layers = torch.nn.ModuleList([ResGCN(GCNConv(hidden_conv_dim,hidden_attack_dim),
                                                nn.BatchNorm1d(hidden_attack_dim))]+\
                                                [ResGCN(GCNConv(hidden_attack_dim,hidden_attack_dim),
                                                nn.BatchNorm1d(hidden_attack_dim)) for i in range(num_attack_layers-1)]
                                                )
        self.attack_edge = EdgeNet(hidden_attack_dim, 28, 3, out_attack_dim)
        # self.global_to_attack = nn.Linear(hidden_global_dim, out_attack_dim)
        # self.attack_final_1 = nn.Linear(2*out_attack_dim, out_attack_dim)
        # self.attack_final_2 = nn.Linear(out_attack_dim, 1)
        self.attack_final = torch.nn.ModuleList(
            [nn.Linear(num_edges, 64),
             nn.Linear(64, num_edges + 1)])

        # Add something to make it edge-wise

        # Fortify head
        self.num_fortify_layers = num_fortify_layers
        self.hidden_fortify_dim = hidden_fortify_dim
        self.fortify_layers = torch.nn.ModuleList([ResGCN(GCNConv(hidden_conv_dim,hidden_fortify_dim),
                                                nn.BatchNorm1d(hidden_fortify_dim))]+\
                                                [ResGCN(GCNConv(hidden_fortify_dim,hidden_fortify_dim),
                                                nn.BatchNorm1d(hidden_fortify_dim)) for i in range(num_fortify_layers-1)]
                                                )
        self.fortify_edge = EdgeNet(hidden_fortify_dim, 28, 3, out_fortify_dim)
        # self.global_to_fortify = nn.Linear(hidden_global_dim, out_fortify_dim)
        # self.fortify_final_1 = nn.Linear(2*out_fortify_dim, out_fortify_dim)
        # self.fortify_final_2 = nn.Linear(out_fortify_dim, 1)
        self.fortify_final = torch.nn.ModuleList(
            [nn.Linear(num_edges, 64),
             nn.Linear(64, num_edges + 1)])

        # Value head
        self.num_value_layers = num_value_layers
        self.value_layers = torch.nn.ModuleList([ResGCN(GCNConv(hidden_conv_dim,hidden_value_dim),
                                                nn.BatchNorm1d(hidden_value_dim))]+\
                                                [ResGCN(GCNConv(hidden_value_dim,hidden_value_dim),
                                                nn.BatchNorm1d(hidden_value_dim)) for i in range(num_value_layers-1)]
                                                )
        self.gate_nn = nn.Linear(hidden_value_dim, 1)
        self.other_nn = nn.Linear(hidden_value_dim, hidden_value_dim)
        self.global_pooling_layer = torch_geometric.nn.GlobalAttention(
            self.gate_nn, self.other_nn)
        self.value_fc_1 = nn.Linear(hidden_value_dim, hidden_value_dim)
        self.value_fc_2 = nn.Linear(hidden_value_dim, 6)

        self.dropout = dropout