def __init__(self, args): super(DeeperGCN, self).__init__() self.num_layers = args.num_layers self.dropout = args.dropout self.block = args.block hidden_channels = args.hidden_channels num_tasks = args.num_tasks conv = args.conv aggr = args.gcn_aggr t = args.t self.learn_t = args.learn_t p = args.p self.learn_p = args.learn_p self.msg_norm = args.msg_norm learn_msg_scale = args.learn_msg_scale conv_encode_edge = args.conv_encode_edge norm = args.norm mlp_layers = args.mlp_layers graph_pooling = args.graph_pooling print('The number of layers {}'.format(self.num_layers), 'Aggr aggregation method {}'.format(aggr), 'block: {}'.format(self.block)) if self.block == 'res+': print('LN/BN->ReLU->GraphConv->Res') elif self.block == 'res': print('GraphConv->LN/BN->ReLU->Res') elif self.block == 'dense': raise NotImplementedError('To be implemented') elif self.block == "plain": print('GraphConv->LN/BN->ReLU') else: raise Exception('Unknown block Type') self.gcns = torch.nn.ModuleList() self.norms = torch.nn.ModuleList() for layer in range(self.num_layers): if conv == 'gen': gcn = GENConv(hidden_channels, hidden_channels, aggr=aggr, t=t, learn_t=self.learn_t, p=p, learn_p=self.learn_p, msg_norm=self.msg_norm, learn_msg_scale=learn_msg_scale, encode_edge=conv_encode_edge, edge_feat_dim=hidden_channels, norm=norm, mlp_layers=mlp_layers) else: raise Exception('Unknown Conv Type') self.gcns.append(gcn) self.norms.append(norm_layer(norm, hidden_channels)) self.node_features_encoder = torch.nn.Linear(7, hidden_channels) self.edge_encoder = torch.nn.Linear(7, hidden_channels) if graph_pooling == "sum": self.pool = global_add_pool elif graph_pooling == "mean": self.pool = global_mean_pool elif graph_pooling == "max": self.pool = global_max_pool else: raise Exception('Unknown Pool Type') self.graph_pred_linear = torch.nn.Linear(hidden_channels, num_tasks)
def __init__(self, args): super(DeeperGCN, self).__init__() self.num_layers = args.num_layers self.dropout = args.dropout self.block = args.block self.conv_encode_edge = args.conv_encode_edge self.add_virtual_node = args.add_virtual_node hidden_channels = args.hidden_channels num_tasks = args.num_tasks conv = args.conv aggr = args.gcn_aggr t = args.t self.learn_t = args.learn_t p = args.p self.learn_p = args.learn_p self.msg_norm = args.msg_norm learn_msg_scale = args.learn_msg_scale norm = args.norm mlp_layers = args.mlp_layers graph_pooling = args.graph_pooling print('The number of layers {}'.format(self.num_layers), 'Aggr aggregation method {}'.format(aggr), 'block: {}'.format(self.block)) if self.block == 'res+': print('LN/BN->ReLU->GraphConv->Res') elif self.block == 'res': print('GraphConv->LN/BN->ReLU->Res') elif self.block == 'dense': raise NotImplementedError('To be implemented') elif self.block == "plain": print('GraphConv->LN/BN->ReLU') else: raise Exception('Unknown block Type') self.gcns = torch.nn.ModuleList() self.norms = torch.nn.ModuleList() if self.add_virtual_node: self.virtualnode_embedding = torch.nn.Embedding(1, hidden_channels) torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0) self.mlp_virtualnode_list = torch.nn.ModuleList() for layer in range(self.num_layers - 1): self.mlp_virtualnode_list.append( MLP([hidden_channels, hidden_channels], norm=norm, last_act=True)) for layer in range(self.num_layers): if conv == 'gen': gcn = GENConv(hidden_channels, hidden_channels, aggr=aggr, t=t, learn_t=self.learn_t, p=p, learn_p=self.learn_p, msg_norm=self.msg_norm, learn_msg_scale=learn_msg_scale, encode_edge=self.conv_encode_edge, bond_encoder=True, norm=norm, mlp_layers=mlp_layers) else: raise Exception('Unknown Conv Type') self.gcns.append(gcn) self.norms.append(norm_layer(norm, hidden_channels)) self.atom_encoder = AtomEncoder(emb_dim=hidden_channels) if not self.conv_encode_edge: self.bond_encoder = BondEncoder(emb_dim=hidden_channels) if graph_pooling == "sum": self.pool = global_add_pool elif graph_pooling == "mean": self.pool = global_mean_pool elif graph_pooling == "max": self.pool = global_max_pool else: raise Exception('Unknown Pool Type') self.graph_pred_linear = torch.nn.Linear(hidden_channels, num_tasks)
def __init__(self, args): super(DeeperGCN, self).__init__() self.num_layers = args.num_layers self.dropout = args.dropout self.block = args.block self.checkpoint_grad = False hidden_channels = args.hidden_channels num_tasks = args.num_tasks conv = args.conv aggr = args.gcn_aggr t = args.t self.learn_t = args.learn_t p = args.p self.learn_p = args.learn_p self.msg_norm = args.msg_norm learn_msg_scale = args.learn_msg_scale conv_encode_edge = args.conv_encode_edge norm = args.norm mlp_layers = args.mlp_layers node_features_file_path = args.nf_path self.use_one_hot_encoding = args.use_one_hot_encoding if aggr in ['softmax_sg', 'softmax', 'power'] and self.num_layers > 50: self.checkpoint_grad = True self.ckp_k = 3 if (self.learn_t or self.learn_p) and self.num_layers > 15: self.checkpoint_grad = True self.ckp_k = 9 print('The number of layers {}'.format(self.num_layers), 'Aggregation method {}'.format(aggr), 'block: {}'.format(self.block)) if self.block == 'res+': print('LN/BN->ReLU->GraphConv->Res') elif self.block == 'res': print('GraphConv->LN/BN->ReLU->Res') elif self.block == 'dense': raise NotImplementedError('To be implemented') elif self.block == "plain": print('GraphConv->LN/BN->ReLU') else: raise Exception('Unknown block Type') self.gcns = torch.nn.ModuleList() self.layer_norms = torch.nn.ModuleList() for layer in range(self.num_layers): if conv == 'gen': gcn = GENConv(hidden_channels, hidden_channels, aggr=aggr, t=t, learn_t=self.learn_t, p=p, learn_p=self.learn_p, msg_norm=self.msg_norm, learn_msg_scale=learn_msg_scale, encode_edge=conv_encode_edge, edge_feat_dim=hidden_channels, norm=norm, mlp_layers=mlp_layers) else: raise Exception('Unknown Conv Type') self.gcns.append(gcn) self.layer_norms.append(norm_layer(norm, hidden_channels)) self.node_features = torch.load(node_features_file_path).to( args.device) if self.use_one_hot_encoding: self.node_one_hot_encoder = torch.nn.Linear(8, 8) self.node_features_encoder = torch.nn.Linear( 8 * 2, hidden_channels) else: self.node_features_encoder = torch.nn.Linear(8, hidden_channels) self.edge_encoder = torch.nn.Linear(8, hidden_channels) self.node_pred_linear = torch.nn.Linear(hidden_channels, num_tasks)
def __init__(self, args): super(DeeperGCN, self).__init__() self.edge_num = 2358104 self.num_layers = args.num_layers self.dropout = args.dropout self.block = args.block self.checkpoint_grad = False hidden_channels = args.hidden_channels conv = args.conv aggr = args.gcn_aggr t = args.t self.learn_t = args.learn_t p = args.p self.learn_p = args.learn_p self.msg_norm = args.msg_norm learn_msg_scale = args.learn_msg_scale norm = args.norm mlp_layers = args.mlp_layers if self.num_layers > 7: self.checkpoint_grad = True print('The number of layers {}'.format(self.num_layers), 'Aggregation method {}'.format(aggr), 'block: {}'.format(self.block)) # if self.block == 'res+': # print('LN/BN->ReLU->GraphConv->Res') # elif self.block == 'res': # print('GraphConv->LN/BN->ReLU->Res') # elif self.block == 'dense': # raise NotImplementedError('To be implemented') # elif self.block == "plain": # print('GraphConv->LN/BN->ReLU') # else: # raise Exception('Unknown block Type') self.gcns = torch.nn.ModuleList() self.norms = torch.nn.ModuleList() self.edge_mask1_train = nn.Parameter(torch.ones(self.edge_num, 1), requires_grad=True) self.edge_mask2_fixed = nn.Parameter(torch.ones(self.edge_num, 1), requires_grad=False) for layer in range(self.num_layers): if conv == 'gen': gcn = GENConv(hidden_channels, hidden_channels, aggr=aggr, t=t, learn_t=self.learn_t, p=p, learn_p=self.learn_p, msg_norm=self.msg_norm, learn_msg_scale=learn_msg_scale, norm=norm, mlp_layers=mlp_layers) else: raise Exception('Unknown Conv Type') self.gcns.append(gcn) self.norms.append(norm_layer(norm, hidden_channels))
def __init__(self, args): super(DeeperGCN, self).__init__() self.num_layers = args.num_layers self.dropout = args.dropout self.block = args.block self.checkpoint_grad = False in_channels = args.in_channels hidden_channels = args.hidden_channels num_tasks = args.num_tasks conv = args.conv aggr = args.gcn_aggr t = args.t self.learn_t = args.learn_t p = args.p self.learn_p = args.learn_p self.msg_norm = args.msg_norm learn_msg_scale = args.learn_msg_scale norm = args.norm mlp_layers = args.mlp_layers if aggr in ['softmax_sg', 'softmax', 'power'] and self.num_layers > 3: self.checkpoint_grad = True self.ckp_k = self.num_layers // 2 print('The number of layers {}'.format(self.num_layers), 'Aggregation method {}'.format(aggr), 'block: {}'.format(self.block)) if self.block == 'res+': print('LN/BN->ReLU->GraphConv->Res') elif self.block == 'res': print('GraphConv->LN/BN->ReLU->Res') elif self.block == 'dense': raise NotImplementedError('To be implemented') elif self.block == "plain": print('GraphConv->LN/BN->ReLU') else: raise Exception('Unknown block Type') self.gcns = torch.nn.ModuleList() self.norms = torch.nn.ModuleList() self.node_features_encoder = torch.nn.Linear(in_channels, hidden_channels) self.node_pred_linear = torch.nn.Linear(hidden_channels, num_tasks) for layer in range(self.num_layers): if conv == 'gen': gcn = GENConv(hidden_channels, hidden_channels, aggr=aggr, t=t, learn_t=self.learn_t, p=p, learn_p=self.learn_p, msg_norm=self.msg_norm, learn_msg_scale=learn_msg_scale, norm=norm, mlp_layers=mlp_layers) elif conv == "gat": gcn = GATConv(hidden_channels, hidden_channels, norm=norm, heads=4) else: raise Exception('Unknown Conv Type') self.gcns.append(gcn) self.norms.append(norm_layer(norm, hidden_channels))
def __init__(self, args): super(DeeperGCN, self).__init__() self.edge_num = 2484941 self.num_layers = args.num_layers self.dropout = args.dropout self.block = args.block self.checkpoint_grad = False in_channels = args.in_channels hidden_channels = args.hidden_channels num_tasks = args.num_tasks conv = args.conv aggr = args.gcn_aggr t = args.t self.learn_t = args.learn_t p = args.p self.learn_p = args.learn_p y = args.y self.learn_y = args.learn_y self.msg_norm = args.msg_norm learn_msg_scale = args.learn_msg_scale norm = args.norm mlp_layers = args.mlp_layers if aggr in ['softmax_sg', 'softmax', 'power'] and self.num_layers > 7: self.checkpoint_grad = True self.ckp_k = self.num_layers // 2 self.gcns = torch.nn.ModuleList() self.norms = torch.nn.ModuleList() self.node_features_encoder = torch.nn.Linear(in_channels, hidden_channels) self.node_pred_linear = torch.nn.Linear(hidden_channels, num_tasks) self.edge_mask1_train = nn.Parameter(torch.ones(self.edge_num, 1), requires_grad=True) self.edge_mask2_fixed = nn.Parameter(torch.ones(self.edge_num, 1), requires_grad=False) for layer in range(self.num_layers): if conv == 'gen': gcn = GENConv(hidden_channels, hidden_channels, aggr=aggr, t=t, learn_t=self.learn_t, p=p, learn_p=self.learn_p, y=y, learn_y=self.learn_y, msg_norm=self.msg_norm, learn_msg_scale=learn_msg_scale, norm=norm, mlp_layers=mlp_layers) else: raise Exception('Unknown Conv Type') self.gcns.append(gcn) self.norms.append(norm_layer(norm, hidden_channels))
def __init__( self, num_layers: int, dropout: float, block: str, conv_encode_edge: bool, add_virtual_node: bool, hidden_channels: int, num_tasks: Optional[int], conv: str, gcn_aggr: str, t: float, learn_t: bool, p: float, learn_p: bool, y: float, learn_y: bool, msg_norm: bool, learn_msg_scale: bool, norm: str, mlp_layers: int, graph_pooling: Optional[str], node_encoder: bool = False, # no pooling, produce node-level representations encode_atom: bool = True, ): super(DeeperGCN, self).__init__() self.num_layers = num_layers self.dropout = dropout self.block = block self.conv_encode_edge = conv_encode_edge self.add_virtual_node = add_virtual_node self.encoder = node_encoder self.encode_atom = encode_atom aggr = gcn_aggr self.learn_t = learn_t self.learn_p = learn_p self.learn_y = learn_y self.msg_norm = msg_norm print( "The number of layers {}".format(self.num_layers), "Aggr aggregation method {}".format(aggr), "block: {}".format(self.block), ) if self.block == "res+": print("LN/BN->ReLU->GraphConv->Res") elif self.block == "res": print("GraphConv->LN/BN->ReLU->Res") elif self.block == "dense": raise NotImplementedError("To be implemented") elif self.block == "plain": print("GraphConv->LN/BN->ReLU") else: raise Exception("Unknown block Type") self.gcns = torch.nn.ModuleList() self.norms = torch.nn.ModuleList() if self.add_virtual_node: self.virtualnode_embedding = torch.nn.Embedding(1, hidden_channels) torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0) self.mlp_virtualnode_list = torch.nn.ModuleList() for layer in range(self.num_layers - 1): self.mlp_virtualnode_list.append( MLP([hidden_channels] * 3, norm=norm)) for layer in range(self.num_layers): if conv == "gen": gcn = GENConv( hidden_channels, hidden_channels, aggr=aggr, t=t, learn_t=self.learn_t, p=p, learn_p=self.learn_p, y=y, learn_y=self.learn_p, msg_norm=self.msg_norm, learn_msg_scale=learn_msg_scale, encode_edge=self.conv_encode_edge, bond_encoder=True, norm=norm, mlp_layers=mlp_layers, ) else: raise Exception("Unknown Conv Type") self.gcns.append(gcn) self.norms.append(norm_layer(norm, hidden_channels)) self.atom_encoder = AtomEncoder(emb_dim=hidden_channels) if not self.conv_encode_edge: self.bond_encoder = BondEncoder(emb_dim=hidden_channels) if not node_encoder: if graph_pooling == "sum": self.pool = global_add_pool elif graph_pooling == "mean": self.pool = global_mean_pool elif graph_pooling == "max": self.pool = global_max_pool else: raise Exception("Unknown Pool Type") self.graph_pred_linear = torch.nn.Linear(hidden_channels, num_tasks)