def inference(lf):
    lf.batch_size = args.dev_batch_size
    lf.eval()
    if args.model == "hypere":
        conve_kg_state_dict = get_conve_kg_state_dict(
            torch.load(args.conve_state_dict_path))
        lf.kg.load_state_dict(conve_kg_state_dict)
        secondary_kg_state_dict = get_complex_kg_state_dict(
            torch.load(args.complex_state_dict_path))
        lf.secondary_kg.load_state_dict(secondary_kg_state_dict)
    elif args.model == "triplee":
        conve_kg_state_dict = get_conve_kg_state_dict(
            torch.load(args.conve_state_dict_path))
        lf.kg.load_state_dict(conve_kg_state_dict)
        complex_kg_state_dict = get_complex_kg_state_dict(
            torch.load(args.complex_state_dict_path))
        lf.secondary_kg.load_state_dict(complex_kg_state_dict)
        distmult_kg_state_dict = get_distmult_kg_state_dict(
            torch.load(args.distmult_state_dict_path))
        lf.tertiary_kg.load_state_dict(distmult_kg_state_dict)
    else:
        lf.load_checkpoint(get_checkpoint_path(args))
    print(lf.kg.entity_embeddings)
    print(lf.kg.relation_embeddings)
    entity_index_path = os.path.join(args.data_dir, "entity2id.txt")
    relation_index_path = os.path.join(args.data_dir, "relation2id.txt")
    if "NELL" in args.data_dir:
        adj_list_path = os.path.join(args.data_dir, "adj_list.pkl")
        seen_entities = data_utils.load_seen_entities(adj_list_path,
                                                      entity_index_path)
    else:
        seen_entities = set()

    eval_metrics = {"dev": {}, "test": {}}

    if args.compute_map:
        relation_sets = [
            "concept:athletehomestadium",
            "concept:athleteplaysforteam",
            "concept:athleteplaysinleague",
            "concept:athleteplayssport",
            "concept:organizationheadquarteredincity",
            "concept:organizationhiredperson",
            "concept:personborninlocation",
            "concept:teamplayssport",
            "concept:worksfor",
        ]
        mps = []
        for r in relation_sets:
            print("* relation: {}".format(r))
            test_path = os.path.join(args.data_dir, "tasks", r, "test.pairs")
            test_data, labels = data_utils.load_triples_with_label(
                test_path,
                r,
                entity_index_path,
                relation_index_path,
                seen_entities=seen_entities,
            )
            pred_scores = lf.forward(test_data, verbose=False)
            mp = src.eval.link_MAP(test_data,
                                   pred_scores,
                                   labels,
                                   lf.kg.all_objects,
                                   verbose=True)
            mps.append(mp)
        import numpy as np

        map_ = np.mean(mps)
        print("Overall MAP = {}".format(map_))
        eval_metrics["test"]["avg_map"] = map
    elif args.eval_by_relation_type:
        dev_path = os.path.join(args.data_dir, "dev.triples")
        dev_data = data_utils.load_triples(
            dev_path,
            entity_index_path,
            relation_index_path,
            seen_entities=seen_entities,
        )
        pred_scores = lf.forward(dev_data, verbose=False)
        to_m_rels, to_1_rels, _ = data_utils.get_relations_by_type(
            args.data_dir, relation_index_path)
        relation_by_types = (to_m_rels, to_1_rels)
        print("Dev set evaluation by relation type (partial graph)")
        src.eval.hits_and_ranks_by_relation_type(dev_data,
                                                 pred_scores,
                                                 lf.kg.dev_objects,
                                                 relation_by_types,
                                                 verbose=True)
        print("Dev set evaluation by relation type (full graph)")
        src.eval.hits_and_ranks_by_relation_type(dev_data,
                                                 pred_scores,
                                                 lf.kg.all_objects,
                                                 relation_by_types,
                                                 verbose=True)
    elif args.eval_by_seen_queries:
        dev_path = os.path.join(args.data_dir, "dev.triples")
        dev_data = data_utils.load_triples(
            dev_path,
            entity_index_path,
            relation_index_path,
            seen_entities=seen_entities,
        )
        pred_scores = lf.forward(dev_data, verbose=False)
        seen_queries = data_utils.get_seen_queries(args.data_dir,
                                                   entity_index_path,
                                                   relation_index_path)
        print("Dev set evaluation by seen queries (partial graph)")
        src.eval.hits_and_ranks_by_seen_queries(dev_data,
                                                pred_scores,
                                                lf.kg.dev_objects,
                                                seen_queries,
                                                verbose=True)
        print("Dev set evaluation by seen queries (full graph)")
        src.eval.hits_and_ranks_by_seen_queries(dev_data,
                                                pred_scores,
                                                lf.kg.all_objects,
                                                seen_queries,
                                                verbose=True)
    else:
        if args.few_shot:
            dev_path = os.path.join(args.data_dir, "dev.triples")
            test_path = os.path.join(args.data_dir, "test.triples")
            _, dev_data = data_utils.load_triples(
                dev_path,
                entity_index_path,
                relation_index_path,
                seen_entities=seen_entities,
                verbose=False,
                few_shot=True,
                lf=lf,
            )
            _, test_data = data_utils.load_triples(
                test_path,
                entity_index_path,
                relation_index_path,
                seen_entities=seen_entities,
                verbose=False,
                few_shot=True,
                lf=lf,
            )
            num = 0
            hits_10 = 0.0
            hits_1 = 0.0
            hits_3 = 0.0
            hits_5 = 0.0
            mrr = 0.0
            for x in test_data:
                lf.load_checkpoint(
                    args.checkpoint_path.replace("[relation]", str(x)))
                print("Test set of relation {} performance:".format(x))
                pred_scores = lf.forward(test_data[x], verbose=False)
                test_metrics = src.eval.hits_and_ranks(test_data[x],
                                                       pred_scores,
                                                       lf.kg.all_objects,
                                                       verbose=True)
                eval_metrics["test"]["hits_at_1"] = test_metrics[0]
                eval_metrics["test"]["hits_at_3"] = test_metrics[1]
                eval_metrics["test"]["hits_at_5"] = test_metrics[2]
                eval_metrics["test"]["hits_at_10"] = test_metrics[3]
                eval_metrics["test"]["mrr"] = test_metrics[4]
                num += len(test_data[x])
                hits_1 += float(test_metrics[0]) * len(test_data[x])
                hits_3 += float(test_metrics[1]) * len(test_data[x])
                hits_5 += float(test_metrics[2]) * len(test_data[x])
                hits_10 += float(test_metrics[3]) * len(test_data[x])
                mrr += float(test_metrics[4]) * len(test_data[x])
            print("Hits@1 = {}".format(hits_1 / num))
            print("Hits@3 = {}".format(hits_3 / num))
            print("Hits@5 = {}".format(hits_5 / num))
            print("Hits@10 = {}".format(hits_10 / num))
            print("MRR = {}".format(mrr / num))
        else:
            test_path = os.path.join(args.data_dir, "test.triples")
            test_data = data_utils.load_triples(
                test_path,
                entity_index_path,
                relation_index_path,
                seen_entities=seen_entities,
                verbose=False,
            )
            print("Test set performance:")
            pred_scores = lf.forward(test_data, verbose=False)
            test_metrics = src.eval.hits_and_ranks(
                test_data,
                pred_scores,
                lf.kg.all_objects,
                verbose=True,
                output=True,
                kg=lf.kg,
                model_name=args.model,
            )
            eval_metrics["test"]["hits_at_1"] = test_metrics[0]
            eval_metrics["test"]["hits_at_3"] = test_metrics[1]
            eval_metrics["test"]["hits_at_5"] = test_metrics[2]
            eval_metrics["test"]["hits_at_10"] = test_metrics[3]
            eval_metrics["test"]["mrr"] = test_metrics[4]

    return eval_metrics
