示例#1
0
文件: gnn.py 项目: WenjinW/PGL
    def __init__(self,
                 num_tasks=1,
                 num_layers=5,
                 emb_dim=300,
                 gnn_type='gin',
                 virtual_node=True,
                 residual=False,
                 drop_ratio=0,
                 JK="last",
                 graph_pooling="sum"):
        '''
            num_tasks (int): number of labels to be predicted
            virtual_node (bool): whether to add virtual node or not
        '''
        super(GNN, self).__init__()

        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks
        self.graph_pooling = graph_pooling

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        ### GNN to generate node embeddings
        if virtual_node:
            self.gnn_node = GNN_node_Virtualnode(num_layers,
                                                 emb_dim,
                                                 JK=JK,
                                                 drop_ratio=drop_ratio,
                                                 residual=residual,
                                                 gnn_type=gnn_type)
        else:
            self.gnn_node = GNN_node(num_layers,
                                     emb_dim,
                                     JK=JK,
                                     drop_ratio=drop_ratio,
                                     residual=residual,
                                     gnn_type=gnn_type)

        ### Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = gnn.GraphPool(pool_type="sum")
        elif self.graph_pooling == "mean":
            self.pool = gnn.GraphPool(pool_type="mean")
        elif self.graph_pooling == "max":
            self.pool = gnn.GraphPool(pool_type="max")
        else:
            raise ValueError("Invalid graph pooling type.")

        self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
示例#2
0
    def __init__(self, num_tasks = 1, num_layers = 5, emb_dim = 300, gnn_type = 'gin',
                 virtual_node = True, residual = False, drop_ratio = 0, JK = "last",
                 graph_pooling = "sum"):
        '''
            num_tasks (int): number of labels to be predicted
            virtual_node (bool): whether to add virtual node or not
        '''
        super(GNN, self).__init__()

        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks
        self.graph_pooling = graph_pooling

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        ### GNN to generate node embeddings
        if virtual_node:
            self.gnn_node = GNN_node_Virtualnode(num_layers, emb_dim, JK = JK,
                                                 drop_ratio = drop_ratio,
                                                 residual = residual,
                                                 gnn_type = gnn_type)
        else:
            self.gnn_node = GNN_node(num_layers, emb_dim, JK = JK, drop_ratio = drop_ratio,
                                     residual = residual, gnn_type = gnn_type)


        ### Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = SumPooling()
        elif self.graph_pooling == "mean":
            self.pool = AvgPooling()
        elif self.graph_pooling == "max":
            self.pool = MaxPooling
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttentionPooling(
                gate_nn = nn.Sequential(nn.Linear(emb_dim, 2*emb_dim),
                                        nn.BatchNorm1d(2*emb_dim),
                                        nn.ReLU(),
                                        nn.Linear(2*emb_dim, 1)))

        elif self.graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, n_iters = 2, n_layers = 2)
        else:
            raise ValueError("Invalid graph pooling type.")

        if graph_pooling == "set2set":
            self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
        else:
            self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
示例#3
0
文件: gnn.py 项目: rpatil524/ogb
    def __init__(self, num_vocab, max_seq_len, node_encoder, num_layer = 5, emb_dim = 300, 
                    gnn_type = 'gin', virtual_node = True, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"):
        '''
            num_tasks (int): number of labels to be predicted
            virtual_node (bool): whether to add virtual node or not
        '''

        super(GNN, self).__init__()

        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_vocab = num_vocab
        self.max_seq_len = max_seq_len
        self.graph_pooling = graph_pooling

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        ### GNN to generate node embeddings
        if virtual_node:
            self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, node_encoder, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type)
        else:
            self.gnn_node = GNN_node(num_layer, emb_dim, node_encoder, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type)


        ### Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = global_add_pool
        elif self.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif self.graph_pooling == "max":
            self.pool = global_max_pool
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1)))
        elif self.graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps = 2)
        else:
            raise ValueError("Invalid graph pooling type.")

        self.graph_pred_linear_list = torch.nn.ModuleList()

        if graph_pooling == "set2set":
            for i in range(max_seq_len):
                 self.graph_pred_linear_list.append(torch.nn.Linear(2*emb_dim, self.num_vocab))

        else:
            for i in range(max_seq_len):
                 self.graph_pred_linear_list.append(torch.nn.Linear(emb_dim, self.num_vocab))
