예제 #1
0
    def __init__(self, net_params):
        super().__init__()
        self.n_layers = 2
        self.embedding_h = nn.Linear(net_params.in_dim, net_params.hidden_dim)

        self.ginlayers = torch.nn.ModuleList()
        for layer in range(net_params.L):
            mlp = MLP(net_params.n_mlp_GIN, net_params.hidden_dim,
                      net_params.hidden_dim, net_params.hidden_dim)
            self.ginlayers.append(
                GINLayer(ApplyNodeFunc(mlp), net_params.neighbor_aggr_GIN,
                         net_params.dropout, net_params.graph_norm,
                         net_params.batch_norm, net_params.residual, 0,
                         net_params.learn_eps_GIN))
            pass

        # Linear function for graph poolings (readout) of output of each layer
        # which maps the output of different layers into a prediction score
        self.linears_prediction = torch.nn.ModuleList()
        for layer in range(self.n_layers + 1):
            self.linears_prediction.append(
                nn.Linear(net_params.hidden_dim, net_params.n_classes))
            pass

        if net_params.readout == 'sum':
            self.pool = SumPooling()
        elif net_params.readout == 'mean':
            self.pool = AvgPooling()
        elif net_params.readout == 'max':
            self.pool = MaxPooling()
        else:
            raise NotImplementedError

        pass
예제 #2
0
    def __init__(self, in_dim, hid_dim, n_layer):
        super(GINEncoder, self).__init__()

        self.n_layer = n_layer

        self.convs = ModuleList()
        self.bns = ModuleList()

        for i in range(n_layer):
            if i == 0:
                n_in = in_dim
            else:
                n_in = hid_dim
            n_out = hid_dim
            block = Sequential(Linear(n_in, n_out), ReLU(),
                               Linear(hid_dim, hid_dim))

            conv = GINConv(block, 'sum')
            bn = BatchNorm1d(hid_dim)

            self.convs.append(conv)
            self.bns.append(bn)

        # sum pooling
        self.pool = SumPooling()
예제 #3
0
    def __init__(self,
                 in_edge_feats,
                 num_node_types=1,
                 hidden_feats=300,
                 n_layers=5,
                 batchnorm=True,
                 activation=F.relu,
                 dropout=0.,
                 gnn_type='gcn',
                 virtual_node=True,
                 residual=False,
                 jk=False):
        super(GNNOGB, self).__init__()

        assert gnn_type in ['gcn', 'gin'], \
            "Expect gnn_type to be either 'gcn' or 'gin', got {}".format(gnn_type)

        self.n_layers = n_layers
        # Initial node embeddings
        self.node_encoder = nn.Embedding(num_node_types, hidden_feats)
        # Hidden layers
        self.layers = nn.ModuleList()
        self.gnn_type = gnn_type
        for _ in range(n_layers):
            if gnn_type == 'gcn':
                self.layers.append(
                    GCNOGBLayer(in_node_feats=hidden_feats,
                                in_edge_feats=in_edge_feats,
                                out_feats=hidden_feats))
            else:
                self.layers.append(
                    GINOGBLayer(node_feats=hidden_feats,
                                in_edge_feats=in_edge_feats))

        self.virtual_node = virtual_node
        if virtual_node:
            self.virtual_node_emb = nn.Embedding(1, hidden_feats)
            self.mlp_virtual_project = nn.ModuleList()
            for _ in range(n_layers - 1):
                self.mlp_virtual_project.append(
                    nn.Sequential(nn.Linear(hidden_feats, 2 * hidden_feats),
                                  nn.BatchNorm1d(2 * hidden_feats), nn.ReLU(),
                                  nn.Linear(2 * hidden_feats, hidden_feats),
                                  nn.BatchNorm1d(hidden_feats), nn.ReLU()))
            self.virtual_readout = SumPooling()

        if batchnorm:
            self.batchnorms = nn.ModuleList()
            for _ in range(n_layers):
                self.batchnorms.append(nn.BatchNorm1d(hidden_feats))
        else:
            self.batchnorms = None

        self.activation = activation
        self.dropout = nn.Dropout(dropout)
        self.residual = residual
        self.jk = jk

        self.reset_parameters()
