예제 #1
0
파일: gnn.py 프로젝트: WenjinW/PGL
    def __init__(self,
                 num_tasks=1,
                 num_layers=5,
                 emb_dim=300,
                 gnn_type='gin',
                 virtual_node=True,
                 residual=False,
                 drop_ratio=0,
                 JK="last",
                 graph_pooling="sum"):
        '''
            num_tasks (int): number of labels to be predicted
            virtual_node (bool): whether to add virtual node or not
        '''
        super(GNN, self).__init__()

        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks
        self.graph_pooling = graph_pooling

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        ### GNN to generate node embeddings
        if virtual_node:
            self.gnn_node = GNN_node_Virtualnode(num_layers,
                                                 emb_dim,
                                                 JK=JK,
                                                 drop_ratio=drop_ratio,
                                                 residual=residual,
                                                 gnn_type=gnn_type)
        else:
            self.gnn_node = GNN_node(num_layers,
                                     emb_dim,
                                     JK=JK,
                                     drop_ratio=drop_ratio,
                                     residual=residual,
                                     gnn_type=gnn_type)

        ### Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = gnn.GraphPool(pool_type="sum")
        elif self.graph_pooling == "mean":
            self.pool = gnn.GraphPool(pool_type="mean")
        elif self.graph_pooling == "max":
            self.pool = gnn.GraphPool(pool_type="max")
        else:
            raise ValueError("Invalid graph pooling type.")

        self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
예제 #2
0
 def __init__(self, input_size, output_size, num_layers=3):
     super(GNNModel, self).__init__()
     self.conv_fn = nn.LayerList()
     self.conv_fn.append(gnn.GCNConv(input_size, output_size))
     for i in range(num_layers - 1):
         self.conv_fn.append(gnn.GCNConv(output_size, output_size))
     self.pool_fn = gnn.GraphPool("sum")
예제 #3
0
파일: conv.py 프로젝트: WenjinW/PGL
    def __init__(self, config):
        super(GNNVirt, self).__init__()
        log.info("gnn_type is %s" % self.__class__.__name__)
        self.config = config

        self.atom_encoder = getattr(ME, self.config.atom_enc_type, ME.AtomEncoder)(
                self.config.emb_dim)

        self.virtualnode_embedding = self.create_parameter(
            shape=[1, self.config.emb_dim],
            dtype='float32',
            default_initializer=nn.initializer.Constant(value=0.0))

        self.convs = paddle.nn.LayerList()
        self.batch_norms = paddle.nn.LayerList()
        self.mlp_virtualnode_list = paddle.nn.LayerList()

        for layer in range(self.config.num_layers):
            self.convs.append(getattr(L, self.config.layer_type)(self.config))
            self.batch_norms.append(L.batch_norm_1d(self.config.emb_dim))

        for layer in range(self.config.num_layers - 1):
            self.mlp_virtualnode_list.append(
                    nn.Sequential(L.Linear(self.config.emb_dim, self.config.emb_dim), 
                        L.batch_norm_1d(self.config.emb_dim), 
                        nn.Swish(),
                        L.Linear(self.config.emb_dim, self.config.emb_dim), 
                        L.batch_norm_1d(self.config.emb_dim), 
                        nn.Swish())
                    )

        self.pool = gnn.GraphPool(pool_type="sum")
예제 #4
0
    def __init__(self,
                 num_layers,
                 emb_dim,
                 drop_ratio=0.5,
                 JK="last",
                 residual=False,
                 gnn_type='gin'):
        '''
            emb_dim (int): node embedding dimensionality
        '''

        super(GNN_node_Virtualnode, self).__init__()
        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        ### add residual connection or not
        self.residual = residual

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = AtomEncoder(emb_dim)

        ### set the initial virtual node embedding to 0.
        #  self.virtualnode_embedding = paddle.nn.Embedding(1, emb_dim)
        self.virtualnode_embedding = self.create_parameter(
            shape=[1, emb_dim],
            dtype='float32',
            default_initializer=nn.initializer.Constant(value=0.0))

        ### List of GNNs
        self.convs = []
        ### batch norms applied to node embeddings
        self.batch_norms = []

        ### List of MLPs to transform virtual node at every layer
        self.mlp_virtualnode_list = []

        for layer in range(num_layers):
            if gnn_type == 'gin':
                self.convs.append(GINConv(emb_dim))
            elif gnn_type == 'gcn':
                self.convs.append(GCNConv(emb_dim))
            else:
                ValueError('Undefined GNN type called {}'.format(gnn_type))

            self.batch_norms.append(paddle.nn.BatchNorm1D(emb_dim))

        for layer in range(num_layers - 1):
            self.mlp_virtualnode_list.append(
                nn.Sequential(nn.Linear(emb_dim, emb_dim),
                              nn.BatchNorm1D(emb_dim), nn.ReLU(),
                              nn.Linear(emb_dim, emb_dim),
                              nn.BatchNorm1D(emb_dim), nn.ReLU()))

        self.pool = gnn.GraphPool(pool_type="sum")

        self.convs = nn.LayerList(self.convs)
        self.batch_norms = nn.LayerList(self.batch_norms)
        self.mlp_virtualnode_list = nn.LayerList(self.mlp_virtualnode_list)
