def load_model(args): if args['model'] == 'GCN': from dgllife.model import GCNPredictor model = GCNPredictor( in_feats=args['in_feats'], hidden_feats=args['gcn_hidden_feats'], classifier_hidden_feats=args['classifier_hidden_feats'], n_tasks=args['n_tasks']) if args['model'] == 'GAT': from dgllife.model import GATPredictor model = GATPredictor( in_feats=args['in_feats'], hidden_feats=args['gat_hidden_feats'], num_heads=args['num_heads'], classifier_hidden_feats=args['classifier_hidden_feats'], n_tasks=args['n_tasks']) if args['model'] == 'AttentiveFP': from dgllife.model import AttentiveFPPredictor model = AttentiveFPPredictor(node_feat_size=args['node_feat_size'], edge_feat_size=args['edge_feat_size'], num_layers=args['num_layers'], num_timesteps=args['num_timesteps'], graph_feat_size=args['graph_feat_size'], n_tasks=args['n_tasks'], dropout=args['dropout']) if args['model'] == 'SchNet': from dgllife.model import SchNetPredictor model = SchNetPredictor( node_feats=args['node_feats'], hidden_feats=args['hidden_feats'], classifier_hidden_feats=args['classifier_hidden_feats'], n_tasks=args['n_tasks']) if args['model'] == 'MGCN': from dgllife.model import MGCNPredictor model = MGCNPredictor( feats=args['feats'], n_layers=args['n_layers'], classifier_hidden_feats=args['classifier_hidden_feats'], n_tasks=args['n_tasks']) if args['model'] == 'MPNN': from dgllife.model import MPNNPredictor model = MPNNPredictor(node_in_feats=args['node_in_feats'], edge_in_feats=args['edge_in_feats'], node_out_feats=args['node_out_feats'], edge_hidden_feats=args['edge_hidden_feats'], n_tasks=args['n_tasks']) return model
def load_model(args): if args['model'] == 'SchNet': from dgllife.model import SchNetPredictor model = SchNetPredictor(node_feats=args['node_feats'], hidden_feats=args['hidden_feats'], classifier_hidden_feats=args['classifier_hidden_feats'], n_tasks=args['n_tasks']) if args['model'] == 'MGCN': from dgllife.model import MGCNPredictor model = MGCNPredictor(feats=args['feats'], n_layers=args['n_layers'], classifier_hidden_feats=args['classifier_hidden_feats'], n_tasks=args['n_tasks']) if args['model'] == 'MPNN': from dgllife.model import MPNNPredictor model = MPNNPredictor(node_in_feats=args['node_in_feats'], edge_in_feats=args['edge_in_feats'], node_out_feats=args['node_out_feats'], edge_hidden_feats=args['edge_hidden_feats'], n_tasks=args['n_tasks']) return model
def load_model(exp_configure): if exp_configure['model'] == 'GCN': from dgllife.model import GCNPredictor model = GCNPredictor( in_feats=exp_configure['in_node_feats'], hidden_feats=[exp_configure['gnn_hidden_feats']] * exp_configure['num_gnn_layers'], activation=[F.relu] * exp_configure['num_gnn_layers'], residual=[exp_configure['residual']] * exp_configure['num_gnn_layers'], batchnorm=[exp_configure['batchnorm']] * exp_configure['num_gnn_layers'], dropout=[exp_configure['dropout']] * exp_configure['num_gnn_layers'], predictor_hidden_feats=exp_configure['predictor_hidden_feats'], predictor_dropout=exp_configure['dropout'], n_tasks=exp_configure['n_tasks']) elif exp_configure['model'] == 'GAT': from dgllife.model import GATPredictor model = GATPredictor( in_feats=exp_configure['in_node_feats'], hidden_feats=[exp_configure['gnn_hidden_feats']] * exp_configure['num_gnn_layers'], num_heads=[exp_configure['num_heads']] * exp_configure['num_gnn_layers'], feat_drops=[exp_configure['dropout']] * exp_configure['num_gnn_layers'], attn_drops=[exp_configure['dropout']] * exp_configure['num_gnn_layers'], alphas=[exp_configure['alpha']] * exp_configure['num_gnn_layers'], residuals=[exp_configure['residual']] * exp_configure['num_gnn_layers'], predictor_hidden_feats=exp_configure['predictor_hidden_feats'], predictor_dropout=exp_configure['dropout'], n_tasks=exp_configure['n_tasks']) elif exp_configure['model'] == 'Weave': from dgllife.model import WeavePredictor model = WeavePredictor( node_in_feats=exp_configure['in_node_feats'], edge_in_feats=exp_configure['in_edge_feats'], num_gnn_layers=exp_configure['num_gnn_layers'], gnn_hidden_feats=exp_configure['gnn_hidden_feats'], graph_feats=exp_configure['graph_feats'], gaussian_expand=exp_configure['gaussian_expand'], n_tasks=exp_configure['n_tasks']) elif exp_configure['model'] == 'MPNN': from dgllife.model import MPNNPredictor model = MPNNPredictor( node_in_feats=exp_configure['in_node_feats'], edge_in_feats=exp_configure['in_edge_feats'], node_out_feats=exp_configure['node_out_feats'], edge_hidden_feats=exp_configure['edge_hidden_feats'], num_step_message_passing=exp_configure['num_step_message_passing'], num_step_set2set=exp_configure['num_step_set2set'], num_layer_set2set=exp_configure['num_layer_set2set'], n_tasks=exp_configure['n_tasks']) elif exp_configure['model'] == 'AttentiveFP': from dgllife.model import AttentiveFPPredictor model = AttentiveFPPredictor( node_feat_size=exp_configure['in_node_feats'], edge_feat_size=exp_configure['in_edge_feats'], num_layers=exp_configure['num_layers'], num_timesteps=exp_configure['num_timesteps'], graph_feat_size=exp_configure['graph_feat_size'], dropout=exp_configure['dropout'], n_tasks=exp_configure['n_tasks']) elif exp_configure['model'] in [ 'gin_supervised_contextpred', 'gin_supervised_infomax', 'gin_supervised_edgepred', 'gin_supervised_masking' ]: from dgllife.model import GINPredictor from dgllife.model import load_pretrained model = GINPredictor(num_node_emb_list=[120, 3], num_edge_emb_list=[6, 3], num_layers=5, emb_dim=300, JK=exp_configure['jk'], dropout=0.5, readout=exp_configure['readout'], n_tasks=exp_configure['n_tasks']) model.gnn = load_pretrained(exp_configure['model']) model.gnn.JK = exp_configure['jk'] else: return ValueError( "Expect model to be from ['GCN', 'GAT', 'Weave', 'MPNN', 'AttentiveFP', " "'gin_supervised_contextpred', 'gin_supervised_infomax', " "'gin_supervised_edgepred', 'gin_supervised_masking'], " "got {}".format(exp_configure['model'])) return model
# loss_fn = nn.MSELoss(reduction='none') loss_fn = CrossEntropyLoss() optimizer = torch.optim.Adam( model.parameters(), lr=10**(-2.5), weight_decay=10**(-5.0), ) train_roc_auc, test_roc_auc, train_prc_auc, test_prc_auc = train_one( 0, model, loss_fn, optimizer, show_acc=True, EDGE=True) #%% model = MPNNPredictor(node_in_feats=n_feats, edge_in_feats=e_feats, node_out_feats=64, edge_hidden_feats=128, n_tasks=2, num_step_message_passing=6, num_step_set2set=6, num_layer_set2set=3) model = model.to(device) # loss_fn = nn.MSELoss(reduction='none') loss_fn = CrossEntropyLoss() optimizer = torch.optim.Adam( model.parameters(), lr=10**(-2.5), weight_decay=10**(-5.0), ) train_roc_auc, test_roc_auc, train_prc_auc, test_prc_auc = train_one( 0, model, loss_fn, optimizer, show_acc=True, EDGE=True)