Пример #1
0
def load_model(args):
    if args['model'] == 'GCN':
        from dgllife.model import GCNPredictor
        model = GCNPredictor(
            in_feats=args['in_feats'],
            hidden_feats=args['gcn_hidden_feats'],
            classifier_hidden_feats=args['classifier_hidden_feats'],
            n_tasks=args['n_tasks'])

    if args['model'] == 'GAT':
        from dgllife.model import GATPredictor
        model = GATPredictor(
            in_feats=args['in_feats'],
            hidden_feats=args['gat_hidden_feats'],
            num_heads=args['num_heads'],
            classifier_hidden_feats=args['classifier_hidden_feats'],
            n_tasks=args['n_tasks'])

    if args['model'] == 'AttentiveFP':
        from dgllife.model import AttentiveFPPredictor
        model = AttentiveFPPredictor(node_feat_size=args['node_feat_size'],
                                     edge_feat_size=args['edge_feat_size'],
                                     num_layers=args['num_layers'],
                                     num_timesteps=args['num_timesteps'],
                                     graph_feat_size=args['graph_feat_size'],
                                     n_tasks=args['n_tasks'],
                                     dropout=args['dropout'])

    if args['model'] == 'SchNet':
        from dgllife.model import SchNetPredictor
        model = SchNetPredictor(
            node_feats=args['node_feats'],
            hidden_feats=args['hidden_feats'],
            classifier_hidden_feats=args['classifier_hidden_feats'],
            n_tasks=args['n_tasks'])

    if args['model'] == 'MGCN':
        from dgllife.model import MGCNPredictor
        model = MGCNPredictor(
            feats=args['feats'],
            n_layers=args['n_layers'],
            classifier_hidden_feats=args['classifier_hidden_feats'],
            n_tasks=args['n_tasks'])

    if args['model'] == 'MPNN':
        from dgllife.model import MPNNPredictor
        model = MPNNPredictor(node_in_feats=args['node_in_feats'],
                              edge_in_feats=args['edge_in_feats'],
                              node_out_feats=args['node_out_feats'],
                              edge_hidden_feats=args['edge_hidden_feats'],
                              n_tasks=args['n_tasks'])

    return model
Пример #2
0
def load_model(args):
    if args['model'] == 'SchNet':
        from dgllife.model import SchNetPredictor
        model = SchNetPredictor(node_feats=args['node_feats'],
                                hidden_feats=args['hidden_feats'],
                                classifier_hidden_feats=args['classifier_hidden_feats'],
                                n_tasks=args['n_tasks'])

    if args['model'] == 'MGCN':
        from dgllife.model import MGCNPredictor
        model = MGCNPredictor(feats=args['feats'],
                              n_layers=args['n_layers'],
                              classifier_hidden_feats=args['classifier_hidden_feats'],
                              n_tasks=args['n_tasks'])

    if args['model'] == 'MPNN':
        from dgllife.model import MPNNPredictor
        model = MPNNPredictor(node_in_feats=args['node_in_feats'],
                              edge_in_feats=args['edge_in_feats'],
                              node_out_feats=args['node_out_feats'],
                              edge_hidden_feats=args['edge_hidden_feats'],
                              n_tasks=args['n_tasks'])

    return model
