示例#1
0
def load_model(args):
    if args['model'] == 'GCN':
        from dgllife.model import GCNPredictor
        model = GCNPredictor(
            in_feats=args['node_featurizer'].feat_size(),
            hidden_feats=args['gcn_hidden_feats'],
            predictor_hidden_feats=args['predictor_hidden_feats'],
            n_tasks=args['n_tasks'])

    if args['model'] == 'GAT':
        from dgllife.model import GATPredictor
        model = GATPredictor(
            in_feats=args['node_featurizer'].feat_size(),
            hidden_feats=args['gat_hidden_feats'],
            num_heads=args['num_heads'],
            predictor_hidden_feats=args['predictor_hidden_feats'],
            n_tasks=args['n_tasks'])

    if args['model'] == 'Weave':
        from dgllife.model import WeavePredictor
        model = WeavePredictor(
            node_in_feats=args['node_featurizer'].feat_size(),
            edge_in_feats=args['edge_featurizer'].feat_size(),
            num_gnn_layers=args['num_gnn_layers'],
            gnn_hidden_feats=args['gnn_hidden_feats'],
            graph_feats=args['graph_feats'],
            n_tasks=args['n_tasks'])

    return model
示例#2
0
def load_model(args):
    if args['model'] == 'GCN':
        from dgllife.model import GCNPredictor
        model = GCNPredictor(
            in_feats=args['in_feats'],
            hidden_feats=args['gcn_hidden_feats'],
            classifier_hidden_feats=args['classifier_hidden_feats'],
            n_tasks=args['n_tasks'])

    if args['model'] == 'GAT':
        from dgllife.model import GATPredictor
        model = GATPredictor(
            in_feats=args['in_feats'],
            hidden_feats=args['gat_hidden_feats'],
            num_heads=args['num_heads'],
            classifier_hidden_feats=args['classifier_hidden_feats'],
            n_tasks=args['n_tasks'])

    if args['model'] == 'AttentiveFP':
        from dgllife.model import AttentiveFPPredictor
        model = AttentiveFPPredictor(node_feat_size=args['node_feat_size'],
                                     edge_feat_size=args['edge_feat_size'],
                                     num_layers=args['num_layers'],
                                     num_timesteps=args['num_timesteps'],
                                     graph_feat_size=args['graph_feat_size'],
                                     n_tasks=args['n_tasks'],
                                     dropout=args['dropout'])

    if args['model'] == 'SchNet':
        from dgllife.model import SchNetPredictor
        model = SchNetPredictor(
            node_feats=args['node_feats'],
            hidden_feats=args['hidden_feats'],
            classifier_hidden_feats=args['classifier_hidden_feats'],
            n_tasks=args['n_tasks'])

    if args['model'] == 'MGCN':
        from dgllife.model import MGCNPredictor
        model = MGCNPredictor(
            feats=args['feats'],
            n_layers=args['n_layers'],
            classifier_hidden_feats=args['classifier_hidden_feats'],
            n_tasks=args['n_tasks'])

    if args['model'] == 'MPNN':
        from dgllife.model import MPNNPredictor
        model = MPNNPredictor(node_in_feats=args['node_in_feats'],
                              edge_in_feats=args['edge_in_feats'],
                              node_out_feats=args['node_out_feats'],
                              edge_hidden_feats=args['edge_hidden_feats'],
                              n_tasks=args['n_tasks'])

    return model
