Ejemplo n.º 1
0
def infer(nlu1,
          table_name,
          data_table,
          path_db,
          db_name,
          model,
          model_bert,
          bert_config,
          max_seq_length,
          num_target_layers,
          beam_size=4,
          show_table=False,
          show_answer_only=False):
    # I know it is of against the DRY principle but to minimize the risk of introducing bug w, the infer function introuced.
    model.eval()
    model_bert.eval()
    engine = DBEngine(os.path.join(path_db, f"{db_name}.db"))

    # Get inputs
    nlu = [nlu1]
    # nlu_t1 = tokenize_corenlp(client, nlu1)
    nlu_t1 = tokenize_corenlp_direct_version(client, nlu1)
    nlu_t = [nlu_t1]

    tb1 = data_table[0]
    hds1 = tb1['header']
    tb = [tb1]
    hds = [hds1]
    hs_t = [[]]

    wemb_n, wemb_h, l_n, l_hpu, l_hs, \
    nlu_tt, t_to_tt_idx, tt_to_t_idx \
        = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                        num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)

    prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(
        wemb_n,
        l_n,
        wemb_h,
        l_hpu,
        l_hs,
        engine,
        tb,
        nlu_t,
        nlu_tt,
        tt_to_t_idx,
        nlu,
        beam_size=beam_size)

    # sort and generate
    pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)
    if len(pr_sql_i) != 1:
        raise EnvironmentError
    pr_sql_q1 = generate_sql_q(pr_sql_i, [tb1])
    pr_sql_q = [pr_sql_q1]

    try:
        pr_ans, _ = engine.execute_return_query(tb[0]['id'], pr_sc[0],
                                                pr_sa[0], pr_sql_i[0]['conds'])
    except:
        pr_ans = ['Answer not found.']
        pr_sql_q = ['Answer not found.']

    if show_answer_only:
        print(f'Q: {nlu[0]}')
        print(f'A: {pr_ans[0]}')
        print(f'SQL: {pr_sql_q}')

    else:
        print(
            f'START ============================================================= '
        )
        print(f'{hds}')
        if show_table:
            print(engine.show_table(table_name))
        print(f'nlu: {nlu}')
        print(f'pr_sql_i : {pr_sql_i}')
        print(f'pr_sql_q : {pr_sql_q}')
        print(f'pr_ans: {pr_ans}')
        print(
            f'---------------------------------------------------------------------'
        )

    return pr_sql_i, pr_ans
Ejemplo n.º 2
0
def infernew(dev_loader,
             data_table,
             model,
             model_bert,
             bert_config,
             tokenizer,
             max_seq_length,
             num_target_layers,
             detail=False,
             path_db=None,
             st_pos=0,
             dset_name='train',
             EG=False,
             beam_size=4):
    model.eval()
    model_bert.eval()

    ave_loss = 0
    cnt = 0
    cnt_sc = 0
    cnt_sa = 0
    cnt_wn = 0
    cnt_wc = 0
    cnt_wo = 0
    cnt_wv = 0
    cnt_wvi = 0
    cnt_lx = 0
    cnt_x = 0

    cnt_list = []

    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))
    results = []
    count = 0
    for iB, t in enumerate(dev_loader):

        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(
            t, data_table, no_hs_t=True, no_sql_t=True)

        g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
        g_wvi_corenlp = get_g_wvi_corenlp(t)

        wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        try:
            g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx,
                                                      g_wvi_corenlp)
            g_wv_str, g_wv_str_wp = convert_pr_wvi_to_string(
                g_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)

        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            for b in range(len(nlu)):
                results1 = {}
                results1["error"] = "Skip happened"
                results1["nlu"] = nlu[b]
                results1["table_id"] = tb[b]["id"]
                results.append(results1)
            continue

        knowledge = []
        for k in t:
            if "bertindex_knowledge" in k:
                knowledge.append(k["bertindex_knowledge"])
            else:
                knowledge.append(max(l_n) * [0])

        knowledge_header = []
        for k in t:
            if "header_knowledge" in k:
                knowledge_header.append(k["header_knowledge"])
            else:
                knowledge_header.append(max(l_hs) * [0])

        if not EG:
            # No Execution guided decoding
            s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(
                wemb_n,
                l_n,
                wemb_h,
                l_hpu,
                l_hs,
                knowledge=knowledge,
                knowledge_header=knowledge_header)

            # get loss & step
            loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa,
                              g_wn, g_wc, g_wo, g_wvi)

            # prediction
            pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(
                s_sc,
                s_sa,
                s_wn,
                s_wc,
                s_wo,
                s_wv,
            )
            pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(
                pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)
            # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu)
            pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo,
                                      pr_wv_str, nlu)
        else:
            # Execution guided decoding
            prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(
                wemb_n,
                l_n,
                wemb_h,
                l_hpu,
                l_hs,
                engine,
                tb,
                nlu_t,
                nlu_tt,
                tt_to_t_idx,
                nlu,
                beam_size=beam_size,
                knowledge=knowledge,
                knowledge_header=knowledge_header)

            # sort and generate
            pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)

            # Follosing variables are just for the consistency with no-EG case.
            pr_wvi = None  # not used
            pr_wv_str = None
            pr_wv_str_wp = None
            loss = torch.tensor([0])

        pr_sql_q1 = generate_sql_q(pr_sql_i, tb)
        pr_sql_q = [pr_sql_q1]

        try:
            pr_ans, _ = engine.execute_return_query(tb[0]['id'], pr_sc[0],
                                                    pr_sa[0],
                                                    pr_sql_i[0]['conds'])
        except:
            pr_ans = ['Answer not found.']
            pr_sql_q = ['Answer not found.']

        yes = True
        if yes:
            print(f'Q: {nlu[0]}')
            print(f'A: {pr_ans[0]}')
            print(f'SQL: {pr_sql_q}')

        else:
            print(
                f'START ============================================================= '
            )
            print(f'{hds}')
            if yes:
                print(engine.show_table(table_name))
            print(f'nlu: {nlu}')
            print(f'pr_sql_i : {pr_sql_i}')
            print(f'pr_sql_q : {pr_sql_q}')
            print(f'pr_ans: {pr_ans}')
            print(
                f'---------------------------------------------------------------------'
            )

        return pr_sql_i, pr_ans