예제 #4
0
    def __init__(self,
                 embed="gpls",
                 dim=64,
                 hidden_dim=64,
                 num_gaussians=64,
                 cutoff=5.0,
                 output_dim=1,
                 n_conv=3,
                 act=ShiftedSoftplus(),
                 aggregation_mode='avg',
                 norm=False):
        """
        Args:
            embed: Group and Period embeding to atomic number
                    Embedding
            dim: dimension of features
            output_dim: dimension of prediction
            cutoff: radius cutoff
            num_gaussians: dimension in the RBF function
            n_conv: number of interaction layers
            norm: normalization
        """
        super().__init__()
        self.name = "SchNet"
        self._dim = dim
        self.cutoff = cutoff
        self.n_conv = n_conv
        self.norm = norm
        self.output_dim = output_dim
        self.aggregation_mode = aggregation_mode

        if act == None:
            self.activation = ShiftedSoftplus()
        else:
            self.activation = act


        assert embed in ['gpls', 'atom', 'gp'], \
            "Expect mode to be 'gpls' or 'atom' or 'gp', got {}".format(embed)
        if embed == "gpls":
            self.embedding_layer = GPLSEmbedding(dim)
        elif embed == "atom":
            self.embedding_layer = AtomEmbedding(dim)
        elif embed == "gp":
            self.embedding_layer = GPEmbedding(dim)

        self.rbf_layer = RBFLayer(0, cutoff, num_gaussians)
        self.conv_layers = nn.ModuleList([
            SchInteraction(self.rbf_layer._fan_out, dim) for i in range(n_conv)
        ])
        self.atom_dense_layer1 = nn.Linear(dim, int(dim / 2))
        self.atom_dense_layer2 = nn.Linear(int(dim / 2), output_dim)
        if self.aggregation_mode == 'sum':
            self.readout = SumPooling()
        elif self.aggregation_mode == "avg":
            self.readout = AvgPooling()
예제 #5
0
    def __init__(self, features_dim, h_dim, num_rels, num_layers, num_bases=-1, gcn_dropout=0):
        super(RGCN, self).__init__()

        self.features_dim, self.h_dim = features_dim, h_dim
        self.num_layers = num_layers
        self.p = gcn_dropout

        self.num_rels = num_rels
        self.num_bases = num_bases
        # create rgcn layers
        self.build_model()
        self.pool = SumPooling()
예제 #6
0
    def __init__(self,
                 dataset,
                 node_feat_dim,
                 edge_feat_dim,
                 hid_dim,
                 out_dim,
                 num_layers,
                 dropout=0.,
                 norm='batch',
                 pooling='mean',
                 beta=1.0,
                 learn_beta=False,
                 aggr='softmax',
                 mlp_layers=1):
        super(DeeperGCN, self).__init__()
        
        self.dataset = dataset
        self.num_layers = num_layers
        self.dropout = dropout
        self.gcns = nn.ModuleList()
        self.norms = nn.ModuleList()

        for i in range(self.num_layers):
            conv = GENConv(dataset=dataset,
                           in_dim=hid_dim,
                           out_dim=hid_dim,
                           aggregator=aggr,
                           beta=beta,
                           learn_beta=learn_beta,
                           mlp_layers=mlp_layers,
                           norm=norm)
            
            self.gcns.append(conv)
            self.norms.append(norm_layer(norm, hid_dim))

        if self.dataset == 'ogbg-molhiv':
            self.node_encoder = AtomEncoder(hid_dim)
        elif self.dataset == 'ogbg-ppa':
            self.node_encoder = nn.Linear(node_feat_dim, hid_dim)
            self.edge_encoder = nn.Linear(edge_feat_dim, hid_dim)
        else:
            raise ValueError(f'Dataset {dataset} is not supported.')

        if pooling == 'sum':
            self.pooling = SumPooling()
        elif pooling == 'mean':
            self.pooling = AvgPooling()
        elif pooling == 'max':
            self.pooling = MaxPooling()
        else:
            raise NotImplementedError(f'{pooling} is not supported.')
        
        self.output = nn.Linear(hid_dim, out_dim)
