예제 #1
0
파일: EmbedTwoCons.py 프로젝트: xzlyu/RQE
 def __init__(self):
     self.graph = Graph()
     self.graph.e2idx_file = "../../MyData/DBO/All/entity2id.txt"
     self.graph.r2idx_file = "../../MyData/DBO/All/relation2id.txt"
     self.graph.triple2idx_file = "../../MyData/DBO/All/triple2id.txt"
     self.graph.load_data()
     self.graph.load_er_embedding()
예제 #2
0
파일: EmbedTwoCons.py 프로젝트: xzlyu/RQE
class GetMetricEmbedding:
    def __init__(self):
        self.graph = Graph()
        self.graph.e2idx_file = "../../MyData/DBO/All/entity2id.txt"
        self.graph.r2idx_file = "../../MyData/DBO/All/relation2id.txt"
        self.graph.triple2idx_file = "../../MyData/DBO/All/triple2id.txt"
        self.graph.load_data()
        self.graph.load_er_embedding()

    def split_cons(self, cons, direction):
        if direction == "left":
            first_r_name, first_e_name = cons.split()
        else:
            first_e_name, first_r_name = cons.split()
        return first_e_name, first_r_name

    def get_metric(self, r_name1, direction1, r_name2, direction2):
        example_file_path = "./TwoCons_eval/{}_{}/2consQandA.txt".format(r_name1.replace("dbo:", ""),
                                                                           r_name2.replace("dbo:", ""))
        metric_file = "./TwoCons_eval/{}_{}/metric.txt".format(r_name1.replace("dbo:", ""),
                                                                 r_name2.replace("dbo:", ""))
        first_e_idx_list = []
        first_r_idx_list = []
        second_e_idx_list = []
        second_r_idx_list = []
        answer_idx_list = []
        with open(example_file_path, 'r', encoding="UTF-8") as f:
            all_lines = f.readlines()
            cnt = 0
            while True:
                if cnt >= len(all_lines) or all_lines[cnt].strip() == "":
                    break
                first_e_name, first_r_name = self.split_cons(all_lines[cnt], direction1)
                first_e_idx_list.append(self.graph.e_name2id[first_e_name])
                first_r_idx_list.append(self.graph.r_name2id[first_r_name])
                cnt += 1
                second_e_name, second_r_name = self.split_cons(all_lines[cnt], direction2)
                second_e_idx_list.append(self.graph.e_name2id[second_e_name])
                second_r_idx_list.append(self.graph.r_name2id[second_r_name])
                cnt += 1
                answer_name = all_lines[cnt].split()
                answer_idx_list.append([self.graph.e_name2id[name] for name in answer_name])
                cnt += 1
        start_time = time.time()
        map_r, mrr_r, hits10_r, mean_rank = self.graph.ekg.get_map_mrr_hits10_meanRank_embed_2cons(
            first_e_idx_list, first_r_idx_list, direction1,
            second_e_idx_list, second_r_idx_list, direction2,
            answer_idx_list)
        end_time = time.time()
        with open(metric_file, 'w', encoding="UTF-8") as f:
            f.write("map:{}\n".format(map_r))
            f.write("mrr:{}\n".format(mrr_r))
            f.write("hits10:{}\n".format(hits10_r))
            f.write("mean_rank:{}\n".format(mean_rank))
            f.write("average_time:{}\n".format((end_time - start_time) / len(first_e_idx_list)))
예제 #3
0
파일: RuleSizeSurvey.py 프로젝트: xzlyu/RQE
def get_graph(e2idx_file, r2idx_file, triple2idx_file):
    graph = Graph()
    graph.e2idx_file = e2idx_file
    graph.r2idx_file = r2idx_file
    graph.triple2idx_file = triple2idx_file
    graph.load_data()
    return graph
