def __init__(self, args, kg, pn, fn_kg, fn, fn_secondary_kg=None):
        super(RewardShapingPolicyGradient, self).__init__(args, kg, pn)
        # super(RewardShapingPolicyGradient, self).__init__(args, kg, pn)
        self.reward_shaping_threshold = args.reward_shaping_threshold

        # Fact network modules
        self.fn_kg = fn_kg
        self.fn = fn
        self.fn_secondary_kg = fn_secondary_kg
        self.mu = args.mu

        fn_model = self.fn_model
        if fn_model in ['conve']:
            fn_state_dict = torch.load(args.conve_state_dict_path)
            # fn_state_dict = torch.load(args.conve_state_dict_path, map_location=LOAD_FLAG)
            # print("CHANGE MAP_LOCATION:", LOAD_FLAG)
            fn_nn_state_dict = get_conve_nn_state_dict(fn_state_dict)
            fn_kg_state_dict = get_conve_kg_state_dict(fn_state_dict)
            self.fn.load_state_dict(fn_nn_state_dict)
        elif fn_model == 'distmult':
            fn_state_dict = torch.load(args.distmult_state_dict_path)
            fn_kg_state_dict = get_distmult_kg_state_dict(fn_state_dict)
        elif fn_model == 'complex':
            fn_state_dict = torch.load(args.complex_state_dict_path)
            fn_kg_state_dict = get_complex_kg_state_dict(fn_state_dict)
        elif fn_model == 'hypere':
            fn_state_dict = torch.load(args.conve_state_dict_path)
            fn_kg_state_dict = get_conve_kg_state_dict(fn_state_dict)
        else:
            raise NotImplementedError

        # try:
        #     from collections import OrderedDict
        #     new_state_dict = OrderedDict()
        #     for k, v in state_dict.items():
        #         name = 'module.' + k  # add `module.`
        #         new_state_dict[name] = v
        #     # load params
        #     # model.load_state_dict(new_state_dict)
        #     self.net.load_state_dict(new_state_dict)
        # except Exception as e:
        #     print(e)

        self.fn_kg.load_state_dict(fn_kg_state_dict)
        if fn_model == 'hypere':
            complex_state_dict = torch.load(args.complex_state_dict_path)
            complex_kg_state_dict = get_complex_kg_state_dict(complex_state_dict)
            self.fn_secondary_kg.load_state_dict(complex_kg_state_dict)

        self.fn.eval()
        self.fn_kg.eval()
        ops.detach_module(self.fn)
        ops.detach_module(self.fn_kg)
        if fn_model == 'hypere':
            self.fn_secondary_kg.eval()
            ops.detach_module(self.fn_secondary_kg)
Beispiel #2
0
    def __init__(self, args, kg, pn, fn_kg, fn, fn_secondary_kg=None):
        super(RewardShapingPolicyGradient, self).__init__(args, kg, pn)
        self.reward_shaping_threshold = args.reward_shaping_threshold

        # Fact network modules
        self.fn_kg = fn_kg
        self.fn = fn
        self.fn_secondary_kg = fn_secondary_kg
        self.mu = args.mu

        fn_model = self.fn_model
        if fn_model in ["conve"]:
            fn_state_dict = torch.load(
                args.conve_state_dict_path, map_location=("cuda:" + str(args.gpu))
            )
            fn_nn_state_dict = get_conve_nn_state_dict(fn_state_dict)
            fn_kg_state_dict = get_conve_kg_state_dict(fn_state_dict)
            self.fn.load_state_dict(fn_nn_state_dict)
        elif fn_model == "distmult":
            fn_state_dict = torch.load(
                args.distmult_state_dict_path, map_location=("cuda:" + str(args.gpu))
            )
            fn_kg_state_dict = get_distmult_kg_state_dict(fn_state_dict)
        elif fn_model == "complex":
            fn_state_dict = torch.load(
                args.complex_state_dict_path, map_location=("cuda:" + str(args.gpu))
            )
            fn_kg_state_dict = get_complex_kg_state_dict(fn_state_dict)
        elif fn_model == "hypere":
            fn_state_dict = torch.load(
                args.conve_state_dict_path, map_location=("cuda:" + str(args.gpu))
            )
            fn_kg_state_dict = get_conve_kg_state_dict(fn_state_dict)
        else:
            raise NotImplementedError
        # ================= newly added ===================
        # added to make original embeddings work
        fn_kg_state_dict["AGG_W"] = torch.nn.init.xavier_uniform_(torch.nn.Parameter(torch.zeros(400, 200)))
        # ================= newly added ===================
        self.fn_kg.load_state_dict(fn_kg_state_dict)
        if fn_model == "hypere":
            complex_state_dict = torch.load(args.complex_state_dict_path)
            complex_kg_state_dict = get_complex_kg_state_dict(complex_state_dict)
            self.fn_secondary_kg.load_state_dict(complex_kg_state_dict)

        self.fn.eval()
        self.fn_kg.eval()
        ops.detach_module(self.fn)
        ops.detach_module(self.fn_kg)
        if fn_model == "hypere":
            self.fn_secondary_kg.eval()
            ops.detach_module(self.fn_secondary_kg)