예제 #2
0
파일: experiments.py 프로젝트: h-peng17/KGE
def inference(lf):
    lf.batch_size = args.dev_batch_size
    lf.eval()
    if args.model == 'hypere':
        conve_kg_state_dict = get_conve_kg_state_dict(torch.load(args.conve_state_dict_path))
        lf.kg.load_state_dict(conve_kg_state_dict)
        secondary_kg_state_dict = get_complex_kg_state_dict(torch.load(args.complex_state_dict_path))
        lf.secondary_kg.load_state_dict(secondary_kg_state_dict)
    elif args.model == 'triplee':
        conve_kg_state_dict = get_conve_kg_state_dict(torch.load(args.conve_state_dict_path))
        lf.kg.load_state_dict(conve_kg_state_dict)
        complex_kg_state_dict = get_complex_kg_state_dict(torch.load(args.complex_state_dict_path))
        lf.secondary_kg.load_state_dict(complex_kg_state_dict)
        distmult_kg_state_dict = get_distmult_kg_state_dict(torch.load(args.distmult_state_dict_path))
        lf.tertiary_kg.load_state_dict(distmult_kg_state_dict)
    else:
        lf.load_checkpoint(get_checkpoint_path(args))
    entity_index_path = os.path.join(args.data_dir, 'entity2id.txt')
    relation_index_path = os.path.join(args.data_dir, 'relation2id.txt')
    if 'NELL' in args.data_dir:
        adj_list_path = os.path.join(args.data_dir, 'adj_list.pkl')
        seen_entities = data_utils.load_seen_entities(adj_list_path, entity_index_path)
    else:
        seen_entities = set()

    eval_metrics = {
        'dev': {},
        'test': {}
    }

    if args.compute_map:
        relation_sets = [
            'concept:athletehomestadium',
            'concept:athleteplaysforteam',
            'concept:athleteplaysinleague',
            'concept:athleteplayssport',
            'concept:organizationheadquarteredincity',
            'concept:organizationhiredperson',
            'concept:personborninlocation',
            'concept:teamplayssport',
            'concept:worksfor'
        ]
        mps = []
        for r in relation_sets:
            print('* relation: {}'.format(r))
            test_path = os.path.join(args.data_dir, 'tasks', r, 'test.pairs')
            test_data, labels = data_utils.load_triples_with_label(
                test_path, r, entity_index_path, relation_index_path, seen_entities=seen_entities)
            pred_scores = lf.forward(test_data, verbose=False)
            mp = src.eval.link_MAP(test_data, pred_scores, labels, lf.kg.all_objects, verbose=True)
            mps.append(mp)
        map_ = np.mean(mps)
        print('Overall MAP = {}'.format(map_))
        eval_metrics['test']['avg_map'] = map
    elif args.eval_by_relation_type:
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        dev_data = data_utils.load_triples(dev_path, entity_index_path, relation_index_path, seen_entities=seen_entities)
        pred_scores = lf.forward(dev_data, verbose=False)
        to_m_rels, to_1_rels, _ = data_utils.get_relations_by_type(args.data_dir, relation_index_path)
        relation_by_types = (to_m_rels, to_1_rels)
        print('Dev set evaluation by relation type (partial graph)')
        src.eval.hits_and_ranks_by_relation_type(
            dev_data, pred_scores, lf.kg.dev_objects, relation_by_types, verbose=True)
        print('Dev set evaluation by relation type (full graph)')
        src.eval.hits_and_ranks_by_relation_type(
            dev_data, pred_scores, lf.kg.all_objects, relation_by_types, verbose=True)
    elif args.eval_by_seen_queries:
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        dev_data = data_utils.load_triples(dev_path, entity_index_path, relation_index_path, seen_entities=seen_entities)
        pred_scores = lf.forward(dev_data, verbose=False)
        seen_queries = data_utils.get_seen_queries(args.data_dir, entity_index_path, relation_index_path)
        print('Dev set evaluation by seen queries (partial graph)')
        src.eval.hits_and_ranks_by_seen_queries(
            dev_data, pred_scores, lf.kg.dev_objects, seen_queries, verbose=True)
        print('Dev set evaluation by seen queries (full graph)')
        src.eval.hits_and_ranks_by_seen_queries(
            dev_data, pred_scores, lf.kg.all_objects, seen_queries, verbose=True)
    else:
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        test_path = os.path.join(args.data_dir, 'test.triples')
        dev_data = data_utils.load_triples(
            dev_path, entity_index_path, relation_index_path, seen_entities=seen_entities, verbose=False)
        test_data = data_utils.load_triples(
            test_path, entity_index_path, relation_index_path, seen_entities=seen_entities, verbose=False)
        print('Dev set performance:')
        pred_scores = lf.forward(dev_data, verbose=args.save_beam_search_paths)
        dev_metrics = src.eval.hits_and_ranks(dev_data, pred_scores, lf.kg.dev_objects, verbose=True)
        eval_metrics['dev'] = {}
        eval_metrics['dev']['hits_at_1'] = dev_metrics[0]
        eval_metrics['dev']['hits_at_3'] = dev_metrics[1]
        eval_metrics['dev']['hits_at_5'] = dev_metrics[2]
        eval_metrics['dev']['hits_at_10'] = dev_metrics[3]
        eval_metrics['dev']['mrr'] = dev_metrics[4]
        src.eval.hits_and_ranks(dev_data, pred_scores, lf.kg.all_objects, verbose=True)
        print('Test set performance:')
        pred_scores = lf.forward(test_data, verbose=False)
        test_metrics = src.eval.hits_and_ranks(test_data, pred_scores, lf.kg.all_objects, verbose=True)
        eval_metrics['test']['hits_at_1'] = test_metrics[0]
        eval_metrics['test']['hits_at_3'] = test_metrics[1]
        eval_metrics['test']['hits_at_5'] = test_metrics[2]
        eval_metrics['test']['hits_at_10'] = test_metrics[3]
        eval_metrics['test']['mrr'] = test_metrics[4]

    return eval_metrics