示例#4
0
    def __init__(self,
                 num_tasks,
                 num_layer=5,
                 emb_dim=300,
                 gnn_type='gin',
                 virtual_node=True,
                 residual=False,
                 drop_ratio=0.5,
                 drop_path_p=0.01,
                 JK="last",
                 graph_pooling="mean",
                 net_linear=False,
                 net_seed=47,
                 edge_p=0.6):
        '''
            num_tasks (int): number of labels to be predicted
            virtual_node (bool): whether to add virtual node or not
        '''

        super(GNN, self).__init__()

        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks
        self.graph_pooling = graph_pooling

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        ### GNN to generate node embeddings
        if gnn_type == 'randomgin':
            if virtual_node:
                self.gnn_node = RandomGNN_node_Virtualnode(
                    num_layer,
                    emb_dim,
                    JK=JK,
                    drop_ratio=drop_ratio,
                    residual=residual,
                    gnn_type=gnn_type,
                    drop_path_p=drop_path_p,
                    net_linear=net_linear,
                    net_seed=net_seed,
                    edge_p=edge_p)
            else:
                self.gnn_node = RandomGNN_node(num_layer,
                                               emb_dim,
                                               JK=JK,
                                               drop_ratio=drop_ratio,
                                               residual=residual,
                                               gnn_type=gnn_type,
                                               drop_path_p=drop_path_p,
                                               net_linear=net_linear,
                                               net_seed=net_seed,
                                               edge_p=edge_p)
        else:
            if virtual_node:
                self.gnn_node = GNN_node_Virtualnode(num_layer,
                                                     emb_dim,
                                                     JK=JK,
                                                     drop_ratio=drop_ratio,
                                                     residual=residual,
                                                     gnn_type=gnn_type)
            else:
                self.gnn_node = GNN_node(num_layer,
                                         emb_dim,
                                         JK=JK,
                                         drop_ratio=drop_ratio,
                                         residual=residual,
                                         gnn_type=gnn_type)

        ### Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = global_add_pool
        elif self.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif self.graph_pooling == "max":
            self.pool = global_max_pool
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn=torch.nn.Sequential(
                torch.nn.Linear(emb_dim, 2 *
                                emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
                torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, 1)))
        elif self.graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps=2)
        else:
            raise ValueError("Invalid graph pooling type.")

        if graph_pooling == "set2set":
            self.graph_pred_linear = torch.nn.Linear(2 * self.emb_dim,
                                                     self.num_tasks)
        else:
            self.graph_pred_linear = torch.nn.Linear(self.emb_dim,
                                                     self.num_tasks)