Beispiel #3
0
    def __init__(self, args, kg, pn, fn_kg, fn, fn_secondary_kg=None):
        super(RewardMinerGradient, self).__init__(args, kg, pn)
        self.reward_shaping_threshold = args.reward_shaping_threshold

        # Fact network modules
        self.fn_kg = fn_kg
        self.fn = fn
        self.fn_secondary_kg = fn_secondary_kg
        self.mu = args.mu

        print(self.fn_kg.relation_embeddings.weight.size(),
              self.fn_kg.entity_embeddings.weight.size())

        fn_model = self.fn_model
        if fn_model in ['conve']:
            fn_state_dict = torch.load(args.conve_state_dict_path)
            fn_nn_state_dict = get_conve_nn_state_dict(fn_state_dict)
            fn_kg_state_dict = get_conve_kg_state_dict(fn_state_dict)
            self.fn.load_state_dict(fn_nn_state_dict)
        elif fn_model == 'distmult':
            fn_state_dict = torch.load(args.distmult_state_dict_path)
            fn_kg_state_dict = get_distmult_kg_state_dict(fn_state_dict)
        elif fn_model == 'complex':
            fn_state_dict = torch.load(args.complex_state_dict_path)
            fn_kg_state_dict = get_complex_kg_state_dict(fn_state_dict)
        elif fn_model == 'hypere':
            fn_state_dict = torch.load(args.conve_state_dict_path)
            fn_kg_state_dict = get_conve_kg_state_dict(fn_state_dict)
        else:
            raise NotImplementedError
        self.fn_kg.load_state_dict(fn_kg_state_dict)
        if fn_model == 'hypere':
            complex_state_dict = torch.load(args.complex_state_dict_path)
            complex_kg_state_dict = get_complex_kg_state_dict(
                complex_state_dict)
            self.fn_secondary_kg.load_state_dict(complex_kg_state_dict)

        print(self.fn_kg.relation_embeddings.weight.size(),
              self.fn_kg.entity_embeddings.weight.size())
        self.calc_dummy_end_embedding()
        print(self.fn_kg.relation_embeddings.weight.size(),
              self.fn_kg.entity_embeddings.weight.size())

        self.fn.eval()
        self.fn_kg.eval()
        ops.detach_module(self.fn)
        ops.detach_module(self.fn_kg)
        if fn_model == 'hypere':
            self.fn_secondary_kg.eval()
            ops.detach_module(self.fn_secondary_kg)