Пример #3
0
def load_model(exp_configure):
    if exp_configure['model'] == 'GCN':
        from dgllife.model import GCNPredictor
        model = GCNPredictor(
            in_feats=exp_configure['in_node_feats'],
            hidden_feats=[exp_configure['gnn_hidden_feats']] *
            exp_configure['num_gnn_layers'],
            activation=[F.relu] * exp_configure['num_gnn_layers'],
            residual=[exp_configure['residual']] *
            exp_configure['num_gnn_layers'],
            batchnorm=[exp_configure['batchnorm']] *
            exp_configure['num_gnn_layers'],
            dropout=[exp_configure['dropout']] *
            exp_configure['num_gnn_layers'],
            predictor_hidden_feats=exp_configure['predictor_hidden_feats'],
            predictor_dropout=exp_configure['dropout'],
            n_tasks=exp_configure['n_tasks'])
    elif exp_configure['model'] == 'GAT':
        from dgllife.model import GATPredictor
        model = GATPredictor(
            in_feats=exp_configure['in_node_feats'],
            hidden_feats=[exp_configure['gnn_hidden_feats']] *
            exp_configure['num_gnn_layers'],
            num_heads=[exp_configure['num_heads']] *
            exp_configure['num_gnn_layers'],
            feat_drops=[exp_configure['dropout']] *
            exp_configure['num_gnn_layers'],
            attn_drops=[exp_configure['dropout']] *
            exp_configure['num_gnn_layers'],
            alphas=[exp_configure['alpha']] * exp_configure['num_gnn_layers'],
            residuals=[exp_configure['residual']] *
            exp_configure['num_gnn_layers'],
            predictor_hidden_feats=exp_configure['predictor_hidden_feats'],
            predictor_dropout=exp_configure['dropout'],
            n_tasks=exp_configure['n_tasks'])
    elif exp_configure['model'] == 'Weave':
        from dgllife.model import WeavePredictor
        model = WeavePredictor(
            node_in_feats=exp_configure['in_node_feats'],
            edge_in_feats=exp_configure['in_edge_feats'],
            num_gnn_layers=exp_configure['num_gnn_layers'],
            gnn_hidden_feats=exp_configure['gnn_hidden_feats'],
            graph_feats=exp_configure['graph_feats'],
            gaussian_expand=exp_configure['gaussian_expand'],
            n_tasks=exp_configure['n_tasks'])
    elif exp_configure['model'] == 'MPNN':
        from dgllife.model import MPNNPredictor
        model = MPNNPredictor(
            node_in_feats=exp_configure['in_node_feats'],
            edge_in_feats=exp_configure['in_edge_feats'],
            node_out_feats=exp_configure['node_out_feats'],
            edge_hidden_feats=exp_configure['edge_hidden_feats'],
            num_step_message_passing=exp_configure['num_step_message_passing'],
            num_step_set2set=exp_configure['num_step_set2set'],
            num_layer_set2set=exp_configure['num_layer_set2set'],
            n_tasks=exp_configure['n_tasks'])
    elif exp_configure['model'] == 'AttentiveFP':
        from dgllife.model import AttentiveFPPredictor
        model = AttentiveFPPredictor(
            node_feat_size=exp_configure['in_node_feats'],
            edge_feat_size=exp_configure['in_edge_feats'],
            num_layers=exp_configure['num_layers'],
            num_timesteps=exp_configure['num_timesteps'],
            graph_feat_size=exp_configure['graph_feat_size'],
            dropout=exp_configure['dropout'],
            n_tasks=exp_configure['n_tasks'])
    elif exp_configure['model'] in [
            'gin_supervised_contextpred', 'gin_supervised_infomax',
            'gin_supervised_edgepred', 'gin_supervised_masking'
    ]:
        from dgllife.model import GINPredictor
        from dgllife.model import load_pretrained
        model = GINPredictor(num_node_emb_list=[120, 3],
                             num_edge_emb_list=[6, 3],
                             num_layers=5,
                             emb_dim=300,
                             JK=exp_configure['jk'],
                             dropout=0.5,
                             readout=exp_configure['readout'],
                             n_tasks=exp_configure['n_tasks'])
        model.gnn = load_pretrained(exp_configure['model'])
        model.gnn.JK = exp_configure['jk']
    else:
        return ValueError(
            "Expect model to be from ['GCN', 'GAT', 'Weave', 'MPNN', 'AttentiveFP', "
            "'gin_supervised_contextpred', 'gin_supervised_infomax', "
            "'gin_supervised_edgepred', 'gin_supervised_masking'], "
            "got {}".format(exp_configure['model']))

    return model
Пример #4
0
# loss_fn = nn.MSELoss(reduction='none')
loss_fn = CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=10**(-2.5),
    weight_decay=10**(-5.0),
)

train_roc_auc, test_roc_auc, train_prc_auc, test_prc_auc = train_one(
    0, model, loss_fn, optimizer, show_acc=True, EDGE=True)

#%%
model = MPNNPredictor(node_in_feats=n_feats,
                      edge_in_feats=e_feats,
                      node_out_feats=64,
                      edge_hidden_feats=128,
                      n_tasks=2,
                      num_step_message_passing=6,
                      num_step_set2set=6,
                      num_layer_set2set=3)
model = model.to(device)
# loss_fn = nn.MSELoss(reduction='none')
loss_fn = CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=10**(-2.5),
    weight_decay=10**(-5.0),
)

train_roc_auc, test_roc_auc, train_prc_auc, test_prc_auc = train_one(
    0, model, loss_fn, optimizer, show_acc=True, EDGE=True)