def __init__(
        self,
        data_path='data/soccer/',
        vec_dim=300,
        # fasttext_model='/home/debanjan/acl_submissions/soccerbot_acl/vocab/wiki.simple.bin'):
        fasttext_model='/data/dchaudhu/soccerbot_acl/vocab/wiki.en.bin'):
        self.data_path = data_path
        self.max_similarity = 85
        self.vec_dim = vec_dim
        cap_path = datapath(fasttext_model)
        self.word_emb = load_facebook_model(cap_path)
        # print (self.max_er_vec)
        self.stop = set(stopwords.words('english'))
        self.punc = string.punctuation
        self.train_dataset = self.get_data('train')
        self.val_dataset = self.get_data('val')
        self.test_dataset = self.get_data('test')
        self.max_er_vec = []  # max er vector combination size
        for dat in self.train_dataset:
            self.max_er_vec.append(sum(len(v) for k, v in dat['kgER'].items()))
        self.max_out_reln = np.max(self.max_er_vec)
        self.inp_graph_max_size = np.max(
            [len(getER_vec(kg['kgER'])) for kg in self.train_dataset])
        print('input graph size:' + str(self.inp_graph_max_size))
        print(self.max_out_reln)
        self.objects = ['o' + str(j) for j in range(self.max_out_reln)]
        self.args = get_args()
        # Create vocabulary and word2id
        self.vocab = defaultdict(float)
        self.get_vocab(self.train_dataset)
        self.get_vocab(self.test_dataset)
        self.get_vocab(self.val_dataset)
        self.vocab[self.args.unk_tok] += 1.0
        self.vocab[self.args.sos_tok] += 1.0
        self.vocab[self.args.eos_tok] += 1.0
        for o in self.objects:
            self.vocab[o] += 1.0

        self.stoi = dict(zip(self.vocab.keys(), range(0, len(self.vocab))))
        # add additional tokens
        # self.stoi[self.args.unk_tok] = len(self.stoi)
        # self.stoi[self.args.sos_tok] = len(self.stoi)
        # self.stoi[self.args.eos_tok] = len(self.stoi)
        # print(len(self.stoi))
        # self.itos = {v: k for k, v in self.stoi.items()}

        # for j in range(self.max_out_reln):
        #     self.stoi['o'+str(j)] = len(self.stoi)+1
        # del self.stoi

        self.itos = {v: k for k, v in self.stoi.items()}
        print(len(self.stoi))
        self.n_words = len(self.stoi)

        self.vectors = np.zeros((len(self.stoi), vec_dim))
        for w, w2i in self.stoi.items():
            if w2i < self.stoi[self.args.eos_tok]:
                self.vectors[w2i] = self.word_emb.wv[w]
示例#2
0
    def get_graph_lap(self, entity, question):
        # orig_g = orig_g.data.numpy()
        i_g = np.ones(len(self.chat_data.trg_stoi))
        # query = ' '.join(self.chat_data.src_itos[w.item()] for w in question)
        # graph_ele = [self.chat_data.trg_stoi[g] for g in self.chat_data.trg_stoi.keys() if '@' in g]
        # remove all elements from vocabulary except at k-hop
        # graph_ele = [self.chat_data.trg_stoi[g] for g in self.chat_data.trg_stoi.keys() if '@' in g or g in self.chat_data.e_r_l.keys()]
        # print (graph_ele)
        # i_g[graph_ele] = 0.0
        input_kg_reln = self.chat_data.e_r_l[entity]
        input_kg = dict()
        input_kg[entity] = list(input_kg_reln)
        A, I = gen_adjacency_mat(input_kg)
        D = get_degree_matrix(A)
        X = []
        # get entity relation matrix
        er_vec = getER_vec(input_kg)
        for e, ele in enumerate(er_vec):
            # ele_present = get_fuzzy_match(ele, convo_dict['q'])[0]/100
            # X.append([self.calculate_similarity(ele, convo_dict['q']), ele_present])
            X.append(self.calculate_similarity(ele, question))
        # Calculate features
        A_hat = A + I
        try:
            dt = np.matmul(np.linalg.inv(D), A_hat)
            h = np.matmul(dt, X)
        except Exception as e:
            # print(e)
            h = np.zeros(len(er_vec))
            # print(dat, A)

        # print(h)  # vector for graph Laplacian

        # kg_dict = {}
        for j, ele in enumerate(er_vec):
            if j == 0:
                try:
                    i_g[self.chat_data.trg_stoi['@entity']] = h[0]
                except KeyError:
                    i_g[self.chat_data.trg_stoi['@entity']] = 1.0
            else:
                # if '@@' + ele in convo_dict['a']:  # check for dual 2-hop relation
                #     kg_dict['@@' + ele] = h[j]
                # else:
                try:
                    ele = ele.replace(' ', '_')
                    i_g[self.chat_data.trg_stoi['@' + ele]] = h[j]
                except KeyError:
                    try:
                        i_g[self.chat_data.trg_stoi['@@' + ele]] = h[j]
                    except Exception:
                        # print (ele, entity)
                        continue

        return torch.from_numpy(i_g)