Beispiel #4
0
    def __init__(self, args, kg, pn, fn_kg, fn, fn_secondary_kg=None):
        super(RewardShapingPolicyGradient, self).__init__(args, kg, pn)
        self.reward_shaping_threshold = args.reward_shaping_threshold

        # Fact network modules
        self.fn_kg = fn_kg
        self.fn = fn
        self.fn_secondary_kg = fn_secondary_kg
        self.mu = args.mu

        fn_model = self.fn_model
        if fn_model in ['conve']:
            fn_state_dict = torch.load(args.conve_state_dict_path,
                                       map_location='cuda:' + str(args.gpu))
            fn_nn_state_dict = get_conve_nn_state_dict(fn_state_dict)
            fn_kg_state_dict = get_conve_kg_state_dict(fn_state_dict)
            self.fn.load_state_dict(fn_nn_state_dict)
        elif fn_model == 'distmult':
            fn_state_dict = torch.load(args.distmult_state_dict_path,
                                       map_location='cuda:' + str(args.gpu))
            fn_kg_state_dict = get_distmult_kg_state_dict(fn_state_dict)
        elif fn_model == 'complex':
            fn_state_dict = torch.load(args.complex_state_dict_path,
                                       map_location='cuda:' + str(args.gpu))
            fn_kg_state_dict = get_complex_kg_state_dict(fn_state_dict)
        elif fn_model == 'hypere':
            fn_state_dict = torch.load(args.conve_state_dict_path,
                                       map_location='cuda:' + str(args.gpu))
            fn_kg_state_dict = get_conve_kg_state_dict(fn_state_dict)
        else:
            raise NotImplementedError
        self.fn_kg.load_state_dict(fn_kg_state_dict)
        if fn_model == 'hypere':
            complex_state_dict = torch.load(args.complex_state_dict_path,
                                            map_location='cuda:' +
                                            str(args.gpu))
            complex_kg_state_dict = get_complex_kg_state_dict(
                complex_state_dict)
            self.fn_secondary_kg.load_state_dict(complex_kg_state_dict)

        self.fn.eval()
        self.fn_kg.eval()
        ops.detach_module(self.fn)
        ops.detach_module(self.fn_kg)
        if fn_model == 'hypere':
            self.fn_secondary_kg.eval()
            ops.detach_module(self.fn_secondary_kg)
    def __init__(self, args, kg, agent, fn_kg, fn, fn_secondary_kg=None):
        super(RewardShapingPolicyGradient, self).__init__(args, kg, agent)
        self.reward_shaping_threshold = args.reward_shaping_threshold

        # Fact network modules
        self.fn_kg = fn_kg
        self.fn = fn
        self.fn_secondary_kg = fn_secondary_kg
        self.mu = args.mu

        fn_model = self.fn_model
        if fn_model in ["conve"]:
            fn_state_dict = torch.load(args.conve_state_dict_path)
            fn_nn_state_dict = get_conve_nn_state_dict(fn_state_dict)
            fn_kg_state_dict = get_conve_kg_state_dict(fn_state_dict)
            self.fn.load_state_dict(fn_nn_state_dict)
        elif fn_model == "distmult":
            fn_state_dict = torch.load(args.distmult_state_dict_path)
            fn_kg_state_dict = get_distmult_kg_state_dict(fn_state_dict)
        elif fn_model == "complex":
            fn_state_dict = torch.load(args.complex_state_dict_path)
            fn_kg_state_dict = get_complex_kg_state_dict(fn_state_dict)
        elif fn_model == "hypere":
            fn_state_dict = torch.load(args.conve_state_dict_path)
            fn_kg_state_dict = get_conve_kg_state_dict(fn_state_dict)
        else:
            raise NotImplementedError
        self.fn_kg.load_state_dict(fn_kg_state_dict)
        if fn_model == "hypere":
            complex_state_dict = torch.load(args.complex_state_dict_path)
            complex_kg_state_dict = get_complex_kg_state_dict(
                complex_state_dict)
            self.fn_secondary_kg.load_state_dict(complex_kg_state_dict)

        self.fn.eval()
        self.fn_kg.eval()
        ops.detach_module(self.fn)
        ops.detach_module(self.fn_kg)
        if fn_model == "hypere":
            self.fn_secondary_kg.eval()
            ops.detach_module(self.fn_secondary_kg)