예제 #4
0
    def _build_container(self):

        self.train_graph = Graph()
        self.search_graph = Graph()

        '''
        Dict for r_name and its list of rule object.
        It is used to search cands.
        {
            r_name:[Rule(),Rule(),...],
            r_name:[Rule(),Rule(),...],
            ...
        }
        '''
        self.r_name2rule_set = {}

        # '''
        # Dict for r_idx and its list of rule object.
        # It is used to feed test_model.
        # {
        #     r_idx:[Rule(),Rule(),...],
        #     r_idx:[Rule(),Rule(),...],
        #     ...
        # }
        # '''
        # self.r_rules_dict_4_feed_model = {}

        '''
        Dict for r_name and its trained test_model.
        {
            r_name: LogisticRegression Model
            r_name: LogisticRegression Model
            ...
        }
        '''
        self.r_name2model = {}
예제 #5
0
class QE:
    def __init__(self, args):
        self.logger = ALogger("Graph", True).getLogger()
        self.util = Util()

        self.sp = None
        self.args = args

        self._build_container()

    def _build_container(self):

        self.train_graph = Graph()
        self.search_graph = Graph()

        '''
        Dict for r_name and its list of rule object.
        It is used to search cands.
        {
            r_name:[Rule(),Rule(),...],
            r_name:[Rule(),Rule(),...],
            ...
        }
        '''
        self.r_name2rule_set = {}

        # '''
        # Dict for r_idx and its list of rule object.
        # It is used to feed test_model.
        # {
        #     r_idx:[Rule(),Rule(),...],
        #     r_idx:[Rule(),Rule(),...],
        #     ...
        # }
        # '''
        # self.r_rules_dict_4_feed_model = {}

        '''
        Dict for r_name and its trained test_model.
        {
            r_name: LogisticRegression Model
            r_name: LogisticRegression Model
            ...
        }
        '''
        self.r_name2model = {}

    def _setSparql(self, sparql):
        self.sparql = sparql
        self.sp = SparqlParser(sparql=sparql)
        self.sp.parse_sparql()
        self.r_name_list = self.sp.r_name_list

    def train_rules(self):
        for idx, r_name in enumerate(self.sp.r_name_list):
            self.logger.info("Train\t{}".format(r_name))
            train_args = self.get_train_args(r_name)
            self.train_graph._build_train_file_path(train_args)
            self.train_graph.load_data()
            r_idx = self.train_graph.r_name2id[r_name]
            model = self.train_graph.get_pra_model4r(r_idx=r_idx)
            self.r_name2model[r_name] = model

    # def test_rules(self):
    #     print("Start Testing Rules.")
    #     for idx, r_idx in enumerate(self.r_idx_list):
    #         metric_record_folder = self.args.test_root + "test_model" + file_path_seg \
    #                                + self.test_graph.get_localname(self.idx2r[r_idx]) + file_path_seg
    #         if os.path.isdir(metric_record_folder):
    #             continue
    #
    #         os.makedirs(metric_record_folder)
    #         metric_record_file = metric_record_folder + "pra_metric.txt"
    #         self.test_graph.load_data()
    #
    #         with open(metric_record_file, 'w', encoding="UTF-8") as f:
    #             f.write("Use Model trained from {}.\n".format(self.args.train_scope))
    #         print("Testing {}".format(self.idx2r[r_idx]))
    #         model = self.r_model_dict[r_idx]
    #         rule_list = self.test_graph.get_rule4train_from_mysql(r_idx)
    #         self.test_graph.test_model(r_idx, model, rule_list, metric_record_folder, metric_record_file)

    def get_rule_set_model(self, graph):
        for r_name in self.r_name_list:
            graph.rule_file = "../MyData/DBO/United_States/model/{}/rule.txt".format(r_name.split(":")[-1])
            graph.rule_num_to_use_file = "../MyData/DBO/United_States/model/{}/rule_num_to_use.txt".format(
                r_name.split(":")[-1])
            graph.model_file = "../MyData/DBO/United_States/model/{}/{}_model.tar".format(r_name.split(":")[-1],
                                                                                          pca_or_cwa)
            self.logger.info("Collect Rules for R: {}".format(r_name))
            self.r_name2rule_set[r_name] = graph.load_rule_obj_from_file(r_name)[:rules_num_to_search_cands]

            r_idx = graph.r_id_by_name(r_name)
            model = graph.get_pra_model4r(r_idx=r_idx)
            self.r_name2model[r_name] = model

    def get_candidates(self, query_name):
        self.logger.info("Get candidates.")
        search_args = self.get_search_args(query_name)
        self.search_graph._build_search_file_path(search_args)

        self.search_graph.load_data()

        self.get_rule_set_model(self.search_graph)

        start_time = time.time()

        self.logger.info("Get candidates and execute 1 var BGP.")
        self.sp.execute_var1BGP(self.r_name2rule_set, self.search_graph)

        self.logger.info("Execute 2 var BGP.")
        self.sp.execute_var2BGP(self.r_name2rule_set, self.search_graph)

        print("Start normalize searched res.")
        self.sp.normalize_searched_res()

        # print("Display result.")
        # self.sp.display_searched_res(graph)
        if len(self.sp.searched_res) > 1500:
            self.sp.searched_res = random.sample(self.sp.searched_res, 1500)

        self.logger.info("Calculate confidence for candidates.")
        self.sp.gen_pra_conf_and_rule_path(self.r_name2rule_set, self.r_name2model, self.search_graph)

        self.logger.info("Sorting and displaying.")
        self.sp.sort_cand_obj_list("pra")
        self.sp.display_cands(self.search_graph)

        end_time = time.time()
        self.logger.info("Finishing generating and displaying candidates. Epalsed: {}.".
                         format(end_time - start_time))

    def get_search_args(self, query_name):
        parser = argparse.ArgumentParser()
        search_folder = "../MyData/DBO/All/"
        # search_folder = "../MyData/DBO/United_States/"

        parser.add_argument('--e2id_file', type=str, default=search_folder + "entity2id.txt",
                            help='entity2id file')
        parser.add_argument('--r2id_file', type=str, default=search_folder + "relation2id.txt",
                            help='relation2id file')
        parser.add_argument('--triple2id_file', type=str, default=search_folder + "triple2id.txt",
                            help='triple2id file')

        parser.add_argument('--qe_res_all', type=str,
                            default="../MyData/DBO/United_States/EmptyQ/{}_{}_qe_res_all.txt".format(query_name,
                                                                                                        pca_or_cwa),
                            help='{} qe resutls with {}'.format(query_name, pca_or_cwa))
        parser.add_argument('--qe_res_topk', type=str,
                            default="../MyData/DBO/United_States/EmptyQ/{}_{}_qe_res_topk.txt".format(query_name,
                                                                                                         pca_or_cwa),
                            help='{} qe resutls with {}'.format(query_name, pca_or_cwa))

        args = parser.parse_args()
        return args

    def get_train_args(self, r_name):

        util = Util()
        scope = "United_States"
        root_folder = "../MyData/{}/{}/".format("DBO", scope)
        model_folder = root_folder + "model/{}/".format(r_name.split(":")[-1])

        util.createFolder(root_folder)
        util.createFolder(model_folder)

        parser = argparse.ArgumentParser()

        parser.add_argument('--root_folder', type=str, default=root_folder,
                            help='root folder file')

        parser.add_argument('--e2id_file', type=str, default=root_folder + "entity2id.txt",
                            help='entity2id file')
        parser.add_argument('--r2id_file', type=str, default=root_folder + "relation2id.txt",
                            help='relation2id file')
        parser.add_argument('--triple2id_file', type=str, default=root_folder + "triple2id.txt",
                            help='triple2id file')

        parser.add_argument('--model_folder', type=str, default=model_folder,
                            help='model folder for {}'.format(r_name))

        parser.add_argument('--rule_file', type=str, default="{}rule.txt".format(model_folder),
                            help='rule file for {}'.format(r_name))
        parser.add_argument('--rule_num_to_use_file', type=str, default="{}rule_num_to_use.txt".format(model_folder),
                            help='rule num to use for {}'.format(r_name))

        parser.add_argument('--train_id_data_file', type=str, default="{}train_id_data.npy".format(model_folder),
                            help='train id data for {}'.format(r_name))

        parser.add_argument('--train_feature_data_file', type=str,
                            default="{}{}_train_feature_data.npy".format(model_folder, pca_or_cwa),
                            help='train feature data for {}'.format(r_name))

        parser.add_argument('--model_file', type=str, default="{}{}_model.tar".format(model_folder, pca_or_cwa),
                            help='lg model for {}'.format(r_name))

        args = parser.parse_args()
        return args
