Ejemplo n.º 1
0
def main(params):
    simplefilter(action='ignore', category=UserWarning)
    simplefilter(action='ignore', category=SparseEfficiencyWarning)

    graph_classifier = initialize_model(params, None, load_model=True)

    logging.info(f"Device: {params.device}")

    all_auc = []
    auc_mean = 0

    all_auc_pr = []
    auc_pr_mean = 0
    for r in range(1, params.runs + 1):

        params.db_path = os.path.join(
            params.main_dir,
            f'data/{params.dataset}/test_subgraphs_{params.experiment_name}_{params.constrained_neg_prob}_en_{params.enclosing_sub_graph}'
        )

        generate_subgraph_datasets(
            params,
            splits=['test'],
            saved_relation2id=graph_classifier.relation2id,
            max_label_value=graph_classifier.gnn.max_label_value)

        test = SubgraphDataset(
            params.db_path,
            'test_pos',
            'test_neg',
            params.file_paths,
            graph_classifier.relation2id,
            add_traspose_rels=params.add_traspose_rels,
            num_neg_samples_per_link=params.num_neg_samples_per_link,
            use_kge_embeddings=params.use_kge_embeddings,
            dataset=params.dataset,
            kge_model=params.kge_model,
            file_name=params.test_file)

        test_evaluator = Evaluator(params, graph_classifier, test)

        result = test_evaluator.eval(save=True)
        logging.info('\nTest Set Performance:' + str(result))
        all_auc.append(result['auc'])
        auc_mean = auc_mean + (result['auc'] - auc_mean) / r

        all_auc_pr.append(result['auc_pr'])
        auc_pr_mean = auc_pr_mean + (result['auc_pr'] - auc_pr_mean) / r

    auc_std = np.std(all_auc)
    auc_pr_std = np.std(all_auc_pr)

    logging.info('\nAvg test Set Performance -- mean auc :' +
                 str(np.mean(all_auc)) + ' std auc: ' + str(np.std(all_auc)))
    logging.info('\nAvg test Set Performance -- mean auc_pr :' +
                 str(np.mean(all_auc_pr)) + ' std auc_pr: ' +
                 str(np.std(all_auc_pr)))
Ejemplo n.º 2
0
def main(params):

    simplefilter(action='ignore', category=UserWarning)
    simplefilter(action='ignore', category=SparseEfficiencyWarning)

    params.db_path = os.path.join(
        params.main_dir,
        f'../../data/{params.dataset}/subgraphs_en_{params.enclosing_sub_graph}_neg_{params.num_neg_samples_per_link}_hop_{params.hop}'
    )

    if not os.path.isdir(params.db_path):
        generate_subgraph_datasets(params)

    train = SubgraphDataset(
        params.db_path,
        'train_pos',
        'train_neg',
        params.file_paths,
        add_traspose_rels=params.add_traspose_rels,
        num_neg_samples_per_link=params.num_neg_samples_per_link,
        dataset=params.dataset,
        file_name=params.train_file)
    valid = SubgraphDataset(
        params.db_path,
        'valid_pos',
        'valid_neg',
        params.file_paths,
        add_traspose_rels=params.add_traspose_rels,
        num_neg_samples_per_link=params.num_neg_samples_per_link,
        dataset=params.dataset,
        file_name=params.valid_file)

    params.num_rels = train.num_rels
    params.aug_num_rels = train.aug_num_rels
    params.inp_dim = train.n_feat_dim

    # Log the max label value to save it in the model. This will be used to cap the labels generated on test set.
    params.max_label_value = train.max_n_label

    graph_classifier = initialize_model(params, dgl_model, params.load_model)

    logging.info(f"Device: {params.device}")
    logging.info(
        f"Input dim : {params.inp_dim}, # Relations : {params.num_rels}, # Augmented relations : {params.aug_num_rels}"
    )

    valid_evaluator = Evaluator(params, graph_classifier, valid)

    trainer = Trainer(params, graph_classifier, train, valid_evaluator)

    logging.info('Starting training with full batch...')

    trainer.train()
