Example #1
0
def auto_test():
    qa_set = load_qadata("/data/zjy/preprocessed_data_10k/train")

    qa_map = getQA_by_state(qa_set)

    symbolic_seqs = auto_generate()
    a = 0
    for qa in qa_map['Comparative Reasoning (All)\n']:

        context = qa['context'].replace("\n", "").strip()
        context_utterance = qa['context_utterance'].replace("\n", "")
        if("around" not  in context_utterance and "approximately" not in context_utterance): continue
        context_entities = qa['context_entities'].replace("\n", "").split("|")
        context_relations = qa['context_relations'].replace("\n", "").split("|")
        context_types = qa['context_types'].replace("\n", "").split("|")
        context_ints = qa['context_ints'].replace("\n", "")
        context_relations.extend(['-' + r for r in context_relations])
        response_entities = qa['response_entities'].replace("\n", "").split("|")
        orig_response = qa['orig_response'].replace("\n", "")
        logging.info(str(a))
        logging.info('context_utterance:' + context_utterance)
        logging.info('context_entities:' + ",".join(context_entities))
        logging.info('context_relations:' + ",".join(context_relations))
        logging.info('context_types:' + ",".join(context_types))
        if "" in context_entities: context_entities.remove("")
        print (a,context_utterance)
        # start_time = time.time()
        # flag = 0
        a += 1
Example #2
0
def auto_test():

    qa_set = load_qadata("/data/zjy/preprocessed_data_10k/train")

    qa_map = getQA_by_state(qa_set)

    symbolic_seqs = auto_generate()
    a = 0
    for qa in qa_map['Quantitative Reasoning (Count) (All)\n']:

        context = qa['context'].replace("\n", "").strip()
        context_utterance = qa['context_utterance'].replace("\n", "")
        context_entities = qa['context_entities'].replace("\n", "").split("|")
        context_relations = qa['context_relations'].replace("\n",
                                                            "").split("|")
        context_types = qa['context_types'].replace("\n", "").split("|")
        context_ints = qa['context_ints'].replace("\n", "")
        context_relations.extend(['-' + r for r in context_relations])
        response_entities = qa['response_entities'].replace("\n",
                                                            "").split("|")
        orig_response = qa['orig_response'].replace("\n", "")
        logging.info(str(a))
        logging.info("context_utterance:" + context_utterance)
        logging.info("context_entities:" + ",".join(context_entities))
        logging.info("context_relations:" + ",".join(context_relations))
        logging.info("context_types:" + ",".join(context_types))
        a += 1
Example #3
0
def test_select(self):
    params = get_params("/data/zjy/csqa_data",
                        "/home/zhangjingyao/preprocessed_data_10k")

    # ls = LuceneSearch(params["lucene_dir"])
    # 读取知识库
    # try:
    #     print("loading...")
    #     wikidata = pickle.load(open('/home/zhangjingyao/data/wikidata.pkl','rb'))
    #     print("loading...")
    #     item_data = pickle.load(open('/home/zhangjingyao/data/entity_items','rb'))
    #     print("loading...")
    #     prop_data = None
    #     print("loading...")
    #     child_par_dict = pickle.load(open('/home/zhangjingyao/data/type_kb.pkl','rb'))
    # except:
    # wikidata, item_data, prop_data, child_par_dict = load_wikidata(params["wikidata_dir"])# data for entity ,property, type

    # 读取qa文件集
    qa_set = load_qadata("/home/zhangjingyao/preprocessed_data_10k/demo")
    question_parser = QuestionParser(params, True)

    f = open("log.txt", 'w+')
    for qafile in qa_set.itervalues():
        for qid in range(len(qafile["context"])):
            # 得到一个qa数据
            q = {k: v[qid] for k, v in qafile.items()}

            # 解析问句
            qstring = q["context_utterance"]
            entities = question_parser.getNER(q)
            relations = question_parser.getRelations(q)
            types = question_parser.getTypes(q)

            # 得到操作序列
            states = random.randint(1, 18)  # 随机生成操作序列
            seq2seq = Seq2Seq()
            symbolic_seq = seq2seq.simple(qstring, entities, relations, types,
                                          states)

            # 符号执行
            time_start = time.time()
            symbolic_exe = symbolics.Symbolics(symbolic_seq)
            answer = symbolic_exe.executor()

            print("answer is :", answer)
            if (type(answer) == dict):
                for key in answer:
                    print([v for v in answer[key]])

            time_end = time.time()
            print('time cost:', time_end - time_start)
            print(
                "--------------------------------------------------------------------------------"
            )

    print(0)