예제 #6
0
import random

from Empty_Answer_Query import eaqs
from RuleBased.Graph import Graph
from RuleBased.Params import pca_or_cwa
from RuleBased.SparqlParser import SparqlParser
from Util import Util

util = Util()

search_folder = "../../MyData/DBO/All/"

graph = Graph()
graph.e2idx_file = search_folder + "entity2id.txt"
graph.r2idx_file = search_folder + "relation2id.txt"
graph.triple2idx_file = search_folder + "triple2id.txt"

graph.load_data()

r_name_list = []
r_idx_list = []

for sparql in eaqs:
    sp = SparqlParser(sparql=sparql)
    sp.parse_sparql()
    for relation in sp.r_name_list:
        r_name_list.append(relation)
    r_name_list = list(set(r_name_list))
r_idx_list = [graph.r_id_by_name(r_name) for r_name in r_name_list]

for r_name in r_name_list:
예제 #7
0
파일: RDQETwoCons.py 프로젝트: xzlyu/RQE
        print(
            "mrr:{}\thits10:{}\tmap:{}\tmean rank:{}\taverage time: {}".format(
                mrr_r, hits10_r, map_r, mean_rank_r, average_time))

        with open(metric_file, 'w', encoding="UTF-8") as f:
            f.write("hits10\t{}\n".format(hits10_r))
            f.write("map\t{}\n".format(map_r))
            f.write("mrr\t{}\n".format(mrr_r))
            f.write("average time\t{}\n".format(average_time))