Beispiel #6
0
def inference(lf):
    lf.batch_size = args.dev_batch_size
    lf.eval()
    if args.model == 'hypere':
        conve_kg_state_dict = get_conve_kg_state_dict(torch.load(args.conve_state_dict_path))
        lf.kg.load_state_dict(conve_kg_state_dict)
        secondary_kg_state_dict = get_complex_kg_state_dict(torch.load(args.complex_state_dict_path))
        lf.secondary_kg.load_state_dict(secondary_kg_state_dict)
    elif args.model == 'triplee':
        conve_kg_state_dict = get_conve_kg_state_dict(torch.load(args.conve_state_dict_path))
        lf.kg.load_state_dict(conve_kg_state_dict)
        complex_kg_state_dict = get_complex_kg_state_dict(torch.load(args.complex_state_dict_path))
        lf.secondary_kg.load_state_dict(complex_kg_state_dict)
        distmult_kg_state_dict = get_distmult_kg_state_dict(torch.load(args.distmult_state_dict_path))
        lf.tertiary_kg.load_state_dict(distmult_kg_state_dict)
    else:
        lf.load_checkpoint(get_checkpoint_path(args))
    entity_index_path = os.path.join(args.data_dir, 'entity2id.txt')
    relation_index_path = os.path.join(args.data_dir, 'relation2id.txt')
    if 'NELL' in args.data_dir:
        adj_list_path = os.path.join(args.data_dir, 'adj_list.pkl')
        seen_entities = data_utils.load_seen_entities(adj_list_path, entity_index_path)
    else:
        seen_entities = set()

    eval_metrics = {
        'dev': {},
        'test': {}
    }

    if args.compute_map:
        relation_sets = [
            'concept:athletehomestadium',
            'concept:athleteplaysforteam',
            'concept:athleteplaysinleague',
            'concept:athleteplayssport',
            'concept:organizationheadquarteredincity',
            'concept:organizationhiredperson',
            'concept:personborninlocation',
            'concept:teamplayssport',
            'concept:worksfor'
        ]
        mps = []
        for r in relation_sets:
            print('* relation: {}'.format(r))
            test_path = os.path.join(args.data_dir, 'tasks', r, 'test.pairs')
            test_data, labels = data_utils.load_triples_with_label(
                test_path, r, entity_index_path, relation_index_path, seen_entities=seen_entities)
            pred_scores = lf.forward(test_data, verbose=False)
            mp = src.eval.link_MAP(test_data, pred_scores, labels, lf.kg.all_objects, verbose=True)
            mps.append(mp)
        map_ = np.mean(mps)
        print('Overall MAP = {}'.format(map_))
        eval_metrics['test']['avg_map'] = map
    elif args.eval_by_relation_type:
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        dev_data = data_utils.load_triples(dev_path, entity_index_path, relation_index_path, seen_entities=seen_entities)
        pred_scores = lf.forward(dev_data, verbose=False)
        to_m_rels, to_1_rels, _ = data_utils.get_relations_by_type(args.data_dir, relation_index_path)
        relation_by_types = (to_m_rels, to_1_rels)
        print('Dev set evaluation by relation type (partial graph)')
        src.eval.hits_and_ranks_by_relation_type(
            dev_data, pred_scores, lf.kg.dev_objects, relation_by_types, verbose=True)
        print('Dev set evaluation by relation type (full graph)')
        src.eval.hits_and_ranks_by_relation_type(
            dev_data, pred_scores, lf.kg.all_objects, relation_by_types, verbose=True)
    elif args.eval_by_seen_queries:
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        dev_data = data_utils.load_triples(dev_path, entity_index_path, relation_index_path, seen_entities=seen_entities)
        pred_scores = lf.forward(dev_data, verbose=False)
        seen_queries = data_utils.get_seen_queries(args.data_dir, entity_index_path, relation_index_path)
        print('Dev set evaluation by seen queries (partial graph)')
        src.eval.hits_and_ranks_by_seen_queries(
            dev_data, pred_scores, lf.kg.dev_objects, seen_queries, verbose=True)
        print('Dev set evaluation by seen queries (full graph)')
        src.eval.hits_and_ranks_by_seen_queries(
            dev_data, pred_scores, lf.kg.all_objects, seen_queries, verbose=True)
    else:
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        test_path = os.path.join(args.data_dir, 'test.triples')
        dev_data = data_utils.load_triples(
            dev_path, entity_index_path, relation_index_path, seen_entities=seen_entities, verbose=False)
        test_data = data_utils.load_triples(
            test_path, entity_index_path, relation_index_path, seen_entities=seen_entities, verbose=False)
        print('Dev set performance:')
        pred_scores = lf.forward(dev_data, verbose=args.save_beam_search_paths)
        dev_metrics = src.eval.hits_and_ranks(dev_data, pred_scores, lf.kg.dev_objects, verbose=True)
        eval_metrics['dev'] = {}
        eval_metrics['dev']['hits_at_1'] = dev_metrics[0]
        eval_metrics['dev']['hits_at_3'] = dev_metrics[1]
        eval_metrics['dev']['hits_at_5'] = dev_metrics[2]
        eval_metrics['dev']['hits_at_10'] = dev_metrics[3]
        eval_metrics['dev']['mrr'] = dev_metrics[4]
        src.eval.hits_and_ranks(dev_data, pred_scores, lf.kg.all_objects, verbose=True)
        print('Test set performance:')
        pred_scores = lf.forward(test_data, verbose=False)
        test_metrics = src.eval.hits_and_ranks(test_data, pred_scores, lf.kg.all_objects, verbose=True)
        eval_metrics['test']['hits_at_1'] = test_metrics[0]
        eval_metrics['test']['hits_at_3'] = test_metrics[1]
        eval_metrics['test']['hits_at_5'] = test_metrics[2]
        eval_metrics['test']['hits_at_10'] = test_metrics[3]
        eval_metrics['test']['mrr'] = test_metrics[4]

    return eval_metrics