예제 #7
0
    def __init__(self,
                 embed="atom",
                 dim=64,
                 cutoff=5.,
                 output_dim=1,
                 num_gaussians=64,
                 n_conv=3,
                 act="ssp",
                 aggregation_mode="avg",
                 norm=False):
        """
        Args:
            dim: dimension of features
            output_dim: dimension of prediction
            cutoff: radius cutoff
            width: width in the RBF function
            n_conv: number of interaction layers
            norm: normalization
        """
        super().__init__()
        self.name = "NMPEUModel"
        self._dim = dim
        self.cutoff = cutoff
        self.num_gaussians = num_gaussians
        self.n_conv = n_conv
        self.norm = norm
        self.activation = ShiftedSoftplus()
        self.aggregation_mode = aggregation_mode

        assert embed in ["atom", "gp", "gpls", "fakeatom"]
        if embed == "gpls":
            self.embedding_layer = GPLSEmbedding(dim)
        elif embed == "gp":
            self.embedding_layer = GPEmbedding(dim)
        elif embed == "atom":
            self.embedding_layer = AtomEmbedding(dim)
        elif embed == "fakeatom":
            self.embedding_layer = FakeAtomEmbedding(dim)
        self.rbf_layer = RBFLayer(0, cutoff, num_gaussians)
        self.conv_layers = nn.ModuleList([
            NMPEUInteraction(self.rbf_layer._fan_out, dim, act=self.activation)
            for i in range(n_conv)
        ])

        self.atom_dense_layer1 = nn.Linear(dim, int(dim / 2))
        self.atom_dense_layer2 = nn.Linear(int(dim / 2), output_dim)
        if self.aggregation_mode == 'sum':
            self.readout = SumPooling()
        elif self.aggregation_mode == "avg":
            self.readout = AvgPooling()
예제 #8
0
 def __init__(self, features_dim, h_dim, out_dim , num_rels, num_bases=-1, num_hidden_layers=2, classifier=False):
     super(Model, self).__init__()
     
     self.features_dim, self.h_dim, self.out_dim = features_dim, h_dim, out_dim
     self.num_hidden_layers = num_hidden_layers
     self.num_rels = num_rels
     self.num_bases = num_bases
     # create rgcn layers
     self.build_model()
     
     
     self.attn = GATConv(in_feats=self.out_dim, out_feats=self.out_dim,num_heads=1)
     self.dense = nn.Linear(self.out_dim,1)
     self.pool = SumPooling()
     self.is_classifier=classifier
예제 #9
0
    def __init__(self,
                 embed="gpls",
                 dim=64,
                 hidden_dim=128,
                 output_dim=1,
                 n_conv=3,
                 cutoff=12.,
                 num_gaussians=64,
                 aggregation_mode='avg',
                 norm=False):
        super(CGCNN, self).__init__()
        self.name = "CGCNN"
        self.dim = dim
        self._dim = hidden_dim
        self.cutoff = cutoff
        self.n_conv = n_conv
        self.norm = norm
        self.num_gaussians = num_gaussians
        self.activation = nn.Softplus()
        self.aggregation_mode = aggregation_mode

        assert embed in ["atom", "gp", "gpls", "fakeatom"]
        if embed == "gpls":
            self.embedding_layer = GPLSEmbedding(dim)
        elif embed == "gp":
            self.embedding_layer = GPEmbedding(dim)
        elif embed == "atom":
            self.embedding_layer = AtomEmbedding(dim)
        elif embed == "fakeatom":
            self.embedding_layer = FakeAtomEmbedding(dim)

        self.rbf_layer = RBFLayer(0, cutoff, num_gaussians)
        self.conv_layers = nn.ModuleList([
            CGCNNConv(self.dim, self.rbf_layer._fan_out) for i in range(n_conv)
        ])

        assert aggregation_mode in ['sum', 'avg'], \
            "Expect mode to be 'sum' or 'avg', got {}".format(aggregation_mode )
        if self.aggregation_mode == 'sum':
            self.readout = SumPooling()
        elif self.aggregation_mode == "avg":
            self.readout = AvgPooling()

        self.conv_to_fc = nn.Linear(dim, hidden_dim)
        self.conv_to_fc_softplus = nn.Softplus()
        self.fc_out = nn.Linear(hidden_dim, output_dim)