if __name__ == "__main__":

    search_folder = "../../MyData/DBO/All/"
    graph = Graph()

    graph.e2idx_file = search_folder + "entity2id.txt"
    graph.r2idx_file = search_folder + "relation2id.txt"
    graph.triple2idx_file = search_folder + "triple2id.txt"

    graph.load_data()

    two_conts_eval = TwoConts()
    two_conts_eval._setGraph(graph)

    # r_name_list_right_left = [["dbo:starring", "dbo:birthPlace", ]]
    #
    # for r_name in r_name_list_right_left:
    #     print("Right Left")
    #     print("R_name:{}\tR_name:{}".format(r_name[0], r_name[1]))
예제 #8
0
파일: EmbedOneCons.py 프로젝트: xzlyu/RQE
    with open(file_path, 'r', encoding="UTF-8") as f:
        for idx, line in enumerate(f.readlines()):
            if idx == 0:
                continue
            e_idx_list = [
                graph.get_e_idx_by_e_name(e_name)
                for e_name in line.strip().split("\t")
            ]
            h_idx_list.append(e_idx_list[0])
            t_idx_list.append(e_idx_list[1:])
    return h_idx_list, t_idx_list


search_folder = "../../MyData/EmbedDBO/"

graph = Graph()
graph.e2idx_file = search_folder + "entity2id.txt"
graph.r2idx_file = search_folder + "relation2id.txt"
graph.triple2idx_file = search_folder + "triple2id.txt"