def inference(lf):
    lf.batch_size = args.dev_batch_size
    lf.eval()
    if args.model == "hypere":
        conve_kg_state_dict = get_conve_kg_state_dict(
            torch.load(args.conve_state_dict_path))
        lf.kg.load_state_dict(conve_kg_state_dict)
        secondary_kg_state_dict = get_complex_kg_state_dict(
            torch.load(args.complex_state_dict_path))
        lf.secondary_kg.load_state_dict(secondary_kg_state_dict)
    elif args.model == "triplee":
        conve_kg_state_dict = get_conve_kg_state_dict(
            torch.load(args.conve_state_dict_path))
        lf.kg.load_state_dict(conve_kg_state_dict)
        complex_kg_state_dict = get_complex_kg_state_dict(
            torch.load(args.complex_state_dict_path))
        lf.secondary_kg.load_state_dict(complex_kg_state_dict)
        distmult_kg_state_dict = get_distmult_kg_state_dict(
            torch.load(args.distmult_state_dict_path))
        lf.tertiary_kg.load_state_dict(distmult_kg_state_dict)
    else:
        lf.load_checkpoint(get_checkpoint_path(args))
    print(lf.kg.entity_embeddings)
    print(lf.kg.relation_embeddings)
    entity_index_path = os.path.join(args.data_dir, "entity2id.txt")
    relation_index_path = os.path.join(args.data_dir, "relation2id.txt")
    if "NELL" in args.data_dir:
        adj_list_path = os.path.join(args.data_dir, "adj_list.pkl")
        seen_entities = data_utils.load_seen_entities(adj_list_path,
                                                      entity_index_path)
    else:
        seen_entities = set()

    eval_metrics = {"dev": {}, "test": {}}

    if args.compute_map:
        relation_sets = [
            "concept:athletehomestadium",
            "concept:athleteplaysforteam",
            "concept:athleteplaysinleague",
            "concept:athleteplayssport",
            "concept:organizationheadquarteredincity",
            "concept:organizationhiredperson",
            "concept:personborninlocation",
            "concept:teamplayssport",
            "concept:worksfor",
        ]
        mps = []
        for r in relation_sets:
            print("* relation: {}".format(r))
            test_path = os.path.join(args.data_dir, "tasks", r, "test.pairs")
            test_data, labels = data_utils.load_triples_with_label(
                test_path,
                r,
                entity_index_path,
                relation_index_path,
                seen_entities=seen_entities,
            )
            pred_scores = lf.forward(test_data, verbose=False)
            mp = src.eval.link_MAP(test_data,
                                   pred_scores,
                                   labels,
                                   lf.kg.all_objects,
                                   verbose=True)
            mps.append(mp)
        import numpy as np

        map_ = np.mean(mps)
        print("Overall MAP = {}".format(map_))
        eval_metrics["test"]["avg_map"] = map
    elif args.eval_by_relation_type:
        dev_path = os.path.join(args.data_dir, "dev.triples")
        dev_data = data_utils.load_triples(
            dev_path,
            entity_index_path,
            relation_index_path,
            seen_entities=seen_entities,
        )
        pred_scores = lf.forward(dev_data, verbose=False)
        to_m_rels, to_1_rels, _ = data_utils.get_relations_by_type(
            args.data_dir, relation_index_path)
        relation_by_types = (to_m_rels, to_1_rels)
        print("Dev set evaluation by relation type (partial graph)")
        src.eval.hits_and_ranks_by_relation_type(dev_data,
                                                 pred_scores,
                                                 lf.kg.dev_objects,
                                                 relation_by_types,
                                                 verbose=True)
        print("Dev set evaluation by relation type (full graph)")
        src.eval.hits_and_ranks_by_relation_type(dev_data,
                                                 pred_scores,
                                                 lf.kg.all_objects,
                                                 relation_by_types,
                                                 verbose=True)
    elif args.eval_by_seen_queries:
        dev_path = os.path.join(args.data_dir, "dev.triples")
        dev_data = data_utils.load_triples(
            dev_path,
            entity_index_path,
            relation_index_path,
            seen_entities=seen_entities,
        )
        pred_scores = lf.forward(dev_data, verbose=False)
        seen_queries = data_utils.get_seen_queries(args.data_dir,
                                                   entity_index_path,
                                                   relation_index_path)
        print("Dev set evaluation by seen queries (partial graph)")
        src.eval.hits_and_ranks_by_seen_queries(dev_data,
                                                pred_scores,
                                                lf.kg.dev_objects,
                                                seen_queries,
                                                verbose=True)
        print("Dev set evaluation by seen queries (full graph)")
        src.eval.hits_and_ranks_by_seen_queries(dev_data,
                                                pred_scores,
                                                lf.kg.all_objects,
                                                seen_queries,
                                                verbose=True)
    else:
        if args.few_shot:
            dev_path = os.path.join(args.data_dir, "dev.triples")
            test_path = os.path.join(args.data_dir, "test.triples")
            _, dev_data = data_utils.load_triples(
                dev_path,
                entity_index_path,
                relation_index_path,
                seen_entities=seen_entities,
                verbose=False,
                few_shot=True,
                lf=lf,
            )
            _, test_data = data_utils.load_triples(
                test_path,
                entity_index_path,
                relation_index_path,
                seen_entities=seen_entities,
                verbose=False,
                few_shot=True,
                lf=lf,
            )
            num = 0
            hits_10 = 0.0
            hits_1 = 0.0
            hits_3 = 0.0
            hits_5 = 0.0
            mrr = 0.0
            for x in test_data:
                lf.load_checkpoint(
                    args.checkpoint_path.replace("[relation]", str(x)))
                print("Test set of relation {} performance:".format(x))
                pred_scores = lf.forward(test_data[x], verbose=False)
                test_metrics = src.eval.hits_and_ranks(test_data[x],
                                                       pred_scores,
                                                       lf.kg.all_objects,
                                                       verbose=True)
                eval_metrics["test"]["hits_at_1"] = test_metrics[0]
                eval_metrics["test"]["hits_at_3"] = test_metrics[1]
                eval_metrics["test"]["hits_at_5"] = test_metrics[2]
                eval_metrics["test"]["hits_at_10"] = test_metrics[3]
                eval_metrics["test"]["mrr"] = test_metrics[4]
                num += len(test_data[x])
                hits_1 += float(test_metrics[0]) * len(test_data[x])
                hits_3 += float(test_metrics[1]) * len(test_data[x])
                hits_5 += float(test_metrics[2]) * len(test_data[x])
                hits_10 += float(test_metrics[3]) * len(test_data[x])
                mrr += float(test_metrics[4]) * len(test_data[x])
            print("Hits@1 = {}".format(hits_1 / num))
            print("Hits@3 = {}".format(hits_3 / num))
            print("Hits@5 = {}".format(hits_5 / num))
            print("Hits@10 = {}".format(hits_10 / num))
            print("MRR = {}".format(mrr / num))
        else:
            test_path = os.path.join(args.data_dir, "test.triples")
            test_data = data_utils.load_triples(
                test_path,
                entity_index_path,
                relation_index_path,
                seen_entities=seen_entities,
                verbose=False,
            )
            print("Test set performance:")
            pred_scores = lf.forward(test_data, verbose=False)
            test_metrics = src.eval.hits_and_ranks(
                test_data,
                pred_scores,
                lf.kg.all_objects,
                verbose=True,
                output=True,
                kg=lf.kg,
                model_name=args.model,
            )
            eval_metrics["test"]["hits_at_1"] = test_metrics[0]
            eval_metrics["test"]["hits_at_3"] = test_metrics[1]
            eval_metrics["test"]["hits_at_5"] = test_metrics[2]
            eval_metrics["test"]["hits_at_10"] = test_metrics[3]
            eval_metrics["test"]["mrr"] = test_metrics[4]

    return eval_metrics
