Пример #1
0
    def get_kg(self, kg_path):
        '''
        Get the KG dictionary
        :param file_name:
        :return:
        '''
        kg_files_club = os.listdir(kg_path + 'clubs/')
        kg_files_nation = os.listdir(kg_path + 'country/')
        er_pair_dict = {}
        ent_l = []
        for kgf in kg_files_club:
            with open(kg_path + 'clubs/' + kgf, 'r') as f:
                kg = f.readlines()
            for t in kg:
                e, r, o = t.lower().replace('\n', '').split('\t')
                e, r, o = clean_str(e), clean_str(r), clean_str(o)
                er_pair_dict[e.strip() + '#' + r.strip()] = o.replace(' ', '_')
                er_pair_dict[o.strip() + '#' + r.strip()] = e.replace(' ', '_')
                ent_l.append(e)
                ent_l.append(o)
        for kgf in kg_files_nation:
            with open(kg_path + 'country/' + kgf, 'r') as f:
                kg = f.readlines()
            for t in kg:
                e, r, o = t.lower().replace('\n', '').split('\t')
                e, r, o = clean_str(e), clean_str(r), clean_str(o)
                er_pair_dict[e.strip() + '#' + r.strip()] = o.replace(' ', '_')
                er_pair_dict[o.strip() + '#' + r.strip()] = e.replace(' ', '_')
                ent_l.append(e)
                ent_l.append(o)

        return er_pair_dict, ent_l
 def _check_kgqa_ans(self, kgqa_dict, qid, q):
     '''
     Get relation and subjects for kg based questions
     :param kgqa_dict:
     :param qid:
     :param q:
     :return:
     '''
     correct_reln = ''
     correct_sub = ''
     if qid in kgqa_dict.keys():
         for v in kgqa_dict[qid]:
             q = clean_str(q)
             q_kg = v[0]
             # print (fuzz.ratio(q, q_kg))
             if fuzz.ratio(q, q_kg) > self.max_similarity:
                 # print (q)
                 correct_reln = v[2]
                 correct_sub = v[3]
                 break
             else:
                 continue
     return correct_reln, correct_sub
    def get_data(self, dataset):
        '''
        Load conversations from files
        :param dataset:
        :return:
        '''
        data_out = []
        found = []
        data_files = self.load_all_files(dataset)
        kgqa_answers = self.get_kgqa(dataset)
        for dat in data_files:
            # print (dat)
            # print ("dat is here")
            convo = self.get_conv(self.data_path + 'conversations/' + dataset +
                                  '_with_entities_er/' +
                                  dat)  # Get Conversations
            conv_hist = []
            for j, c in enumerate(convo):
                convo_dict = {}
                correct_ent = []
                # r, o = self._check_kgqa_ans(kgqa_answers, dat.replace('.json', ''), c['q' + str(j + 1)])
                convo_dict['q'] = clean_str(c['q' + str(j + 1)])
                convo_dict['_a'] = clean_str(c['a' + str(j + 1)])
                convo_dict['a'] = clean_str(c['a' + str(j + 1)])
                convo_dict['kgER'] = OrderedDict(c['kgER' + str(j + 1)])
                s = list(c['correct_ERcomb'].keys())  # Get the subject
                if s:  # Answerable by KG
                    # convo_dict['relation'] = r
                    s = s[0]
                    r_s = list(
                        c['correct_ERcomb'][s].keys())  # Get all relations
                    o = [c['correct_ERcomb'][s][r]
                         for r in r_s]  # Get all correct objects
                    convo_dict['object'] = o
                    convo_dict['subject'] = s
                    convo_dict['relations'] = r_s
                    inp_kg = convo_dict['kgER']
                    # Correct er pair list
                    correct_er_pair = []
                    # Get output relation
                    entity_reln_pair = []
                    relations = []
                    for ent in inp_kg.keys():
                        for r in inp_kg[ent]:
                            entity_reln_pair.append(
                                [ent, r])  # get all entity relation pair
                            relations.append(r)
                    # if 'relation' in data:
                    object_dict = {}
                    for p, er_pair in enumerate(entity_reln_pair):
                        e, r = er_pair
                        try:
                            object_dict['o' + str(p)] = e + '#' + r
                        except KeyError:
                            print(dat)
                        if e == s and r in r_s:
                            correct_er_pair.append(p)
                            correct_ent.append(e + '#' + r)
                    convo_dict['output_kg'] = correct_er_pair
                    convo_dict['er_pair'] = object_dict
                    # Replace answer with object
                    if len(correct_er_pair
                           ) > 0:  # check if correct er pair exists in answer
                        sent_gate = np.zeros(len(convo_dict['_a'].split()))
                        answer = []

                        # answer_entity =
                        # for j, c_p in enumerate(correct_er_pair):
                        #     answer = convo_dict['a'].replace(o[j], 'o' + str(c_p))
                        # answer = [convo_dict['a'].replace(o[j], 'o' + str(correct_er_pair[j])) for j in range(len(correct_er_pair))]
                        # sent_gate[b] = 1.0
                        for k, w_a in enumerate(convo_dict['_a'].split()):
                            # print(self.get_w2i(w_a))
                            # y[b, k] = self.get_w2i(w_a)
                            if w_a in o:
                                try:
                                    sent_gate[k] = 1
                                    j = [
                                        i for i, obj in enumerate(o)
                                        if obj == w_a
                                    ]
                                    answer.append('o' +
                                                  str(correct_er_pair[j[0]]))
                                except Exception as e:
                                    print(sent_gate, convo_dict['_a'])
                                    print(e)
                            else:
                                answer.append(w_a)
                            # if w_a in self.global_ent:
                            # correct_ent.append(w_a)
                        convo_dict['a'] = ' '.join(w for w in answer)
                        convo_dict['correct_ent'] = correct_ent
                    try:
                        convo_dict['sent_gate'] = sent_gate
                    except Exception as e:
                        print(e, correct_er_pair, convo_dict['q'])
                    # Get Degree Matrix
                    A, I = gen_adjacency_mat(convo_dict['kgER'])
                    D = torch.from_numpy(get_degree_matrix(A))
                    # get entity relation matrix
                    X = []
                    er_vec = getER_vec(convo_dict['kgER'])
                    for e, ele in enumerate(er_vec):
                        # ele_present = get_fuzzy_match(ele, convo_dict['q'])[0]/100
                        # X.append([self.calculate_similarity(ele, convo_dict['q']), ele_present])
                        X.append(
                            self.calculate_similarity(ele, convo_dict['q']))
                    # Calculate features
                    A_hat = A + I
                    dt = np.matmul(np.linalg.inv(D), A_hat)
                    h = np.matmul(dt, X)
                    convo_dict['X_feat'] = X
                    inp_ft = []
                    all_similarities = []
                    for k, e in enumerate(er_vec):
                        all_similarities.append([e, h[k]])
                    for s in all_similarities:
                        ele, sim = s
                        if ele not in inp_kg:
                            inp_ft.append(sim)
                        else:
                            ent_sim = sim
                    convo_dict['input_graph_feature'] = inp_ft
                    if s:
                        found.append(s)
                        # print (correct_sub, convo_dict['q'], correct_reln, dat)
                data_out.append(convo_dict)
        print(len(found))
        return data_out
    def get_data(self, dataset):
        '''
        Load conversations from files
        :param dataset:
        :return:
        '''
        data_out = []
        # found = []
        data_files = self.load_all_files(dataset)
        for dat in data_files:
            convo = self.get_conv(self.data_path + 'manually_annotated/' +
                                  dataset + '_sketch/' +
                                  dat)  # Get Conversations

            conv_hist = []
            for j, c in enumerate(convo):
                convo_dict = {}
                convo_dict['f'] = dat.replace('.json', '')
                # conv_hist = []
                if conv_hist:
                    try:
                        conv_hist.append(c['q' + str(j + 1)])
                    except Exception as e:
                        print(dat, e)
                    if self.args.use_bert:
                        convo_dict['q'] = (' ' +
                                           self.args.bert_sep + ' ').join(
                                               u
                                               for u in conv_hist)  # For bert
                    else:
                        convo_dict['q'] = (' ' + self.args.eou_tok + ' ').join(
                            u for u in conv_hist)
                    try:
                        conv_hist.append(c['a' + str(j + 1)])
                    except Exception as e:
                        print(e, dat)
                        exit()
                else:
                    convo_dict['q'] = c['q' + str(j + 1)]
                    conv_hist.append(c['q' + str(j + 1)])
                    conv_hist.append(c['a' + str(j + 1)])
                convo_dict['q'] = convo_dict['q']
                convo_dict['_a'] = c['a' + str(j + 1)]
                convo_dict['_q'] = clean_str(c['q' + str(j + 1)])
                # convo_dict['a'] = c['a' + str(j + 1)]

                # Get KG

                if c['input_ent' + str(j + 1)]:
                    convo_dict['e'] = c['input_ent' + str(j + 1)]
                else:
                    convo_dict['e'] = self.args.no_ent_tok

                convo_dict['o'] = c['obj' + str(j + 1)].split(',')
                convo_dict['r'] = c['corr_rel' + str(j + 1)].split(',')

                convo_dict['a'] = c["a" + str(j + 1) + "_v2"]
                if convo_dict['e']:
                    _, best_entity_ans = get_fuzzy_match(
                        convo_dict['e'], convo_dict['a'])
                    if best_entity_ans != "":
                        convo_dict['a'] = convo_dict['a'].replace(
                            best_entity_ans, '@entity')
                    input_kg_reln = self.conn_ent[convo_dict['e']]
                    input_kg = dict()
                    input_kg[convo_dict['e']] = list(input_kg_reln)
                    A, I = gen_adjacency_mat(input_kg)
                    D = get_degree_matrix(A)
                    X = []
                    # get entity relation matrix
                    er_vec = getER_vec(input_kg)
                    for e, ele in enumerate(er_vec):
                        # ele_present = get_fuzzy_match(ele, convo_dict['q'])[0]/100
                        # X.append([self.calculate_similarity(ele, convo_dict['q']), ele_present])
                        X.append(
                            self.calculate_similarity(ele, convo_dict['_q']))
                    # Calculate features
                    A_hat = A + I
                    try:
                        dt = np.matmul(np.linalg.inv(D), A_hat)
                        h = np.matmul(dt, X)
                    except Exception as e:
                        print(e)
                        h = np.zeros(len(er_vec))
                        print(dat, A)

                    kg_dict = {}
                    for j, ele in enumerate(er_vec):
                        if j == 0:
                            kg_dict['@entity'] = h[0]
                        else:
                            if len(ele.split()) > 1:
                                kg_dict['@' + ele.replace(' ', '_')] = h[j]
                            else:
                                kg_dict['@' + ele] = h[j]
                    convo_dict['h'] = kg_dict
                # convo_dict['a'] = c["a"+str(j+1)+"_v2"]
                # if answer is empty put the original answer
                if not convo_dict['a']:
                    convo_dict['a'] = convo_dict['_a']

                s_g = np.zeros(len(convo_dict['a'].split()))
                for j, w in enumerate(convo_dict['a'].split()):
                    if '@' in w:
                        s_g[j] = 1.0
                convo_dict['s'] = s_g
                data_out.append(convo_dict)

        return data_out