예제 #3
0
def inference(lf):
    lf.batch_size = args.dev_batch_size
    lf.eval()
    if args.model == 'hypere':
        conve_kg_state_dict = get_conve_kg_state_dict(
            torch.load(args.conve_state_dict_path))
        lf.kg.load_state_dict(conve_kg_state_dict)
        secondary_kg_state_dict = get_complex_kg_state_dict(
            torch.load(args.complex_state_dict_path))
        lf.secondary_kg.load_state_dict(secondary_kg_state_dict)
    elif args.model == 'triplee':
        conve_kg_state_dict = get_conve_kg_state_dict(
            torch.load(args.conve_state_dict_path))
        lf.kg.load_state_dict(conve_kg_state_dict)
        complex_kg_state_dict = get_complex_kg_state_dict(
            torch.load(args.complex_state_dict_path))
        lf.secondary_kg.load_state_dict(complex_kg_state_dict)
        distmult_kg_state_dict = get_distmult_kg_state_dict(
            torch.load(args.distmult_state_dict_path))
        lf.tertiary_kg.load_state_dict(distmult_kg_state_dict)
    else:
        lf.load_checkpoint(get_checkpoint_path(args))
    print(lf.kg.entity_embeddings)
    print(lf.kg.relation_embeddings)
    entity_index_path = os.path.join(args.data_dir, 'entity2id.txt')
    relation_index_path = os.path.join(args.data_dir, 'relation2id.txt')
    if 'NELL' in args.data_dir:
        adj_list_path = os.path.join(args.data_dir, 'adj_list.pkl')
        seen_entities = data_utils.load_seen_entities(adj_list_path,
                                                      entity_index_path)
    else:
        seen_entities = set()

    eval_metrics = {'dev': {}, 'test': {}}

    if args.compute_map:
        relation_sets = [
            'concept:athletehomestadium', 'concept:athleteplaysforteam',
            'concept:athleteplaysinleague', 'concept:athleteplayssport',
            'concept:organizationheadquarteredincity',
            'concept:organizationhiredperson', 'concept:personborninlocation',
            'concept:teamplayssport', 'concept:worksfor'
        ]
        mps = []
        for r in relation_sets:
            print('* relation: {}'.format(r))
            test_path = os.path.join(args.data_dir, 'tasks', r, 'test.pairs')
            test_data, labels = data_utils.load_triples_with_label(
                test_path,
                r,
                entity_index_path,
                relation_index_path,
                seen_entities=seen_entities)
            pred_scores = lf.forward(test_data, verbose=False)
            mp = src.eval.link_MAP(test_data,
                                   pred_scores,
                                   labels,
                                   lf.kg.all_objects,
                                   verbose=True)
            mps.append(mp)
        import numpy as np
        map_ = np.mean(mps)
        print('Overall MAP = {}'.format(map_))
        eval_metrics['test']['avg_map'] = map
    elif args.eval_by_relation_type:
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        dev_data = data_utils.load_triples(dev_path,
                                           entity_index_path,
                                           relation_index_path,
                                           seen_entities=seen_entities)
        pred_scores = lf.forward(dev_data, verbose=False)
        to_m_rels, to_1_rels, _ = data_utils.get_relations_by_type(
            args.data_dir, relation_index_path)
        relation_by_types = (to_m_rels, to_1_rels)
        print('Dev set evaluation by relation type (partial graph)')
        src.eval.hits_and_ranks_by_relation_type(dev_data,
                                                 pred_scores,
                                                 lf.kg.dev_objects,
                                                 relation_by_types,
                                                 verbose=True)
        print('Dev set evaluation by relation type (full graph)')
        src.eval.hits_and_ranks_by_relation_type(dev_data,
                                                 pred_scores,
                                                 lf.kg.all_objects,
                                                 relation_by_types,
                                                 verbose=True)
    elif args.eval_by_seen_queries:
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        dev_data = data_utils.load_triples(dev_path,
                                           entity_index_path,
                                           relation_index_path,
                                           seen_entities=seen_entities)
        pred_scores = lf.forward(dev_data, verbose=False)
        seen_queries = data_utils.get_seen_queries(args.data_dir,
                                                   entity_index_path,
                                                   relation_index_path)
        print('Dev set evaluation by seen queries (partial graph)')
        src.eval.hits_and_ranks_by_seen_queries(dev_data,
                                                pred_scores,
                                                lf.kg.dev_objects,
                                                seen_queries,
                                                verbose=True)
        print('Dev set evaluation by seen queries (full graph)')
        src.eval.hits_and_ranks_by_seen_queries(dev_data,
                                                pred_scores,
                                                lf.kg.all_objects,
                                                seen_queries,
                                                verbose=True)
    else:
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        test_path = os.path.join(args.data_dir, 'test.triples')
        dev_data = data_utils.load_triples(dev_path,
                                           entity_index_path,
                                           relation_index_path,
                                           seen_entities=seen_entities,
                                           verbose=False)
        # _, test_data = data_utils.load_triples(
        #     test_path, entity_index_path, relation_index_path, seen_entities=seen_entities, verbose=False, few_shot=True, lf=lf)
        test_data = data_utils.load_triples(test_path,
                                            entity_index_path,
                                            relation_index_path,
                                            seen_entities=seen_entities,
                                            verbose=False)
        # print('Dev set performance:')
        # pred_scores = lf.forward(dev_data, verbose=False)
        # dev_metrics = src.eval.hits_and_ranks(dev_data, pred_scores, lf.kg.dev_objects, verbose=True)
        # eval_metrics['dev'] = {}
        # eval_metrics['dev']['hits_at_1'] = dev_metrics[0]
        # eval_metrics['dev']['hits_at_3'] = dev_metrics[1]
        # eval_metrics['dev']['hits_at_5'] = dev_metrics[2]
        # eval_metrics['dev']['hits_at_10'] = dev_metrics[3]
        # eval_metrics['dev']['mrr'] = dev_metrics[4]
        # src.eval.hits_and_ranks(dev_data, pred_scores, lf.kg.all_objects, verbose=True)
        print('Test set performance:')
        pred_scores = lf.forward(test_data, verbose=False)
        test_metrics = src.eval.hits_and_ranks(test_data,
                                               pred_scores,
                                               lf.kg.all_objects,
                                               verbose=True,
                                               output=False,
                                               kg=lf.kg,
                                               model_name=args.model,
                                               split_relation=False)
        eval_metrics['dev'] = {}
        eval_metrics['dev']['hits_at_1'] = test_metrics[0]
        eval_metrics['dev']['hits_at_3'] = test_metrics[1]
        eval_metrics['dev']['hits_at_5'] = test_metrics[2]
        eval_metrics['dev']['hits_at_10'] = test_metrics[3]
        eval_metrics['dev']['mrr'] = test_metrics[4]
        # num = 0; hits_10 = 0.0; hits_1 = 0.0; hits_3 = 0.0; hits_5 = 0.0
        # for x in test_data:
        #     print('Test set of relation {} performance:'.format(x))
        #     pred_scores = lf.forward(test_data[x], verbose=False)
        #     test_metrics = src.eval.hits_and_ranks(test_data[x], pred_scores, lf.kg.all_objects, verbose=True)
        #     eval_metrics['test']['hits_at_1'] = test_metrics[0]
        #     eval_metrics['test']['hits_at_3'] = test_metrics[1]
        #     eval_metrics['test']['hits_at_5'] = test_metrics[2]
        #     eval_metrics['test']['hits_at_10'] = test_metrics[3]
        #     eval_metrics['test']['mrr'] = test_metrics[4]
        #     num += len(test_data[x])
        #     hits_1 += float(test_metrics[0]) * len(test_data[x])
        #     hits_3 += float(test_metrics[1]) * len(test_data[x])
        #     hits_5 += float(test_metrics[2]) * len(test_data[x])
        #     hits_10 += float(test_metrics[3]) * len(test_data[x])
        # print('Hits@1 = {}'.format(hits_1 / num))
        # print('Hits@3 = {}'.format(hits_3 / num))
        # print('Hits@5 = {}'.format(hits_5 / num))
        # print('Hits@10 = {}'.format(hits_10 / num))

    return eval_metrics