Ejemplo n.º 3
0
def main(params):
    simplefilter(action='ignore', category=UserWarning)
    simplefilter(action='ignore', category=SparseEfficiencyWarning)

    params.db_path = os.path.join(
        params.main_dir,
        f'data/{params.dataset}/subgraphs_en_{params.enclosing_sub_graph}_neg_{params.num_neg_samples_per_link}_hop_{params.hop}'
    )

    if not os.path.isdir(params.db_path):
        generate_subgraph_datasets(params)

    train = SubgraphDataset(
        params.db_path,
        'train_pos',
        'train_neg',
        params.file_paths,
        add_traspose_rels=params.add_traspose_rels,
        num_neg_samples_per_link=params.num_neg_samples_per_link,
        use_kge_embeddings=params.use_kge_embeddings,
        dataset=params.dataset,
        kge_model=params.kge_model,
        file_name=params.train_file)
    #assert 0
    valid = SubgraphDataset(
        params.db_path,
        'valid_pos',
        'valid_neg',
        params.file_paths,
        add_traspose_rels=params.add_traspose_rels,
        num_neg_samples_per_link=params.num_neg_samples_per_link,
        use_kge_embeddings=params.use_kge_embeddings,
        dataset=params.dataset,
        kge_model=params.kge_model,
        file_name=params.valid_file,
        ssp_graph=train.ssp_graph,
        id2entity=train.id2entity,
        id2relation=train.id2relation,
        rel=train.num_rels,
        graph=train.graph)
    test = SubgraphDataset(
        params.db_path,
        'test_pos',
        'test_neg',
        params.file_paths,
        add_traspose_rels=params.add_traspose_rels,
        num_neg_samples_per_link=params.num_neg_samples_per_link,
        use_kge_embeddings=params.use_kge_embeddings,
        dataset=params.dataset,
        kge_model=params.kge_model,
        file_name=params.valid_file,
        ssp_graph=train.ssp_graph,
        id2entity=train.id2entity,
        id2relation=train.id2relation,
        rel=train.num_rels,
        graph=train.graph)
    params.num_rels = train.num_rels
    params.aug_num_rels = train.aug_num_rels
    params.inp_dim = train.n_feat_dim
    params.train_rels = 200 if params.dataset == 'BioSNAP' else params.num_rels
    params.num_nodes = 35000

    # Log the max label value to save it in the model. This will be used to cap the labels generated on test set.
    params.max_label_value = train.max_n_label
    logging.info(f"Device: {params.device}")
    logging.info(
        f"Input dim : {params.inp_dim}, # Relations : {params.num_rels}, # Augmented relations : {params.aug_num_rels}"
    )

    graph_classifier = initialize_model(params, dgl_model, params.load_model)
    if params.dataset == 'drugbank':
        if params.feat == 'morgan':
            import pickle
            with open('data/{}/DB_molecular_feats.pkl'.format(params.dataset),
                      'rb') as f:
                x = pickle.load(f, encoding='utf-8')
            mfeat = []
            for y in x['Morgan_Features']:
                mfeat.append(y)
            params.feat_dim = 1024
        elif params.feat == 'pca':
            mfeat = np.loadtxt('data/{}/PCA.txt'.format(params.dataset))
            params.feat_dim = 200
        elif params.feat == 'pretrained':
            mfeat = np.loadtxt('data/{}/pretrained.txt'.format(params.dataset))
            params.feat_dim = 200
    elif params.dataset == 'BioSNAP':
        mfeat = []
        rfeat = []
        import pickle
        with open('data/{}/id2drug_feat.pkl'.format(params.dataset),
                  'rb') as f:
            x = pickle.load(f, encoding='utf-8')
        for z in x:
            y = x[z]['Morgan']
            mfeat.append(y)
            y = x[z]['rdkit2d']
            rfeat.append(y)
            params.feat_dim = 1024

    graph_classifier.drug_feat(
        torch.FloatTensor(np.array(mfeat)).to(params.device))

    valid_evaluator = Evaluator(
        params, graph_classifier,
        valid) if params.dataset == 'drugbank' else Evaluator_ddi2(
            params, graph_classifier, valid)
    test_evaluator = Evaluator(
        params, graph_classifier,
        test) if params.dataset == 'drugbank' else Evaluator_ddi2(
            params, graph_classifier, test)
    train_evaluator = Evaluator(
        params, graph_classifier,
        train) if params.dataset == 'drugbank' else Evaluator_ddi2(
            params, graph_classifier, valid)

    trainer = Trainer(params, graph_classifier, train, train_evaluator,
                      valid_evaluator, test_evaluator)

    logging.info('Starting training with full batch...')
    trainer.train()