Example #4
0
def auto_test():

    qa_set = load_qadata("/data/zjy/preprocessed_data_10k/train")

    qa_map = getQA_by_state(qa_set)

    symbolic_seqs = auto_generate()
    a = 0
    for qa in qa_map['Simple Question (Direct)\n']:

        context = qa['context'].replace("\n", "").strip()
        context_utterance = qa['context_utterance'].replace("\n", "")
        context_entities = qa['context_entities'].replace("\n", "").split("|")
        context_relations = qa['context_relations'].replace("\n",
                                                            "").split("|")
        context_types = qa['context_types'].replace("\n", "").split("|")
        context_ints = qa['context_ints'].replace("\n", "")
        # Get reverse relation: has_child and -has_child.
        context_relations.extend(['-' + r for r in context_relations])
        response_entities = qa['response_entities'].replace("\n",
                                                            "").split("|")
        orig_response = qa['orig_response'].replace("\n", "")
        logging.info(str(a) + " " + context_utterance)
        #print context_utterance
        print(a, time.time())
        flag = 0
        a += 1
        for seq in [['A1']]:
            seq_with_param = {i: [] for i in range(len(seq))}
            for i in range(len(seq)):
                symbolic = seq[i]
                if (int(symbolic[1:]) in [1]):
                    for e in context_entities:
                        for r in context_relations:
                            for t in context_types:
                                seq_with_param[i].append({symbolic: (e, r, t)})
                                # print symbolic,e,r,t
                # if (int(symbolic[1:]) in [3]):
                #     for e in context_entities:
                #         seq_with_param[i].append({symbolic: (e, '', '')})
                #         # print symbolic, e

            if (len(seq_with_param) == 1):

                for sym1 in seq_with_param[0]:
                    if flag == 4:
                        break

                    sym_seq = [sym1]
                    symbolic_exe = Symbolics(sym_seq)
                    answer = symbolic_exe.executor()
                    if cal_precesion(orig_response, response_entities, answer):
                        flag += 1
                        logging.info(sym_seq)
def auto_test():

    qa_set = load_qadata("/data/zjy/preprocessed_data_10k/train")

    qa_map = getQA_by_state(qa_set)

    symbolic_seqs = auto_generate()
    a = 0
    for qa in qa_map['Quantitative Reasoning (All)\n']:

        context = qa['context'].replace("\n", "").strip()
        context_utterance = qa['context_utterance'].replace("\n", "")
        if ("atleast" in context_utterance):
            print(context_utterance)
        continue
        context_entities = qa['context_entities'].replace("\n", "").split("|")
        context_relations = qa['context_relations'].replace("\n",
                                                            "").split("|")
        context_types = qa['context_types'].replace("\n", "").split("|")
        context_ints = qa['context_ints'].replace("\n", "")
        context_relations.extend(['-' + r for r in context_relations])
        response_entities = qa['response_entities'].replace("\n",
                                                            "").split("|")
        orig_response = qa['orig_response'].replace("\n", "")
        logging.info(str(a) + " " + context_utterance)
        print(context_utterance)
        print(a, time.time())
        flag = 0
        a += 1
        for seq in symbolic_seqs:
            seq_with_param = {i: [] for i in range(len(seq))}
            for i in range(len(seq)):
                symbolic = seq[i]
                if (int(symbolic[1:]) in [15]):
                    for e in context_entities:
                        for r in context_relations:
                            for t in context_types:
                                seq_with_param[i].append({symbolic: (e, r, t)})
                                # print symbolic,e,r,t
                if (int(symbolic[1:]) in [2, 16]):
                    for et in context_types:
                        for r in context_relations:
                            for t in context_types:
                                seq_with_param[i].append(
                                    {symbolic: (et, r, t)})
                                # print symbolic,e,r,t

            if (len(seq_with_param) == 3):

                for sym1 in seq_with_param[0]:
                    if flag == 4:
                        break
                    for sym2 in seq_with_param[1]:
                        if flag == 4:
                            break
                        for sym3 in seq_with_param[2]:
                            if flag == 4: break
                            sym_seq = [sym1, sym2, sym3]
                            #print(sym_seq, time.time())
                            symbolic_exe = Symbolics(sym_seq)
                            answer = symbolic_exe.executor()
                            # print sym_seq, answer
                            if cal_precesion(orig_response, response_entities,
                                             answer):
                                flag += 1
                                logging.info(sym_seq)
                                print(sym_seq, time.time())