예제 #10
0
    def __init__(self, net_params):
        super().__init__()
        num_node_type = net_params['num_node_type']
        hidden_dim = net_params['hidden_dim']
        n_classes = net_params['n_classes']
        dropout = net_params['dropout']
        self.n_layers = net_params['L']
        n_mlp_layers = net_params['n_mlp_GIN']  # GIN
        learn_eps = net_params['learn_eps_GIN']  # GIN
        neighbor_aggr_type = net_params['neighbor_aggr_GIN']  # GIN
        readout = net_params['readout']  # this is graph_pooling_type
        batch_norm = net_params['batch_norm']
        residual = net_params['residual']
        self.pos_enc = net_params['pos_enc']
        if self.pos_enc:
            pos_enc_dim = net_params['pos_enc_dim']
            self.embedding_pos_enc = nn.Linear(pos_enc_dim, hidden_dim)
        else:
            in_dim = 1
            self.embedding_h = nn.Embedding(in_dim, hidden_dim)

        # List of MLPs
        self.ginlayers = torch.nn.ModuleList()

        for layer in range(self.n_layers):
            mlp = MLP(n_mlp_layers, hidden_dim, hidden_dim, hidden_dim)

            self.ginlayers.append(
                GINLayer(ApplyNodeFunc(mlp), neighbor_aggr_type, dropout,
                         batch_norm, residual, 0, learn_eps))

        # Linear function for graph poolings (readout) of output of each layer
        # which maps the output of different layers into a prediction score
        self.linears_prediction = torch.nn.ModuleList()

        for layer in range(self.n_layers + 1):
            self.linears_prediction.append(nn.Linear(hidden_dim, n_classes))

        if readout == 'sum':
            self.pool = SumPooling()
        elif readout == 'mean':
            self.pool = AvgPooling()
        elif readout == 'max':
            self.pool = MaxPooling()
        else:
            raise NotImplementedError
예제 #11
0
    def __init__(self,
                 num_node_emb_list,
                 num_edge_emb_list,
                 num_layers=5,
                 emb_dim=300,
                 JK='last',
                 dropout=0.5,
                 readout='mean',
                 n_tasks=1):
        super(GINPredictor, self).__init__()

        if num_layers < 2:
            raise ValueError('Number of GNN layers must be greater '
                             'than 1, got {:d}'.format(num_layers))

        self.gnn = GIN(num_node_emb_list=num_node_emb_list,
                       num_edge_emb_list=num_edge_emb_list,
                       num_layers=num_layers,
                       emb_dim=emb_dim,
                       JK=JK,
                       dropout=dropout)

        if readout == 'sum':
            self.readout = SumPooling()
        elif readout == 'mean':
            self.readout = AvgPooling()
        elif readout == 'max':
            self.readout = MaxPooling()
        elif readout == 'attention':
            if JK == 'concat':
                self.readout = GlobalAttentionPooling(
                    gate_nn=nn.Linear((num_layers + 1) * emb_dim, 1))
            else:
                self.readout = GlobalAttentionPooling(
                    gate_nn=nn.Linear(emb_dim, 1))
        elif readout == 'set2set':
            self.readout = Set2Set()
        else:
            raise ValueError(
                "Expect readout to be 'sum', 'mean', "
                "'max', 'attention' or 'set2set', got {}".format(readout))

        if JK == 'concat':
            self.predict = nn.Linear((num_layers + 1) * emb_dim, n_tasks)
        else:
            self.predict = nn.Linear(emb_dim, n_tasks)