if transe_embed is True:
    transe_embed = False

graph.load_data()
graph.load_er_embedding()

r_name_list = []

for sparql in eaqs:
    sp = SparqlParser(sparql=sparql)
    sp.parse_sparql()
예제 #9
0
                f.write("{}\t{}\n".format(res[1][0], res[1][1]))
                # print("Answer:", end="\t")
                for answer_name in res[2]:
                    # print("{}".format(answer_name), end="\t")
                    f.write("{}\t".format(answer_name))
                if res_idx != len(res_list) - 1:
                    f.write("\n")
                # print("\n")


if __name__ == "__main__":

    e2idx_file = "../../../MyData/DBO/All/entity2id.txt"
    r2idx_file = "../../../MyData/DBO/All/relation2id.txt"
    triple2idx_file = "../../../MyData/DBO/All/triple2id.txt"
    graph = Graph()
    graph.e2idx_file = e2idx_file
    graph.r2idx_file = r2idx_file
    graph.triple2idx_file = triple2idx_file
    graph.load_data()

    r_name_list_left = [["dbo:owner", "dbo:foundationPlace"],
                        ["dbo:regionServed", "dbo:product"],
                        ["dbo:locationCountry", "dbo:foundationPlace"],
                        ["dbo:regionServed", "dbo:owner"]]
    for r_name in r_name_list_left:
        tcs = TwoConsSampler(graph)
        tcs.set_relation(r_name[0], r_name[1])
        tcs.sample_data()
        tcs.record_test_sample()
예제 #10
0
파일: RDQEOneCons.py 프로젝트: xzlyu/RQE
                continue

            e_idx_list = [
                graph.get_e_idx_by_e_name(e_name)
                for e_name in line.strip().split("\t")
            ]
            h_idx_list.append(e_idx_list[0])
            t_idx_list.append(e_idx_list[1:])
    return h_idx_list, t_idx_list


util = Util()

search_folder = "../../MyData/DBO/All/"

graph = Graph()
graph.e2idx_file = search_folder + "entity2id.txt"
graph.r2idx_file = search_folder + "relation2id.txt"
graph.triple2idx_file = search_folder + "triple2id.txt"

graph.load_data()

r_name_list = []
r_idx_list = []

for sparql in eaqs:
    sp = SparqlParser(sparql=sparql)
    sp.parse_sparql()
    for relation in sp.r_name_list:
        r_name_list.append(relation)
    r_name_list = list(set(r_name_list))
예제 #11
0
            res_list.append(res)

        with open(self.fileName, 'w', encoding="UTF-8") as f:
            print("Res Num: {}".format(len(res_list)))
            for res_idx,res in enumerate(res_list):
                f.write("{}\t{}\n".format(res[0][0], res[0][1]))
                f.write("{}\t{}\n".format(res[1][0], res[1][1]))
                for answer_name in res[2]:
                    f.write("{}\t".format(answer_name))
                if res_idx != len(res_list) - 1:
                    f.write("\n")


if __name__ == "__main__":
    filePathConfig = FilePathConfig()
    e2idx_file = "../" + filePathConfig.test_e2idx_file
    r2idx_file = "../" + filePathConfig.test_r2idx_file
    triple2idx_file = "../" + filePathConfig.test_triple2idx_file
    # e2idx_file = "../" + filePathConfig.search_e2idx_file
    # r2idx_file = "../" + filePathConfig.search_r2idx_file
    # triple2idx_file = "../" + filePathConfig.search_triple2idx_file
    graph = Graph(e2idx_file, r2idx_file, triple2idx_file)
    graph.load_data()

    r_name_list_right_left = [["dbo:starring", "dbo:birthPlace"]]
    for r_name in r_name_list_right_left:
        tcs = TwoConsSampler(graph)
        tcs.set_relation(r_name[0], r_name[1])
        tcs.sample_data()
        tcs.record_test_sample()