예제 #5
0
파일: conv.py 프로젝트: WenjinW/PGL
    def __init__(self, config):
        super(JuncGNNVirt, self).__init__()
        log.info("gnn_type is %s" % self.__class__.__name__)
        self.config = config
        self.num_layers = config.num_layers
        self.drop_ratio = config.drop_ratio
        self.JK = config.JK
        self.residual = config.residual
        self.emb_dim = config.emb_dim
        self.gnn_type = config.gnn_type
        self.layer_type = config.layer_type

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = getattr(ME, self.config.atom_enc_type, ME.AtomEncoder)(
                self.emb_dim)

        self.junc_embed = paddle.nn.Embedding(6000, self.emb_dim)

        ### set the initial virtual node embedding to 0.
        #  self.virtualnode_embedding = paddle.nn.Embedding(1, emb_dim)
        #  torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)
        self.virtualnode_embedding = self.create_parameter(
            shape=[1, self.emb_dim],
            dtype='float32',
            default_initializer=nn.initializer.Constant(value=0.0))

        ### List of GNNs
        self.convs = nn.LayerList()
        ### batch norms applied to node embeddings
        self.batch_norms = nn.LayerList()

        ### List of MLPs to transform virtual node at every layer
        self.mlp_virtualnode_list = nn.LayerList()

        self.junc_convs = nn.LayerList()

        for layer in range(self.num_layers):
            self.convs.append(getattr(L, self.layer_type)(self.config))
            self.junc_convs.append(gnn.GINConv(self.emb_dim, self.emb_dim))

            self.batch_norms.append(L.batch_norm_1d(self.emb_dim))

        for layer in range(self.num_layers - 1):
            self.mlp_virtualnode_list.append(
                    nn.Sequential(L.Linear(self.emb_dim, self.emb_dim), 
                        L.batch_norm_1d(self.emb_dim), 
                        nn.Swish(),
                        L.Linear(self.emb_dim, self.emb_dim), 
                        L.batch_norm_1d(self.emb_dim), 
                        nn.Swish())
                    )

        self.pool = gnn.GraphPool(pool_type="sum")
예제 #6
0
파일: conv.py 프로젝트: WenjinW/PGL
    def __init__(self, config, with_efeat=False):
        super(LiteGEM, self).__init__()
        log.info("gnn_type is %s" % self.__class__.__name__)

        self.config = config
        self.with_efeat = with_efeat
        self.num_layers = config.num_layers
        self.drop_ratio = config.drop_ratio
        self.virtual_node = config.virtual_node
        self.emb_dim = config.emb_dim
        self.norm = config.norm

        self.gnns = paddle.nn.LayerList()
        self.norms = paddle.nn.LayerList()

        if self.virtual_node:
            log.info("using virtual node in %s" % self.__class__.__name__)
            self.mlp_virtualnode_list = paddle.nn.LayerList()

            self.virtualnode_embedding = self.create_parameter(
                shape=[1, self.emb_dim],
                dtype='float32',
                default_initializer=nn.initializer.Constant(value=0.0))

            for layer in range(self.num_layers - 1):
                self.mlp_virtualnode_list.append(L.MLP([self.emb_dim] * 3,
                                                       norm=self.norm))

        for layer in range(self.num_layers):
            self.gnns.append(L.LiteGEMConv(config, with_efeat=not self.with_efeat))
            self.norms.append(L.norm_layer(self.norm, self.emb_dim))

        self.atom_encoder = getattr(ME, self.config.atom_enc_type, ME.AtomEncoder)(
                emb_dim=self.emb_dim)
        if self.config.exfeat:
            self.atom_encoder_float = ME.AtomEncoderFloat(emb_dim=self.emb_dim)

        if self.with_efeat:
            self.bond_encoder = getattr(ME, self.config.bond_enc_type, ME.BondEncoder)(
                    emb_dim=self.emb_dim)

        self.pool = gnn.GraphPool(pool_type="sum")

        if self.config.appnp_k is not None:
            self.appnp = gnn.APPNP(k_hop=self.config.appnp_k, alpha=self.config.appnp_a)

        if self.config.graphnorm is not None:
            self.gn = gnn.GraphNorm()
