class GCNWithJK(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super(GCNWithJK, self).__init__() self.conv1 = GCNConv(dataset.num_features, hidden) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(GCNConv(hidden, hidden)) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(num_layers * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) xs = [x] for conv in self.convs: x = F.relu(conv(x, edge_index)) xs += [x] x = self.jump(xs) x = global_mean_pool(x, batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class DiffPool(torch.nn.Module): def __init__(self, args, num_nodes=10, num_layers=4, hidden=16, ratio=0.25): super(DiffPool, self).__init__() self.args = args num_features = self.args.filters_3 self.att = DenseAttentionModule(self.args) num_nodes = ceil(ratio * num_nodes) self.embed_block1 = Block(num_features, hidden, hidden) self.pool_block1 = Block(num_features, hidden, num_nodes) self.embed_blocks = torch.nn.ModuleList() self.pool_blocks = torch.nn.ModuleList() for i in range((num_layers // 2) - 1): num_nodes = ceil(ratio * num_nodes) self.embed_blocks.append(Block(hidden, hidden, hidden)) self.pool_blocks.append(Block(hidden, hidden, num_nodes)) self.jump = JumpingKnowledge(mode="cat") self.lin1 = Linear((len(self.embed_blocks) + 1) * hidden, hidden) self.lin2 = Linear(hidden, num_features) def reset_parameters(self): self.embed_block1.reset_parameters() self.pool_block1.reset_parameters() for block1, block2 in zip(self.embed_blocks, self.pool_blocks): block1.reset_parameters() block2.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, x, adj, mask): s = self.pool_block1(x, adj, mask, add_loop=True) x = F.relu(self.embed_block1(x, adj, mask, add_loop=True)) xs = [self.att(x, mask)] x, adj, _, _ = dense_diff_pool(x, adj, s, mask) for i, (embed, pool) in enumerate(zip(self.embed_blocks, self.pool_blocks)): s = pool(x, adj) x = F.relu(embed(x, adj)) xs.append(self.att(x)) if i < (len(self.embed_blocks) - 1): x, adj, _, _ = dense_diff_pool(x, adj, s) x = self.jump(xs) x = F.relu(self.lin1(x)) x = self.lin2(x) return x def __repr__(self): return self.__class__.__name__
class DiffPool(torch.nn.Module): def __init__(self, dataset, num_pools, hidden, ratio=0.25): super(DiffPool, self).__init__() self.num_pools, self.hidden = num_pools, hidden num_nodes = ceil(ratio * dataset[0].num_nodes) self.embed_block1 = Block(dataset.num_features, hidden, hidden) self.pool_block1 = Block(dataset.num_features, hidden, num_nodes) self.embed_blocks = torch.nn.ModuleList() self.pool_blocks = torch.nn.ModuleList() for i in range(num_pools - 1): num_nodes = ceil(ratio * num_nodes) self.embed_blocks.append(Block(hidden, hidden, hidden)) self.pool_blocks.append(Block(hidden, hidden, num_nodes)) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear((len(self.embed_blocks) + 1) * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.embed_block1.reset_parameters() self.pool_block1.reset_parameters() for embed_block, pool_block in zip(self.embed_blocks, self.pool_blocks): embed_block.reset_parameters() pool_block.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): # x:[batch_size,num_nodes,in_channels] x, adj, mask = data.x, data.adj, data.mask # x:[batch_size, num_nodes, c_num_nodes] s = self.pool_block1(x, adj, mask, add_loop=True) # s:[batch_size, num_nodes, hidden] x = F.relu(self.embed_block1(x, adj, mask, add_loop=True)) xs = [x.mean(dim=1)] # x:[batch_size, c_num_nodes, hidden] x, adj, _, _ = dense_diff_pool(x, adj, s, mask) # adj: [batch_size,c_num_nodes, c_num_nodes] for i, (embed_block, pool_block) in enumerate( zip(self.embed_blocks, self.pool_blocks)): # s: [batch_size,c_num_nodes, cc_num_nodes] s = pool_block(x, adj) # x: [batch_size,c_num_nodes,hidden] x = F.relu(embed_block(x, adj)) xs.append(x.mean(dim=1)) if i < len(self.embed_blocks) - 1: # x: [batch_size,cc_num_nodes, hidden] x, adj, _, _ = dense_diff_pool(x, adj, s) # adj: [batch_size,cc_num_nodes,cc_num_nodes] x = self.jump(xs) # x: [batch_size,len(self.embed_blocks)+1)*hidden] x = F.relu(self.lin1(x)) # x: [batch_size,hidden] x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) # x: [batch_size,dataset.num_classes] return F.log_softmax(x, dim=-1)
class Coarsening(torch.nn.Module): def __init__(self, dataset, hidden, ratio=0.25): # we only use 1 layer for coarsening super(Coarsening, self).__init__() # self.embed_block1 = GNNBlock(dataset.num_features, hidden, hidden) self.embed_block1 = DenseGCNConv(dataset.num_features, hidden) self.coarse_block1 = CoarsenBlock(hidden, ratio) self.embed_block2 = DenseGCNConv(hidden, dataset.num_features) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(hidden + dataset.num_features, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.embed_block1.reset_parameters() self.coarse_block1.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data, epsilon=0.01, opt_epochs=100): x, adj, mask = data.x, data.adj, data.mask batch_num_nodes = data.mask.sum(-1) x1 = F.relu(self.embed_block1(x, adj, mask, add_loop=True)) # xs = [x1.mean(dim=1)] coarse_x, new_adj, S = self.coarse_block1(x1, adj, batch_num_nodes) xs = [coarse_x.mean(dim=1)] x2 = F.tanh(self.embed_block2(coarse_x, new_adj, mask, add_loop=True)) xs.append(x2.mean(dim=1)) opt_loss = 0.0 for i in range(len(x)): x3 = self.get_nonzero_rows(x[i]) x4 = self.get_nonzero_rows(x2[i]) # if x3.size()[0]==0 or x4.size()[0]==0: # continue # opt_loss += sinkhorn_loss_default(x3, x4, epsilon, niter=opt_epochs).float() opt_loss += sinkhorn_loss_default(x3, x2[i], epsilon, niter=opt_epochs) return xs, new_adj, S, opt_loss def predict(self, xs): x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def get_nonzero_rows(self, M):# M is a matrix # row_ind = M.sum(-1).nonzero().squeeze() #nonzero has bugs in Pytorch 1.2.0......... #So we use other methods to take place of it MM, MM_ind = M.sum(-1).sort() N = (M.sum(-1)>0).sum() return M[MM_ind[:N]] def __repr__(self): return self.__class__.__name__
class DiffPool(torch.nn.Module): def __init__(self, dataset, num_layers, hidden, ratio=0.25): super(DiffPool, self).__init__() num_nodes = ceil(ratio * dataset[0].num_nodes) self.embed_block1 = Block(dataset.num_features, hidden, hidden) self.pool_block1 = Block(dataset.num_features, hidden, num_nodes) self.embed_blocks = torch.nn.ModuleList() self.pool_blocks = torch.nn.ModuleList() for i in range((num_layers // 2) - 1): num_nodes = ceil(ratio * num_nodes) self.embed_blocks.append(Block(hidden, hidden, hidden)) self.pool_blocks.append(Block(hidden, hidden, num_nodes)) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear((len(self.embed_blocks) + 1) * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.embed_block1.reset_parameters() self.pool_block1.reset_parameters() for embed_block, pool_block in zip(self.embed_blocks, self.pool_blocks): embed_block.reset_parameters() pool_block.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, adj, mask = data.x, data.adj, data.mask link_losses = 0. ent_losses = 0. s = self.pool_block1(x, adj, mask, add_loop=True) x = F.relu(self.embed_block1(x, adj, mask, add_loop=True)) xs = [x.mean(dim=1)] x, adj, link_loss, ent_loss = dense_diff_pool(x, adj, s, mask) link_losses += link_loss ent_losses += ent_loss for i, (embed_block, pool_block) in enumerate( zip(self.embed_blocks, self.pool_blocks)): s = pool_block(x, adj) x = F.relu(embed_block(x, adj)) xs.append(x.mean(dim=1)) if i < len(self.embed_blocks) - 1: x, adj, link_loss, ent_loss = dense_diff_pool(x, adj, s) link_losses += link_loss ent_losses += ent_loss x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1), link_losses + ent_losses def __repr__(self): return self.__class__.__name__
class GATJK(nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, dropout=0.5, heads=2, jk_type='max'): super(GATJK, self).__init__() self.convs = nn.ModuleList() self.convs.append( GATConv(in_channels, hidden_channels, heads=heads, concat=True)) self.bns = nn.ModuleList() self.bns.append(nn.BatchNorm1d(hidden_channels*heads)) for _ in range(num_layers - 2): self.convs.append( GATConv(hidden_channels*heads, hidden_channels, heads=heads, concat=True) ) self.bns.append(nn.BatchNorm1d(hidden_channels*heads)) self.convs.append( GATConv(hidden_channels*heads, hidden_channels, heads=heads)) self.dropout = dropout self.activation = F.elu # note: uses elu self.jump = JumpingKnowledge(jk_type, channels=hidden_channels*heads, num_layers=1) if jk_type == 'cat': self.final_project = nn.Linear(hidden_channels*heads*num_layers, out_channels) else: # max or lstm self.final_project = nn.Linear(hidden_channels*heads, out_channels) def reset_parameters(self): for conv in self.convs: conv.reset_parameters() for bn in self.bns: bn.reset_parameters() self.jump.reset_parameters() self.final_project.reset_parameters() def forward(self, data): x = data.graph['node_feat'] xs = [] for i, conv in enumerate(self.convs[:-1]): x = conv(x, data.graph['edge_index']) x = self.bns[i](x) x = self.activation(x) xs.append(x) x = F.dropout(x, p=self.dropout, training=self.training) x = self.convs[-1](x, data.graph['edge_index']) xs.append(x) x = self.jump(xs) x = self.final_project(x) return x
class GIN(nn.Module): def __init__(self, dataset, num_layers, hidden, train_eps=False, mode='cat'): super().__init__() self.conv1 = GINConv(nn.Sequential( nn.Linear(dataset.num_features, hidden), nn.ReLU(), nn.Linear(hidden, hidden), nn.ReLU(), nn.BatchNorm1d(hidden), ), train_eps=train_eps) self.convs = nn.ModuleList() for i in range(num_layers - 1): self.convs.append( GINConv(nn.Sequential( nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden), nn.ReLU(), nn.BatchNorm1d(hidden), ), train_eps=train_eps)) self.jump = JumpingKnowledge(mode) if mode == 'cat': self.lin1 = nn.Linear(num_layers * hidden, hidden) else: self.lin1 = nn.Linear(hidden, hidden) self.lin2 = nn.Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = self.conv1(x, edge_index) xs = [x] for conv in self.convs: x = conv(x, edge_index) xs += [x] x = self.jump(xs) x = global_mean_pool(x, batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return x def __repr__(self): return self.__class__.__name__
class DiffPool(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super(DiffPool, self).__init__() num_nodes = ceil(0.25 * dataset[0].num_nodes) self.embed_block1 = Block(dataset.num_features, hidden, hidden) self.pool_block1 = Block(dataset.num_features, hidden, num_nodes) self.embed_blocks = torch.nn.ModuleList() self.pool_blocks = torch.nn.ModuleList() for i in range((num_layers // 2) - 1): num_nodes = ceil(0.25 * num_nodes) self.embed_blocks.append(Block(hidden, hidden, hidden)) self.pool_blocks.append(Block(hidden, hidden, num_nodes)) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear((len(self.embed_blocks) + 1) * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.embed_block1.reset_parameters() self.pool_block1.reset_parameters() for block1, block2 in zip(self.embed_blocks, self.pool_blocks): block1.reset_parameters() block2.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, adj, mask = data.x, data.adj, data.mask s = self.pool_block1(x, adj, mask, add_loop=True) x = F.relu(self.embed_block1(x, adj, mask, add_loop=True)) xs = [x.mean(dim=1)] x, adj, reg = dense_diff_pool(x, adj, s, mask) for embed, pool in zip(self.embed_blocks, self.pool_blocks): s = pool(x, adj) x = F.relu(embed(x, adj)) xs.append(x.mean(dim=1)) x, adj, _, _ = dense_diff_pool(x, adj, s) x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class JKNet(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout, mode='concat'): super().__init__() self.convs = torch.nn.ModuleList() self.convs.append(GCNConv(in_channels, hidden_channels, cached=False)) self.bns = torch.nn.ModuleList() self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) for _ in range(num_layers - 1): self.convs.append( GCNConv(hidden_channels, hidden_channels, cached=False)) self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) self.jump = JumpingKnowledge(mode) if mode == 'cat': self.lin1 = Linear(num_layers * hidden_channels, hidden_channels) else: self.lin1 = Linear(hidden_channels, hidden_channels) self.lin2 = Linear(hidden_channels, out_channels) self.dropout = dropout def reset_parameters(self): for conv in self.convs: conv.reset_parameters() for bn in self.bns: bn.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, x, adj_t): xs = [] for i, conv in enumerate(self.convs): x = conv(x, adj_t) x = self.bns[i](x) x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) xs += [x] x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=self.dropout, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1)
class ModelEnsemble(torch.nn.Module): def __init__(self, num_features, num_class): super(ModelEnsemble, self).__init__() self.JK = JumpingKnowledge(mode='lstm', channels=num_class, num_layers=1) # self.linear = Linear(num_features, num_class) def reset_parameters(self): self.JK.reset_parameters() def forward(self, x): x = self.JK(x) return log_softmax(x, dim=-1)
class GraphSAGEWithJK(torch.nn.Module): def __init__(self, num_input_features, num_layers, hidden, mode='cat'): super(GraphSAGEWithJK, self).__init__() self.conv1 = SAGEConv(num_input_features, hidden) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(SAGEConv(hidden, hidden)) self.jump = JumpingKnowledge(mode) if mode == 'cat': self.lin1 = Linear(3 * num_layers * hidden, hidden) else: self.lin1 = Linear(3 * hidden, hidden) self.lin2 = Linear(hidden, 2) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) xs = [x] for conv in self.convs: x = F.relu(conv(x, edge_index)) xs += [x] x = self.jump(xs) x = torch.cat([ global_add_pool(x, batch), global_mean_pool(x, batch), global_max_pool(x, batch) ], dim=1) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class GCNWithJK(torch.nn.Module): def __init__(self, num_features, output_channels, num_layers=3, nb_neurons=128, mode='cat', **kwargs): super(GCNWithJK, self).__init__() self.conv1 = GCNConv(num_features, nb_neurons) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(GCNConv(nb_neurons, nb_neurons)) self.jump = JumpingKnowledge(mode) if mode == 'cat': self.lin1 = Linear(num_layers * nb_neurons, nb_neurons) else: self.lin1 = Linear(nb_neurons, nb_neurons) self.lin2 = Linear(nb_neurons, output_channels) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data, target_size, **kwargs): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) xs = [x] for conv in self.convs: x = F.relu(conv(x, edge_index)) xs += [x] x = self.jump(xs) x = global_mean_pool(x, batch, size=target_size) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class Graclus(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super(Graclus, self).__init__() self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean') self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(GraphConv(hidden, hidden, aggr='mean')) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(num_layers * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) xs = [global_mean_pool(x, batch)] for i, conv in enumerate(self.convs): x = F.relu(conv(x, edge_index)) xs += [global_mean_pool(x, batch)] if i % 2 == 0 and i < len(self.convs) - 1: cluster = graclus(edge_index, num_nodes=x.size(0)) data = Batch(x=x, edge_index=edge_index, batch=batch) data = max_pool(cluster, data) x, edge_index, batch = data.x, data.edge_index, data.batch x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class TopK(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super(TopK, self).__init__() self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean') self.convs = torch.nn.ModuleList() self.pools = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(GraphConv(hidden, hidden, aggr='mean')) self.pools.append(TopKPooling(hidden, ratio=0.8)) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(num_layers * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv, pool in zip(self.convs, self.pools): conv.reset_parameters() pool.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) xs = [global_mean_pool(x, batch)] for i, (conv, pool) in enumerate(zip(self.convs, self.pools)): x = F.relu(conv(x, edge_index)) xs += [global_mean_pool(x, batch)] if i % 2 == 0: x, edge_index, _, batch, _ = pool(x, edge_index, batch=batch) x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class EdgeCIN0(torch.nn.Module): """ A variant of CIN0 operating up to edge level. It may optionally ignore two_cell features. This model is based on https://github.com/rusty1s/pytorch_geometric/blob/master/benchmark/kernel/gin.py """ def __init__(self, num_input_features, num_classes, num_layers, hidden, dropout_rate: float = 0.5, jump_mode=None, nonlinearity='relu', include_top_features=True, update_top_features=True, readout='sum'): super(EdgeCIN0, self).__init__() self.max_dim = 1 self.include_top_features = include_top_features # If the top features are included, then they can be updated by a network. self.update_top_features = include_top_features and update_top_features self.dropout_rate = dropout_rate self.jump_mode = jump_mode self.convs = torch.nn.ModuleList() self.update_top_nns = torch.nn.ModuleList() self.nonlinearity = nonlinearity self.pooling_fn = get_pooling_fn(readout) conv_nonlinearity = get_nonlinearity(nonlinearity, return_module=True) for i in range(num_layers): layer_dim = num_input_features if i == 0 else hidden v_conv_update = Sequential(Linear(layer_dim, hidden), conv_nonlinearity(), Linear(hidden, hidden), conv_nonlinearity(), BN(hidden)) e_conv_update = Sequential(Linear(layer_dim, hidden), conv_nonlinearity(), Linear(hidden, hidden), conv_nonlinearity(), BN(hidden)) v_conv_up = Sequential(Linear(layer_dim * 2, layer_dim), conv_nonlinearity(), BN(layer_dim)) e_conv_down = Sequential(Linear(layer_dim * 2, layer_dim), conv_nonlinearity(), BN(layer_dim)) e_conv_inp_dim = layer_dim * 2 if include_top_features else layer_dim e_conv_up = Sequential(Linear(e_conv_inp_dim, layer_dim), conv_nonlinearity(), BN(layer_dim)) self.convs.append( EdgeCINConv(layer_dim, layer_dim, v_conv_up, e_conv_down, e_conv_up, v_conv_update, e_conv_update, train_eps=False)) if self.update_top_features and i < num_layers - 1: self.update_top_nns.append( Sequential(Linear(layer_dim, hidden), conv_nonlinearity(), Linear(hidden, hidden), conv_nonlinearity(), BN(hidden))) self.jump = JumpingKnowledge( jump_mode) if jump_mode is not None else None if jump_mode == 'cat': self.lin1 = Linear(num_layers * hidden, hidden) else: self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, num_classes) def reset_parameters(self): for conv in self.convs: conv.reset_parameters() if self.jump_mode is not None: self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() for net in self.update_top_nns: net.reset_parameters() def pool_complex(self, xs, data): # All complexes have nodes so we can extract the batch size from cochains[0] batch_size = data.cochains[0].batch.max() + 1 # The MP output is of shape [message_passing_dim, batch_size, feature_dim] pooled_xs = torch.zeros(self.max_dim + 1, batch_size, xs[0].size(-1), device=batch_size.device) for i in range(len(xs)): # It's very important that size is supplied. pooled_xs[i, :, :] = self.pooling_fn(xs[i], data.cochains[i].batch, size=batch_size) return pooled_xs def jump_complex(self, jump_xs): # Perform JumpingKnowledge at each level of the complex xs = [] for jumpx in jump_xs: xs += [self.jump(jumpx)] return xs def forward(self, data: ComplexBatch): model_nonlinearity = get_nonlinearity(self.nonlinearity, return_module=False) xs, jump_xs = None, None for c, conv in enumerate(self.convs): params = data.get_all_cochain_params( max_dim=self.max_dim, include_top_features=self.include_top_features) xs = conv(*params) # If we are at the last convolutional layer, we do not need to update after # We also check two_cell features do indeed exist in this batch before doing this. if self.update_top_features and c < len( self.convs) - 1 and 2 in data.cochains: top_x = self.update_top_nns[c](data.cochains[2].x) data.set_xs(xs + [top_x]) else: data.set_xs(xs) if self.jump_mode is not None: if jump_xs is None: jump_xs = [[] for _ in xs] for i, x in enumerate(xs): jump_xs[i] += [x] if self.jump_mode is not None: xs = self.jump_complex(jump_xs) pooled_xs = self.pool_complex(xs, data) x = pooled_xs.sum(dim=0) x = model_nonlinearity(self.lin1(x)) x = F.dropout(x, p=self.dropout_rate, training=self.training) x = self.lin2(x) return x def __repr__(self): return self.__class__.__name__
class EdgeOrient(torch.nn.Module): """ A model for edge-defined signals taking edge orientation into account. """ def __init__(self, num_input_features, num_classes, num_layers, hidden, dropout_rate: float = 0.0, jump_mode=None, nonlinearity='id', readout='sum', fully_invar=False): super(EdgeOrient, self).__init__() self.max_dim = 1 self.fully_invar = fully_invar orient = not self.fully_invar self.dropout_rate = dropout_rate self.jump_mode = jump_mode self.convs = torch.nn.ModuleList() self.nonlinearity = nonlinearity self.pooling_fn = get_pooling_fn(readout) for i in range(num_layers): layer_dim = num_input_features if i == 0 else hidden # !!!!! Biases must be set to false. Otherwise, the model is not equivariant !!!! update_up = Linear(layer_dim, hidden, bias=False) update_down = Linear(layer_dim, hidden, bias=False) update = Linear(layer_dim, hidden, bias=False) self.convs.append( OrientedConv(dim=1, up_msg_size=layer_dim, down_msg_size=layer_dim, update_up_nn=update_up, update_down_nn=update_down, update_nn=update, act_fn=get_nonlinearity(nonlinearity, return_module=False), orient=orient)) self.jump = JumpingKnowledge( jump_mode) if jump_mode is not None else None self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, num_classes) def reset_parameters(self): for conv in self.convs: conv.reset_parameters() if self.jump_mode is not None: self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data: CochainBatch, include_partial=False): if self.fully_invar: data.x = torch.abs(data.x) for c, conv in enumerate(self.convs): x = conv(data) data.x = x cell_pred = x # To obtain orientation invariance, we take the absolute value of the features. # Unless we did that already before the first layer. batch_size = data.batch.max() + 1 if not self.fully_invar: x = torch.abs(x) x = self.pooling_fn(x, data.batch, size=batch_size) # At this point we have invariance: we can use any non-linearity we like. # Here, independently from previous non-linearities, we choose ReLU. # Note that this makes the model non-linear even when employing identity # in previous layers. x = torch.relu(self.lin1(x)) x = F.dropout(x, p=self.dropout_rate, training=self.training) x = self.lin2(x) if include_partial: return x, cell_pred return x def __repr__(self): return self.__class__.__name__
class Net(torch.nn.Module): def __init__(self, num_classes, gnn_layers, embed_dim, hidden_dim, jk_layer, process_step, dropout): super(Net, self).__init__() self.dropout = dropout self.convs = torch.nn.ModuleList() self.embedding = Embedding(6, embed_dim) for i in range(gnn_layers): if i == 0: self.convs.append( AGGINConv(Sequential(Linear(2 * embed_dim + 2, hidden_dim), ReLU(), Linear(hidden_dim, hidden_dim), ReLU(), BN(hidden_dim)), train_eps=True)) else: self.convs.append( AGGINConv(Sequential(Linear(hidden_dim, hidden_dim), ReLU(), Linear(hidden_dim, hidden_dim), ReLU(), BN(hidden_dim)), train_eps=True)) if jk_layer.isdigit(): jk_layer = int(jk_layer) self.jk = JumpingKnowledge(mode='lstm', channels=hidden_dim, gnn_layers=jk_layer) self.s2s = (Set2Set(hidden_dim, processing_steps=process_step)) self.fc1 = Linear(2 * hidden_dim, hidden_dim) self.fc2 = Linear(hidden_dim, int(hidden_dim / 2)) self.fc3 = Linear(int(hidden_dim / 2), num_classes) elif jk_layer == 'cat': self.jk = JumpingKnowledge(mode=jk_layer) self.s2s = (Set2Set(gnn_layers * hidden_dim, processing_steps=process_step)) self.fc1 = Linear(2 * gnn_layers * hidden_dim, hidden_dim) self.fc2 = Linear(hidden_dim, int(hidden_dim / 2)) self.fc3 = Linear(int(hidden_dim / 2), num_classes) elif jk_layer == 'max': self.jk = JumpingKnowledge(mode=jk_layer) self.s2s = (Set2Set(hidden_dim, processing_steps=process_step)) self.fc1 = Linear(2 * hidden_dim, hidden_dim) self.fc2 = Linear(hidden_dim, int(hidden_dim / 2)) self.fc3 = Linear(int(hidden_dim / 2), num_classes) def reset_parameters(self): self.embedding.reset_parameters() for conv in self.convs: conv.reset_parameters() self.jk.reset_parameters() self.s2s.reset_parameters() self.fc1.reset_parameters() self.fc2.reset_parameters() self.fc3.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch # Embedding the categorical values from Gene expression and Node type xc = x[:, :2].type(torch.long) ems = self.embedding(xc) ems = ems.view(-1, ems.shape[1] * ems.shape[2]) x = torch.cat((ems, x[:, 2:]), dim=1) xs = [] for i, conv in enumerate(self.convs): x = conv(x, edge_index) xs += [x] x = self.jk(xs) x = self.s2s(x, batch) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) logits = self.fc3(x) return logits
class EdgeMPNN(torch.nn.Module): """ An MPNN operating in the line graph. """ def __init__(self, num_input_features, num_classes, num_layers, hidden, dropout_rate: float = 0.0, jump_mode=None, nonlinearity='relu', readout='sum', fully_invar=True): super(EdgeMPNN, self).__init__() self.max_dim = 1 self.dropout_rate = dropout_rate self.fully_invar = fully_invar orient = not self.fully_invar self.jump_mode = jump_mode self.convs = torch.nn.ModuleList() self.nonlinearity = nonlinearity self.pooling_fn = get_pooling_fn(readout) for i in range(num_layers): layer_dim = num_input_features if i == 0 else hidden # We pass this lambda function to discard upper adjacencies update_up = lambda x: 0 update_down = Linear(layer_dim, hidden, bias=False) update = Linear(layer_dim, hidden, bias=False) self.convs.append( OrientedConv(dim=1, up_msg_size=layer_dim, down_msg_size=layer_dim, update_up_nn=update_up, update_down_nn=update_down, update_nn=update, act_fn=get_nonlinearity(nonlinearity, return_module=False), orient=orient)) self.jump = JumpingKnowledge( jump_mode) if jump_mode is not None else None self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, num_classes) def reset_parameters(self): for conv in self.convs: conv.reset_parameters() if self.jump_mode is not None: self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data: CochainBatch, include_partial=False): if self.fully_invar: data.x = torch.abs(data.x) for c, conv in enumerate(self.convs): x = conv(data) data.x = x cell_pred = x batch_size = data.batch.max() + 1 if not self.fully_invar: x = torch.abs(x) x = self.pooling_fn(x, data.batch, size=batch_size) # At this point we have invariance: we can use any non-linearity we like. # Here, independently from previous non-linearities, we choose ReLU. # Note that this makes the model non-linear even when employing identity # in previous layers. x = torch.relu(self.lin1(x)) x = F.dropout(x, p=self.dropout_rate, training=self.training) x = self.lin2(x) if include_partial: return x, cell_pred return x def __repr__(self): return self.__class__.__name__
class MultiLayerCoarsening(torch.nn.Module): def __init__(self, dataset, hidden, num_layers=2, ratio=0.5): super(MultiLayerCoarsening, self).__init__() self.embed_block1 = DenseGCNConv(dataset.num_features, hidden) self.coarse_block1 = CoarsenBlock(hidden, ratio) self.embed_block2 = DenseGCNConv(hidden, dataset.num_features) # self.embed_block2 = GNNBlock(hidden, hidden, dataset.num_features) self.num_layers = num_layers self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(hidden + dataset.num_features, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.embed_block1.reset_parameters() self.coarse_block1.reset_parameters() self.embed_block2.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data, epsilon=0.01, opt_epochs=100): x, adj, mask = data.x, data.adj, data.mask batch_num_nodes = data.mask.sum(-1) new_adjs = [adj] Ss = [] x1 = F.relu(self.embed_block1(x, adj, mask, add_loop=True)) xs = [x1.mean(dim=1)] new_adj = adj coarse_x = x1 # coarse_x, new_adj, S = self.coarse_block1(x1, adj, batch_num_nodes) # new_adjs.append(new_adj) # Ss.append(S) for i in range(self.num_layers): coarse_x, new_adj, S = self.coarse_block1(coarse_x, new_adj, batch_num_nodes) new_adjs.append(new_adj) Ss.append(S) x2 = self.embed_block2( coarse_x, new_adj, mask, add_loop=True ) #should not add ReLu, otherwise x2 could be all zero. xs.append(x2.mean(dim=1)) opt_loss = 0.0 for i in range(len(x)): x3 = self.get_nonzero_rows(x[i]) x4 = self.get_nonzero_rows(x2[i]) if x3.size()[0] == 0: continue if x4.size()[0] == 0: # opt_loss += sinkhorn_loss_default(x3, x2[i], epsilon, niter=opt_epochs).float() continue opt_loss += sinkhorn_loss_default(x3, x4, epsilon, niter=opt_epochs).float() return xs, new_adjs, Ss, opt_loss def get_nonzero_rows(self, M): # M is a matrix # row_ind = M.sum(-1).nonzero().squeeze() #nonzero has bugs in Pytorch 1.2.0......... #So we use other methods to take place of it MM, MM_ind = torch.abs(M.sum(-1)).sort() N = (torch.abs(M.sum(-1)) > 0).sum() return M[MM_ind[:N]] def predict(self, xs): x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class MultiLayerCoarsening(torch.nn.Module): def __init__(self, dataset, hidden, num_layers=2, ratio=0.5): super(MultiLayerCoarsening, self).__init__() self.embed_block1 = DenseGCNConv(dataset.num_features, hidden) self.coarse_block1 = CoarsenBlock(hidden, ratio) self.embed_block2 = DenseGCNConv(hidden, dataset.num_features) # self.embed_block2 = GNNBlock(hidden, hidden, dataset.num_features) self.num_layers = num_layers self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear( hidden *num_layers, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.embed_block1.reset_parameters() self.coarse_block1.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data, epsilon=0.01, opt_epochs=100): x, adj, mask = data.x, data.adj, data.mask batch_num_nodes = data.mask.sum(-1) new_adjs = [adj] Ss = [] x1 = F.relu(self.embed_block1(x, adj, mask, add_loop=True)) # xs = [x1.mean(dim=1)] xs = [] coarse_x, new_adj, S = self.coarse_block1(x1, adj, batch_num_nodes) new_adjs.append(new_adj) Ss.append(S) # x2 = F.relu(self.embed_block1(coarse_x, new_adj, mask, add_loop=True)) xs.append(coarse_x.mean(dim=1)) x2 = self.embed_block2(coarse_x, new_adj, mask, add_loop=True) # should not add ReLu, otherwise x2 could be all zero. # xs.append(x2.mean(dim=1)) for i in range(self.num_layers-1): x1 = F.relu(self.embed_block1(F.relu(x2), new_adj, mask, add_loop=True)) coarse_x, new_adj, S = self.coarse_block1(x1, new_adj, batch_num_nodes) new_adjs.append(new_adj) Ss.append(S) xs.append(coarse_x.mean(dim=1)) x2 = self.embed_block2(coarse_x, new_adj, mask, add_loop=True)#should not add ReLu, otherwise x2 could be all zero. # xs.append(x2.mean(dim=1)) opt_loss = 0.0 for i in range(len(x)): x3 = self.get_nonzero_rows(x[i]) x4 = self.get_nonzero_rows(x2[i]) if x3.size()[0]==0: continue if x4.size()[0]==0: opt_loss += sinkhorn_loss_default(x3, x2[i], epsilon, niter=opt_epochs).float() continue opt_loss += sinkhorn_loss_default(x3, x4, epsilon, niter=opt_epochs).float() return xs, new_adjs, Ss, opt_loss def get_nonzero_rows(self, M):# M is a matrix # row_ind = M.sum(-1).nonzero().squeeze() #nonzero has bugs in Pytorch 1.2.0......... #So we use other methods to take place of it MM, MM_ind = torch.abs(M.sum(-1)).sort() N = (torch.abs(M.sum(-1))>0).sum() return M[MM_ind[:N]] def predict(self, xs): x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def test(self, train_z, train_y, test_z, test_y, solver='lbfgs', multi_class='auto', *args, **kwargs): r"""Evaluates latent space quality via a logistic regression downstream task.""" clf = LogisticRegression(solver=solver, multi_class=multi_class, *args, **kwargs).fit(train_z.detach().cpu().numpy(), train_y.detach().cpu().numpy()) return clf.score(test_z.detach().cpu().numpy(), test_y.detach().cpu().numpy()) def __repr__(self): return self.__class__.__name__
class JK(nn.Module): def __init__(self, nfeat, nhid, nclass, dropout=0.5, lr=0.01, weight_decay=5e-4, n_edge=1, with_relu=True, drop=False, with_bias=True, device=None): super(JK, self).__init__() assert device is not None, "Please specify 'device'!" self.device = device self.nfeat = nfeat self.hidden_sizes = [nhid] self.nclass = int(nclass) self.dropout = dropout self.lr = lr self.drop = drop if not with_relu: self.weight_decay = 0 else: self.weight_decay = weight_decay self.with_relu = with_relu self.with_bias = with_bias self.n_edge = n_edge self.output = None self.best_model = None self.best_output = None self.adj_norm = None self.features = None self.gate = Parameter(torch.rand(1)) # creat a generator between [0,1] # self.beta = Parameter(torch.Tensor(self.n_edge)) nclass = int(nclass) """JK from torch-geometric""" num_features = nfeat dim = nhid nn1 = Sequential( Linear(num_features, dim), ReLU(), ) self.gc1 = GINConv(nn1) self.bn1 = torch.nn.BatchNorm1d(dim) nn2 = Sequential( Linear(dim, dim), ReLU(), ) self.gc2 = GINConv(nn2) nn3 = Sequential( Linear(dim, dim), ReLU(), ) self.gc3 = GINConv(nn3) self.jump = JumpingKnowledge(mode='cat') # 'cat', 'lstm', 'max' self.bn2 = torch.nn.BatchNorm1d(dim) # self.fc1 = Linear(dim*3, dim) self.fc2 = Linear(dim * 2, int(nclass)) def forward(self, x, adj): """we don't change the edge_index, just update the edge_weight; some edge_weight are regarded as removed if it equals to zero""" x = x.to_dense() edge_index = adj._indices() """GJK-Nets""" if self.attention: adj = self.att_coef(x, adj, i=0) x1 = F.relu( self.gc1(x, edge_index=edge_index, edge_weight=adj._values())) if self.attention: # if attention=True, use attention mechanism adj_2 = self.att_coef(x1, adj, i=1) adj_values = self.gate * adj._values() + ( 1 - self.gate) * adj_2._values() else: adj_values = adj._values() x1 = F.dropout(x1, self.dropout, training=self.training) x2 = F.relu(self.gc2(x1, edge_index=edge_index, edge_weight=adj_values)) # x2 = self.bn1(x2) # if self.attention: # if attention=True, use attention mechanism # adj_3 = self.att_coef(x2, adj, i=1) # adj_values = self.gate * adj_2._values() + (1 - self.gate) * adj_3._values() # else: # adj_values = adj._values() x2 = F.dropout(x2, self.dropout, training=self.training) # x3 = F.relu(self.gc2(x2, edge_index=edge_index, edge_weight=adj_values)) # x3 = F.dropout(x3, self.dropout, training=self.training) x_last = self.jump([x1, x2]) x_last = F.dropout(x_last, self.dropout, training=self.training) x_last = self.fc2(x_last) return F.log_softmax(x_last, dim=1) def initialize(self): self.gc1.reset_parameters() self.gc2.reset_parameters() self.fc2.reset_parameters() try: self.jump.reset_parameters() self.fc1.reset_parameters() self.gc3.reset_parameters() except: pass def att_coef(self, fea, edge_index, is_lil=False, i=0): if is_lil == False: edge_index = edge_index._indices() else: edge_index = edge_index.tocoo() n_node = fea.shape[0] row, col = edge_index[0].cpu().data.numpy()[:], edge_index[1].cpu( ).data.numpy()[:] # row, col = edge_index[0], edge_index[1] fea_copy = fea.cpu().data.numpy() sim_matrix = cosine_similarity(X=fea_copy, Y=fea_copy) # try cosine similarity sim = sim_matrix[row, col] sim[sim < 0.1] = 0 # print('dropped {} edges'.format(1-sim.nonzero()[0].shape[0]/len(sim))) # """use jaccard for binary features and cosine for numeric features""" # fea_start, fea_end = fea[edge_index[0]], fea[edge_index[1]] # isbinray = np.array_equal(fea_copy, fea_copy.astype(bool)) # check is the fea are binary # np.seterr(divide='ignore', invalid='ignore') # if isbinray: # fea_start, fea_end = fea_start.T, fea_end.T # sim = jaccard_score(fea_start, fea_end, average=None) # similarity scores of each edge # else: # fea_copy[np.isinf(fea_copy)] = 0 # fea_copy[np.isnan(fea_copy)] = 0 # sim_matrix = cosine_similarity(X=fea_copy, Y=fea_copy) # try cosine similarity # sim = sim_matrix[edge_index[0], edge_index[1]] # sim[sim < 0.01] = 0 """build a attention matrix""" att_dense = lil_matrix((n_node, n_node), dtype=np.float32) att_dense[row, col] = sim if att_dense[0, 0] == 1: att_dense = att_dense - sp.diags( att_dense.diagonal(), offsets=0, format="lil") # normalization, make the sum of each row is 1 att_dense_norm = normalize(att_dense, axis=1, norm='l1') """add learnable dropout, make character vector""" if self.drop: character = np.vstack( (att_dense_norm[row, col].A1, att_dense_norm[col, row].A1)) character = torch.from_numpy(character.T) drop_score = self.drop_learn_1(character) drop_score = torch.sigmoid( drop_score ) # do not use softmax since we only have one element mm = torch.nn.Threshold(0.5, 0) drop_score = mm(drop_score) mm_2 = torch.nn.Threshold(-0.49, 1) drop_score = mm_2(-drop_score) drop_decision = drop_score.clone().requires_grad_() # print('rate of left edges', drop_decision.sum().data/drop_decision.shape[0]) drop_matrix = lil_matrix((n_node, n_node), dtype=np.float32) drop_matrix[row, col] = drop_decision.cpu().data.numpy().squeeze(-1) att_dense_norm = att_dense_norm.multiply( drop_matrix.tocsr()) # update, remove the 0 edges if att_dense_norm[ 0, 0] == 0: # add the weights of self-loop only add self-loop at the first layer degree = (att_dense_norm != 0).sum(1).A1 # degree = degree.squeeze(-1).squeeze(-1) lam = 1 / (degree + 1) # degree +1 is to add itself self_weight = sp.diags(np.array(lam), offsets=0, format="lil") att = att_dense_norm + self_weight # add the self loop else: att = att_dense_norm att_adj = edge_index att_edge_weight = att[row, col] att_edge_weight = np.exp(att_edge_weight) # exponent, kind of softmax att_edge_weight = torch.tensor(np.array(att_edge_weight)[0], dtype=torch.float32).cuda() shape = (n_node, n_node) new_adj = torch.sparse.FloatTensor(att_adj, att_edge_weight, shape) return new_adj def add_loop_sparse(self, adj, fill_value=1): # make identify sparse tensor row = torch.range(0, int(adj.shape[0] - 1), dtype=torch.int64) i = torch.stack((row, row), dim=0) v = torch.ones(adj.shape[0], dtype=torch.float32) shape = adj.shape I_n = torch.sparse.FloatTensor(i, v, shape) return adj + I_n.to(self.device) def fit( self, features, adj, labels, idx_train, idx_val=None, idx_test=None, train_iters=81, att_0=None, attention=False, model_name=None, initialize=True, verbose=False, normalize=False, patience=500, ): ''' train the gcn model, when idx_val is not None, pick the best model according to the validation loss ''' self.sim = None self.attention = attention self.idx_test = idx_test # self.device = self.gc1.weight.device if initialize: self.initialize() if type(adj) is not torch.Tensor: features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device) else: features = features.to(self.device) adj = adj.to(self.device) labels = labels.to(self.device) # normalize = False # we don't need normalize here, the norm is conducted in the GCN (self.gcn1) model # if normalize: # if utils.is_sparse_tensor(adj): # adj_norm = utils.normalize_adj_tensor(adj, sparse=True) # else: # adj_norm = utils.normalize_adj_tensor(adj) # else: # adj_norm = adj adj = self.add_loop_sparse(adj) """Make the coefficient D^{-1/2}(A+I)D^{-1/2}""" self.adj_norm = adj self.features = features self.labels = labels if idx_val is None: self._train_without_val(labels, idx_train, train_iters, verbose) else: if patience < train_iters: self._train_with_early_stopping(labels, idx_train, idx_val, train_iters, patience, verbose) else: self._train_with_val(labels, idx_train, idx_val, train_iters, verbose) def _train_without_val(self, labels, idx_train, train_iters, verbose): self.train() optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) for i in range(train_iters): optimizer.zero_grad() output = self.forward(self.features, self.adj_norm) loss_train = F.nll_loss( output[idx_train], labels[idx_train], weight=None ) # this weight is the weight of each training nodes loss_train.backward() optimizer.step() if verbose and i % 10 == 0: print('Epoch {}, training loss: {}'.format( i, loss_train.item())) self.eval() output = self.forward(self.features, self.adj_norm) self.output = output def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose): if verbose: print('=== training gcn model ===') optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) best_loss_val = 100 best_acc_val = 0 for i in range(train_iters): self.train() optimizer.zero_grad() output = self.forward(self.features, self.adj_norm) loss_train = F.nll_loss(output[idx_train], labels[idx_train]) loss_train.backward() optimizer.step() # pred = output[self.idx_test].max(1)[1] # acc_test =accuracy(output[self.idx_test], labels[self.idx_test]) # acc_test = pred.eq(labels[self.idx_test]).sum().item() / self.idx_test.shape[0] self.eval() output = self.forward(self.features, self.adj_norm) loss_val = F.nll_loss(output[idx_val], labels[idx_val]) acc_val = utils.accuracy(output[idx_val], labels[idx_val]) if verbose and i % 20 == 0: print('Epoch {}, training loss: {}, test acc: {}'.format( i, loss_train.item(), acc_val)) if best_loss_val > loss_val: best_loss_val = loss_val self.output = output weights = deepcopy(self.state_dict()) if acc_val > best_acc_val: best_acc_val = acc_val self.output = output weights = deepcopy(self.state_dict()) if verbose: print( '=== picking the best model according to the performance on validation ===' ) self.load_state_dict(weights) def _train_with_early_stopping(self, labels, idx_train, idx_val, train_iters, patience, verbose): if verbose: print('=== training gcn model ===') optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) early_stopping = patience best_loss_val = 100 for i in range(train_iters): self.train() optimizer.zero_grad() output = self.forward(self.features, self.adj_norm) loss_train = F.nll_loss(output[idx_train], labels[idx_train]) loss_train.backward() optimizer.step() self.eval() output = self.forward(self.features, self.adj_norm) if verbose and i % 10 == 0: print('Epoch {}, training loss: {}'.format( i, loss_train.item())) loss_val = F.nll_loss(output[idx_val], labels[idx_val]) if best_loss_val > loss_val: best_loss_val = loss_val self.output = output weights = deepcopy(self.state_dict()) patience = early_stopping else: patience -= 1 if i > early_stopping and patience <= 0: break if verbose: print('=== early stopping at {0}, loss_val = {1} ==='.format( i, best_loss_val)) self.load_state_dict(weights) def test(self, idx_test, model_name=None): # self.model_name = model_name self.eval() output = self.predict() # output = self.output loss_test = F.nll_loss(output[idx_test], self.labels[idx_test]) acc_test = utils.accuracy(output[idx_test], self.labels[idx_test]) print("Test set results:", "loss= {:.4f}".format(loss_test.item()), "accuracy= {:.4f}".format(acc_test.item())) return acc_test, output def _set_parameters(self): # TODO pass def predict(self, features=None, adj=None): '''By default, inputs are unnormalized data''' # self.eval() if features is None and adj is None: return self.forward(self.features, self.adj_norm) else: if type(adj) is not torch.Tensor: features, adj = utils.to_tensor(features, adj, device=self.device) self.features = features if utils.is_sparse_tensor(adj): self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True) else: self.adj_norm = utils.normalize_adj_tensor(adj) return self.forward(self.features, self.adj_norm)
class Graclus(torch.nn.Module): def __init__(self, num_features, num_classes, num_layers, hidden, pooling_type, no_cat=False, encode_edge=False): super(Graclus, self).__init__() self.encode_edge = encode_edge if encode_edge: self.conv1 = GCNConv(hidden, aggr='add') else: self.conv1 = GraphConv(num_features, hidden, aggr='add') self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(GraphConv(hidden, hidden, aggr='add')) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(num_layers * hidden, hidden) if no_cat: self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, num_classes) self.pooling_type = pooling_type self.no_cat = no_cat self.atom_encoder = AtomEncoder(emb_dim=hidden) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch if self.encode_edge: x = self.atom_encoder(x) x = self.conv1(x, edge_index, data.edge_attr) else: x = self.conv1(x, edge_index) x = F.relu(x) xs = [global_mean_pool(x, batch)] for i, conv in enumerate(self.convs): x = F.relu(conv(x, edge_index)) xs += [global_mean_pool(x, batch)] if self.pooling_type != 'none': if self.pooling_type == 'complement': complement = batched_negative_edges(edge_index=edge_index, batch=batch, force_undirected=True) cluster = graclus(complement, num_nodes=x.size(0)) elif self.pooling_type == 'graclus': cluster = graclus(edge_index, num_nodes=x.size(0)) data = Batch(x=x, edge_index=edge_index, batch=batch) data = max_pool(cluster, data) x, edge_index, batch = data.x, data.edge_index, data.batch if not self.no_cat: x = self.jump(xs) else: x = global_mean_pool(x, batch) x = F.relu(self.lin1(x)) x = self.lin2(x) return x
class MincutPool(torch.nn.Module): def __init__(self, dataset, num_layers, hidden, ratio=0.1): super(MincutPool, self).__init__() num_nodes = ceil(ratio * dataset[0].num_nodes) self.embed_block1 = Block(dataset.num_features, hidden, hidden) self.pool_block1 = Linear(hidden, num_nodes) self.embed_blocks = torch.nn.ModuleList() self.pool_blocks = torch.nn.ModuleList() for i in range(num_layers - 1): num_nodes = ceil(ratio * num_nodes) self.embed_blocks.append(Block(hidden, hidden, hidden)) self.pool_blocks.append(Linear(hidden, num_nodes)) self.embed_final = Block(hidden, hidden, hidden) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(hidden * (num_layers + 1), hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.embed_block1.reset_parameters() self.pool_block1.reset_parameters() for embed_block, pool_block in zip(self.embed_blocks, self.pool_blocks): embed_block.reset_parameters() pool_block.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, adj, mask = data.x, data.adj, data.mask mincut_losses = 0. ortho_losses = 0. x = F.relu(self.embed_block1(x, adj, mask, add_loop=True)) s = self.pool_block1(x) xs = [x.mean(dim=1)] x, adj, mincut_loss, ortho_loss = dense_mincut_pool(x, adj, s, mask) mincut_losses += mincut_loss ortho_losses += ortho_loss for i, (embed_block, pool_block) in enumerate( zip(self.embed_blocks, self.pool_blocks)): x = F.relu(embed_block(x, adj)) s = pool_block(x) xs.append(x.mean(dim=1)) if i < len(self.embed_blocks): x, adj, mincut_loss, ortho_loss = dense_mincut_pool(x, adj, s) mincut_losses += mincut_loss ortho_losses += ortho_loss x = F.relu(self.embed_final(x, adj, add_loop=True)) xs.append(x.mean(dim=1)) x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) # print(mincut_losses+ortho_losses) return F.log_softmax(x, dim=-1), mincut_losses + ortho_losses def __repr__(self): return self.__class__.__name__
class GCNWithJK(torch.nn.Module): def __init__(self, num_layers, num_input_features, hidden, mode='cat'): super(GCNWithJK, self).__init__() self.conv1 = GCNConv(num_input_features, hidden) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(GCNConv(hidden, hidden)) self.jump = JumpingKnowledge(mode) if mode == 'cat': # concatenation self.lin1 = Linear(3 * num_layers * hidden, hidden) else: self.lin1 = Linear(3 * hidden, hidden) self.lin2 = Linear(hidden, 2) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): # data: Batch(batch=[num_nodes_in_batch], # edge_attr=[2*num_nodes_in_batch,num_edge_features_per_edge], # edge_index=[2,2*num_nodes_in_batch], # pos=[num_nodes_in_batch,2], # x=[num_nodes_in_batch, num_input_features_per_node], # y=[num_graphs_in_batch, num_classes] # example: Batch(batch=[2490], edge_attr=[4980,1], edge_index=[2,4980], pos=[2490,2], x=[2490,33], y=[32,2] x, edge_index, batch = data.x, data.edge_index, data.batch # x.shape: torch.Size([num_nodes_in_batch, num_input_features_per_node]) # edge_index.shape: torch.Size([2, 2*num_nodes_in_batch]) # batch.shape: torch.Size([num_nodes_in_batch]) # example: x.shape = troch.Size([2490,33]) # edge_index.shape = torch.Size([2,4980]) # batch.shape = torch.Size([2490]) x = F.relu(self.conv1(x, edge_index)) # x.shape: torch.Size([num_nodes_in_batch, hidden]) # example: x.shape = troch.Size([2490,66]) xs = [x] for conv in self.convs: x = F.relu(conv(x, edge_index)) xs += [x] # xs: list containing layer-wise representations x = self.jump( xs ) # aggregate information across different layers (concatenation) # x.shape: torch.Size([num_nodes_in_batch, num_layers * hidden]) # example: x.shape = torch.Size([2490, 3*66]) x = torch.cat([ global_add_pool(x, batch), global_mean_pool(x, batch), global_max_pool(x, batch) ], dim=1) # x.shape: torch.Size([num_graphs_in_batch, 3*num_layers * hidden]) # example: x.shape = torch.Size([32, 3*3*66]) x = F.relu(self.lin1(x)) # x.shape: torch.Size([num_graphs_in_batch, hidden]) # example: x.shape = torch.Size([32, 66]) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) # x.shape: torch.Size([num_graphs_in_batch, num_classes]) # example: x.shape = torch.Size([32, 2]) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class CIN0(torch.nn.Module): """ A cellular version of GIN. This model is based on https://github.com/rusty1s/pytorch_geometric/blob/master/benchmark/kernel/gin.py """ def __init__(self, num_input_features, num_classes, num_layers, hidden, dropout_rate: float = 0.5, max_dim: int = 2, jump_mode=None, nonlinearity='relu', readout='sum'): super(CIN0, self).__init__() self.max_dim = max_dim self.dropout_rate = dropout_rate self.jump_mode = jump_mode self.convs = torch.nn.ModuleList() self.nonlinearity = nonlinearity self.pooling_fn = get_pooling_fn(readout) conv_nonlinearity = get_nonlinearity(nonlinearity, return_module=True) for i in range(num_layers): layer_dim = num_input_features if i == 0 else hidden conv_update = Sequential(Linear(layer_dim, hidden), conv_nonlinearity(), Linear(hidden, hidden), conv_nonlinearity(), BN(hidden)) conv_up = Sequential(Linear(layer_dim * 2, layer_dim), conv_nonlinearity(), BN(layer_dim)) conv_down = Sequential(Linear(layer_dim * 2, layer_dim), conv_nonlinearity(), BN(layer_dim)) self.convs.append( CINConv(layer_dim, layer_dim, conv_up, conv_down, conv_update, train_eps=False, max_dim=self.max_dim)) self.jump = JumpingKnowledge( jump_mode) if jump_mode is not None else None if jump_mode == 'cat': self.lin1 = Linear(num_layers * hidden, hidden) else: self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, num_classes) def reset_parameters(self): for conv in self.convs: conv.reset_parameters() if self.jump_mode is not None: self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def pool_complex(self, xs, data): # All complexes have nodes so we can extract the batch size from cochains[0] batch_size = data.cochains[0].batch.max() + 1 # The MP output is of shape [message_passing_dim, batch_size, feature_dim] pooled_xs = torch.zeros(self.max_dim + 1, batch_size, xs[0].size(-1), device=batch_size.device) for i in range(len(xs)): # It's very important that size is supplied. pooled_xs[i, :, :] = self.pooling_fn(xs[i], data.cochains[i].batch, size=batch_size) return pooled_xs def jump_complex(self, jump_xs): # Perform JumpingKnowledge at each level of the complex xs = [] for jumpx in jump_xs: xs += [self.jump(jumpx)] return xs def forward(self, data: ComplexBatch): model_nonlinearity = get_nonlinearity(self.nonlinearity, return_module=False) xs, jump_xs = None, None for c, conv in enumerate(self.convs): params = data.get_all_cochain_params(max_dim=self.max_dim) xs = conv(*params) data.set_xs(xs) if self.jump_mode is not None: if jump_xs is None: jump_xs = [[] for _ in xs] for i, x in enumerate(xs): jump_xs[i] += [x] if self.jump_mode is not None: xs = self.jump_complex(jump_xs) pooled_xs = self.pool_complex(xs, data) x = pooled_xs.sum(dim=0) x = model_nonlinearity(self.lin1(x)) x = F.dropout(x, p=self.dropout_rate, training=self.training) x = self.lin2(x) return x def __repr__(self): return self.__class__.__name__
class SparseCIN(torch.nn.Module): """ A cellular version of GIN. This model is based on https://github.com/rusty1s/pytorch_geometric/blob/master/benchmark/kernel/gin.py """ def __init__(self, num_input_features, num_classes, num_layers, hidden, dropout_rate: float = 0.5, max_dim: int = 2, jump_mode=None, nonlinearity='relu', readout='sum', train_eps=False, final_hidden_multiplier: int = 2, use_coboundaries=False, readout_dims=(0, 1, 2), final_readout='sum', apply_dropout_before='lin2', graph_norm='bn'): super(SparseCIN, self).__init__() self.max_dim = max_dim if readout_dims is not None: self.readout_dims = tuple( [dim for dim in readout_dims if dim <= max_dim]) else: self.readout_dims = list(range(max_dim + 1)) self.final_readout = final_readout self.dropout_rate = dropout_rate self.apply_dropout_before = apply_dropout_before self.jump_mode = jump_mode self.convs = torch.nn.ModuleList() self.nonlinearity = nonlinearity self.pooling_fn = get_pooling_fn(readout) self.graph_norm = get_graph_norm(graph_norm) act_module = get_nonlinearity(nonlinearity, return_module=True) for i in range(num_layers): layer_dim = num_input_features if i == 0 else hidden self.convs.append( SparseCINConv(up_msg_size=layer_dim, down_msg_size=layer_dim, boundary_msg_size=layer_dim, passed_msg_boundaries_nn=None, passed_msg_up_nn=None, passed_update_up_nn=None, passed_update_boundaries_nn=None, train_eps=train_eps, max_dim=self.max_dim, hidden=hidden, act_module=act_module, layer_dim=layer_dim, graph_norm=self.graph_norm, use_coboundaries=use_coboundaries)) self.jump = JumpingKnowledge( jump_mode) if jump_mode is not None else None self.lin1s = torch.nn.ModuleList() for _ in range(max_dim + 1): if jump_mode == 'cat': # These layers don't use a bias. Then, in case a level is not present the output # is just zero and it is not given by the biases. self.lin1s.append( Linear(num_layers * hidden, final_hidden_multiplier * hidden, bias=False)) else: self.lin1s.append( Linear(hidden, final_hidden_multiplier * hidden)) self.lin2 = Linear(final_hidden_multiplier * hidden, num_classes) def reset_parameters(self): for conv in self.convs: conv.reset_parameters() if self.jump_mode is not None: self.jump.reset_parameters() self.lin1s.reset_parameters() self.lin2.reset_parameters() def pool_complex(self, xs, data): # All complexes have nodes so we can extract the batch size from cochains[0] batch_size = data.cochains[0].batch.max() + 1 # print(batch_size) # The MP output is of shape [message_passing_dim, batch_size, feature_dim] pooled_xs = torch.zeros(self.max_dim + 1, batch_size, xs[0].size(-1), device=batch_size.device) for i in range(len(xs)): # It's very important that size is supplied. pooled_xs[i, :, :] = self.pooling_fn(xs[i], data.cochains[i].batch, size=batch_size) new_xs = [] for i in range(self.max_dim + 1): new_xs.append(pooled_xs[i]) return new_xs def jump_complex(self, jump_xs): # Perform JumpingKnowledge at each level of the complex xs = [] for jumpx in jump_xs: xs += [self.jump(jumpx)] return xs def forward(self, data: ComplexBatch, include_partial=False): act = get_nonlinearity(self.nonlinearity, return_module=False) xs, jump_xs = None, None res = {} for c, conv in enumerate(self.convs): params = data.get_all_cochain_params(max_dim=self.max_dim, include_down_features=False) start_to_process = 0 # if i == len(self.convs) - 2: # start_to_process = 1 # if i == len(self.convs) - 1: # start_to_process = 2 xs = conv(*params, start_to_process=start_to_process) data.set_xs(xs) if include_partial: for k in range(len(xs)): res[f"layer{c}_{k}"] = xs[k] if self.jump_mode is not None: if jump_xs is None: jump_xs = [[] for _ in xs] for i, x in enumerate(xs): jump_xs[i] += [x] if self.jump_mode is not None: xs = self.jump_complex(jump_xs) xs = self.pool_complex(xs, data) # Select the dimensions we want at the end. xs = [xs[i] for i in self.readout_dims] if include_partial: for k in range(len(xs)): res[f"pool_{k}"] = xs[k] new_xs = [] for i, x in enumerate(xs): if self.apply_dropout_before == 'lin1': x = F.dropout(x, p=self.dropout_rate, training=self.training) new_xs.append(act(self.lin1s[self.readout_dims[i]](x))) x = torch.stack(new_xs, dim=0) if self.apply_dropout_before == 'final_readout': x = F.dropout(x, p=self.dropout_rate, training=self.training) if self.final_readout == 'mean': x = x.mean(0) elif self.final_readout == 'sum': x = x.sum(0) else: raise NotImplementedError if self.apply_dropout_before not in ['lin1', 'final_readout']: x = F.dropout(x, p=self.dropout_rate, training=self.training) x = self.lin2(x) if include_partial: res['out'] = x return x, res return x def __repr__(self): return self.__class__.__name__
class DEA_GNN_JK(torch.nn.Module): def __init__(self, num_nodes, embed_dim, gnn_in_dim, gnn_hidden_dim, gnn_out_dim, gnn_num_layers, mlp_in_dim, mlp_hidden_dim, mlp_out_dim=1, mlp_num_layers=2, dropout=0.5, gnn_batchnorm=False, mlp_batchnorm=False, K=2, jk_mode='max'): super(DEA_GNN_JK, self).__init__() assert jk_mode in ['max','sum','mean','lstm','cat'] # Embedding self.emb = torch.nn.Embedding(num_nodes, embedding_dim=embed_dim) # GNN convs_list = [TAGConv(gnn_in_dim, gnn_hidden_dim, K)] for i in range(gnn_num_layers-2): convs_list.append(TAGConv(gnn_hidden_dim, gnn_hidden_dim, K)) convs_list.append(TAGConv(gnn_hidden_dim, gnn_out_dim, K)) self.convs = torch.nn.ModuleList(convs_list) # MLP lins_list = [torch.nn.Linear(mlp_in_dim, mlp_hidden_dim)] for i in range(mlp_num_layers-2): lins_list.append(torch.nn.Linear(mlp_hidden_dim, mlp_hidden_dim)) lins_list.append(torch.nn.Linear(mlp_hidden_dim, mlp_out_dim)) self.lins = torch.nn.ModuleList(lins_list) # Batchnorm self.gnn_batchnorm = gnn_batchnorm self.mlp_batchnorm = mlp_batchnorm if self.gnn_batchnorm: self.gnn_bns = torch.nn.ModuleList([torch.nn.BatchNorm1d(gnn_hidden_dim) for i in range(gnn_num_layers)]) if self.mlp_batchnorm: self.mlp_bns = torch.nn.ModuleList([torch.nn.BatchNorm1d(mlp_hidden_dim) for i in range(mlp_num_layers-1)]) self.jk_mode = jk_mode if self.jk_mode in ['max', 'lstm', 'cat']: self.jk = JumpingKnowledge(mode=self.jk_mode, channels=gnn_hidden_dim, num_layers=gnn_num_layers) self.dropout = dropout self.loss_fn = torch.nn.BCEWithLogitsLoss() self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.emb.weight) for conv in self.convs: conv.reset_parameters() for lin in self.lins: lin.reset_parameters() if self.gnn_batchnorm: for bn in self.gnn_bns: bn.reset_parameters() if self.mlp_batchnorm: for bn in self.mlp_bns: bn.reset_parameters() if self.jk_mode in ['max', 'lstm', 'cat']: self.jk.reset_parameters() def forward(self, x_feature, adj_t, edge_label_index): if x_feature is not None: out = torch.cat([self.emb.weight, x_feature], dim=1) else: out = self.emb.weight out_list = [] for i in range(len(self.convs)): out = self.convs[i](out, adj_t) if self.gnn_batchnorm: out = self.gnn_bns[i](out) out = F.relu(out) out = F.dropout(out, p=self.dropout, training=self.training) out_list += [out] if self.jk_mode in ['max', 'lstm', 'cat']: out = self.jk(out_list) elif self.jk_mode == 'mean': out_stack = torch.stack(out_list, dim=0) out = torch.mean(out_stack, dim=0) elif self.jk_mode == 'sum': out_stack = torch.stack(out_list, dim=0) out = torch.sum(out_stack, dim=0) gnn_embed = out[edge_label_index,:] embed_product = gnn_embed[0, :, :] * gnn_embed[1, :, :] out = embed_product for i in range(len(self.lins)-1): out = self.lins[i](out) if self.mlp_batchnorm: out = self.mlp_bns[i](out) out = F.relu(out) out = F.dropout(out, p=self.dropout, training=self.training) out = self.lins[-1](out).squeeze(1) return out def loss(self, y_pred, y_true): return self.loss_fn(y_pred, y_true)
class Net(torch.nn.Module): def __init__(self, num_classes, num_layers, feat_dim, embed_dim, jk_layer, process_step, dropout): super(Net, self).__init__() self.dropout = dropout self.num_layers = num_layers self.convs = torch.nn.ModuleList() for i in range(num_layers): if i == 0: self.convs.append( AGGINConv(Sequential(Linear(feat_dim, embed_dim), ReLU(), Linear(embed_dim, embed_dim), ReLU(), BN(embed_dim)), train_eps=True, dropout=self.dropout)) else: self.convs.append( AGGINConv(Sequential(Linear(embed_dim, embed_dim), ReLU(), Linear(embed_dim, embed_dim), ReLU(), BN(embed_dim)), train_eps=True, dropout=self.dropout)) if jk_layer.isdigit(): jk_layer = int(jk_layer) self.jump = JumpingKnowledge(mode='lstm', channels=embed_dim, num_layers=jk_layer) self.gpl = (Set2Set(embed_dim, processing_steps=process_step)) self.fc1 = Linear(2 * embed_dim, embed_dim) # self.fc1 = Linear(embed_dim, embed_dim) self.fc2 = Linear(embed_dim, num_classes) elif jk_layer == 'cat': self.jump = JumpingKnowledge(mode=jk_layer) self.gpl = (Set2Set(num_layers * embed_dim, processing_steps=process_step)) self.fc1 = Linear(2 * embed_dim, embed_dim) # self.fc1 = Linear(num_layers * embed_dim, embed_dim) self.fc2 = Linear(embed_dim, num_classes) elif jk_layer == 'max': self.jump = JumpingKnowledge(mode=jk_layer) self.gpl = (Set2Set(embed_dim, processing_steps=process_step)) self.fc1 = Linear(2 * embed_dim, embed_dim) # self.fc1 = Linear(embed_dim, embed_dim) self.fc2 = Linear(embed_dim, num_classes) def reset_parameters(self): for conv in self.convs: conv.reset_parameters() self.gpl.reset_parameters() self.jump.reset_parameters() self.fc1.reset_parameters() self.fc2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch xs = [] for i in range(self.num_layers): x = self.convs[i](x, edge_index) xs += [x] x = self.jump(xs) x = self.gpl(x, batch) # x = global_max_pool(x, batch) x = F.relu(self.fc1(x)) # x = F.dropout(x, p=self.dropout, training=self.training) logits = self.fc2(x) return logits
class SOMPool(torch.nn.Module): def __init__(self, dataset, num_layers, hidden, ratio=0.25): super(SOMPool, self).__init__() num_nodes = ceil(ratio * dataset[0].num_nodes) self.num_nodes = hidden self.embed_block1 = Block(dataset.num_features, hidden, hidden) self.pool_block1 = Block(dataset.num_features, hidden, num_nodes) self.dimnum = dataset.num_features self.embed_blocks = torch.nn.ModuleList() self.pool_blocks = torch.nn.ModuleList() for i in range((num_layers // 2) - 1): num_nodes = ceil(ratio * num_nodes) self.embed_blocks.append(Block(hidden, hidden, hidden)) self.pool_blocks.append(Block(hidden, hidden, num_nodes)) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear((len(self.embed_blocks) + 1) * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.embed_block1.reset_parameters() self.pool_block1.reset_parameters() for embed_block, pool_block in zip(self.embed_blocks, self.pool_blocks): embed_block.reset_parameters() pool_block.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): #print(data) x, adj, mask = data.x, data.adj, data.mask somnum = ceil(sqrt(self.num_nodes / 1.5)) som = MiniSom(somnum, somnum, self.dimnum, sigma=0.3, learning_rate=0.5) tempdata = x.reshape(-1, self.dimnum) tempdata = tempdata.cpu().numpy() som.train_batch(tempdata, 100) qnt = som.quantization(tempdata) qnt = torch.from_numpy(qnt).float().to(device) qnt = qnt.reshape(adj.size()[0], -1, self.dimnum) #print(qnt.size()) #print(adj.size()) s = self.pool_block1(qnt, adj, mask, add_loop=True) x = F.relu(self.embed_block1(x, adj, mask, add_loop=True)) xs = [x.mean(dim=1)] x, adj, _, _ = dense_diff_pool(x, adj, s, mask) for i, (embed_block, pool_block) in enumerate( zip(self.embed_blocks, self.pool_blocks)): s = pool_block(x, adj) x = F.relu(embed_block(x, adj)) xs.append(x.mean(dim=1)) if i < len(self.embed_blocks) - 1: x, adj, _, _ = dense_diff_pool(x, adj, s) x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__