示例#5
0
    def __init__(self,
                 num_tasks,
                 num_layer=5,
                 emb_dim=300,
                 gnn_type='gin',
                 virtual_node=True,
                 residual=False,
                 drop_ratio=0.5,
                 JK="last",
                 graph_pooling="mean",
                 transformers=False,
                 controller=False):
        '''
            num_tasks (int): number of labels to be predicted
            virtual_node (bool): whether to add virtual node or not
        '''

        super(GNN, self).__init__()

        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks
        self.graph_pooling = graph_pooling
        if transformers:
            self.transformer_layers = torch.nn.ModuleList([
                torch.nn.TransformerEncoderLayer(d_model=emb_dim,
                                                 dim_feedforward=emb_dim * 4,
                                                 nhead=4) for _ in range(2)
            ])
        if controller:
            self.controller_layers = torch.nn.ModuleList([
                ControllerTransformer(depth=1,
                                      expansion_ratio=4,
                                      n_heads=4,
                                      s2g_sharing=True,
                                      in_features=300,
                                      out_features=1,
                                      set_fn_feats=[128, 128, 128, 128, 5],
                                      method='lin2',
                                      hidden_mlp=[128],
                                      predict_diagonal=False,
                                      attention=True) for _ in range(1)
            ])
        self.transformers = transformers
        self.controller = controller

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        ### GNN to generate node embeddings
        if virtual_node:
            self.gnn_node = GNN_node_Virtualnode(num_layer,
                                                 emb_dim,
                                                 JK=JK,
                                                 drop_ratio=drop_ratio,
                                                 residual=residual,
                                                 gnn_type=gnn_type)
        else:
            self.gnn_node = GNN_node(num_layer,
                                     emb_dim,
                                     JK=JK,
                                     drop_ratio=drop_ratio,
                                     residual=residual,
                                     gnn_type=gnn_type)

        ### Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = global_add_pool
        elif self.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif self.graph_pooling == "max":
            self.pool = global_max_pool
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn=torch.nn.Sequential(
                torch.nn.Linear(emb_dim, 2 *
                                emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
                torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, 1)))
        elif self.graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps=2)
        else:
            raise ValueError("Invalid graph pooling type.")

        if graph_pooling == "set2set":
            self.graph_pred_linear = torch.nn.Linear(2 * self.emb_dim,
                                                     self.num_tasks)
        else:
            self.graph_pred_linear = torch.nn.Linear(self.emb_dim,
                                                     self.num_tasks)
示例#6
0
文件: gnn.py 项目: giannipele/ogb
    def __init__(self,
                 num_tasks,
                 num_layer=5,
                 emb_dim=300,
                 gnn_type='gin',
                 virtual_node=True,
                 residual=False,
                 drop_ratio=0.5,
                 JK="last",
                 graph_pooling="mean",
                 laf_fun='mean',
                 laf_layers='false',
                 device='cuda',
                 lafgrad=True):
        '''
            num_tasks (int): number of labels to be predicted
            virtual_node (bool): whether to add virtual node or not
        '''

        super(GNN, self).__init__()

        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks
        self.graph_pooling = graph_pooling

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        ### GNN to generate node embeddings
        if virtual_node:
            self.gnn_node = GNN_node_Virtualnode(num_layer,
                                                 emb_dim,
                                                 JK=JK,
                                                 drop_ratio=drop_ratio,
                                                 residual=residual,
                                                 gnn_type=gnn_type,
                                                 laf_fun=laf_fun,
                                                 laf_layers=laf_layers,
                                                 device=device)
        else:
            self.gnn_node = GNN_node(num_layer,
                                     emb_dim,
                                     JK=JK,
                                     drop_ratio=drop_ratio,
                                     residual=residual,
                                     gnn_type=gnn_type,
                                     laf_fun=laf_fun,
                                     laf_layers=laf_layers,
                                     device=device)

        ### Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = global_add_pool
        elif self.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif self.graph_pooling == "max":
            self.pool = global_max_pool
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn=torch.nn.Sequential(
                torch.nn.Linear(emb_dim, 2 *
                                emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
                torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, 1)))
        elif self.graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps=2)
        elif self.graph_pooling == "laf":
            if laf_fun == 'exp':
                self.pool = ScatterExponentialLAF(device=device)
            else:
                self.pool = ScatterAggregationLayer(function=laf_fun,
                                                    grad=lafgrad,
                                                    device=device)
        else:
            raise ValueError("Invalid graph pooling type.")

        if graph_pooling == "set2set":
            self.graph_pred_linear = torch.nn.Linear(2 * self.emb_dim,
                                                     self.num_tasks)
        else:
            self.graph_pred_linear = torch.nn.Linear(self.emb_dim,
                                                     self.num_tasks)