Пример #5
0
    def get_data(self, dataset):
        '''
        Load conversations from files
        :param dataset:
        :return:
        '''
        data_out = []
        found = []
        data_files = self.load_all_files(dataset)
        kgqa_answers = self.get_kgqa(dataset)
        for dat in data_files:
            # print (dat)
            # print ("dat is here")
            convo = self.get_conv(self.data_path + 'conversations/' + dataset +
                                  '_with_entities_er/' +
                                  dat)  # Get Conversations
            conv_hist = []
            for j, c in enumerate(convo):
                convo_dict = {}
                r, o = self._check_kgqa_ans(kgqa_answers,
                                            dat.replace('.json', ''),
                                            c['q' + str(j + 1)])
                if conv_hist:
                    conv_hist.append(c['q' + str(j + 1)])
                    convo_dict['q'] = (' ' + self.args.eou_tok + ' ').join(
                        u for u in conv_hist)
                    conv_hist.append(c['a' + str(j + 1)])
                else:
                    convo_dict['q'] = c['q' + str(j + 1)]
                    conv_hist.append(c['q' + str(j + 1)])
                    conv_hist.append(c['a' + str(j + 1)])
                convo_dict['file_name'] = dat
                convo_dict['a'] = clean_str(c['a' + str(j + 1)])
                convo_dict['_a'] = clean_str(c['a' + str(j + 1)])
                convo_dict['kgER'] = OrderedDict(c['kgER' + str(j + 1)])
                correct_ent = []
                correct_sub = ''
                if r:
                    # convo_dict['relation'] = r
                    convo_dict['object'] = o
                    # if dataset == 'train':
                    probable_reln = []
                    #inp_kg = c['kgER_e' + str(j + 1)]
                    inp_kg = convo_dict['kgER']
                    # if dat == '3S8A4GJRD4FZ3YXJ1VN53CDZINRV6E.json':
                    # for e_r in c['kgER_e'+str(j+1)]:
                    if len(inp_kg.keys()) > 1:  # More than 1 entity in KG
                        for k, v in inp_kg.items():
                            for r_in_v, kg_r in enumerate(v):
                                if kg_r.lower() == r:
                                    probable_reln.append([k, kg_r.lower()])
                        if len(probable_reln) > 1:
                            best_sub = np.argmax([
                                get_fuzzy_match(opt[0], convo_dict['q'])[0]
                                for opt in probable_reln
                            ])
                            correct_sub, correct_reln = probable_reln[best_sub]
                    else:
                        for k, v in inp_kg.items():
                            for r_in_v, kg_r in enumerate(v):
                                if kg_r.lower() == r:
                                    correct_sub, correct_reln = k, kg_r.lower()
                    if not correct_sub:
                        try:
                            correct_sub, correct_reln = probable_reln[0]
                        except Exception as e:
                            # print (dat, '\t', convo_dict['q'])
                            print(e)
                    convo_dict['subject'] = correct_sub
                    convo_dict['relation'] = r
                    # best_sub_match = get_fuzzy_match(correct_sub, convo_dict['a'])[1]
                    # convo_dict['q'] = convo_dict['q'].replace(best_sub_match, self.args.ent_tok)
                    best_match = get_fuzzy_match(o, convo_dict['a'])[
                        1]  # get_fuzzy_match return similarity and object

                    # Get output relation
                    entity_reln_pair = []
                    object_dict = {}
                    relations = []
                    for ent in inp_kg.keys():
                        for r in inp_kg[ent]:
                            entity_reln_pair.append(
                                [clean_str(ent),
                                 clean_str(r)])  # get all entity relation pair
                            relations.append(r)
                    # if 'relation' in data:
                    for p, er_pair in enumerate(entity_reln_pair):
                        e, r = er_pair
                        try:
                            object_dict['o' + str(p)] = self.ent_d[e + '#' + r]
                        except KeyError:
                            try:
                                object_dict['o' +
                                            str(p)] = self.ent_d[e + ' #' + r]
                            except KeyError:
                                object_dict['o' + str(p)] = 'o' + str(p)
                        if e == correct_sub and r == correct_reln:
                            correct_er_pair = p
                            correct_ent.append(self.ent_d[e + '#' + r])
                    convo_dict['output_kg'] = correct_er_pair
                    convo_dict['er_pair'] = object_dict
                    # Replace answer with object
                    if 'correct_er_pair' in vars():
                        sent_gate = np.zeros(len(convo_dict['a'].split()))
                        answer = convo_dict['a'].replace(
                            best_match, self.args.mem_tok)
                        # sent_gate[b] = 1.0
                        for k, w_a in enumerate(answer.split()):
                            # print(self.get_w2i(w_a))
                            # y[b, k] = self.get_w2i(w_a)
                            if w_a == ('o' + str(correct_er_pair)):
                                try:
                                    sent_gate[k] = 1
                                except Exception:
                                    print(sent_gate, convo_dict['a'])
                                    print(sent_gate)
                        convo_dict['a'] = answer
                        convo_dict['_a'] = convo_dict['_a'].replace(
                            best_match, best_match.replace(' ', '_'))
                        # convo_dict['a'] = ' '.join(w for w in answer)
                    convo_dict['correct_ent'] = correct_ent
                    try:
                        convo_dict['sent_gate'] = sent_gate
                    except Exception as e:
                        print(e, correct_er_pair, convo_dict['q'])
                    # Get Degree Matrix
                    A, I = gen_adjacency_mat(convo_dict['kgER'])
                    D = torch.from_numpy(get_degree_matrix(A))
                    # get entity relation matrix
                    X = []
                    er_vec = getER_vec(convo_dict['kgER'])
                    for e, ele in enumerate(er_vec):
                        # ele_present = get_fuzzy_match(ele, convo_dict['q'])[0]/100
                        # X.append([self.calculate_similarity(ele, convo_dict['q']), ele_present])
                        X.append(
                            self.calculate_similarity(ele, convo_dict['q']))
                    # Calculate features
                    A_hat = A + I
                    dt = np.matmul(np.linalg.inv(D), A_hat)
                    h = np.matmul(dt, X)
                    convo_dict['X_feat'] = X
                    inp_ft = []
                    all_similarities = []
                    for k, e in enumerate(er_vec):
                        all_similarities.append([e, h[k]])
                    for s in all_similarities:
                        ele, sim = s
                        if ele not in inp_kg:
                            inp_ft.append(np.exp(sim))
                        else:
                            ent_sim = sim
                    convo_dict['input_graph_feature'] = inp_ft
                    if correct_sub:
                        found.append(correct_sub)
                        # print (correct_sub, convo_dict['q'], correct_reln, dat)
                data_out.append(convo_dict)
        print(len(found))
        return data_out