Ejemplo n.º 3
0
def infer(nlu1,
          table_name,
          data_table,
          path_db,
          db_name,
          model,
          model_bert,
          bert_config,
          max_seq_length,
          num_target_layers,
          beam_size=4,
          show_table=False,
          show_answer_only=False):
    # I know it is of against the DRY principle but to minimize the risk of introducing bug w, the infer function introuced.
    model.eval()
    model_bert.eval()
    engine = DBEngine(os.path.join(path_db, f"{db_name}.db"))

    # 问题输入
    nlu = [nlu1]  #问题数组
    '''
    ==tokenize_corenlp_direct_version函数作用:就是英文分词(可能按照stanza规则分?)==
    ==client:stanford的corenlp代理类==
    ==nlu1:刚刚定义的问题列表==
    2020/12/02修改:修改infer中文分词问题
    2020/12/11修改:取消使用stanza中文分词,直接完全分词
    '''
    # nlu_t1 = tokenize_corenlp_direct_version(client, nlu1)
    nlu_t1 = list(nlu1)
    nlu_t = [nlu_t1]  # 把分词之后的数据也放到数组里

    #tb1 = data_table[0]
    '''
        2020/12/01修改:tb1根据问题来选择表
        循环查找即可
    '''
    for temple_table in data_table:
        if temple_table['name'] == table_name:
            tb1 = temple_table
            break
    hds1 = tb1['header']
    tb = [tb1]
    hds = [hds1]
    hs_t = [[]]
    # 获取bert-output
    wemb_n, wemb_h, l_n, l_hpu, l_hs, \
    nlu_tt, t_to_tt_idx, tt_to_t_idx \
        = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                        num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
    # 获取sqlova-output
    '''2020/12/11修改:换用model预测'''
    # prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu,
    #                                                                                 l_hs, engine, tb,
    #                                                                                 nlu_t, nlu_tt,
    #                                                                                 tt_to_t_idx, nlu,
    #                                                                                 beam_size=beam_size)
    # 获取sqlova得出的6大重要部分参数 pr_sc pr_sa
    s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu,
                                               l_hs)
    # 将权重参数变成值
    pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(
        s_sc,
        s_sa,
        s_wn,
        s_wc,
        s_wo,
        s_wv,
    )
    # 根据值得出where-value(这里是分词版,一个一个切分)
    pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt,
                                                       tt_to_t_idx, nlu)
    # 最后在将值整合成conds(pre版本)
    '''2020/12/03修改:将agg和sel变为列表形式(generate_sql_i函数内修改)'''
    pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str,
                              nlu)

    # 切分出where-col/where-op/where-val
    pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)
    if len(pr_sql_i) != 1:  # 判断是不是生成了conds
        raise EnvironmentError
    pr_sql_q1 = generate_sql_q(pr_sql_i, [tb1])  # 根据上面的conds生成sql语句
    pr_sql_q = [pr_sql_q1]  # 将生成的sql语句放到list里,因为infer可能有多句
    '''下面执行SQL语句'''
    try:
        pr_ans, _ = engine.execute_return_query(tb[0]['id'], pr_sc[0],
                                                pr_sa[0], pr_sql_i[0]['conds'])
    except:
        pr_ans = ['Answer not found.']
        pr_sql_q = ['Answer not found.']

    if show_answer_only:
        print(f'Q: {nlu[0]}')
        print(f'A: {pr_ans[0]}')
        print(f'SQL: {pr_sql_q}')

    else:
        print(
            f'START ============================================================= '
        )
        print(f'{hds}')
        if show_table:
            print(engine.show_table(table_name))
        print(f'nlu: {nlu}')
        print(f'pr_sql_i : {pr_sql_i}')
        print(f'pr_sql_q : {pr_sql_q}')
        print(f'pr_ans: {pr_ans}')
        print(
            f'---------------------------------------------------------------------'
        )

    return pr_sql_i, pr_ans