예제 #12
0
    def __init__(self,
                 dims,
                 device,
                 attributor_dims,
                 num_rels,
                 pool='att',
                 num_bases=-1,
                 pos_weight=0,
                 nucs=True,
                 clustered=False,
                 num_clusts=8):
        """

        :param dims: the embeddings dimensions
        :param attributor_dims: the number of motifs to look for
        :param num_rels: the number of possible edge types
        :param num_bases: technical rGCN option
        :param rec: the constant in front of reconstruction loss
        :param mot: the constant in front of motif detection loss
        :param orth: the constant in front of dictionnary orthogonality loss
        :param attribute: Wether we want the network to use the attribution module
        """
        super(Model, self).__init__()
        # self.num_nodes = num_nodes
        self.dims = dims
        self.attributor_dims = attributor_dims
        self.num_rels = num_rels
        self.num_bases = num_bases
        self.pos_weight = pos_weight
        self.device = device
        self.nucs = nucs
        self.clustered = clustered
        self.num_clusts = num_clusts

        if pool == 'att':
            pooling_gate_nn = nn.Linear(attributor_dims[0], 1)
            self.pool = GlobalAttentionPooling(pooling_gate_nn)
        else:
            self.pool = SumPooling()

        self.embedder = Embedder(dims=dims,
                                 num_rels=num_rels,
                                 num_bases=num_bases)

        self.attributor = Attributor(attributor_dims, clustered=clustered)
예제 #13
0
 def __init__(self, _final_dimension: int, hidden_dimension: int,
              output_dimension: int, _act: _typing.Optional[str],
              _dropout: _typing.Optional[float],
              num_graph_features: _typing.Optional[int]):
     super(_SumPoolMLPDecoder, self).__init__()
     if (isinstance(num_graph_features, int) and num_graph_features > 0):
         _final_dimension += num_graph_features
         self.__num_graph_features: _typing.Optional[
             int] = num_graph_features
     else:
         self.__num_graph_features: _typing.Optional[int] = None
     self._sumpool = SumPooling()
     self._fc1: torch.nn.Linear = torch.nn.Linear(_final_dimension,
                                                  hidden_dimension)
     self._fc2: torch.nn.Linear = torch.nn.Linear(hidden_dimension,
                                                  output_dimension)
     self._act: _typing.Optional[str] = _act
     self._dropout: _typing.Optional[float] = _dropout
예제 #14
0
    def __init__(self,
                 g,
                 rel_names,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 type_aggreg,
                 activation,
                 pooling,
                 dropout):
        super(Classifier, self).__init__()
        self.g = g
        self.activation = activation
        self.layers = nn.ModuleList()
        # input layer
        self.layers.append(HeteroGraphConv({
            rel:GraphConv(in_feats, n_hidden,allow_zero_in_degree=True,norm='both')
            for rel in rel_names
            },
            aggregate=type_aggreg
            ))
        # hidden layers
        for i in range(n_layers - 1):
            self.layers.append(HeteroGraphConv({
                rel:GraphConv(n_hidden, n_hidden, allow_zero_in_degree=True,norm='both')
                for rel in rel_names
                },
                aggregate=type_aggreg
                ))
        # output layer
        self.dropout = nn.Dropout(p=dropout)


        if pooling == 'sum':
            self.pool = SumPooling()
        elif pooling == 'mean':
            self.pool = AvgPooling()
        elif pooling == 'max':
            self.pool = MaxPooling()
        else:
            raise NotImplementedError

        self.classify = nn.Linear(n_hidden,n_classes)
예제 #15
0
 def __init__(self, input_dimensions: _typing.Sequence[int],
              output_dimension: int, dropout: float,
              graph_pooling_type: str):
     super(_JKSumPoolDecoder, self).__init__()
     self._linear_transforms: torch.nn.ModuleList = torch.nn.ModuleList()
     for input_dimension in input_dimensions:
         self._linear_transforms.append(
             torch.nn.Linear(input_dimension, output_dimension))
     self._dropout: torch.nn.Dropout = torch.nn.Dropout(dropout)
     if not isinstance(graph_pooling_type, str):
         raise TypeError
     elif graph_pooling_type.lower() == 'sum':
         self.__pool = SumPooling()
     elif graph_pooling_type.lower() == 'mean':
         self.__pool = AvgPooling()
     elif graph_pooling_type.lower() == 'max':
         self.__pool = MaxPooling()
     else:
         raise NotImplementedError