示例#3
0
    def get_data(self, dataset):
        '''
        Load conversations from files
        :param dataset:
        :return:
        '''
        data_out = []
        # found = []
        data_files = self.load_all_files(dataset)
        for dat in data_files:
            # print (dat)
            # print ("dat is here")
            convo = self.get_conv(self.data_path + 'manually_annotated/' + dataset + '_sketch/' + dat)  # Get Conversations

            conv_hist = []
            for j, c in enumerate(convo):
                convo_dict = {}
                convo_dict['f'] = dat.replace('.json', '')
                # conv_hist = []
                if conv_hist:
                    try:
                        conv_hist.append(c['q' + str(j + 1)])
                    except Exception as e:
                        print (dat, e)
                    if self.args.use_bert:
                        convo_dict['q'] = (' '+self.args.bert_sep+' ').join(u for u in conv_hist)  # For bert
                    else:
                        convo_dict['q'] = (' '+self.args.eou_tok+' ').join(u for u in conv_hist)
                    try:
                        conv_hist.append(c['a' + str(j + 1)])
                    except Exception as e:
                        print (e, dat)
                        exit()
                else:
                    convo_dict['q'] = c['q' + str(j + 1)]
                    conv_hist.append(c['q' + str(j + 1)])
                    conv_hist.append(c['a' + str(j + 1)])
                convo_dict['q'] = convo_dict['q']
                convo_dict['_a'] = c['a' + str(j + 1)]
                convo_dict['_q'] = c['q' + str(j + 1)]
                # convo_dict['a'] = c['a' + str(j + 1)]

                # Get KG

                if c['input_ent' + str(j + 1)]:
                    convo_dict['e'] = c['input_ent' + str(j + 1)]
                    if 'changs' in convo_dict['e']:
                        convo_dict['e'] = 'p_._f_._changs'
                else:
                    convo_dict['e'] = self.args.no_ent_tok

                convo_dict['o'] = c['obj' + str(j + 1)].split(',')
                convo_dict['r'] = c['corr_rel' + str(j + 1)].split(',')
                convo_dict['kvr'] = c['kvr_entlist_qa' + str(j + 1)]

                ir = []
                if 'navigate' in dat:  # add 2 hop information in navigation
                    with open(self.data_path + 'KG/' + dat.replace('.json', '') + '_kg.txt', 'r') as f:  # Read KG file
                        nav_kg = f.readlines()
                    nav_dict = defaultdict(list)
                    nav_1hop = defaultdict(list)
                    for t in nav_kg:
                        e, r, o = t.replace('\n', '').split('\t')
                        nav_dict[e].append([r, o])
                        nav_dict[o].append([r, e])
                        # nav_1hop[e].append(o)
                        # nav_1hop[o].append(e)
                    if convo_dict['e'] in nav_dict.keys():
                        for w in convo_dict['_a'].split():
                            objects = [o for r, o in nav_dict[convo_dict['e']]]
                            _2hop_obj = []
                            for o in objects:
                                for conn_obj in nav_dict[o]:
                                    _2hop_obj.append(conn_obj)
                            if convo_dict['e'] in nav_dict.keys():  # if object is in 1 hop
                                if w == convo_dict['e']:
                                    ir.append('@entity')
                                elif w in objects:
                                    relation = [r for r, o in nav_dict[convo_dict['e']] if o == w]
                                    # object = w
                                    ir.append('@'+relation[0])
                                elif w in [o for r, o in _2hop_obj]:
                                    # _2hop_obj = [o for r, o in nav_dict[object]]
                                    # if w in _2hop_obj:
                                    relation = [r for r, o in _2hop_obj if o == w]
                                        #object = w
                                    ir.append('@@' + relation[0])
                                    #else:
                                    #    ir.append(w)
                                else:
                                    ir.append(w)

                elif 'weather' not in dat: # Process data for weather separately
                    for w in convo_dict['_a'].split():
                        if '' not in convo_dict['o'] and 'none' not in convo_dict['o']:
                            if w in convo_dict['o']:  # If word in output objects
                                ans_obj = [o for o in convo_dict['o'] if o == w]
                                ans_obj = ans_obj[0]
                                if convo_dict['e']: # check if there's input entity
                                    if convo_dict['e'] not in self.global_ent:
                                        print (convo_dict['e'], dat)
                                    try:
                                        # connected_reln = self.eo_dict[convo_dict['e']+'#'+ans_obj[0]]
                                        connected_reln = [r for r in convo_dict['r'] if r
                                                          in self.eo_dict[convo_dict['e']+'#'+ans_obj]]
                                        ir.append('@' + connected_reln[0])
                                    except Exception as e:
                                        # print (e)
                                        # print (dat)
                                        ir.append(w)
                                        # exit()
                                else:
                                    ir.append(w)
                            elif w == convo_dict['e']:
                                ir.append('@entity')
                            else:
                                ir.append(w)
                        else:
                            ir.append(w)
                else:
                    with open(self.data_path + 'KG/' + dat.replace('.json', '') + '_kg.txt', 'r') as f:  # Read KG file
                        weather_kg = f.readlines()
                    weather_dict = defaultdict(list)
                    for t in weather_kg:
                        e, r, o = t.replace('\n', '').split('\t')
                        weather_dict[e].append([r, o])
                    if convo_dict['e'] in weather_dict.keys():
                        for w in convo_dict['_a'].split():
                            objects = [o for r, o in weather_dict[convo_dict['e']]]
                            if convo_dict['e'] in weather_dict.keys():
                                if w == convo_dict['e']:
                                    ir.append('@entity')
                                elif w in objects:
                                    relation = [r for r, o in weather_dict[convo_dict['e']] if o == w]
                                    ir.append('@'+relation[0])
                                else:
                                    ir.append(w)

                convo_dict['a'] = ' '.join(ir)
                # if answer is empty put the original answer
                if not convo_dict['a']:
                    convo_dict['a'] = convo_dict['_a']
                # Add the graph laplacian information
                if convo_dict['e']:
                    input_kg_reln = self.e_r_l[convo_dict['e']]
                    input_kg = dict()
                    input_kg[convo_dict['e']] = list(input_kg_reln)
                    A, I = gen_adjacency_mat(input_kg)
                    D = get_degree_matrix(A)
                    X = []
                    # get entity relation matrix
                    er_vec = getER_vec(input_kg)
                    for e, ele in enumerate(er_vec):
                        # ele_present = get_fuzzy_match(ele, convo_dict['q'])[0]/100
                        # X.append([self.calculate_similarity(ele, convo_dict['q']), ele_present])
                        X.append(self.calculate_similarity(ele, convo_dict['_q']))
                    # Calculate features
                    A_hat = A + I
                    try:
                        dt = np.matmul(np.linalg.inv(D), A_hat)
                        h = np.matmul(dt, X)
                    except Exception as e:
                        print (e)
                        h = np.zeros(len(er_vec))
                        print (dat, A)

                    kg_dict = {}
                    for j, ele in enumerate(er_vec):
                        if j == 0:
                            kg_dict['@entity'] = h[0]
                        else:
                            if '@@'+ele in convo_dict['a']:  # check for dual 2-hop relation
                                kg_dict['@@'+ele] = h[j]
                            else:
                                kg_dict['@'+ele] = h[j]
                    convo_dict['h'] = kg_dict
                data_out.append(convo_dict)

        return data_out
    def get_data(self, dataset):
        '''
        Load conversations from files
        :param dataset:
        :return:
        '''
        data_out = []
        found = []
        data_files = self.load_all_files(dataset)
        kgqa_answers = self.get_kgqa(dataset)
        for dat in data_files:
            # print (dat)
            # print ("dat is here")
            convo = self.get_conv(self.data_path + 'conversations/' + dataset +
                                  '_with_entities_er/' +
                                  dat)  # Get Conversations
            conv_hist = []
            for j, c in enumerate(convo):
                convo_dict = {}
                correct_ent = []
                # r, o = self._check_kgqa_ans(kgqa_answers, dat.replace('.json', ''), c['q' + str(j + 1)])
                convo_dict['q'] = clean_str(c['q' + str(j + 1)])
                convo_dict['_a'] = clean_str(c['a' + str(j + 1)])
                convo_dict['a'] = clean_str(c['a' + str(j + 1)])
                convo_dict['kgER'] = OrderedDict(c['kgER' + str(j + 1)])
                s = list(c['correct_ERcomb'].keys())  # Get the subject
                if s:  # Answerable by KG
                    # convo_dict['relation'] = r
                    s = s[0]
                    r_s = list(
                        c['correct_ERcomb'][s].keys())  # Get all relations
                    o = [c['correct_ERcomb'][s][r]
                         for r in r_s]  # Get all correct objects
                    convo_dict['object'] = o
                    convo_dict['subject'] = s
                    convo_dict['relations'] = r_s
                    inp_kg = convo_dict['kgER']
                    # Correct er pair list
                    correct_er_pair = []
                    # Get output relation
                    entity_reln_pair = []
                    relations = []
                    for ent in inp_kg.keys():
                        for r in inp_kg[ent]:
                            entity_reln_pair.append(
                                [ent, r])  # get all entity relation pair
                            relations.append(r)
                    # if 'relation' in data:
                    object_dict = {}
                    for p, er_pair in enumerate(entity_reln_pair):
                        e, r = er_pair
                        try:
                            object_dict['o' + str(p)] = e + '#' + r
                        except KeyError:
                            print(dat)
                        if e == s and r in r_s:
                            correct_er_pair.append(p)
                            correct_ent.append(e + '#' + r)
                    convo_dict['output_kg'] = correct_er_pair
                    convo_dict['er_pair'] = object_dict
                    # Replace answer with object
                    if len(correct_er_pair
                           ) > 0:  # check if correct er pair exists in answer
                        sent_gate = np.zeros(len(convo_dict['_a'].split()))
                        answer = []

                        # answer_entity =
                        # for j, c_p in enumerate(correct_er_pair):
                        #     answer = convo_dict['a'].replace(o[j], 'o' + str(c_p))
                        # answer = [convo_dict['a'].replace(o[j], 'o' + str(correct_er_pair[j])) for j in range(len(correct_er_pair))]
                        # sent_gate[b] = 1.0
                        for k, w_a in enumerate(convo_dict['_a'].split()):
                            # print(self.get_w2i(w_a))
                            # y[b, k] = self.get_w2i(w_a)
                            if w_a in o:
                                try:
                                    sent_gate[k] = 1
                                    j = [
                                        i for i, obj in enumerate(o)
                                        if obj == w_a
                                    ]
                                    answer.append('o' +
                                                  str(correct_er_pair[j[0]]))
                                except Exception as e:
                                    print(sent_gate, convo_dict['_a'])
                                    print(e)
                            else:
                                answer.append(w_a)
                            # if w_a in self.global_ent:
                            # correct_ent.append(w_a)
                        convo_dict['a'] = ' '.join(w for w in answer)
                        convo_dict['correct_ent'] = correct_ent
                    try:
                        convo_dict['sent_gate'] = sent_gate
                    except Exception as e:
                        print(e, correct_er_pair, convo_dict['q'])
                    # Get Degree Matrix
                    A, I = gen_adjacency_mat(convo_dict['kgER'])
                    D = torch.from_numpy(get_degree_matrix(A))
                    # get entity relation matrix
                    X = []
                    er_vec = getER_vec(convo_dict['kgER'])
                    for e, ele in enumerate(er_vec):
                        # ele_present = get_fuzzy_match(ele, convo_dict['q'])[0]/100
                        # X.append([self.calculate_similarity(ele, convo_dict['q']), ele_present])
                        X.append(
                            self.calculate_similarity(ele, convo_dict['q']))
                    # Calculate features
                    A_hat = A + I
                    dt = np.matmul(np.linalg.inv(D), A_hat)
                    h = np.matmul(dt, X)
                    convo_dict['X_feat'] = X
                    inp_ft = []
                    all_similarities = []
                    for k, e in enumerate(er_vec):
                        all_similarities.append([e, h[k]])
                    for s in all_similarities:
                        ele, sim = s
                        if ele not in inp_kg:
                            inp_ft.append(sim)
                        else:
                            ent_sim = sim
                    convo_dict['input_graph_feature'] = inp_ft
                    if s:
                        found.append(s)
                        # print (correct_sub, convo_dict['q'], correct_reln, dat)
                data_out.append(convo_dict)
        print(len(found))
        return data_out
    def get_data(self, dataset):
        '''
        Load conversations from files
        :param dataset:
        :return:
        '''
        data_out = []
        # found = []
        data_files = self.load_all_files(dataset)
        for dat in data_files:
            convo = self.get_conv(self.data_path + 'manually_annotated/' +
                                  dataset + '_sketch/' +
                                  dat)  # Get Conversations

            conv_hist = []
            for j, c in enumerate(convo):
                convo_dict = {}
                convo_dict['f'] = dat.replace('.json', '')
                # conv_hist = []
                if conv_hist:
                    try:
                        conv_hist.append(c['q' + str(j + 1)])
                    except Exception as e:
                        print(dat, e)
                    if self.args.use_bert:
                        convo_dict['q'] = (' ' +
                                           self.args.bert_sep + ' ').join(
                                               u
                                               for u in conv_hist)  # For bert
                    else:
                        convo_dict['q'] = (' ' + self.args.eou_tok + ' ').join(
                            u for u in conv_hist)
                    try:
                        conv_hist.append(c['a' + str(j + 1)])
                    except Exception as e:
                        print(e, dat)
                        exit()
                else:
                    convo_dict['q'] = c['q' + str(j + 1)]
                    conv_hist.append(c['q' + str(j + 1)])
                    conv_hist.append(c['a' + str(j + 1)])
                convo_dict['q'] = convo_dict['q']
                convo_dict['_a'] = c['a' + str(j + 1)]
                convo_dict['_q'] = clean_str(c['q' + str(j + 1)])
                # convo_dict['a'] = c['a' + str(j + 1)]

                # Get KG

                if c['input_ent' + str(j + 1)]:
                    convo_dict['e'] = c['input_ent' + str(j + 1)]
                else:
                    convo_dict['e'] = self.args.no_ent_tok

                convo_dict['o'] = c['obj' + str(j + 1)].split(',')
                convo_dict['r'] = c['corr_rel' + str(j + 1)].split(',')

                convo_dict['a'] = c["a" + str(j + 1) + "_v2"]
                if convo_dict['e']:
                    _, best_entity_ans = get_fuzzy_match(
                        convo_dict['e'], convo_dict['a'])
                    if best_entity_ans != "":
                        convo_dict['a'] = convo_dict['a'].replace(
                            best_entity_ans, '@entity')
                    input_kg_reln = self.conn_ent[convo_dict['e']]
                    input_kg = dict()
                    input_kg[convo_dict['e']] = list(input_kg_reln)
                    A, I = gen_adjacency_mat(input_kg)
                    D = get_degree_matrix(A)
                    X = []
                    # get entity relation matrix
                    er_vec = getER_vec(input_kg)
                    for e, ele in enumerate(er_vec):
                        # ele_present = get_fuzzy_match(ele, convo_dict['q'])[0]/100
                        # X.append([self.calculate_similarity(ele, convo_dict['q']), ele_present])
                        X.append(
                            self.calculate_similarity(ele, convo_dict['_q']))
                    # Calculate features
                    A_hat = A + I
                    try:
                        dt = np.matmul(np.linalg.inv(D), A_hat)
                        h = np.matmul(dt, X)
                    except Exception as e:
                        print(e)
                        h = np.zeros(len(er_vec))
                        print(dat, A)

                    kg_dict = {}
                    for j, ele in enumerate(er_vec):
                        if j == 0:
                            kg_dict['@entity'] = h[0]
                        else:
                            if len(ele.split()) > 1:
                                kg_dict['@' + ele.replace(' ', '_')] = h[j]
                            else:
                                kg_dict['@' + ele] = h[j]
                    convo_dict['h'] = kg_dict
                # convo_dict['a'] = c["a"+str(j+1)+"_v2"]
                # if answer is empty put the original answer
                if not convo_dict['a']:
                    convo_dict['a'] = convo_dict['_a']

                s_g = np.zeros(len(convo_dict['a'].split()))
                for j, w in enumerate(convo_dict['a'].split()):
                    if '@' in w:
                        s_g[j] = 1.0
                convo_dict['s'] = s_g
                data_out.append(convo_dict)

        return data_out