Example #6
0
def auto_test():
    fname = "quantative_auto_symbolic.txt"
    qa_result = open(fname, "a+")
    qa_result.truncate()
    print >> qa_result, "ssss"
    qa_set = load_qadata("/data/zjy/preprocessed_data_10k/train")

    qa_map = getQA_by_state(qa_set)

    symbolic_seqs = auto_generate()
    a = 0
    for qa in qa_map['Logical Reasoning (All)\n']:

        context = qa['context'].replace("\n", "").strip()
        context_utterance = qa['context_utterance'].replace("\n", "")
        context_entities = qa['context_entities'].replace("\n", "").split("|")
        context_relations = qa['context_relations'].replace("\n",
                                                            "").split("|")
        context_types = qa['context_types'].replace("\n", "").split("|")
        context_ints = qa['context_ints'].replace("\n", "")
        context_relations.extend(['-' + r for r in context_relations])
        response_entities = qa['response_entities'].replace("\n",
                                                            "").split("|")
        orig_response = qa['orig_response'].replace("\n", "")
        logging.info(str(a) + " " + context_utterance)

        print(a, time.time())
        flag = 0
        a += 1
        if a < continue_num:
            continue
        for seq in symbolic_seqs:
            seq_with_param = {i: [] for i in range(len(seq))}
            for i in range(len(seq)):
                symbolic = seq[i]
                if (int(symbolic[1:]) in [1, 8, 9, 10]):
                    for e in context_entities:
                        for r in context_relations:
                            for t in context_types:
                                seq_with_param[i].append({symbolic: (e, r, t)})
                                # print symbolic,e,r,t
                # if (int(symbolic[1:]) in [3]):
                #     for e in context_entities:
                #         seq_with_param[i].append({symbolic: (e, '', '')})
                #         # print symbolic, e

            if (len(seq_with_param) == 2):

                for sym1 in seq_with_param[0]:
                    if flag == 4:
                        break
                    for sym2 in seq_with_param[1]:
                        if flag == 4:
                            break
                        sym_seq = [sym1, sym2]
                        symbolic_exe = Symbolics(sym_seq)
                        answer = symbolic_exe.executor()
                        if cal_precesion(orig_response, response_entities,
                                         answer):
                            flag += 1
                            logging.info(sym_seq)
Example #7
0
def auto_test():
    qa_set = load_qadata("/data/zjy/preprocessed_data_10k/train")

    qa_map = getQA_by_state(qa_set)

    symbolic_seqs = auto_generate()
    a = 0
    for qa in qa_map['Quantitative Reasoning (Count) (All)\n']:

        context = qa['context'].replace("\n", "").strip()
        context_utterance = qa['context_utterance'].replace("\n", "")
        context_entities = qa['context_entities'].replace("\n", "").split("|")
        context_relations = qa['context_relations'].replace("\n",
                                                            "").split("|")
        context_types = qa['context_types'].replace("\n", "").split("|")
        context_ints = qa['context_ints'].replace("\n", "")
        context_relations.extend(['-' + r for r in context_relations])
        response_entities = qa['response_entities'].replace("\n",
                                                            "").split("|")
        orig_response = qa['orig_response'].replace("\n", "")
        logging.info(str(a) + " " + context_utterance)
        if "" in context_entities: context_entities.remove("")
        print(a, context_utterance)
        start_time = time.time()
        flag = 0
        a += 1
        if a < continue_num:
            continue
        for seq in symbolic_seqs:
            print(seq)
            seq_with_param = {i: [] for i in range(len(seq))}
            for i in range(len(seq)):
                symbolic = seq[i]
                if (int(symbolic[1:]) in [1, 8, 9, 10]):
                    for e in context_entities:
                        for r in context_relations:
                            for t in context_types:
                                seq_with_param[i].append({symbolic: (e, r, t)})

                if (int(symbolic[1:]) in [2, 16]):
                    for et in context_types:
                        for r in context_relations:
                            for t in context_types:
                                seq_with_param[i].append(
                                    {symbolic: (et, r, t)})
                                # print symbolic,e,r,t
                if (int(symbolic[1:]) in [12, 13, 14, 15]
                        and context_ints != ""):
                    for N in [int(n) for n in context_ints.split()]:
                        seq_with_param[i].append({symbolic: (str(N), '', '')})

                if (int(symbolic[1:]) in [6, 7]):
                    for e in context_entities:
                        seq_with_param[i].append({symbolic: (e, '', '')})

                if (int(symbolic[1:]) in [4, 5, 11]):
                    seq_with_param[i].append({symbolic: ('', '', '')})
                    seq_with_param[i].append({symbolic: ('&', '', '')})
                    seq_with_param[i].append({symbolic: ('-', '', '')})
                    seq_with_param[i].append({symbolic: ('|', '', '')})
            print(time.time() - start_time)
            if (len(seq_with_param) == 3 and seq_with_param[2] != []
                    and "A11" in seq_with_param[2][0]
                    and time.time() - start_time < 120):

                for sym1 in seq_with_param[0]:
                    if flag == 4:
                        break
                    for sym2 in seq_with_param[1]:
                        if flag == 4:
                            break
                        for sym3 in seq_with_param[2]:
                            if flag == 4: break
                            sym_seq = [sym1, sym2, sym3]
                            symbolic_exe = Symbolics(sym_seq)
                            answer = symbolic_exe.executor()
                            answer_e = Symbolics(sym_seq[0:2]).executor()
                            answer_entities = []
                            if ('|' in answer_e):
                                answer_entities = answer_e['|']
                            elif ('&' in answer_e):
                                answer_entities = answer_e['&']
                            elif ('-' in answer_e):
                                answer_entities = answer_e['-']
                            else:
                                answer_entities = answer_e.keys()
                            print(sym_seq, answer, orig_response)
                            # print sorted(answer_entities), sorted(response_entities)
                            if cal_precesion(orig_response, answer_entities,
                                             response_entities, answer):
                                print(sorted(answer_entities),
                                      sorted(response_entities))
                                flag += 1
                                logging.info(sym_seq)
                                print(sym_seq, time.time())