예제 #16
0
    def __init__(self,
                 in_edge_feats,
                 num_node_types=1,
                 hidden_feats=300,
                 n_layers=5,
                 n_tasks=1,
                 batchnorm=True,
                 activation=F.relu,
                 dropout=0.,
                 gnn_type='gcn',
                 virtual_node=True,
                 residual=False,
                 jk=False,
                 readout='mean'):
        super(GNNOGBPredictor, self).__init__()

        assert gnn_type in ['gcn', 'gin'], \
            "Expect gnn_type to be 'gcn' or 'gin', got {}".format(gnn_type)
        assert readout in ['mean', 'sum', 'max'], \
            "Expect readout to be in ['mean', 'sum', 'max'], got {}".format(readout)

        self.gnn = GNNOGB(in_edge_feats=in_edge_feats,
                          num_node_types=num_node_types,
                          hidden_feats=hidden_feats,
                          n_layers=n_layers,
                          batchnorm=batchnorm,
                          activation=activation,
                          dropout=dropout,
                          gnn_type=gnn_type,
                          virtual_node=virtual_node,
                          residual=residual,
                          jk=jk)

        if readout == 'mean':
            self.readout = AvgPooling()
        if readout == 'sum':
            self.readout = SumPooling()
        if readout == 'max':
            self.readout = MaxPooling()

        self.predict = nn.Linear(hidden_feats, n_tasks)
예제 #17
0
    def __init__(self, in_dim, out_dim, num_layers, norm):
        super(GCN, self).__init__()

        self.num_layers = num_layers
        self.layers = nn.ModuleList()

        self.layers.append(
            GraphConv(in_dim,
                      out_dim,
                      bias=False,
                      norm=norm,
                      activation=nn.PReLU()))
        self.pooling = SumPooling()

        for _ in range(num_layers - 1):
            self.layers.append(
                GraphConv(out_dim,
                          out_dim,
                          bias=False,
                          norm=norm,
                          activation=nn.PReLU()))
예제 #18
0
파일: dgl_gcc.py 프로젝트: jkx19/cogdl
    def __init__(
        self,
        num_layers,
        num_mlp_layers,
        input_dim,
        hidden_dim,
        output_dim,
        final_dropout,
        learn_eps,
        graph_pooling_type,
        neighbor_pooling_type,
        use_selayer,
    ):
        """model parameters setting

        Paramters
        ---------
        num_layers: int
            The number of linear layers in the neural network
        num_mlp_layers: int
            The number of linear layers in mlps
        input_dim: int
            The dimensionality of input features
        hidden_dim: int
            The dimensionality of hidden units at ALL layers
        output_dim: int
            The number of classes for prediction
        final_dropout: float
            dropout ratio on the final linear layer
        learn_eps: boolean
            If True, learn epsilon to distinguish center nodes from neighbors
            If False, aggregate neighbors and center nodes altogether.
        neighbor_pooling_type: str
            how to aggregate neighbors (sum, mean, or max)
        graph_pooling_type: str
            how to aggregate entire nodes in a graph (sum, mean or max)

        """
        super(UnsupervisedGIN, self).__init__()
        self.num_layers = num_layers
        self.learn_eps = learn_eps

        # List of MLPs
        self.ginlayers = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for layer in range(self.num_layers - 1):
            if layer == 0:
                mlp = MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim, use_selayer)
            else:
                mlp = MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim, use_selayer)

            self.ginlayers.append(
                GINConv(
                    ApplyNodeFunc(mlp, use_selayer),
                    neighbor_pooling_type,
                    0,
                    self.learn_eps,
                )
            )
            self.batch_norms.append(
                SELayer(hidden_dim, int(np.sqrt(hidden_dim))) if use_selayer else nn.BatchNorm1d(hidden_dim)
            )

        # Linear function for graph poolings of output of each layer
        # which maps the output of different layers into a prediction score
        self.linears_prediction = torch.nn.ModuleList()

        for layer in range(num_layers):
            if layer == 0:
                self.linears_prediction.append(nn.Linear(input_dim, output_dim))
            else:
                self.linears_prediction.append(nn.Linear(hidden_dim, output_dim))

        self.drop = nn.Dropout(final_dropout)

        if graph_pooling_type == "sum":
            self.pool = SumPooling()
        elif graph_pooling_type == "mean":
            self.pool = AvgPooling()
        elif graph_pooling_type == "max":
            self.pool = MaxPooling()
        else:
            raise NotImplementedError