示例#3
0
def load_model(exp_configure):
    if exp_configure['model'] == 'GCN':
        from dgllife.model import GCNPredictor
        model = GCNPredictor(
            in_feats=exp_configure['in_node_feats'],
            hidden_feats=[exp_configure['gnn_hidden_feats']] *
            exp_configure['num_gnn_layers'],
            activation=[F.relu] * exp_configure['num_gnn_layers'],
            residual=[exp_configure['residual']] *
            exp_configure['num_gnn_layers'],
            batchnorm=[exp_configure['batchnorm']] *
            exp_configure['num_gnn_layers'],
            dropout=[exp_configure['dropout']] *
            exp_configure['num_gnn_layers'],
            predictor_hidden_feats=exp_configure['predictor_hidden_feats'],
            predictor_dropout=exp_configure['dropout'],
            n_tasks=exp_configure['n_tasks'])
    elif exp_configure['model'] == 'GAT':
        from dgllife.model import GATPredictor
        model = GATPredictor(
            in_feats=exp_configure['in_node_feats'],
            hidden_feats=[exp_configure['gnn_hidden_feats']] *
            exp_configure['num_gnn_layers'],
            num_heads=[exp_configure['num_heads']] *
            exp_configure['num_gnn_layers'],
            feat_drops=[exp_configure['dropout']] *
            exp_configure['num_gnn_layers'],
            attn_drops=[exp_configure['dropout']] *
            exp_configure['num_gnn_layers'],
            alphas=[exp_configure['alpha']] * exp_configure['num_gnn_layers'],
            residuals=[exp_configure['residual']] *
            exp_configure['num_gnn_layers'],
            predictor_hidden_feats=exp_configure['predictor_hidden_feats'],
            predictor_dropout=exp_configure['dropout'],
            n_tasks=exp_configure['n_tasks'])
    elif exp_configure['model'] == 'Weave':
        from dgllife.model import WeavePredictor
        model = WeavePredictor(
            node_in_feats=exp_configure['in_node_feats'],
            edge_in_feats=exp_configure['in_edge_feats'],
            num_gnn_layers=exp_configure['num_gnn_layers'],
            gnn_hidden_feats=exp_configure['gnn_hidden_feats'],
            graph_feats=exp_configure['graph_feats'],
            gaussian_expand=exp_configure['gaussian_expand'],
            n_tasks=exp_configure['n_tasks'])
    elif exp_configure['model'] == 'MPNN':
        from dgllife.model import MPNNPredictor
        model = MPNNPredictor(
            node_in_feats=exp_configure['in_node_feats'],
            edge_in_feats=exp_configure['in_edge_feats'],
            node_out_feats=exp_configure['node_out_feats'],
            edge_hidden_feats=exp_configure['edge_hidden_feats'],
            num_step_message_passing=exp_configure['num_step_message_passing'],
            num_step_set2set=exp_configure['num_step_set2set'],
            num_layer_set2set=exp_configure['num_layer_set2set'],
            n_tasks=exp_configure['n_tasks'])
    elif exp_configure['model'] == 'AttentiveFP':
        from dgllife.model import AttentiveFPPredictor
        model = AttentiveFPPredictor(
            node_feat_size=exp_configure['in_node_feats'],
            edge_feat_size=exp_configure['in_edge_feats'],
            num_layers=exp_configure['num_layers'],
            num_timesteps=exp_configure['num_timesteps'],
            graph_feat_size=exp_configure['graph_feat_size'],
            dropout=exp_configure['dropout'],
            n_tasks=exp_configure['n_tasks'])
    elif exp_configure['model'] in [
            'gin_supervised_contextpred', 'gin_supervised_infomax',
            'gin_supervised_edgepred', 'gin_supervised_masking'
    ]:
        from dgllife.model import GINPredictor
        from dgllife.model import load_pretrained
        model = GINPredictor(num_node_emb_list=[120, 3],
                             num_edge_emb_list=[6, 3],
                             num_layers=5,
                             emb_dim=300,
                             JK=exp_configure['jk'],
                             dropout=0.5,
                             readout=exp_configure['readout'],
                             n_tasks=exp_configure['n_tasks'])
        model.gnn = load_pretrained(exp_configure['model'])
        model.gnn.JK = exp_configure['jk']
    else:
        return ValueError(
            "Expect model to be from ['GCN', 'GAT', 'Weave', 'MPNN', 'AttentiveFP', "
            "'gin_supervised_contextpred', 'gin_supervised_infomax', "
            "'gin_supervised_edgepred', 'gin_supervised_masking'], "
            "got {}".format(exp_configure['model']))

    return model