def auto_test():
    qa_set = load_qadata("/data/zjy/preprocessed_data_10k/train")

    qa_map = getQA_by_state(qa_set)

    symbolic_seqs = auto_generate()
    a = 0
    for qa in qa_map['Comparative Reasoning (Count) (All)\n']:

        context = qa['context'].replace("\n", "").strip()
        context_utterance = qa['context_utterance'].replace("\n", "")
        if ("around" not in context_utterance
                and "approximately" not in context_utterance):
            # print(context_utterance)
            continue
        context_entities = qa['context_entities'].replace("\n", "").split("|")
        context_relations = qa['context_relations'].replace("\n",
                                                            "").split("|")
        context_types = qa['context_types'].replace("\n", "").split("|")
        context_ints = qa['context_ints'].replace("\n", "")
        context_relations.extend(['-' + r for r in context_relations])
        response_entities = qa['response_entities'].replace("\n",
                                                            "").split("|")
        orig_response = qa['orig_response'].replace("\n", "")
        logging.info(str(a) + " " + context_utterance)
        if "" in context_entities: context_entities.remove("")
        print(a, context_utterance)
        start_time = time.time()
        flag = 0
        a += 1
        if a < continue_num:
            continue
        for seq in symbolic_seqs:
            print(seq)
            seq_with_param = {i: [] for i in range(len(seq))}
            for i in range(len(seq)):
                symbolic = seq[i]
                if (int(symbolic[1:]) in [1, 8, 9, 10]):
                    for e in context_entities:
                        for r in context_relations:
                            for t in context_types:
                                seq_with_param[i].append({symbolic: (e, r, t)})

                if (int(symbolic[1:]) in [2, 16]):
                    for et in context_types:
                        for r in context_relations:
                            for t in context_types:
                                seq_with_param[i].append(
                                    {symbolic: (et, r, t)})
                                # print symbolic,e,r,t
                # if (int(symbolic[1:]) in [12,13,14,15] and context_ints != ""):
                #     for N in [int(n) for n in context_ints.split()]:
                #         seq_with_param[i].append({symbolic: (str(N), '', '')})

                if (int(symbolic[1:]) in [15]):
                    for e in context_entities:
                        seq_with_param[i].append({symbolic: (e, '', '')})

                if (int(symbolic[1:]) in [4, 5, 11]):
                    seq_with_param[i].append({symbolic: ('', '', '')})
                    seq_with_param[i].append({symbolic: ('&', '', '')})
                    seq_with_param[i].append({symbolic: ('-', '', '')})
                    seq_with_param[i].append({symbolic: ('|', '', '')})

            if (len(seq_with_param) == 3):
                for sym1 in seq_with_param[0]:
                    if flag == 4:
                        break
                    for sym2 in seq_with_param[1]:
                        if flag == 4:
                            break
                        for sym3 in seq_with_param[2]:
                            if flag == 4:
                                break
                            #print (flag)
                            sym_seq = [sym1, sym2, sym3]
                            symbolic_exe = Symbolics(sym_seq)
                            answer = symbolic_exe.executor()

                            #print(sym_seq)
                            Uset = list(
                                set(answer).intersection(
                                    set(response_entities)))
                            #print(Uset[:10])
                            if len(Uset) > len(response_entities) / 2:

                                flag += 1
                                logging.info(sym_seq)
                                print(sym_seq, time.time())