예제 #19
0
 def __init__(self):
     super(SumReadout, self).__init__()
     self.sum_pooler = SumPooling()
예제 #20
0
파일: gin.py 프로젝트: MitchellTesla/AutoGL
    def __init__(self, args):
        """model parameters setting

        Paramters
        ---------
        num_layers: int
            The number of linear layers in the neural network
        num_mlp_layers: int
            The number of linear layers in mlps
        input_dim: int
            The dimensionality of input features
        hidden_dim: int
            The dimensionality of hidden units at ALL layers
        output_dim: int
            The number of classes for prediction
        final_dropout: float
            dropout ratio on the final linear layer
        eps: boolean
            If True, learn epsilon to distinguish center nodes from neighbors
            If False, aggregate neighbors and center nodes altogether.
        neighbor_pooling_type: str
            how to aggregate neighbors (sum, mean, or max)
        graph_pooling_type: str
            how to aggregate entire nodes in a graph (sum, mean or max)

        """
        super(GIN, self).__init__()
        self.args = args

        missing_keys = list(
            set(
                [
                    "features_num",
                    "num_class",
                    "num_graph_features",
                    "num_layers",
                    "hidden",
                    "dropout",
                    "act",
                    "mlp_layers",
                    "eps",
                ]
            )
            - set(self.args.keys())
        )
        if len(missing_keys) > 0:
            raise Exception("Missing keys: %s." % ",".join(missing_keys))

        self.num_graph_features = self.args["num_graph_features"]
        self.num_layers = self.args["num_layers"]
        assert self.num_layers > 2, "Number of layers in GIN should not less than 3"
        if not self.num_layers == len(self.args["hidden"]) + 1:
            LOGGER.warn("Warning: layer size does not match the length of hidden units")

        self.eps = True if self.args["eps"]=="True" else False
        self.num_mlp_layers = self.args["mlp_layers"]
        input_dim = self.args["features_num"]
        hidden = self.args["hidden"]
        neighbor_pooling_type = self.args["neighbor_pooling_type"]
        graph_pooling_type = self.args["graph_pooling_type"]
        if self.args["act"] == "leaky_relu":
            act = LeakyReLU()
        elif self.args["act"] == "relu":
            act = ReLU()
        elif self.args["act"] == "elu":
            act = ELU()
        elif self.args["act"] == "tanh":
            act = Tanh()
        else:
            act = ReLU()
        final_dropout = self.args["dropout"]
        output_dim = self.args["num_class"]

        # List of MLPs
        self.ginlayers = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for layer in range(self.num_layers - 1):
            if layer == 0:
                mlp = MLP(self.num_mlp_layers, input_dim, hidden[layer], hidden[layer])
            else:
                mlp = MLP(self.num_mlp_layers, hidden[layer-1], hidden[layer], hidden[layer])

            self.ginlayers.append(
                GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.eps))
            self.batch_norms.append(nn.BatchNorm1d(hidden[layer]))

        # Linear function for graph poolings of output of each layer
        # which maps the output of different layers into a prediction score
        self.linears_prediction = torch.nn.ModuleList()

        for layer in range(self.num_layers):
            if layer == 0:
                self.linears_prediction.append(
                    nn.Linear(input_dim, output_dim))
            else:
                self.linears_prediction.append(
                    nn.Linear(hidden[layer-1], output_dim))

        self.drop = nn.Dropout(final_dropout)

        if graph_pooling_type == 'sum':
            self.pool = SumPooling()
        elif graph_pooling_type == 'mean':
            self.pool = AvgPooling()
        elif graph_pooling_type == 'max':
            self.pool = MaxPooling()
        else:
            raise NotImplementedError