示例#6
0
    def get_data(self, dataset):
        '''
        Load conversations from files
        :param dataset:
        :return:
        '''
        data_out = []
        found = []
        data_files = self.load_all_files(dataset)
        kgqa_answers = self.get_kgqa(dataset)
        for dat in data_files:
            # print (dat)
            # print ("dat is here")
            convo = self.get_conv(self.data_path + 'conversations/' + dataset +
                                  '_with_entities_er/' +
                                  dat)  # Get Conversations
            conv_hist = []
            for j, c in enumerate(convo):
                convo_dict = {}
                r, o = self._check_kgqa_ans(kgqa_answers,
                                            dat.replace('.json', ''),
                                            c['q' + str(j + 1)])
                if conv_hist:
                    conv_hist.append(c['q' + str(j + 1)])
                    convo_dict['q'] = (' ' + self.args.eou_tok + ' ').join(
                        u for u in conv_hist)
                    conv_hist.append(c['a' + str(j + 1)])
                else:
                    convo_dict['q'] = c['q' + str(j + 1)]
                    conv_hist.append(c['q' + str(j + 1)])
                    conv_hist.append(c['a' + str(j + 1)])
                convo_dict['file_name'] = dat
                convo_dict['a'] = clean_str(c['a' + str(j + 1)])
                convo_dict['_a'] = clean_str(c['a' + str(j + 1)])
                convo_dict['kgER'] = OrderedDict(c['kgER' + str(j + 1)])
                correct_ent = []
                correct_sub = ''
                if r:
                    # convo_dict['relation'] = r
                    convo_dict['object'] = o
                    # if dataset == 'train':
                    probable_reln = []
                    #inp_kg = c['kgER_e' + str(j + 1)]
                    inp_kg = convo_dict['kgER']
                    # if dat == '3S8A4GJRD4FZ3YXJ1VN53CDZINRV6E.json':
                    # for e_r in c['kgER_e'+str(j+1)]:
                    if len(inp_kg.keys()) > 1:  # More than 1 entity in KG
                        for k, v in inp_kg.items():
                            for r_in_v, kg_r in enumerate(v):
                                if kg_r.lower() == r:
                                    probable_reln.append([k, kg_r.lower()])
                        if len(probable_reln) > 1:
                            best_sub = np.argmax([
                                get_fuzzy_match(opt[0], convo_dict['q'])[0]
                                for opt in probable_reln
                            ])
                            correct_sub, correct_reln = probable_reln[best_sub]
                    else:
                        for k, v in inp_kg.items():
                            for r_in_v, kg_r in enumerate(v):
                                if kg_r.lower() == r:
                                    correct_sub, correct_reln = k, kg_r.lower()
                    if not correct_sub:
                        try:
                            correct_sub, correct_reln = probable_reln[0]
                        except Exception as e:
                            # print (dat, '\t', convo_dict['q'])
                            print(e)
                    convo_dict['subject'] = correct_sub
                    convo_dict['relation'] = r
                    # best_sub_match = get_fuzzy_match(correct_sub, convo_dict['a'])[1]
                    # convo_dict['q'] = convo_dict['q'].replace(best_sub_match, self.args.ent_tok)
                    best_match = get_fuzzy_match(o, convo_dict['a'])[
                        1]  # get_fuzzy_match return similarity and object

                    # Get output relation
                    entity_reln_pair = []
                    object_dict = {}
                    relations = []
                    for ent in inp_kg.keys():
                        for r in inp_kg[ent]:
                            entity_reln_pair.append(
                                [clean_str(ent),
                                 clean_str(r)])  # get all entity relation pair
                            relations.append(r)
                    # if 'relation' in data:
                    for p, er_pair in enumerate(entity_reln_pair):
                        e, r = er_pair
                        try:
                            object_dict['o' + str(p)] = self.ent_d[e + '#' + r]
                        except KeyError:
                            try:
                                object_dict['o' +
                                            str(p)] = self.ent_d[e + ' #' + r]
                            except KeyError:
                                object_dict['o' + str(p)] = 'o' + str(p)
                        if e == correct_sub and r == correct_reln:
                            correct_er_pair = p
                            correct_ent.append(self.ent_d[e + '#' + r])
                    convo_dict['output_kg'] = correct_er_pair
                    convo_dict['er_pair'] = object_dict
                    # Replace answer with object
                    if 'correct_er_pair' in vars():
                        sent_gate = np.zeros(len(convo_dict['a'].split()))
                        answer = convo_dict['a'].replace(
                            best_match, self.args.mem_tok)
                        # sent_gate[b] = 1.0
                        for k, w_a in enumerate(answer.split()):
                            # print(self.get_w2i(w_a))
                            # y[b, k] = self.get_w2i(w_a)
                            if w_a == ('o' + str(correct_er_pair)):
                                try:
                                    sent_gate[k] = 1
                                except Exception:
                                    print(sent_gate, convo_dict['a'])
                                    print(sent_gate)
                        convo_dict['a'] = answer
                        convo_dict['_a'] = convo_dict['_a'].replace(
                            best_match, best_match.replace(' ', '_'))
                        # convo_dict['a'] = ' '.join(w for w in answer)
                    convo_dict['correct_ent'] = correct_ent
                    try:
                        convo_dict['sent_gate'] = sent_gate
                    except Exception as e:
                        print(e, correct_er_pair, convo_dict['q'])
                    # Get Degree Matrix
                    A, I = gen_adjacency_mat(convo_dict['kgER'])
                    D = torch.from_numpy(get_degree_matrix(A))
                    # get entity relation matrix
                    X = []
                    er_vec = getER_vec(convo_dict['kgER'])
                    for e, ele in enumerate(er_vec):
                        # ele_present = get_fuzzy_match(ele, convo_dict['q'])[0]/100
                        # X.append([self.calculate_similarity(ele, convo_dict['q']), ele_present])
                        X.append(
                            self.calculate_similarity(ele, convo_dict['q']))
                    # Calculate features
                    A_hat = A + I
                    dt = np.matmul(np.linalg.inv(D), A_hat)
                    h = np.matmul(dt, X)
                    convo_dict['X_feat'] = X
                    inp_ft = []
                    all_similarities = []
                    for k, e in enumerate(er_vec):
                        all_similarities.append([e, h[k]])
                    for s in all_similarities:
                        ele, sim = s
                        if ele not in inp_kg:
                            inp_ft.append(np.exp(sim))
                        else:
                            ent_sim = sim
                    convo_dict['input_graph_feature'] = inp_ft
                    if correct_sub:
                        found.append(correct_sub)
                        # print (correct_sub, convo_dict['q'], correct_reln, dat)
                data_out.append(convo_dict)
        print(len(found))
        return data_out