def __init__(self, input_size=1, num_classes=2): super(GIN, self).__init__() self.conv1 = GINConv(nn.Linear(input_size, num_classes), aggregator_type='sum') self.conv2 = GINConv(nn.Linear(num_classes, num_classes), aggregator_type='sum') self.pool = SumPooling()
def __init__(self, num_tasks = 1, num_layers = 5, emb_dim = 300, gnn_type = 'gin', virtual_node = True, residual = False, drop_ratio = 0, JK = "last", graph_pooling = "sum"): ''' num_tasks (int): number of labels to be predicted virtual_node (bool): whether to add virtual node or not ''' super(GNN, self).__init__() self.num_layers = num_layers self.drop_ratio = drop_ratio self.JK = JK self.emb_dim = emb_dim self.num_tasks = num_tasks self.graph_pooling = graph_pooling if self.num_layers < 2: raise ValueError("Number of GNN layers must be greater than 1.") ### GNN to generate node embeddings if virtual_node: self.gnn_node = GNN_node_Virtualnode(num_layers, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) else: self.gnn_node = GNN_node(num_layers, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) ### Pooling function to generate whole-graph embeddings if self.graph_pooling == "sum": self.pool = SumPooling() elif self.graph_pooling == "mean": self.pool = AvgPooling() elif self.graph_pooling == "max": self.pool = MaxPooling elif self.graph_pooling == "attention": self.pool = GlobalAttentionPooling( gate_nn = nn.Sequential(nn.Linear(emb_dim, 2*emb_dim), nn.BatchNorm1d(2*emb_dim), nn.ReLU(), nn.Linear(2*emb_dim, 1))) elif self.graph_pooling == "set2set": self.pool = Set2Set(emb_dim, n_iters = 2, n_layers = 2) else: raise ValueError("Invalid graph pooling type.") if graph_pooling == "set2set": self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks) else: self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
def __init__(self, input_dim, output_dim, hidden_layers, activation, activation_params, dropout_p, pooling="avg"): super(GCN, self).__init__() # Checks num_layers = len(hidden_layers) if num_layers < 1: raise ValueError( f"You must have at least one hidden layer. You passed {num_layers}: {hidden_layers}" ) if activation == "relu": self.activation = nn.ReLU(**activation_params) elif activation == "leaky": self.activation = nn.LeakyReLU(**activation_params) elif activation == "elu": self.activation = nn.ELU(**activation_params) elif activation == "selu": self.activation = nn.SELU(**activation_params) elif activation == "tanh": self.activation = nn.Tanh() else: raise NotImplementedError(f"Unknown activation method: {pooling}") self.layers = nn.ModuleList() # Input layer self.layers.append(GraphConv(input_dim, hidden_layers[0])) self.bn_layer = nn.BatchNorm1d(input_dim) # Hidden layers for i in range(num_layers - 1): self.layers.append( GraphConv(hidden_layers[i], hidden_layers[i + 1])) # Additional layers if pooling == "avg": self.pool = AvgPooling() elif pooling == "sum": self.pool = SumPooling() elif pooling == "max": self.pool = MaxPooling() else: raise NotImplementedError(f"Unknown pooling method: {pooling}") self.linear = nn.Linear(hidden_layers[-1], output_dim) self.dropout = nn.Dropout(p=dropout_p)
def __init__(self, num_layers, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'): ''' num_layers (int): number of GNN message passing layers emb_dim (int): node embedding dimensionality ''' super(GNN_node_Virtualnode, self).__init__() self.num_layers = num_layers self.drop_ratio = drop_ratio self.JK = JK ### add residual connection or not self.residual = residual if self.num_layers < 2: raise ValueError("Number of GNN layers must be greater than 1.") self.atom_encoder = AtomEncoder(emb_dim) ### set the initial virtual node embedding to 0. self.virtualnode_embedding = nn.Embedding(1, emb_dim) nn.init.constant_(self.virtualnode_embedding.weight.data, 0) ### List of GNNs self.convs = nn.ModuleList() ### batch norms applied to node embeddings self.batch_norms = nn.ModuleList() ### List of MLPs to transform virtual node at every layer self.mlp_virtualnode_list = nn.ModuleList() for layer in range(num_layers): if gnn_type == 'gin': self.convs.append(GINConv(emb_dim)) elif gnn_type == 'gcn': self.convs.append(GCNConv(emb_dim)) else: ValueError('Undefined GNN type called {}'.format(gnn_type)) self.batch_norms.append(nn.BatchNorm1d(emb_dim)) for layer in range(num_layers - 1): self.mlp_virtualnode_list.append(nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU())) self.pool = SumPooling()
def __init__(self, in_dim, hidden_dim, out_dim, num_heads, pooling="avg"): super(GAT, self).__init__() self.layer1 = GATConv(in_dim, hidden_dim, num_heads) self.layer2 = GATConv(hidden_dim * num_heads, hidden_dim * num_heads, num_heads) self.linear1 = nn.Linear(hidden_dim * num_heads, hidden_dim * num_heads) self.linear2 = nn.Linear(hidden_dim * num_heads, out_dim) # Additional layers if pooling == "avg": self.pool = AvgPooling() elif pooling == "sum": self.pool = SumPooling() elif pooling == "max": self.pool = MaxPooling() else: raise NotImplementedError(f"Unknown pooling method: {pooling}")
def __init__(self, node_in_feats, edge_in_feats, node_hidden_feats=500, num_encode_gnn_layers=3): super(WLNReactionRanking, self).__init__() self.gnn = WLN(node_in_feats=node_in_feats, edge_in_feats=edge_in_feats, node_out_feats=node_hidden_feats, n_layers=num_encode_gnn_layers, set_comparison=False) self.diff_gnn = WLN(node_in_feats=node_hidden_feats, edge_in_feats=edge_in_feats, node_out_feats=node_hidden_feats, n_layers=1, project_in_feats=False, set_comparison=False) self.readout = SumPooling() self.predict = nn.Sequential( nn.Linear(node_hidden_feats, node_hidden_feats), nn.ReLU(), nn.Linear(node_hidden_feats, 1))
def __init__(self, flen, dropout_rate, intermediate_rep=128, n_hidden=256, n_layers=6): super(GCN, self).__init__() self.layers = nn.ModuleList() # input layer self.layers.append(GraphConv(flen, n_hidden, activation=nn.ReLU())) # hidden layers for i in range(n_layers - 1): self.layers.append(GraphConv(n_hidden, n_hidden, activation=nn.ReLU())) # output layer self.layers.append(GraphConv(n_hidden, n_hidden)) self.linear_model = nn.Sequential( nn.Linear(n_hidden, n_hidden), nn.ReLU(), nn.Linear(n_hidden, intermediate_rep) ) self.dropout = nn.Dropout(p=dropout_rate) self.pooling = SumPooling()
def __init__(self, num_layers, hidden_units, gcn_type='gcn', pooling_type='sum', node_attributes=None, edge_weights=None, node_embedding=None, use_embedding=False, num_nodes=None, dropout=0.5, max_z=1000): super(GCN, self).__init__() self.num_layers = num_layers self.dropout = dropout self.pooling_type = pooling_type self.use_attribute = False if node_attributes is None else True self.use_embedding = use_embedding self.use_edge_weight = False if edge_weights is None else True self.z_embedding = nn.Embedding(max_z, hidden_units) if node_attributes is not None: self.node_attributes_lookup = nn.Embedding.from_pretrained( node_attributes) self.node_attributes_lookup.weight.requires_grad = False if edge_weights is not None: self.edge_weights_lookup = nn.Embedding.from_pretrained( edge_weights) self.edge_weights_lookup.weight.requires_grad = False if node_embedding is not None: self.node_embedding = nn.Embedding.from_pretrained(node_embedding) self.node_embedding.weight.requires_grad = False elif use_embedding: self.node_embedding = nn.Embedding(num_nodes, hidden_units) initial_dim = hidden_units if self.use_attribute: initial_dim += self.node_attributes_lookup.embedding_dim if self.use_embedding: initial_dim += self.node_embedding.embedding_dim self.layers = nn.ModuleList() if gcn_type == 'gcn': self.layers.append( GraphConv(initial_dim, hidden_units, allow_zero_in_degree=True)) for _ in range(num_layers - 1): self.layers.append( GraphConv(hidden_units, hidden_units, allow_zero_in_degree=True)) elif gcn_type == 'sage': self.layers.append( SAGEConv(initial_dim, hidden_units, aggregator_type='gcn')) for _ in range(num_layers - 1): self.layers.append( SAGEConv(hidden_units, hidden_units, aggregator_type='gcn')) else: raise ValueError('Gcn type error.') self.linear_1 = nn.Linear(hidden_units, hidden_units) self.linear_2 = nn.Linear(hidden_units, 1) if pooling_type != 'sum': raise ValueError('Pooling type error.') self.pooling = SumPooling()