Beispiel #8
0
def inference(lf):
    lf.batch_size = args.dev_batch_size
    lf.eval()
    if args.model == 'hypere':
        conve_kg_state_dict = get_conve_kg_state_dict(
            torch.load(args.conve_state_dict_path))
        lf.kg.load_state_dict(conve_kg_state_dict)
        secondary_kg_state_dict = get_complex_kg_state_dict(
            torch.load(args.complex_state_dict_path))
        lf.secondary_kg.load_state_dict(secondary_kg_state_dict)
    elif args.model == 'triplee':
        conve_kg_state_dict = get_conve_kg_state_dict(
            torch.load(args.conve_state_dict_path))
        lf.kg.load_state_dict(conve_kg_state_dict)
        complex_kg_state_dict = get_complex_kg_state_dict(
            torch.load(args.complex_state_dict_path))
        lf.secondary_kg.load_state_dict(complex_kg_state_dict)
        distmult_kg_state_dict = get_distmult_kg_state_dict(
            torch.load(args.distmult_state_dict_path))
        lf.tertiary_kg.load_state_dict(distmult_kg_state_dict)
    else:
        lf.load_checkpoint(get_checkpoint_path(args))
    print(lf.kg.entity_embeddings)
    print(lf.kg.relation_embeddings)
    entity_index_path = os.path.join(args.data_dir, 'entity2id.txt')
    relation_index_path = os.path.join(args.data_dir, 'relation2id.txt')
    if 'NELL' in args.data_dir:
        adj_list_path = os.path.join(args.data_dir, 'adj_list.pkl')
        seen_entities = data_utils.load_seen_entities(adj_list_path,
                                                      entity_index_path)
    else:
        seen_entities = set()

    eval_metrics = {'dev': {}, 'test': {}}

    if args.compute_map:
        relation_sets = [
            'concept:athletehomestadium', 'concept:athleteplaysforteam',
            'concept:athleteplaysinleague', 'concept:athleteplayssport',
            'concept:organizationheadquarteredincity',
            'concept:organizationhiredperson', 'concept:personborninlocation',
            'concept:teamplayssport', 'concept:worksfor'
        ]
        mps = []
        for r in relation_sets:
            print('* relation: {}'.format(r))
            test_path = os.path.join(args.data_dir, 'tasks', r, 'test.pairs')
            test_data, labels = data_utils.load_triples_with_label(
                test_path,
                r,
                entity_index_path,
                relation_index_path,
                seen_entities=seen_entities)
            pred_scores = lf.forward(test_data, verbose=False)
            mp = src.eval.link_MAP(test_data,
                                   pred_scores,
                                   labels,
                                   lf.kg.all_objects,
                                   verbose=True)
            mps.append(mp)
        import numpy as np
        map_ = np.mean(mps)
        print('Overall MAP = {}'.format(map_))
        eval_metrics['test']['avg_map'] = map
    elif args.eval_by_relation_type:
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        dev_data = data_utils.load_triples(dev_path,
                                           entity_index_path,
                                           relation_index_path,
                                           seen_entities=seen_entities)
        pred_scores = lf.forward(dev_data, verbose=False)
        to_m_rels, to_1_rels, _ = data_utils.get_relations_by_type(
            args.data_dir, relation_index_path)
        relation_by_types = (to_m_rels, to_1_rels)
        print('Dev set evaluation by relation type (partial graph)')
        src.eval.hits_and_ranks_by_relation_type(dev_data,
                                                 pred_scores,
                                                 lf.kg.dev_objects,
                                                 relation_by_types,
                                                 verbose=True)
        print('Dev set evaluation by relation type (full graph)')
        src.eval.hits_and_ranks_by_relation_type(dev_data,
                                                 pred_scores,
                                                 lf.kg.all_objects,
                                                 relation_by_types,
                                                 verbose=True)
    elif args.eval_by_seen_queries:
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        dev_data = data_utils.load_triples(dev_path,
                                           entity_index_path,
                                           relation_index_path,
                                           seen_entities=seen_entities)
        pred_scores = lf.forward(dev_data, verbose=False)
        seen_queries = data_utils.get_seen_queries(args.data_dir,
                                                   entity_index_path,
                                                   relation_index_path)
        print('Dev set evaluation by seen queries (partial graph)')
        src.eval.hits_and_ranks_by_seen_queries(dev_data,
                                                pred_scores,
                                                lf.kg.dev_objects,
                                                seen_queries,
                                                verbose=True)
        print('Dev set evaluation by seen queries (full graph)')
        src.eval.hits_and_ranks_by_seen_queries(dev_data,
                                                pred_scores,
                                                lf.kg.all_objects,
                                                seen_queries,
                                                verbose=True)
    else:
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        test_path = os.path.join(args.data_dir, 'test.triples')
        dev_data = data_utils.load_triples(dev_path,
                                           entity_index_path,
                                           relation_index_path,
                                           seen_entities=seen_entities,
                                           verbose=False)
        # _, test_data = data_utils.load_triples(
        #     test_path, entity_index_path, relation_index_path, seen_entities=seen_entities, verbose=False, few_shot=True, lf=lf)
        test_data = data_utils.load_triples(test_path,
                                            entity_index_path,
                                            relation_index_path,
                                            seen_entities=seen_entities,
                                            verbose=False)
        # print('Dev set performance:')
        # pred_scores = lf.forward(dev_data, verbose=False)
        # dev_metrics = src.eval.hits_and_ranks(dev_data, pred_scores, lf.kg.dev_objects, verbose=True)
        # eval_metrics['dev'] = {}
        # eval_metrics['dev']['hits_at_1'] = dev_metrics[0]
        # eval_metrics['dev']['hits_at_3'] = dev_metrics[1]
        # eval_metrics['dev']['hits_at_5'] = dev_metrics[2]
        # eval_metrics['dev']['hits_at_10'] = dev_metrics[3]
        # eval_metrics['dev']['mrr'] = dev_metrics[4]
        # src.eval.hits_and_ranks(dev_data, pred_scores, lf.kg.all_objects, verbose=True)
        print('Test set performance:')
        pred_scores = lf.forward(test_data, verbose=False)
        test_metrics = src.eval.hits_and_ranks(test_data,
                                               pred_scores,
                                               lf.kg.all_objects,
                                               verbose=True,
                                               output=False,
                                               kg=lf.kg,
                                               model_name=args.model,
                                               split_relation=False)
        eval_metrics['dev'] = {}
        eval_metrics['dev']['hits_at_1'] = test_metrics[0]
        eval_metrics['dev']['hits_at_3'] = test_metrics[1]
        eval_metrics['dev']['hits_at_5'] = test_metrics[2]
        eval_metrics['dev']['hits_at_10'] = test_metrics[3]
        eval_metrics['dev']['mrr'] = test_metrics[4]
        # num = 0; hits_10 = 0.0; hits_1 = 0.0; hits_3 = 0.0; hits_5 = 0.0
        # for x in test_data:
        #     print('Test set of relation {} performance:'.format(x))
        #     pred_scores = lf.forward(test_data[x], verbose=False)
        #     test_metrics = src.eval.hits_and_ranks(test_data[x], pred_scores, lf.kg.all_objects, verbose=True)
        #     eval_metrics['test']['hits_at_1'] = test_metrics[0]
        #     eval_metrics['test']['hits_at_3'] = test_metrics[1]
        #     eval_metrics['test']['hits_at_5'] = test_metrics[2]
        #     eval_metrics['test']['hits_at_10'] = test_metrics[3]
        #     eval_metrics['test']['mrr'] = test_metrics[4]
        #     num += len(test_data[x])
        #     hits_1 += float(test_metrics[0]) * len(test_data[x])
        #     hits_3 += float(test_metrics[1]) * len(test_data[x])
        #     hits_5 += float(test_metrics[2]) * len(test_data[x])
        #     hits_10 += float(test_metrics[3]) * len(test_data[x])
        # print('Hits@1 = {}'.format(hits_1 / num))
        # print('Hits@3 = {}'.format(hits_3 / num))
        # print('Hits@5 = {}'.format(hits_5 / num))
        # print('Hits@10 = {}'.format(hits_10 / num))

    return eval_metrics