예제 #7
0
    def __init__(self,
                 num_layers,
                 emb_dim,
                 drop_ratio=0.5,
                 JK="last",
                 residual=False,
                 gnn_type='gin'):
        '''
            emb_dim (int): node embedding dimensionality
            num_layers (int): number of GNN message passing layers
        '''

        super(GNN_node, self).__init__()
        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        ### add residual connection or not
        self.residual = residual

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = AtomEncoder(emb_dim)

        ###List of GNNs
        self.convs = []
        self.batch_norms = []

        for layer in range(num_layers):
            if gnn_type == 'gin':
                self.convs.append(GINConv(emb_dim))
            elif gnn_type == 'gcn':
                self.convs.append(GCNConv(emb_dim))
            else:
                ValueError('Undefined GNN type called {}'.format(gnn_type))

            self.batch_norms.append(paddle.nn.BatchNorm1D(emb_dim))

        self.pool = gnn.GraphPool(pool_type="sum")
        self.convs = nn.LayerList(self.convs)
        self.batch_norms = nn.LayerList(self.batch_norms)
예제 #8
0
    def __init__(self, config, with_efeat=True):
        super(LiteGEM, self).__init__()
        log.info("gnn_type is %s" % self.__class__.__name__)

        self.config = config
        self.with_efeat = with_efeat
        self.num_layers = config["num_layers"]
        self.drop_ratio = config["dropout_rate"]
        self.virtual_node = config["virtual_node"]
        self.emb_dim = config["emb_dim"]
        self.norm = config["norm"]
        self.num_tasks = config["num_tasks"]

        self.atom_names = config["atom_names"]
        self.atom_float_names = config["atom_float_names"]
        self.bond_names = config["bond_names"]
        self.gnns = paddle.nn.LayerList()
        self.norms = paddle.nn.LayerList()

        if self.virtual_node:
            log.info("using virtual node in %s" % self.__class__.__name__)
            self.mlp_virtualnode_list = paddle.nn.LayerList()

            self.virtualnode_embedding = self.create_parameter(
                shape=[1, self.emb_dim],
                dtype='float32',
                default_initializer=nn.initializer.Constant(value=0.0))

            for layer in range(self.num_layers - 1):
                self.mlp_virtualnode_list.append(
                    MLP([self.emb_dim] * 3, norm=self.norm))

        for layer in range(self.num_layers):
            self.gnns.append(
                LiteGEMConv(config, with_efeat=not self.with_efeat))
            self.norms.append(norm_layer(self.norm, self.emb_dim))

        self.atom_embedding = AtomEmbedding(self.atom_names, self.emb_dim)
        self.atom_float_embedding = AtomFloatEmbedding(self.atom_float_names,
                                                       self.emb_dim)

        if self.with_efeat:
            self.init_bond_embedding = BondEmbedding(self.config["bond_names"],
                                                     self.emb_dim)

        self.pool = gnn.GraphPool(pool_type="sum")

        if not self.config["graphnorm"]:
            self.gn = gnn.GraphNorm()

        hidden_size = self.emb_dim

        if self.config["clf_layers"] == 3:
            log.info("clf_layers is 3")
            self.graph_pred_linear = nn.Sequential(
                Linear(hidden_size, hidden_size // 2),
                batch_norm_1d(hidden_size // 2), nn.Swish(),
                Linear(hidden_size // 2, hidden_size // 4),
                batch_norm_1d(hidden_size // 4), nn.Swish(),
                Linear(hidden_size // 4, self.num_tasks))
        elif self.config["clf_layers"] == 2:
            log.info("clf_layers is 2")
            self.graph_pred_linear = nn.Sequential(
                Linear(hidden_size, hidden_size // 2),
                batch_norm_1d(hidden_size // 2), nn.Swish(),
                Linear(hidden_size // 2, self.num_tasks))
        else:
            self.graph_pred_linear = Linear(hidden_size, self.num_tasks)
예제 #9
0
파일: model.py 프로젝트: WenjinW/PGL
 def __init__(self, pool_type=None):
     super().__init__()
     self.pool_type = pool_type
     self.pool_fun = gnn.GraphPool("sum")
예제 #10
0
파일: model.py 프로젝트: Yelrose/PGL
 def __init__(self, fun):
     super().__init__()
     self.pool_type = fun
     self.pool_fun = gnn.GraphPool()