示例#1
0
def epoch_acc(batch_size, table_data, sql_data, db_path, cs, out=False):
    with open('../submit/results_val.jsonl') as inf:
        sqlOutdata = []
        for idx, line in enumerate(inf):
            _sql = json.loads(line.strip())
            _sql['sql'] = _sql['query']
            _sql['question'] = _sql['nlu']
            # print(_sql)
            if ('error') in _sql:
                sqlOutdata.append(sqlOutdata[idx - 1])
            else:
                sqlOutdata.append(_sql)
    print(len(sql_data))
    print(len(sqlOutdata))
    engine = DBEngine(db_path)
    perm = list(range(len(sqlOutdata)))
    badcase = 0
    one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0
    for st in tqdm(range(len(sql_data) // batch_size + 1)):
        ed = (st + 1) * batch_size if (st + 1) * batch_size < len(
            perm) else len(perm)
        st = st * batch_size
        q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, gt_type, raw_data = \
            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)

        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        # query_gt: ground truth of sql, data['sql'], containing sel, agg, conds:{sel, op, value}
        raw_q_seq = [x[0] for x in raw_data]  # original question
        # try:

        pred_queriesc, allfsc = genByout(sqlOutdata[st:ed], table_ids,
                                         raw_q_seq, gt_type, 'val', cs)

        one_err, tot_err = check_acc(raw_data, pred_queriesc, query_gt, allfsc)

        # except:
        #     badcase += 1
        #     print 'badcase', badcase
        #     continue
        one_acc_num += (ed - st - one_err)
        tot_acc_num += (ed - st - tot_err)

        # Execution Accuracy
        for sql_gt, sql_pred, tid in zip(query_gt, pred_queriesc, table_ids):
            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                    sql_gt['conds'], sql_gt['cond_conn_op'])
            try:
                ret_pred = engine.execute(tid, sql_pred['sel'],
                                          sql_pred['agg'], sql_pred['conds'],
                                          sql_pred['cond_conn_op'])
            except:
                ret_pred = None
            ex_acc_num += (ret_gt == ret_pred)
    return one_acc_num / len(sql_data), tot_acc_num / len(
        sql_data), ex_acc_num / len(sql_data)
示例#2
0
def predict(data_loader, data_table, model, model_bert, bert_config, tokenizer,
            max_seq_length,
            num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4,
            path_db=None, dset_name='test'):

    model.eval()
    model_bert.eval()

    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))
    results = []
    for iB, t in enumerate(data_loader):
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True)
        g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
        g_wvi_corenlp = get_g_wvi_corenlp(t)
        wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        if not EG:
            # No Execution guided decoding
            s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs)
            pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, )
            pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)
            pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu)
        else:
            # Execution guided decoding
            prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu,
                                                                                            l_hs, engine, tb,
                                                                                            nlu_t, nlu_tt,
                                                                                            tt_to_t_idx, nlu,
                                                                                            beam_size=beam_size)
            # sort and generate
            pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)
            # Following variables are just for consistency with no-EG case.
            pr_wvi = None # not used
            pr_wv_str=None
            pr_wv_str_wp=None

        pr_sql_q = generate_sql_q(pr_sql_i, tb)
        pr_sql_q_base = generate_sql_q_base(pr_sql_i, tb)

        for b, (pr_sql_i1, pr_sql_q1, pr_sql_q1_base) in enumerate(zip(pr_sql_i, pr_sql_q, pr_sql_q_base)):
            results1 = {}
            results1["query"] = pr_sql_i1
            results1["table_id"] = tb[b]["id"]
            results1["nlu"] = nlu[b]
            results1["sql"] = pr_sql_q1
            results1["sql_with_params"] = pr_sql_q1_base
            rr = engine.execute_query(tb[b]["id"], Query.from_dict(pr_sql_i1, ordered=True), lower=False)
            results1["answer"] = rr
            results.append(results1)

    return results
示例#3
0
def predict(data_loader, data_table, model, model_bert, bert_config, tokenizer,
            max_seq_length,
            num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4,
            path_db=None, dset_name='test'):

    model.eval()
    model_bert.eval()

    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))
    results = []
    for iB, t in enumerate(data_loader):
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True,result=True)
      #  g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
       # g_wvi_corenlp = get_g_wvi_corenlp(t)
        wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        if not EG:
            # No Execution guided decoding
            s_sn, s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, s_wco = model(wemb_n, l_n, wemb_h, l_hpu, l_hs)

            # get loss & step


            # prediction
            pr_sn, pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, dwpr_sc, pr_wco = pred_sw_se(s_sn, s_sc, s_sa, s_wn, s_wc,
                                                                                           s_wo, s_wv, s_wco, typ=True)
            pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)
            # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu)
            pr_sql_i = generate_sql_i(pr_sn, pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, pr_wco, nlu)
        else:
            # Execution guided decoding
            # prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu,
            #                                                                                 l_hs, engine, tb,
            #                                                                                 nlu_t, nlu_tt,
            #                                                                                 tt_to_t_idx, nlu,
            #                                                                                 beam_size=beam_size)
            # # sort and generate
            # pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)
            # # Following variables are just for consistency with no-EG case.
            # pr_wvi = None # not used
            # pr_wv_str=None
            # pr_wv_str_wp=None
            return

        pr_sql_q = generate_sql_q(pr_sql_i, tb)

        for b, (pr_sql_i1, pr_sql_q1) in enumerate(zip(pr_sql_i, pr_sql_q)):
            results1 = {}
            pr_sql_i1['sel']=pr_sc[b]
            results1["query"] = pr_sql_i1
            results1["table_id"] = tb[b]["id"]
            results1["nlu"] = nlu[b]
            results1["sql"] = pr_sql_q1
            results.append(results1)

    return results
def test(
    data_loader,
    data_table,
    model,
    model_bert,
    tokenizer,
    sql_vocab,
    max_seq_length,
    detail=False,
    st_pos=0,
    cnt_tot=1,
    EG=False,
    beam_only=True,
    beam_size=4,
    path_db=None,
    dset_name='test',
    col_pool_type='start_tok',
    aug=False,
):
    model.eval()
    model_bert.eval()

    ave_loss = 0
    cnt = 0
    cnt_lx = 0
    cnt_x = 0
    results = []
    cnt_list = []

    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))

    for iB, t in enumerate(tqdm(data_loader)):

        cnt += len(t)
        if cnt < st_pos:
            continue
        # Get fields
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(
            t, data_table, no_hs_t=True, no_sql_t=True)
        g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
        g_wvi_corenlp = get_g_wvi_corenlp(t)

        all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, i_sql_vocab, \
        l_n, l_hpu, l_hs, l_input, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_bert_output_s2s(model_bert, tokenizer, nlu_t, hds, sql_vocab, max_seq_length,sample=False)
        try:
            #
            g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx,
                                                      g_wvi_corenlp)
        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            # e.g. train: 32.
            for b in range(len(nlu)):
                results1 = {}
                results1["error"] = "Skip happened"
                results1["nlu"] = nlu[b]
                results1["table_id"] = tb[b]["id"]
                results.append(results1)
            continue

        # Generate g_pnt_idx
        g_pnt_idxs = gen_g_pnt_idx(g_wvi,
                                   sql_i,
                                   i_hds,
                                   i_sql_vocab,
                                   col_pool_type=col_pool_type)
        pnt_start_tok = i_sql_vocab[0][-2][0]
        pnt_end_tok = i_sql_vocab[0][-1][0]
        # check
        # print(array(tokens[0])[g_pnt_idxs[0]])
        wenc_s2s = all_encoder_layer[-1]

        # wemb_h = [B, max_header_number, hS]
        cls_vec = pooled_output

        if not EG:
            score = model(
                wenc_s2s,
                l_input,
                cls_vec,
                pnt_start_tok,
            )
            loss = Loss_s2s(score, g_pnt_idxs)

            pr_pnt_idxs = pred_pnt_idxs(score, pnt_start_tok, pnt_end_tok)
        else:
            # EG
            pr_pnt_idxs, p_list, pnt_list_beam = model.EG_forward(
                wenc_s2s,
                l_input,
                cls_vec,
                pnt_start_tok,
                pnt_end_tok,
                i_sql_vocab,
                i_nlu,
                i_hds,  # for EG
                tokens,
                nlu,
                nlu_t,
                hds,
                tt_to_t_idx,  # for EG
                tb,
                engine,
                beam_size,
                beam_only=beam_only)
            if beam_only:
                loss = torch.tensor([0])
            else:
                # print('EG on!')
                loss = torch.tensor([1])

        g_i_vg_list, g_i_vg_sub_list = gen_i_vg_from_pnt_idxs(
            g_pnt_idxs, i_sql_vocab, i_nlu, i_hds)
        g_sql_q_s2s, g_sql_i = gen_sql_q_from_i_vg(tokens, nlu, nlu_t, hds,
                                                   tt_to_t_idx, pnt_start_tok,
                                                   pnt_end_tok, g_pnt_idxs,
                                                   g_i_vg_list,
                                                   g_i_vg_sub_list)

        pr_i_vg_list, pr_i_vg_sub_list = gen_i_vg_from_pnt_idxs(
            pr_pnt_idxs, i_sql_vocab, i_nlu, i_hds)
        pr_sql_q_s2s, pr_sql_i = gen_sql_q_from_i_vg(
            tokens, nlu, nlu_t, hds, tt_to_t_idx, pnt_start_tok, pnt_end_tok,
            pr_pnt_idxs, pr_i_vg_list, pr_i_vg_sub_list)

        g_sql_q = generate_sql_q(sql_i, tb)

        try:
            pr_sql_q = generate_sql_q(pr_sql_i, tb)
            # gen pr_sc, pr_sa
            pr_sc = []
            pr_sa = []
            for pr_sql_i1 in pr_sql_i:
                pr_sc.append(pr_sql_i1["sel"])
                pr_sa.append(pr_sql_i1["agg"])
        except:
            bS = len(sql_i)
            pr_sql_q = ['NA'] * bS
            pr_sc = ['NA'] * bS
            pr_sa = ['NA'] * bS

        for b, pr_sql_i1 in enumerate(pr_sql_i):
            results1 = {}
            results1["query"] = pr_sql_i1
            results1["table_id"] = tb[b]["id"]
            results1["nlu"] = nlu[b]
            results.append(results1)

        # Cacluate accuracy
        cnt_lx1_list = get_cnt_lx_list_s2s(g_pnt_idxs, pr_pnt_idxs)

        # if not aug:
        #     cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i)
        # else:
        cnt_x1_list = [0] * len(t)
        g_ans = ['N/A (data augmented'] * len(t)
        pr_ans = ['N/A (data augmented'] * len(t)

        # statistics
        ave_loss += loss.item()

        # count
        cnt_lx += sum(cnt_lx1_list)
        cnt_x += sum(cnt_x1_list)
        # report
        if detail:
            print(f"Ground T  :   {g_pnt_idxs}")
            print(f"Prediction:   {pr_pnt_idxs}")
            print(f"Ground T  :   {g_sql_q}")
            print(f"Prediction:   {pr_sql_q}")

    ave_loss /= cnt

    acc_lx = cnt_lx / cnt
    acc_x = cnt_x / cnt

    acc = [ave_loss, acc_lx, acc_x]
    return acc, results
示例#5
0
def train(train_loader,
          train_table,
          model,
          model_bert,
          opt,
          bert_config,
          tokenizer,
          max_seq_length,
          num_target_layers,
          accumulate_gradients=1,
          check_grad=True,
          st_pos=0,
          opt_bert=None,
          path_db=None,
          dset_name='train',
          mvl=2):  #max value length
    model.train()
    model_bert.train()
    #train table is a dict, key is table id, value is the whole table

    ave_loss = 0
    cnt = 0  # count the # of examples

    cnt_sn = 0  # count select number

    cnt_sc = 0  # count the # of correct predictions of select column
    cnt_sa = 0  # of selectd aggregation
    cnt_wn = 0  # of where number

    cnt_wr = 0

    #where relation number = cnt_wn - 1

    cnt_wc = 0  # of where column
    cnt_wo = 0  # of where operator
    cnt_wv = 0  # of where-value
    cnt_wvi = 0  # of where-value index (on question tokens)
    cnt_lx = 0  # of logical form acc
    cnt_x = 0  # of execution acc

    # Engine for SQL querying.
    engine = DBEngine(os.path.join(path_db, dset_name, f"{dset_name}.db"))

    for iB, t in enumerate(train_loader):  #generate each data batch
        cnt += len(t)

        if cnt < st_pos:
            continue
        # Get fields
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(
            t, train_table, no_hs_t=True, no_sql_t=True, generate_mode=False)
        # nlu  : natural language utterance
        # nlu_t: tokenized nlu
        # sql_i: canonical form of SQL query
        # sql_q: full SQL query text. Not used.
        # sql_t: tokenized SQL query
        # tb   : table
        # hs_t : tokenized headers. Not used.
        '''
        print('nlu: ', nlu)
        print('nlu_t: ', nlu_t)
        print('sql_i: ', sql_i)
        print('sql_q: ', sql_q)
        print('sql_t: ', sql_t)
        #print('tb: ', tb)
        print('hs_t: ', hs_t)
        print('hds: ', hds)
        '''
        try:
            g_sn, g_sc, g_sa, g_wn, g_wr, g_dwn, g_wc, g_wo, g_wv, g_wrcn, wvi_change_index = get_g(
                sql_i)  #get the where values
            '''
            print('g_sn: ', g_sn)
            print('g_sc: ', g_sc)
            print('g_sa: ', g_sa)
            print('g_wn: ', g_wn)
            print('g_wr: ', g_wr)
            print('g_dwn: ', g_dwn)
            print('g_wc: ', g_wc)
            print('g_wo: ', g_wo)
            print('g_wv: ', g_wv)
            print('g_wrcn: ', g_wrcn)
            '''

            #g_sn: (a list of double) number of select column;
            #g_sc: (a list of list) select column names;
            #g_sa: (a list of list) agg for each col;
            #g_wr: (a list of double) if value=0, then there is only one condition, else there are two conditions;
            #g_wc: (a list of list) where col;
            #g_wo: (a list of list) where op;
            #g_wv: (a list of list) where val;
            # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset.
            g_wvi_corenlp = get_g_wvi_corenlp(t, wvi_change_index)
            # this function is to get the indices of where values from the question token

            wemb_n, wemb_h, l_n, l_hpu, l_hs, \
            nlu_tt, t_to_tt_idx, tt_to_t_idx, wemb_v, l_npu, l_token \
                = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers, num_out_layers_v=num_target_layers)
            '''
            print('wemb_n: ', torch.tensor(wemb_n).size())
            print('wemb_h: ', torch.tensor(wemb_h).size())
            '''
            #print('l_n: ', l_n[0])
            #print('l_hpu: ', l_hpu)
            #print('l_hs: ', l_hs)
            #print('nlu_tt: ', nlu_tt[0])

            #print('t_to_tt_idx: ', t_to_tt_idx)
            #print('tt_to_t_idx: ', tt_to_t_idx)
            #print('g_wvi_corenlp', g_wvi_corenlp)

            # wemb_n: natural language embedding
            # wemb_h: header embedding
            # l_n: token lengths of each question
            # l_hpu: header token lengths
            # l_hs: the number of columns (headers) of the tables.

            #
            g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(
                t_to_tt_idx, g_wvi_corenlp
            )  #if not exist, it will not train not include the length, so the end value is the start index of this word, not the end index of this word, so it need to add sth
            g_wvi = g_wvi_corenlp
            if g_wvi:
                for L in g_wvi:
                    for e in L:
                        if e[1] - e[0] + 1 > mvl:
                            cnt -= len(t)
                            print('error: ', e)
                            raise RuntimeError(
                                'invalid training set'
                            )  #only train length no larger than 8 of where value
            g_wvi = get_g_wvi_stidx_length_jian_yi(
                g_wvi)  #不能sort,sort会导致两者对应不上
            #print('g_wvi', g_wvi[0][0])
        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            # e.g. train: 32.
            continue
        # score
        s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4 = model(
            mvl,
            wemb_n,
            l_n,
            wemb_h,
            l_hpu,
            l_hs,
            wemb_v,
            l_npu,
            l_token,
            g_sn=g_sn,
            g_sc=g_sc,
            g_sa=g_sa,
            g_wn=g_wn,
            g_dwn=g_dwn,
            g_wr=g_wr,
            g_wc=g_wc,
            g_wo=g_wo,
            g_wvi=g_wvi,
            g_wrcn=g_wrcn)

        #print('g_wvi: ', g_wvi[0])
        '''
        print('s_sn: ', s_sn)
        print('s_sc: ', s_sc)
        print('s_sa: ', s_sa)
        print('s_wn: ', s_wn)
        print('s_wr: ', s_wr)
        print('s_hrpc: ', s_hrpc)
        print('s_wrpc', s_wrpc)
        print('s_nrpc: ', s_nrpc)
        print('s_wc: ', s_wc)
        print('s_wo: ', s_wo)
        print('s_wv1: ', s_wv1)
        print('s_wv2: ', s_wv2)
        '''

        # Calculate loss & step
        loss = Loss_sw_se(s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo,
                          s_wv1, s_wv2, s_wv3, s_wv4, g_sn, g_sc, g_sa, g_wn,
                          g_dwn, g_wr, g_wc, g_wo, g_wvi, g_wrcn, mvl)
        '''
        print('ave_loss', ave_loss)
        print('loss: ', loss.item())
        print('cnt: ', cnt)
        '''
        # Calculate gradient
        if iB % accumulate_gradients == 0:  # mode
            # at start, perform zero_grad
            opt.zero_grad()
            if opt_bert:
                opt_bert.zero_grad()
            loss.backward()
            if accumulate_gradients == 1:
                opt.step()
                if opt_bert:
                    opt_bert.step()
        elif iB % accumulate_gradients == (accumulate_gradients - 1):
            # at the final, take step with accumulated graident
            loss.backward()
            opt.step()
            if opt_bert:
                opt_bert.step()
        else:
            # at intermediate stage, just accumulates the gradients
            loss.backward()

        #print('grad finish')

        # Prediction
        #print('s_wc: ', s_wc.size())
        pr_sn, pr_sc, pr_sa, pr_wn, pr_wr, pr_hrpc, pr_wc, pr_wo, pr_wvi = pred_sw_se(
            s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2,
            s_wv3, s_wv4, mvl)
        '''
        print('pr_sn: ', pr_sn)
        print('pr_sc: ', pr_sc)
        print('pr_sa: ', pr_sa)
        print('pr_wn: ', pr_wn)
        print('pr_wr: ', pr_wr)
        print('pr_hrpc: ', pr_hrpc)
        print('pr_wrpc', pr_wrpc)
        print('pr_nrpc: ', pr_nrpc)
        print('pr_wc: ', pr_wc)
        print('pr_wo: ', pr_wo)
        print('pr_wvi: ', pr_wvi)
        '''
        pr_wvi_decode = g_wvi_decoder_stidx_length_jian_yi(pr_wvi)
        #print('pr_wvi_decode: ', pr_wvi_decode)
        pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(
            pr_wvi_decode, nlu_t, nlu_tt, tt_to_t_idx)
        #print('pr_wv_str: ', pr_wv_str)
        #print('pr_wv_str_wp: ', pr_wv_str_wp)
        # Sort pr_wc:
        #   Sort pr_wc when training the model as pr_wo and pr_wvi are predicted using ground-truth where-column (g_wc)
        #   In case of 'dev' or 'test', it is not necessary as the ground-truth is not used during inference.
        pr_sc_sorted = sort_pr_wc(pr_sc, g_sc)
        pr_wc_sorted = sort_pr_wc(pr_wc, g_wc)
        #print('pr_wc: ', pr_wc)
        #print('g_wc: ', g_wc)
        pr_sql_i = generate_sql_i(pr_sc_sorted, pr_sa, pr_wn, pr_wr,
                                  pr_wc_sorted, pr_wo, pr_wv_str, nlu)

        #print('pr_sql_i: ', pr_sql_i)

        # Cacluate accuracy
        cnt_sn1_list, cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \
        cnt_wr1_list, cnt_wc1_list, cnt_wo1_list, \
        cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sn, g_sc, g_sa, g_wn, g_wr, g_wc, g_wo, g_wvi,
                                                                   pr_sn, pr_sc, pr_sa, pr_wn, pr_wr, pr_wc, pr_wo, pr_wvi,
                                                                   sql_i, pr_sql_i,
                                                                   mode='train')
        '''
        print('cnt_sn1_list: ', cnt_sn1_list)
        print('cnt_sc1_list: ', cnt_sc1_list)
        print('cnt_sa1_list: ', cnt_sa1_list)
        print('cnt_wn1_list: ', cnt_wn1_list)
        print('cnt_wr1_list: ', cnt_wr1_list)
        print('cnt_wc1_list: ', cnt_wc1_list)
        print('cnt_wo1_list', cnt_wo1_list)
        print('cnt_wvi1_list: ', cnt_wvi1_list)
        print('cnt_wv1_list: ', cnt_wv1_list)
        '''

        cnt_lx1_list = get_cnt_lx_list(cnt_sn1_list, cnt_sc1_list,
                                       cnt_sa1_list, cnt_wn1_list,
                                       cnt_wr1_list, cnt_wc1_list,
                                       cnt_wo1_list, cnt_wv1_list)
        # lx stands for logical form accuracy

        # Execution accuracy test.
        cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa,
                                                    sql_i, pr_sc, pr_sa,
                                                    pr_sql_i)

        # statistics
        ave_loss += loss.item()
        '''
        print('cnt_lx1_list: ', cnt_lx1_list)
        print('cnt_x1_list: ', cnt_x1_list)
        print('g_ans: ', g_ans)
        print('pr_ans: ', pr_ans)
        print('ave_loss: ', ave_loss)
        '''

        # count
        cnt_sn += sum(cnt_sn1_list)
        cnt_sc += sum(cnt_sc1_list)
        cnt_sa += sum(cnt_sa1_list)
        cnt_wn += sum(cnt_wn1_list)
        cnt_wr += sum(cnt_wr1_list)
        cnt_wc += sum(cnt_wc1_list)
        cnt_wo += sum(cnt_wo1_list)
        cnt_wvi += sum(cnt_wvi1_list)
        cnt_wv += sum(cnt_wv1_list)
        cnt_lx += sum(cnt_lx1_list)
        cnt_x += sum(cnt_x1_list)
        if iB % 200 == 0:
            logger.info(
                '%d - th data batch -> loss: %.4f; acc_sn: %.4f; acc_sc: %.4f; acc_sa: %.4f; acc_wn: %.4f; acc_wr: %.4f; acc_wc: %.4f; acc_wo: %.4f; acc_wvi: %.4f; acc_wv: %.4f; acc_lx: %.4f; acc_x %.4f;'
                % (iB, ave_loss / cnt, cnt_sn / cnt, cnt_sc / cnt, cnt_sa /
                   cnt, cnt_wn / cnt, cnt_wr / cnt, cnt_wc / cnt, cnt_wo / cnt,
                   cnt_wvi / cnt, cnt_wv / cnt, cnt_lx / cnt, cnt_x / cnt))
            #print('train: [ ', iB, '- th data batch -> loss:', ave_loss / cnt, '; acc_sn: ', cnt_sn / cnt, '; acc_sc: ', cnt_sc / cnt, '; acc_sa: ', cnt_sa / cnt, '; acc_wn: ', cnt_wn / cnt, '; acc_wr: ', cnt_wr / cnt, '; acc_wc: ', cnt_wc / cnt, '; acc_wo: ', cnt_wo / cnt, '; acc_wvi: ', cnt_wvi / cnt, '; acc_wv: ', cnt_wv / cnt, '; acc_lx: ', cnt_lx / cnt, '; acc_x: ', cnt_x / cnt, ' ]')

    ave_loss = ave_loss / cnt
    acc_sn = cnt_sn / cnt
    acc_sc = cnt_sc / cnt
    acc_sa = cnt_sa / cnt
    acc_wn = cnt_wn / cnt
    acc_wr = cnt_wr / cnt
    acc_wc = cnt_wc / cnt
    acc_wo = cnt_wo / cnt
    acc_wvi = cnt_wvi / cnt
    acc_wv = cnt_wv / cnt
    acc_lx = cnt_lx / cnt
    acc_x = cnt_x / cnt

    acc = [
        ave_loss, acc_sn, acc_sc, acc_sa, acc_wn, acc_wr, acc_wc, acc_wo,
        acc_wvi, acc_wv, acc_lx, acc_x
    ]

    aux_out = 1

    return acc, aux_out
示例#6
0
def test(data_loader,
         data_table,
         model,
         model_bert,
         bert_config,
         tokenizer,
         max_seq_length,
         num_target_layers,
         detail=False,
         st_pos=0,
         cnt_tot=1,
         EG=False,
         beam_size=4,
         path_db=None,
         dset_name='test'):
    model.eval()
    model_bert.eval()

    print('g_scn/wc/wn/wo dev/test不监督')
    cnt = 0

    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))
    results = []
    total = 0
    one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0
    for iB, t in enumerate(data_loader):

        cnt += len(t)
        if cnt < st_pos:
            continue
        # Get fields
        nlu, nlu_t, sql_i, sql_t, tb, hs_t, hds = get_fields(t,
                                                             data_table,
                                                             no_hs_t=True,
                                                             no_sql_t=True)

        g_sc, g_sa, g_sop, g_wn, g_wc, g_wo, g_wv, g_sel_num_seq, g_sel_ag_seq, conds = get_g(
            sql_i)

        g_wvi_corenlp = get_g_wvi_corenlp(t)

        wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)

        try:
            g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx,
                                                      g_wvi_corenlp)
        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            for b in range(len(nlu)):
                results1 = {}
                results1["error"] = "Skip happened"
                results1["nlu"] = nlu[b]
                results1["table_id"] = tb[b]["id"]
                results.append(results1)
            continue

        # model specific part
        # score
        if not EG:
            # No Execution guided decoding
            s_scn, s_sc, s_sa, s_sop, s_wn, s_wc, s_wo, s_wv = model(
                wemb_n, l_n, wemb_h, l_hpu, l_hs)

            # prediction
            score = []
            score.append(s_scn)
            score.append(s_sc)
            score.append(s_sa)
            score.append(s_sop)
            tuple(score)
            pr_sql_i1 = model.gen_query(score, nlu_tt, nlu)

            pr_wn, pr_wc, pr_sop, pr_wo, pr_wvi = pred_sw_se(
                s_sop, s_wn, s_wc, s_wo, s_wv)
            pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(
                pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)  # 映射到字符串

            pr_wc_sorted = sort_pr_wc(pr_wc, g_wc)
            pr_sql_i = generate_sql_i(pr_sql_i1, pr_wn, pr_wc_sorted, pr_wo,
                                      pr_wv_str, nlu)

        else:
            # Execution guided decoding
            prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(
                wemb_n,
                l_n,
                wemb_h,
                l_hpu,
                l_hs,
                engine,
                tb,
                nlu_t,
                nlu_tt,
                tt_to_t_idx,
                nlu,
                beam_size=beam_size)
            # sort and generate
            pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)

            # Follosing variables are just for the consistency with no-EG case.

        # # Saving for the official evaluation later.
        for b, pr_sql_i1 in enumerate(pr_sql_i):
            results1 = {}
            results1["query"] = pr_sql_i1
            results1["table_id"] = tb[b]["id"]
            results1["nlu"] = nlu[b]
            results.append(results1)

        one_err, tot_err = model.check_acc(nlu, pr_sql_i, sql_i)
        one_acc_num += (len(pr_sql_i) - one_err)
        tot_acc_num += (len(pr_sql_i) - tot_err)
        total += len(pr_sql_i)

        # Execution Accuracy
        table_ids = []
        for x in range(len(tb)):
            table_ids.append(tb[x]['id'])

        for sql_gt, sql_pred, tid in zip(sql_i, pr_sql_i, table_ids):
            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                    sql_gt['conds'], sql_gt['cond_conn_op'])
            try:
                ret_pred = engine.execute(tid, sql_pred['sel'],
                                          sql_pred['agg'], sql_pred['conds'],
                                          sql_pred['cond_conn_op'])
            except:
                ret_pred = None
            ex_acc_num += (ret_gt == ret_pred)

    return ((one_acc_num / total), (tot_acc_num / total),
            ex_acc_num / total), results
def train(train_loader,
          train_table,
          model,
          model_bert,
          opt,
          scheduler,
          tokenizer,
          sql_vocab,
          max_seq_length,
          accumulate_gradients=1,
          check_grad=False,
          st_pos=0,
          opt_bert=None,
          bert_scheduler=None,
          path_db=None,
          dset_name='train',
          col_pool_type='start_tok',
          aug=False):
    model.train()
    model_bert.train()
    model_old = deepcopy(model_bert)

    ave_loss = 0
    cnt = 0  # count the # of examples
    cnt_x = 0
    cnt_lx = 0  # of logical form acc

    # Engine for SQL querying.
    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))

    for iB, t in enumerate(tqdm(train_loader)):
        cnt += len(t)
        opt_bert.zero_grad()
        opt.zero_grad()
        if cnt < st_pos:
            continue
        # Get fields
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(
            t, train_table, no_hs_t=True, no_sql_t=True)
        g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
        # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset.
        g_wvi_corenlp = get_g_wvi_corenlp(t)

        # g_wvi_corenlp = get_g_wvi_corenlp(t)
        all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, i_sql_vocab, \
        l_n, l_hpu, l_hs, l_input, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_bert_output_s2s(model_bert, tokenizer, nlu_t, hds, sql_vocab, max_seq_length,sample=True)

        try:
            #
            g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx,
                                                      g_wvi_corenlp)
        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            # e.g. train: 32.
            continue

        # Generate g_pnt_idx
        g_pnt_idxs = gen_g_pnt_idx(g_wvi,
                                   sql_i,
                                   i_hds,
                                   i_sql_vocab,
                                   col_pool_type=col_pool_type)
        pnt_start_tok = i_sql_vocab[0][-2][0]
        pnt_end_tok = i_sql_vocab[0][-1][0]
        # check
        # print(array(tokens[0])[g_pnt_idxs[0]])
        wenc_s2s = all_encoder_layer[-1]

        # wemb_h = [B, max_header_number, hS]
        cls_vec = pooled_output

        score = model(wenc_s2s,
                      l_input,
                      cls_vec,
                      pnt_start_tok,
                      g_pnt_idxs=g_pnt_idxs)

        # Calculate loss & step
        loss = Loss_s2s(score, g_pnt_idxs)

        ## calculate UR
        loss = custom_regularization(model_old, model_bert, args.bS, loss)

        loss.backward()

        opt_bert.step()
        opt.step()
        bert_scheduler.step()
        scheduler.step()

        if check_grad:
            named_parameters = model.named_parameters()

            mu_list, sig_list = get_mean_grad(named_parameters)

            grad_abs_mean_mean = mean(mu_list)
            grad_abs_mean_sig = std(mu_list)
            grad_abs_sig_mean = mean(sig_list)
        else:
            grad_abs_mean_mean = 1
            grad_abs_mean_sig = 1
            grad_abs_sig_mean = 1

        # Prediction
        pr_pnt_idxs = pred_pnt_idxs(score, pnt_start_tok, pnt_end_tok)
        # generate pr_sql_q
        # pr_sql_q_rough = generate_sql_q_s2s(pr_pnt_idxs, tokens, tb)
        # g_sql_q_rough = generate_sql_q_s2s(g_pnt_idxs, tokens, tb)

        g_i_vg_list, g_i_vg_sub_list = gen_i_vg_from_pnt_idxs(
            g_pnt_idxs, i_sql_vocab, i_nlu, i_hds)

        g_sql_q_s2s, g_sql_i = gen_sql_q_from_i_vg(tokens, nlu, nlu_t, hds,
                                                   tt_to_t_idx, pnt_start_tok,
                                                   pnt_end_tok, g_pnt_idxs,
                                                   g_i_vg_list,
                                                   g_i_vg_sub_list)

        pr_i_vg_list, pr_i_vg_sub_list = gen_i_vg_from_pnt_idxs(
            pr_pnt_idxs, i_sql_vocab, i_nlu, i_hds)

        pr_sql_q_s2s, pr_sql_i = gen_sql_q_from_i_vg(
            tokens, nlu, nlu_t, hds, tt_to_t_idx, pnt_start_tok, pnt_end_tok,
            pr_pnt_idxs, pr_i_vg_list, pr_i_vg_sub_list)

        g_sql_q = generate_sql_q(sql_i, tb)

        try:
            pr_sql_q = generate_sql_q(pr_sql_i, tb)
            # gen pr_sc, pr_sa
            pr_sc = []
            pr_sa = []
            for pr_sql_i1 in pr_sql_i:
                pr_sc.append(pr_sql_i1["sel"])
                pr_sa.append(pr_sql_i1["agg"])
        except:
            bS = len(sql_i)
            pr_sql_q = ['NA'] * bS
            pr_sc = ['NA'] * bS
            pr_sa = ['NA'] * bS

        # Cacluate accuracy
        cnt_lx1_list = get_cnt_lx_list_s2s(g_pnt_idxs, pr_pnt_idxs)

        # if not aug:
        #     cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i)
        # else:
        cnt_x1_list = [0] * len(t)
        g_ans = ['N/A (data augmented'] * len(t)
        pr_ans = ['N/A (data augmented'] * len(t)

        # statistics
        ave_loss += loss.item()

        # count
        cnt_lx += sum(cnt_lx1_list)
        cnt_x += sum(cnt_x1_list)

    ave_loss /= cnt
    acc_lx = cnt_lx / cnt
    acc_x = cnt_x / cnt

    acc = [ave_loss, acc_lx, acc_x]
    aux_out = [grad_abs_mean_mean, grad_abs_mean_sig, grad_abs_sig_mean]

    return acc, aux_out
示例#8
0
def run_split(split):
    engine = DBEngine(os.path.join(path_db, f"{split}.db"))
     with open(split + '_tok.jsonl') as f:
        for idx, line in enumerate(f):
            t1 = json.loads(line.strip())
示例#9
0
文件: train.py 项目: Mars-Wei/MISP
def train(train_loader,
          train_table,
          model,
          model_bert,
          opt,
          bert_config,
          tokenizer,
          max_seq_length,
          num_target_layers,
          accumulate_gradients=1,
          check_grad=True,
          st_pos=0,
          opt_bert=None,
          path_db=None,
          dset_name='train'):
    model.train()
    model_bert.train()

    ave_loss = 0
    cnt = 0  # count the # of examples
    cnt_sc = 0  # count the # of correct predictions of select column
    cnt_sa = 0  # of selectd aggregation
    cnt_wn = 0  # of where number
    cnt_wc = 0  # of where column
    cnt_wo = 0  # of where operator
    cnt_wv = 0  # of where-value
    cnt_wvi = 0  # of where-value index (on question tokens)
    cnt_lx = 0  # of logical form acc
    cnt_x = 0  # of execution acc

    # Engine for SQL querying.
    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))

    for iB, t in enumerate(train_loader):
        cnt += len(t)

        if cnt < st_pos:
            continue
        # Get fields
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(
            t, train_table, no_hs_t=True, no_sql_t=True)
        # nlu  : natural language utterance
        # nlu_t: tokenized nlu
        # sql_i: canonical form of SQL query
        # sql_q: full SQL query text. Not used.
        # sql_t: tokenized SQL query
        # tb   : table
        # hs_t : tokenized headers. Not used.

        g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
        # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset.
        g_wvi_corenlp = get_g_wvi_corenlp(t)

        wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)

        # wemb_n: natural language embedding
        # wemb_h: header embedding
        # l_n: token lengths of each question
        # l_hpu: header token lengths
        # l_hs: the number of columns (headers) of the tables.
        try:
            #
            g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx,
                                                      g_wvi_corenlp)
        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            # e.g. train: 32.
            continue

        # score
        s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n,
                                                   l_n,
                                                   wemb_h,
                                                   l_hpu,
                                                   l_hs,
                                                   g_sc=g_sc,
                                                   g_sa=g_sa,
                                                   g_wn=g_wn,
                                                   g_wc=g_wc,
                                                   g_wvi=g_wvi)

        # Calculate loss & step
        loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa, g_wn,
                          g_wc, g_wo, g_wvi)

        # Calculate gradient
        if iB % accumulate_gradients == 0:  # mode
            # at start, perform zero_grad
            opt.zero_grad()
            if opt_bert:
                opt_bert.zero_grad()
            loss.backward()
            if accumulate_gradients == 1:
                opt.step()
                if opt_bert:
                    opt_bert.step()
        elif iB % accumulate_gradients == (accumulate_gradients - 1):
            # at the final, take step with accumulated graident
            loss.backward()
            opt.step()
            if opt_bert:
                opt_bert.step()
        else:
            # at intermediate stage, just accumulates the gradients
            loss.backward()

        # Prediction
        pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(
            s_sc,
            s_sa,
            s_wn,
            s_wc,
            s_wo,
            s_wv,
        )
        pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(
            pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)

        # Sort pr_wc:
        #   Sort pr_wc when training the model as pr_wo and pr_wvi are predicted using ground-truth where-column (g_wc)
        #   In case of 'dev' or 'test', it is not necessary as the ground-truth is not used during inference.
        pr_wc_sorted = sort_pr_wc(pr_wc, g_wc)
        pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc_sorted, pr_wo,
                                  pr_wv_str, nlu)

        # Cacluate accuracy
        cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \
        cnt_wc1_list, cnt_wo1_list, \
        cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi,
                                                                   pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi,
                                                                   sql_i, pr_sql_i,
                                                                   mode='train')

        cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list,
                                       cnt_wn1_list, cnt_wc1_list,
                                       cnt_wo1_list, cnt_wv1_list)
        # lx stands for logical form accuracy

        # Execution accuracy test.
        cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa,
                                                    sql_i, pr_sc, pr_sa,
                                                    pr_sql_i)

        # statistics
        ave_loss += loss.item()

        # count
        cnt_sc += sum(cnt_sc1_list)
        cnt_sa += sum(cnt_sa1_list)
        cnt_wn += sum(cnt_wn1_list)
        cnt_wc += sum(cnt_wc1_list)
        cnt_wo += sum(cnt_wo1_list)
        cnt_wvi += sum(cnt_wvi1_list)
        cnt_wv += sum(cnt_wv1_list)
        cnt_lx += sum(cnt_lx1_list)
        cnt_x += sum(cnt_x1_list)

    ave_loss /= cnt
    acc_sc = cnt_sc / cnt
    acc_sa = cnt_sa / cnt
    acc_wn = cnt_wn / cnt
    acc_wc = cnt_wc / cnt
    acc_wo = cnt_wo / cnt
    acc_wvi = cnt_wv / cnt
    acc_wv = cnt_wv / cnt
    acc_lx = cnt_lx / cnt
    acc_x = cnt_x / cnt

    acc = [
        ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv,
        acc_lx, acc_x
    ]

    aux_out = 1

    return acc, aux_out
示例#10
0
def train(train_loader,
          train_table,
          model,
          model_bert,
          opt,
          bert_config,
          tokenizer,
          max_seq_length,
          num_target_layers,
          accumulate_gradients=1,
          check_grad=True,
          st_pos=0,
          opt_bert=None,
          path_db=None,
          dset_name='train'):
    model.train()
    model_bert.train()

    ave_loss = 0
    cnt = 0  # count the # of examples
    cnt_sc = 0  # count the # of correct predictions of select column
    cnt_sa = 0  # of selectd aggregation
    cnt_wn = 0  # of where number
    cnt_wc = 0  # of where column
    cnt_wo = 0  # of where operator
    cnt_wv = 0  # of where-value
    cnt_wvi = 0  # of where-value index (on question tokens)
    cnt_lx = 0  # of logical form acc
    cnt_x = 0  # of execution acc

    # Engine for SQL querying.
    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))

    for iB, t in enumerate(train_loader):

        # t is the whole line from *_tok.jsonl

        cnt += len(t)

        if cnt < st_pos:
            continue
        # Get fields
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(
            t, train_table, no_hs_t=True, no_sql_t=True)
        # nlu  : natural language utterance. A whole sentence.
        # nlu_t: tokenized nlu. token word of sentence, not numbers.
        # sql_i: canonical form of SQL query. I saw it equal to sql_q
        # sql_q: full SQL query text. Not used.
        # sql_t: tokenized SQL query. Now is none.
        # tb   : table with content
        # hs_t : tokenized headers. Not used.
        # hds  : head of table

        g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
        # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset.

        # g_sc select column; g_sa agg type; g_wn the amount of where condition;
        # g_wc where column; g_wo where operator; g_wv where value (put after operator)

        g_wvi_corenlp = get_g_wvi_corenlp(t)
        # the index of where value in NL. The index is a pair that show start index and stop index.

        wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        # wemb_n: natural language embedding. [batch, seq_len, hidden * num_target_layers]
        # wemb_h: header embedding. [table_header_amount_in_batch, max_length_header_token,  hidden * num_target_layers]
        #         The first and third dimension of wemb_h always have valid data. The second dimension may be not.
        #         Invalid data will fill 0 in it. Use l_hpu for finding the valid data.
        # l_n: token lengths of each question
        # l_hpu: header token lengths. This is a one dimension list contain several table header.
        # l_hs: the number of columns (headers) of the tables. Can be used for split the first dimension of wemb_h.
        # You can check encode_hpu in utils_wikisql.py for the reason of wemb_h and l_hpu. This design is good!

        try:
            #
            g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx,
                                                      g_wvi_corenlp)
        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            # e.g. {"question":"What is the highest number of not usable satellites when there are more than 0 launch failures, less than 30 retired, and the block is I?",
            #       "question_tok":["What","is","the","highest","number","of","not","usable","satellites","when","there","are","more","than","0","launch","failures",",","less","than","30","retired",",","and","the","block","is","I","?"],
            #       "sql":{"sel":3,"conds":[[5,1,0],[4,2,30],[0,0,"block i"]],"agg":1},
            #       "query":{"sel":3,"conds":[[5,1,0],[4,2,30],[0,0,"block i"]],"agg":1},
            #       "wvi_corenlp":null,
            #       "tok_error":"SQuAD style st, ed are not found under CoreNLP."}
            # In example, the condition is 'block i' but NL is 'block is i'! So the "wvi_corenlp":null.
            # If there is no condition, "wvi_corenlp" will be [] and will not cause except.
            # If there is except, you will loss one batch training data.
            continue

        # score
        # s_sc: [batch, max_table_header_number]
        # s_sa: [batch, n_agg_ops]
        # s_wn: [batch, max_where_condition_in_wikisql + 1] +1 for when no conditon. So, it can become a classification problem.
        # s_wc: [batch, max_table_header_number]
        # s_wo: [batch, max_where_condition_in_wikisql, n_cond_ops]
        # s_wv: [batch, max_where_condition_in_wikisql, max_NL_Len_in_batch, 2]
        s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n,
                                                   l_n,
                                                   wemb_h,
                                                   l_hpu,
                                                   l_hs,
                                                   g_sc=g_sc,
                                                   g_sa=g_sa,
                                                   g_wn=g_wn,
                                                   g_wc=g_wc,
                                                   g_wvi=g_wvi)

        # Calculate loss & step
        loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa, g_wn,
                          g_wc, g_wo, g_wvi)

        # Calculate gradient
        if iB % accumulate_gradients == 0:  # mode
            # at start, perform zero_grad
            opt.zero_grad()
            if opt_bert:
                opt_bert.zero_grad()
            loss.backward()
            if accumulate_gradients == 1:
                opt.step()
                if opt_bert:
                    opt_bert.step()
        elif iB % accumulate_gradients == (accumulate_gradients - 1):
            # at the final, take step with accumulated graident
            loss.backward()
            opt.step()
            if opt_bert:
                opt_bert.step()
        else:
            # at intermediate stage, just accumulates the gradients
            loss.backward()

        # Prediction
        pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(
            s_sc,
            s_sa,
            s_wn,
            s_wc,
            s_wo,
            s_wv,
        )
        pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(
            pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)
        pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str,
                                  nlu)

        # Cacluate accuracy
        cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \
        cnt_wc1_list, cnt_wo1_list, \
        cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi,
                                                      pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi,
                                                      sql_i, pr_sql_i,
                                                      mode='train')

        cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list,
                                       cnt_wn1_list, cnt_wc1_list,
                                       cnt_wo1_list, cnt_wv1_list)
        # lx stands for logical form accuracy

        # Execution accuracy test.
        cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa,
                                                    sql_i, pr_sc, pr_sa,
                                                    pr_sql_i)

        # statistics
        ave_loss += loss.item()

        # count
        cnt_sc += sum(cnt_sc1_list)
        cnt_sa += sum(cnt_sa1_list)
        cnt_wn += sum(cnt_wn1_list)
        cnt_wc += sum(cnt_wc1_list)
        cnt_wo += sum(cnt_wo1_list)
        cnt_wvi += sum(cnt_wvi1_list)
        cnt_wv += sum(cnt_wv1_list)
        cnt_lx += sum(cnt_lx1_list)
        cnt_x += sum(cnt_x1_list)

    ave_loss /= cnt
    acc_sc = cnt_sc / cnt
    acc_sa = cnt_sa / cnt
    acc_wn = cnt_wn / cnt
    acc_wc = cnt_wc / cnt
    acc_wo = cnt_wo / cnt
    acc_wvi = cnt_wv / cnt
    acc_wv = cnt_wv / cnt
    acc_lx = cnt_lx / cnt
    acc_x = cnt_x / cnt

    acc = [
        ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv,
        acc_lx, acc_x
    ]

    aux_out = 1

    return acc, aux_out
示例#11
0
def infernew(dev_loader,
             data_table,
             model,
             model_bert,
             bert_config,
             tokenizer,
             max_seq_length,
             num_target_layers,
             detail=False,
             path_db=None,
             st_pos=0,
             dset_name='train',
             EG=False,
             beam_size=4):
    model.eval()
    model_bert.eval()

    ave_loss = 0
    cnt = 0
    cnt_sc = 0
    cnt_sa = 0
    cnt_wn = 0
    cnt_wc = 0
    cnt_wo = 0
    cnt_wv = 0
    cnt_wvi = 0
    cnt_lx = 0
    cnt_x = 0

    cnt_list = []

    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))
    results = []
    count = 0
    for iB, t in enumerate(dev_loader):

        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(
            t, data_table, no_hs_t=True, no_sql_t=True)

        g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
        g_wvi_corenlp = get_g_wvi_corenlp(t)

        wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        try:
            g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx,
                                                      g_wvi_corenlp)
            g_wv_str, g_wv_str_wp = convert_pr_wvi_to_string(
                g_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)

        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            for b in range(len(nlu)):
                results1 = {}
                results1["error"] = "Skip happened"
                results1["nlu"] = nlu[b]
                results1["table_id"] = tb[b]["id"]
                results.append(results1)
            continue

        knowledge = []
        for k in t:
            if "bertindex_knowledge" in k:
                knowledge.append(k["bertindex_knowledge"])
            else:
                knowledge.append(max(l_n) * [0])

        knowledge_header = []
        for k in t:
            if "header_knowledge" in k:
                knowledge_header.append(k["header_knowledge"])
            else:
                knowledge_header.append(max(l_hs) * [0])

        if not EG:
            # No Execution guided decoding
            s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(
                wemb_n,
                l_n,
                wemb_h,
                l_hpu,
                l_hs,
                knowledge=knowledge,
                knowledge_header=knowledge_header)

            # get loss & step
            loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa,
                              g_wn, g_wc, g_wo, g_wvi)

            # prediction
            pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(
                s_sc,
                s_sa,
                s_wn,
                s_wc,
                s_wo,
                s_wv,
            )
            pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(
                pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)
            # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu)
            pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo,
                                      pr_wv_str, nlu)
        else:
            # Execution guided decoding
            prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(
                wemb_n,
                l_n,
                wemb_h,
                l_hpu,
                l_hs,
                engine,
                tb,
                nlu_t,
                nlu_tt,
                tt_to_t_idx,
                nlu,
                beam_size=beam_size,
                knowledge=knowledge,
                knowledge_header=knowledge_header)

            # sort and generate
            pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)

            # Follosing variables are just for the consistency with no-EG case.
            pr_wvi = None  # not used
            pr_wv_str = None
            pr_wv_str_wp = None
            loss = torch.tensor([0])

        pr_sql_q1 = generate_sql_q(pr_sql_i, tb)
        pr_sql_q = [pr_sql_q1]

        try:
            pr_ans, _ = engine.execute_return_query(tb[0]['id'], pr_sc[0],
                                                    pr_sa[0],
                                                    pr_sql_i[0]['conds'])
        except:
            pr_ans = ['Answer not found.']
            pr_sql_q = ['Answer not found.']

        yes = True
        if yes:
            print(f'Q: {nlu[0]}')
            print(f'A: {pr_ans[0]}')
            print(f'SQL: {pr_sql_q}')

        else:
            print(
                f'START ============================================================= '
            )
            print(f'{hds}')
            if yes:
                print(engine.show_table(table_name))
            print(f'nlu: {nlu}')
            print(f'pr_sql_i : {pr_sql_i}')
            print(f'pr_sql_q : {pr_sql_q}')
            print(f'pr_ans: {pr_ans}')
            print(
                f'---------------------------------------------------------------------'
            )

        return pr_sql_i, pr_ans
示例#12
0
def infer(nlu1,
          table_name,
          data_table,
          path_db,
          db_name,
          model,
          model_bert,
          bert_config,
          max_seq_length,
          num_target_layers,
          beam_size=4,
          show_table=False,
          show_answer_only=False):
    # I know it is of against the DRY principle but to minimize the risk of introducing bug w, the infer function introuced.
    model.eval()
    model_bert.eval()
    engine = DBEngine(os.path.join(path_db, f"{db_name}.db"))

    # 问题输入
    nlu = [nlu1]  #问题数组
    '''
    ==tokenize_corenlp_direct_version函数作用:就是英文分词(可能按照stanza规则分?)==
    ==client:stanford的corenlp代理类==
    ==nlu1:刚刚定义的问题列表==
    2020/12/02修改:修改infer中文分词问题
    2020/12/11修改:取消使用stanza中文分词,直接完全分词
    '''
    # nlu_t1 = tokenize_corenlp_direct_version(client, nlu1)
    nlu_t1 = list(nlu1)
    nlu_t = [nlu_t1]  # 把分词之后的数据也放到数组里

    #tb1 = data_table[0]
    '''
        2020/12/01修改:tb1根据问题来选择表
        循环查找即可
    '''
    for temple_table in data_table:
        if temple_table['name'] == table_name:
            tb1 = temple_table
            break
    hds1 = tb1['header']
    tb = [tb1]
    hds = [hds1]
    hs_t = [[]]
    # 获取bert-output
    wemb_n, wemb_h, l_n, l_hpu, l_hs, \
    nlu_tt, t_to_tt_idx, tt_to_t_idx \
        = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                        num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
    # 获取sqlova-output
    '''2020/12/11修改:换用model预测'''
    # prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu,
    #                                                                                 l_hs, engine, tb,
    #                                                                                 nlu_t, nlu_tt,
    #                                                                                 tt_to_t_idx, nlu,
    #                                                                                 beam_size=beam_size)
    # 获取sqlova得出的6大重要部分参数 pr_sc pr_sa
    s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu,
                                               l_hs)
    # 将权重参数变成值
    pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(
        s_sc,
        s_sa,
        s_wn,
        s_wc,
        s_wo,
        s_wv,
    )
    # 根据值得出where-value(这里是分词版,一个一个切分)
    pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt,
                                                       tt_to_t_idx, nlu)
    # 最后在将值整合成conds(pre版本)
    '''2020/12/03修改:将agg和sel变为列表形式(generate_sql_i函数内修改)'''
    pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str,
                              nlu)

    # 切分出where-col/where-op/where-val
    pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)
    if len(pr_sql_i) != 1:  # 判断是不是生成了conds
        raise EnvironmentError
    pr_sql_q1 = generate_sql_q(pr_sql_i, [tb1])  # 根据上面的conds生成sql语句
    pr_sql_q = [pr_sql_q1]  # 将生成的sql语句放到list里,因为infer可能有多句
    '''下面执行SQL语句'''
    try:
        pr_ans, _ = engine.execute_return_query(tb[0]['id'], pr_sc[0],
                                                pr_sa[0], pr_sql_i[0]['conds'])
    except:
        pr_ans = ['Answer not found.']
        pr_sql_q = ['Answer not found.']

    if show_answer_only:
        print(f'Q: {nlu[0]}')
        print(f'A: {pr_ans[0]}')
        print(f'SQL: {pr_sql_q}')

    else:
        print(
            f'START ============================================================= '
        )
        print(f'{hds}')
        if show_table:
            print(engine.show_table(table_name))
        print(f'nlu: {nlu}')
        print(f'pr_sql_i : {pr_sql_i}')
        print(f'pr_sql_q : {pr_sql_q}')
        print(f'pr_ans: {pr_ans}')
        print(
            f'---------------------------------------------------------------------'
        )

    return pr_sql_i, pr_ans
示例#13
0
def test(data_loader,
         data_table,
         model,
         model_bert,
         bert_config,
         tokenizer,
         max_seq_length,
         num_target_layers,
         detail=False,
         st_pos=0,
         cnt_tot=1,
         path_db=None,
         dset_name='test'):
    model.eval()
    model_bert.eval()

    ave_loss = 0
    cnt = 0
    cnt_sc = 0
    cnt_sa = 0
    cnt_wn = 0
    cnt_wc = 0
    cnt_wo = 0
    cnt_wv = 0
    cnt_wvi = 0
    cnt_ao = 0
    cnt_ord = 0
    cnt_lx = 0
    cnt_x = 0

    cnt_list = []

    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))
    results = []
    many = 0
    for iB, t in enumerate(data_loader):

        cnt += len(t)
        if cnt < st_pos:
            continue
        # Get fields
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds, hs_type = get_fields(
            t, data_table, no_hs_t=True, no_sql_t=True)

        g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_ao, g_ord = get_g(sql_i)

        g_wvi_corenlp = get_g_wvi_corenlp(t)

        wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        try:
            g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx,
                                                      g_wvi_corenlp)
            g_wv_str, g_wv_str_wp = convert_pr_wvi_to_string(
                g_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu, hs_type)

        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            for b in range(len(nlu)):
                results1 = {}
                results1["error"] = "Skip happened"
                results1["nlu"] = nlu[b]
                results1["table_id"] = tb[b]["id"]
                results.append(results1)
            continue

        s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, s_ao, s_ord = model(
            wemb_n, l_n, wemb_h, l_hpu, l_hs, hs_type=hs_type)

        # Calculate loss & step
        loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, s_ao, s_ord,
                          g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi, g_ao, g_ord,
                          hs_type)

        # prediction
        pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, pr_ao, pr_ord = pred_sw_se(
            s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, s_ao, s_ord, hs_type)
        pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(
            pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu, hs_type)

        # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu)
        pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str,
                                  pr_ao, pr_ord, nlu, hs_type)

        # Saving for the official evaluation later.
        for b, pr_sql_i1 in enumerate(pr_sql_i):
            results1 = {}
            results1["query"] = pr_sql_i1
            results1["table_id"] = tb[b]["id"]
            results1["nlu"] = nlu[b]
            results.append(results1)


        cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \
        cnt_wc1_list, cnt_wo1_list, \
        cnt_wvi1_list, cnt_wv1_list,cnt_ao1_list, cnt_ord1_list = get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi,g_ao,g_ord,
                                                                   pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, pr_ao, pr_ord,
                                                                   sql_i, pr_sql_i,hs_type,
                                                                   mode='test')

        cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list,
                                       cnt_wn1_list, cnt_wc1_list,
                                       cnt_wo1_list, cnt_wv1_list,
                                       cnt_ao1_list, cnt_ord1_list, hs_type)

        # Execution accura y test
        cnt_x1_list = []
        # lx stands for logical form accuracy

        # Execution accuracy test.
        cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa,
                                                    sql_i, pr_sc, pr_sa,
                                                    pr_sql_i, pr_ao, pr_ord,
                                                    hds)  # xiajing:need change

        # stat
        ave_loss += loss.item()

        # count
        cnt_sc += sum(cnt_sc1_list)
        cnt_sa += sum(cnt_sa1_list)
        cnt_wn += sum(cnt_wn1_list)
        cnt_wc += sum(cnt_wc1_list)
        cnt_wo += sum(cnt_wo1_list)
        cnt_wv += sum(cnt_wv1_list)
        cnt_ao += sum(cnt_ao1_list)
        cnt_ord += sum(cnt_ord1_list)
        cnt_wvi += sum(cnt_wvi1_list)
        cnt_lx += sum(cnt_lx1_list)
        cnt_x += sum(cnt_x1_list)

        current_cnt = [
            cnt_tot, cnt, cnt_sc, cnt_sa, cnt_wn, cnt_wc, cnt_wo, cnt_wv,
            cnt_wvi, cnt_ao, cnt_ord, cnt_lx, cnt_x
        ]
        cnt_list1 = [
            cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list,
            cnt_wo1_list, cnt_wv1_list, cnt_ao1_list, cnt_ord1_list,
            cnt_lx1_list, cnt_x1_list
        ]
        cnt_list.append(cnt_list1)

    ave_loss /= cnt
    acc_sc = cnt_sc / cnt
    acc_sa = cnt_sa / cnt
    acc_wn = cnt_wn / cnt
    acc_wc = cnt_wc / cnt
    acc_wo = cnt_wo / cnt
    acc_wvi = cnt_wvi / cnt
    acc_wv = cnt_wv / cnt
    acc_ao = cnt_ao / cnt
    acc_ord = cnt_ord / cnt
    acc_lx = cnt_lx / cnt
    acc_x = cnt_x / cnt

    acc = [
        ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv,
        acc_ao, acc_ord, acc_lx, acc_x
    ]
    return acc, results, cnt_list
示例#14
0
def test(data_loader,
         data_table,
         model,
         model_bert,
         bert_config,
         tokenizer,
         max_seq_length,
         num_target_layers,
         detail=False,
         st_pos=0,
         cnt_tot=1,
         EG=False,
         beam_size=4,
         path_db=None,
         dset_name='test'):
    model.eval()
    model_bert.eval()

    ave_loss = 0
    cnt = 0
    cnt_sc = 0
    cnt_sa = 0
    cnt_wn = 0
    cnt_wc = 0
    cnt_wo = 0
    cnt_wv = 0
    cnt_wvi = 0
    cnt_lx = 0
    cnt_x = 0

    cnt_list = []
    #初始化数据库查询引擎
    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))
    results = []  #相比于train,他就多了个收取结果的数组
    for iB, t in enumerate(data_loader):
        '''t为bs个数据的详情'''
        cnt += len(t)  #每个循环中的例子数
        if cnt < st_pos:
            continue
        # 将问题拆分成一个个部分,并且结合问题所对应的表
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(
            t, data_table, no_hs_t=True, no_sql_t=True)
        # nlu:bs个问题
        # nlu_t:标记化的问题,这里不分词
        # sql_i:SQL查询的规范形式
        # sql_q:完整的SQL查询文本。 不曾用过。
        # sql_t:没软用
        # tb:bs个问题对应的表格(不一定一对一,但是保证bs个问题要找的表在里面)
        # hs_t:标记化的标头。 不曾用过。
        # hds:表头
        '''分别获取bs个问题的sc, sa, wn, wc, wo, wv(多个wn都放在list里)'''
        g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
        '''这个是获取loader里WV的起止'''
        g_wvi_corenlp = get_g_wvi_corenlp(t)

        wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        '''函数作用:获取所有从bert模型中输出的参数'''
        # wemb_n: 问题的参数
        # wemb_h: 表字段的参数
        # l_n: 问题的长度
        # l_hpu: 我们不是把问题和表头合在一起了嘛,这就是通过表头的起始,获取每个表头字段的长度
        # l_hs: 表字段总数
        # nlu_tt: 已经分词了的问题
        # t_to_tt_idx: 将已分词的每个字(词)标记它的序号
        # tt_to_t_idx: 同上?

        try:
            g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx,
                                                      g_wvi_corenlp)
            g_wv_str, g_wv_str_wp = convert_pr_wvi_to_string(
                g_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)

        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            for b in range(len(nlu)):
                results1 = {}
                results1["error"] = "Skip happened"
                results1["nlu"] = nlu[b]
                results1["table_id"] = tb[b]["id"]
                results.append(results1)
            continue

        # model specific part
        # score
        if not EG:
            # 上面已经获取了bert模型的输出,这里将这个输出输入到s2s模型中(并结合问题json的各个字段),获取这个模型得出的bat_sizen内六大关键元素的权重
            s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h,
                                                       l_hpu, l_hs)

            # 生成/计算损失值
            loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa,
                              g_wn, g_wc, g_wo, g_wvi)

            # 预测得出:最可能的sc/sa/wn/wc/wo/wvi
            pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(
                s_sc,
                s_sa,
                s_wn,
                s_wc,
                s_wo,
                s_wv,
            )
            # 根据预测得出的wv起始位置来获取where-value值
            pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(
                pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)
            # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu)
            # 由预测出的pr_等生成对应的sql语句表示
            pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo,
                                      pr_wv_str, nlu)
        else:
            # Execution guided decoding
            prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(
                wemb_n,
                l_n,
                wemb_h,
                l_hpu,
                l_hs,
                engine,
                tb,
                nlu_t,
                nlu_tt,
                tt_to_t_idx,
                nlu,
                beam_size=beam_size)
            # sort and generate
            pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)

            # Follosing variables are just for the consistency with no-EG case.
            pr_wvi = None  # not used
            pr_wv_str = None
            pr_wv_str_wp = None
            loss = torch.tensor([0])

        g_sql_q = generate_sql_q(sql_i, tb)
        pr_sql_q = generate_sql_q(pr_sql_i, tb)

        # Saving for the official evaluation later.
        for b, pr_sql_i1 in enumerate(pr_sql_i):
            results1 = {}
            results1["query"] = pr_sql_i1
            results1["table_id"] = tb[b]["id"]
            results1["nlu"] = nlu[b]
            results.append(results1)

        cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \
        cnt_wc1_list, cnt_wo1_list, \
        cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi,
                                                      pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi,
                                                      sql_i, pr_sql_i,
                                                      mode='test')

        cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list,
                                       cnt_wn1_list, cnt_wc1_list,
                                       cnt_wo1_list, cnt_wv1_list)

        # Execution accura y test
        cnt_x1_list = []
        # lx stands for logical form accuracy

        # Execution accuracy test.
        cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa,
                                                    sql_i, pr_sc, pr_sa,
                                                    pr_sql_i)

        # stat
        ave_loss += loss.item()

        # count
        cnt_sc += sum(cnt_sc1_list)
        cnt_sa += sum(cnt_sa1_list)
        cnt_wn += sum(cnt_wn1_list)
        cnt_wc += sum(cnt_wc1_list)
        cnt_wo += sum(cnt_wo1_list)
        cnt_wv += sum(cnt_wv1_list)
        cnt_wvi += sum(cnt_wvi1_list)
        cnt_lx += sum(cnt_lx1_list)
        cnt_x += sum(cnt_x1_list)

        current_cnt = [
            cnt_tot, cnt, cnt_sc, cnt_sa, cnt_wn, cnt_wc, cnt_wo, cnt_wv,
            cnt_wvi, cnt_lx, cnt_x
        ]
        cnt_list1 = [
            cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list,
            cnt_wo1_list, cnt_wv1_list, cnt_lx1_list, cnt_x1_list
        ]
        cnt_list.append(cnt_list1)
        # report
        if detail:
            report_detail(hds, nlu, g_sc, g_sa, g_wn, g_wc, g_wo, g_wv,
                          g_wv_str, g_sql_q, g_ans, pr_sc, pr_sa, pr_wn, pr_wc,
                          pr_wo, pr_wv_str, pr_sql_q, pr_ans, cnt_list1,
                          current_cnt)

    ave_loss /= cnt
    acc_sc = cnt_sc / cnt
    acc_sa = cnt_sa / cnt
    acc_wn = cnt_wn / cnt
    acc_wc = cnt_wc / cnt
    acc_wo = cnt_wo / cnt
    acc_wvi = cnt_wvi / cnt
    acc_wv = cnt_wv / cnt
    acc_lx = cnt_lx / cnt
    acc_x = cnt_x / cnt

    acc = [
        ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv,
        acc_lx, acc_x
    ]
    return acc, results, cnt_list
示例#15
0
def train(train_loader,
          train_table,
          model,
          model_bert,
          opt,
          bert_config,
          tokenizer,
          max_seq_length,
          num_target_layers,
          accumulate_gradients=1,
          check_grad=True,
          st_pos=0,
          opt_bert=None,
          path_db=None,
          dset_name='train'):
    model.train()  #将模块设置为训练模式/评估模式。
    model_bert.train()

    ave_loss = 0
    cnt = 0  # count the # of examples
    cnt_sc = 0  # count the # of correct predictions of select column
    cnt_sa = 0  # of selectd aggregation
    cnt_wn = 0  # of where number
    cnt_wc = 0  # of where column
    cnt_wo = 0  # of where operator
    cnt_wv = 0  # of where-value
    cnt_wvi = 0  # of where-value index (on question tokens)
    cnt_lx = 0  # of logical form acc
    cnt_x = 0  # of execution acc

    # 初始化数据库查询引擎
    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))

    for iB, t in enumerate(train_loader):
        '''t为batch_size个数据'''

        cnt += len(t)  #每个循环中的例子数

        if cnt < st_pos:
            continue
        # 将问题拆分成一个个部分,并且结合问题所对应的表
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(
            t, train_table, no_hs_t=True, no_sql_t=True)
        # nlu:bs个问题
        # nlu_t:标记化的问题,这里不分词
        # sql_i:SQL查询的规范形式
        # sql_q:完整的SQL查询文本。 不曾用过。
        # sql_t:没软用
        # tb:bs个问题对应的表格(不一定一对一,但是保证bs个问题要找的表在里面)
        # hs_t:标记化的标头。 不曾用过。
        # hds:表头
        '''分别获取bs个问题的sc, sa, wn, wc, wo, wv(多个wn都放在list里)'''
        g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
        # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset.
        '''这个是获取loader里WV的起止(有问题待改进)'''
        g_wvi_corenlp = get_g_wvi_corenlp(t)

        wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        '''函数作用:获取所有从bert模型中输出的参数'''
        # wemb_n: 问题的参数
        # wemb_h: 表字段的参数
        # l_n: 问题的长度
        # l_hpu: 我们不是把问题和表头合在一起了嘛,这就是通过表头的起始,获取每个表头字段的长度
        # l_hs: 表字段总数
        # nlu_tt: 已经分词了的问题
        # t_to_tt_idx: 将已分词的每个字(词)标记它的序号
        # tt_to_t_idx: 同上?

        try:
            #验证/过滤?
            g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx,
                                                      g_wvi_corenlp)
        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            # e.g. train: 32.
            continue

        # 上面已经获取了bert模型的输出,这里将这个输出输入到s2s模型中(并结合问题json的各个字段),获取这个模型得出的bat_sizen内六大关键元素的权重
        s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n,
                                                   l_n,
                                                   wemb_h,
                                                   l_hpu,
                                                   l_hs,
                                                   g_sc=g_sc,
                                                   g_sa=g_sa,
                                                   g_wn=g_wn,
                                                   g_wc=g_wc,
                                                   g_wvi=g_wvi)

        # 生成/计算损失值
        loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa, g_wn,
                          g_wc, g_wo, g_wvi)

        # 计算梯度
        if iB % accumulate_gradients == 0:  # mode
            # at start, perform zero_grad
            opt.zero_grad()
            if opt_bert:
                opt_bert.zero_grad()
            loss.backward()
            if accumulate_gradients == 1:
                opt.step()
                if opt_bert:
                    opt_bert.step()
        elif iB % accumulate_gradients == (accumulate_gradients - 1):
            # at the final, take step with accumulated graident
            loss.backward()
            opt.step()
            if opt_bert:
                opt_bert.step()
        else:
            # at intermediate stage, just accumulates the gradients
            loss.backward()

        # 预测得出:最可能的sc/sa/wn/wc/wo/wvi
        pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(
            s_sc,
            s_sa,
            s_wn,
            s_wc,
            s_wo,
            s_wv,
        )
        # 根据预测得出的wv起始位置来获取where-value值
        pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(
            pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)

        # Sort pr_wc:
        #   Sort pr_wc when training the model as pr_wo and pr_wvi are predicted using ground-truth where-column (g_wc)
        #   In case of 'dev' or 'test', it is not necessary as the ground-truth is not used during inference.
        pr_wc_sorted = sort_pr_wc(pr_wc, g_wc)  #对预测出的wc进行排序
        pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc_sorted, pr_wo,
                                  pr_wv_str, nlu)  #由预测出的pr_等生成对应的sql语句表示

        # 计算准确率(1:正确 0:错误)
        cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \
        cnt_wc1_list, cnt_wo1_list, \
        cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi,
                                                      pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi,
                                                      sql_i, pr_sql_i,
                                                      mode='train')
        #是否全对/全对的数量(完美的sql,全对:1 不全对:0)
        cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list,
                                       cnt_wn1_list, cnt_wc1_list,
                                       cnt_wo1_list, cnt_wv1_list)
        # lx stands for logical form accuracy

        # 获得结果,出现小错误频率大王!!
        cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa,
                                                    sql_i, pr_sc, pr_sa,
                                                    pr_sql_i)

        # statistics
        ave_loss += loss.item()

        # count
        cnt_sc += sum(cnt_sc1_list)
        cnt_sa += sum(cnt_sa1_list)
        cnt_wn += sum(cnt_wn1_list)
        cnt_wc += sum(cnt_wc1_list)
        cnt_wo += sum(cnt_wo1_list)
        cnt_wvi += sum(cnt_wvi1_list)
        cnt_wv += sum(cnt_wv1_list)
        cnt_lx += sum(cnt_lx1_list)
        cnt_x += sum(cnt_x1_list)

    ave_loss /= cnt
    acc_sc = cnt_sc / cnt
    acc_sa = cnt_sa / cnt
    acc_wn = cnt_wn / cnt
    acc_wc = cnt_wc / cnt
    acc_wo = cnt_wo / cnt
    acc_wvi = cnt_wvi / cnt
    acc_wv = cnt_wv / cnt
    acc_lx = cnt_lx / cnt
    acc_x = cnt_x / cnt

    acc = [
        ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv,
        acc_lx, acc_x
    ]

    aux_out = 1

    return acc, aux_out
示例#16
0
def predict(data_loader,
            data_table,
            model,
            model_bert,
            bert_config,
            tokenizer,
            max_seq_length,
            num_target_layers,
            detail=False,
            st_pos=0,
            cnt_tot=1,
            EG=False,
            beam_size=4,
            path_db=None,
            dset_name='test'):

    model.eval()
    model_bert.eval()

    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))
    results = []
    # for iB, t in enumerate(data_loader):
    #     nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True)
    #     g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
    #     g_wvi_corenlp = get_g_wvi_corenlp(t)
    #     wemb_n, wemb_h, l_n, l_hpu, l_hs, \
    #     nlu_tt, t_to_tt_idx, tt_to_t_idx \
    #         = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
    #                         num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
    #     if not EG:
    #         # No Execution guided decoding
    #         s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs)
    #         pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, )
    #         pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)
    #         pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu)
    #     else:
    #         # Execution guided decoding
    #         prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu,
    #                                                                                         l_hs, engine, tb,
    #                                                                                         nlu_t, nlu_tt,
    #                                                                                         tt_to_t_idx, nlu,
    #                                                                                         beam_size=beam_size)
    #         # sort and generate
    #         pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)
    #         # Following variables are just for consistency with no-EG case.
    #         pr_wvi = None # not used
    #         pr_wv_str=None
    #         pr_wv_str_wp=None
    #
    #     pr_sql_q = generate_sql_q(pr_sql_i, tb)
    #
    #     for b, (pr_sql_i1, pr_sql_q1) in enumerate(zip(pr_sql_i, pr_sql_q)):
    #         results1 = {}
    #         results1["query"] = pr_sql_i1
    #         results1["table_id"] = tb[b]["id"]
    #         results1["nlu"] = nlu[b]
    #         results1["sql"] = pr_sql_q1
    #         results.append(results1)

    for iB, t in enumerate(data_loader):
        # Get fields
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(
            t, data_table, no_hs_t=True, no_sql_t=True)

        g_sc, g_sa, g_sop, g_wn, g_wc, g_wo, g_wv, g_sel_num_seq, g_sel_ag_seq = get_g(
            sql_i)
        g_wvi_corenlp = get_g_wvi_corenlp(t)

        wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        try:
            pass
            g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx,
                                                      g_wvi_corenlp)
        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            for b in range(len(nlu)):
                results1 = {}
                results1["error"] = "Skip happened"
                results1["nlu"] = nlu[b]
                results1["table_id"] = tb[b]["id"]
                results.append(results1)
            continue

        # model specific part
        # score
        if not EG:
            # No Execution guided decoding
            g_sel_seq = [x[1] for x in g_sel_ag_seq]
            s_scn, s_sc, s_sa, s_sop, s_wn, s_wc, s_wo, s_wv = model(
                wemb_n,
                l_n,
                wemb_h,
                l_hpu,
                l_hs,
                g_sc=g_sc,
                g_sa=g_sa,
                g_wn=g_wn,
                g_wc=g_wc,
                g_sop=g_sop,
                g_wo=g_wo,
                g_wvi=g_wv,
                g_sel=g_sel_seq,
                g_scn=g_sel_num_seq)
            # prediction
            score = []
            score.append(s_scn)
            score.append(s_sc)
            score.append(s_sa)
            score.append(s_sop)
            tuple(score)
            pr_sql_i1 = model.gen_query(score, nlu_tt, nlu)

            pr_wn, pr_wc, pr_sop, pr_wo, pr_wvi = pred_sw_se(
                s_sop, s_wn, s_wc, s_wo, s_wv)
            pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(
                pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)  # 映射到字符串

            pr_wc_sorted = sort_pr_wc(pr_wc, g_wc)
            pr_sql_i = generate_sql_i(pr_sql_i1, pr_wn, pr_wc_sorted, pr_wo,
                                      pr_wv_str, nlu)

        else:
            # Execution guided decoding
            prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(
                wemb_n,
                l_n,
                wemb_h,
                l_hpu,
                l_hs,
                engine,
                tb,
                nlu_t,
                nlu_tt,
                tt_to_t_idx,
                nlu,
                beam_size=beam_size)
            # sort and generate
            pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)

        # # Saving for the official evaluation later.
        for b, pr_sql_i1 in enumerate(pr_sql_i):
            results.append(pr_sql_i1)

    return results
示例#17
0
def test(data_loader,
         data_table,
         model,
         model_bert,
         bert_config,
         tokenizer,
         max_seq_length,
         num_target_layers,
         detail=False,
         st_pos=0,
         cnt_tot=1,
         EG=False,
         beam_size=4,
         path_db=None,
         dset_name='test',
         mvl=2):
    model.eval()
    model_bert.eval()

    ave_loss = 0
    cnt = 0
    cnt_sn = 0
    cnt_sc = 0
    cnt_sa = 0
    cnt_wn = 0
    cnt_wr = 0
    cnt_wc = 0
    cnt_wo = 0
    cnt_wv = 0
    cnt_wvi = 0
    cnt_lx = 0
    cnt_x = 0
    cnt_err = 0
    cnt_still = 0
    cnt_skip = 0

    cnt_hrpc = 0

    cnt_list = []

    engine = DBEngine(os.path.join(path_db, dset_name, f"{dset_name}.db"))
    results = []
    for iB, t in enumerate(data_loader):

        #print('iB : %d' % iB)

        cnt += len(t)
        if cnt < st_pos:
            continue
        # Get fields
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(
            t, data_table, no_hs_t=True, no_sql_t=True, generate_mode=False)

        g_sn, g_sc, g_sa, g_wn, g_wr, g_dwn, g_wc, g_wo, g_wv, g_r_c_n, wvi_change_index = get_g(
            sql_i)
        g_wrcn = g_r_c_n


        wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx, wemb_v, l_npu, l_token \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers, num_out_layers_v=num_target_layers)
        try:  #here problem
            #print('ok')
            g_wvi_corenlp = get_g_wvi_corenlp(t, wvi_change_index)
            g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx,
                                                      g_wvi_corenlp)
            #print('no')

            g_wv_str, g_wv_str_wp = convert_pr_wvi_to_string(
                g_wvi, nlu_t, nlu_tt, tt_to_t_idx)
            g_wvi = get_g_wvi_stidx_length_jian_yi(g_wvi_corenlp)
            #print('gogogo:', g_wvi)
            #这里需要连同脏数据一起计算准确率
        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            for b in range(len(nlu)):
                results1 = {}
                results1["error"] = "Skip happened"
                results1["nlu"] = nlu[b]
                results1["table_id"] = tb[b]["id"]
                results.append(results1)
            cnt_skip += len(nlu)
            continue

        # model specific part
        # score
        if not EG:
            # No Execution guided decoding
            s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4 = model(
                mvl, wemb_n, l_n, wemb_h, l_hpu, l_hs, wemb_v, l_npu, l_token)

            # get loss & step
            #loss = Loss_sw_se(s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wrpc, s_nrpc, s_wc, s_wo, s_wv1, s_wv2, g_sn, g_sc, g_sa, g_wn, g_dwn, g_wr, g_wc, g_wo, g_wvi, g_wrcn)
            #unable for loss
            loss = torch.tensor([0])
            # prediction
            pr_sn, pr_sc, pr_sa, pr_wn, pr_wr, pr_hrpc, pr_wc, pr_wo, pr_wvi = pred_sw_se(
                s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2,
                s_wv3, s_wv4, mvl)
            pr_wvi_decode = g_wvi_decoder_stidx_length_jian_yi(pr_wvi)
            pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(
                pr_wvi_decode, nlu_t, nlu_tt, tt_to_t_idx)
            # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu)
            pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wr, pr_wc, pr_wo,
                                      pr_wv_str, nlu)
        else:

            s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4 = model(
                mvl, wemb_n, l_n, wemb_h, l_hpu, l_hs, wemb_v, l_npu, l_token)
            pr_sn1, pr_sc1, pr_sa1, pr_wn1, pr_wr1, pr_hrpc1, pr_wc1, pr_wo1, pr_wvi1 = pred_sw_se(
                s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2,
                s_wv3, s_wv4, mvl)
            pr_wvi_decode = g_wvi_decoder_stidx_length_jian_yi(pr_wvi1)
            pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(
                pr_wvi_decode, nlu_t, nlu_tt, tt_to_t_idx)
            pr_sql_i1 = generate_sql_i(pr_sc1, pr_sa1, pr_wn1, pr_wr1, pr_wc1,
                                       pr_wo1, pr_wv_str, nlu)

            # Execution guided decoding
            pr_sql_i, exe_error1, still_error1 = model.beam_forward(
                pr_sql_i1,
                mvl,
                wemb_n,
                l_n,
                wemb_h,
                l_hpu,
                l_hs,
                wemb_v,
                l_npu,
                l_token,
                engine,
                tb,
                nlu_t,
                beam_size=beam_size)
            # sort and generate
            #pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)
            cnt_err += exe_error1
            cnt_still += still_error1

            pr_sn, pr_sc, pr_sa, pr_wn, pr_wr, pr_wc, pr_wo, pr_wv = generate_pr(
                pr_sql_i)

            # Follosing variables are just for the consistency with no-EG case.
            pr_wvi = None  # not used
            pr_wv_str = None
            pr_wv_str_wp = None
            loss = torch.tensor([0])
            '''
            s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4 = model.beam_forward(mvl, wemb_n, l_n, wemb_h, l_hpu, l_hs, wemb_v, l_npu, l_token, engine, tb, nlu_t, beam_size=beam_size)

            # get loss & step
            #loss = Loss_sw_se(s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wrpc, s_nrpc, s_wc, s_wo, s_wv1, s_wv2, g_sn, g_sc, g_sa, g_wn, g_dwn, g_wr, g_wc, g_wo, g_wvi, g_wrcn)
            #unable for loss
            loss = torch.tensor([0])
            # prediction
            pr_sn, pr_sc, pr_sa, pr_wn, pr_wr, pr_hrpc, pr_wc, pr_wo, pr_wvi = pred_sw_se(s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4, mvl)
            pr_wvi_decode = g_wvi_decoder_stidx_length_jian_yi(pr_wvi)
            pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi_decode, nlu_t, nlu_tt, tt_to_t_idx)
            # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu)
            pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wr, pr_wc, pr_wo, pr_wv_str, nlu)
            '''

        g_sql_q = generate_sql_q(sql_i, tb)
        pr_sql_q = generate_sql_q(pr_sql_i, tb)

        # Saving for the official evaluation later.
        for b, pr_sql_i1 in enumerate(pr_sql_i):
            results1 = {}
            results1["query"] = pr_sql_i1
            results1["table_id"] = tb[b]["id"]
            results1["nlu"] = nlu[b]
            results.append(results1)

        cnt_sn1_list, cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \
        cnt_wr1_list, cnt_wc1_list, cnt_wo1_list, \
        cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sn, g_sc, g_sa, g_wn, g_wr, g_wc, g_wo, g_wvi,
                                                                   pr_sn, pr_sc, pr_sa, pr_wn, pr_wr, pr_wc, pr_wo, pr_wvi,
                                                                   sql_i, pr_sql_i,
                                                                   mode='test')

        cnt_lx1_list = get_cnt_lx_list(cnt_sn1_list, cnt_sc1_list,
                                       cnt_sa1_list, cnt_wn1_list,
                                       cnt_wr1_list, cnt_wc1_list,
                                       cnt_wo1_list, cnt_wv1_list)

        # Execution accura y test
        cnt_x1_list = []
        # lx stands for logical form accuracy

        # Execution accuracy test.
        cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa,
                                                    sql_i, pr_sc, pr_sa,
                                                    pr_sql_i)

        # stat
        ave_loss += loss.item()

        #print('loss: ', ave_loss / cnt)

        # count
        cnt_sn += sum(cnt_sn1_list)
        cnt_sc += sum(cnt_sc1_list)
        cnt_sa += sum(cnt_sa1_list)
        cnt_wn += sum(cnt_wn1_list)
        cnt_wr += sum(cnt_wr1_list)
        cnt_wc += sum(cnt_wc1_list)
        cnt_wo += sum(cnt_wo1_list)
        cnt_wv += sum(cnt_wv1_list)
        cnt_wvi += sum(cnt_wvi1_list)
        cnt_lx += sum(cnt_lx1_list)
        cnt_x += sum(cnt_x1_list)

        if iB % 10 == 0:
            logger.info(
                '%d - th data batch -> loss: %.4f; acc_sn: %.4f; acc_sc: %.4f; acc_sa: %.4f; acc_wn: %.4f; acc_wr: %.4f; acc_wc: %.4f; acc_wo: %.4f; acc_wvi: %.4f; acc_wv: %.4f; acc_lx: %.4f; acc_x: %.4f; execute_error: %.4f; skip_error: %.4f; still_error: %.4f'
                % (iB, ave_loss / cnt, cnt_sn / cnt, cnt_sc / cnt, cnt_sa /
                   cnt, cnt_wn / cnt, cnt_wr / cnt, cnt_wc / cnt, cnt_wo / cnt,
                   cnt_wvi / cnt, cnt_wv / cnt, cnt_lx / cnt, cnt_x / cnt,
                   cnt_err / cnt, cnt_skip / cnt, cnt_still / cnt))

        current_cnt = [
            cnt_tot, cnt, cnt_sn, cnt_sc, cnt_sa, cnt_wn, cnt_wr, cnt_wc,
            cnt_wo, cnt_wv, cnt_wvi, cnt_lx, cnt_x
        ]
        cnt_list1 = [
            cnt_sn1_list, cnt_sc1_list, cnt_sa1_list, cnt_wn1_list,
            cnt_wr1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list,
            cnt_lx1_list, cnt_x1_list
        ]
        cnt_list.append(cnt_list1)
        # report
        if detail:
            report_detail(hds, nlu, g_sn, g_sc, g_sa, g_wn, g_wr, g_wc, g_wo,
                          g_wv, g_wv_str, g_sql_q, g_ans, pr_sn, pr_sc, pr_sa,
                          pr_wn, pr_wr, pr_wc, pr_wo, pr_wv_str, pr_sql_q,
                          pr_ans, cnt_list1, current_cnt)

    ave_loss /= cnt
    acc_sn = cnt_sn / cnt
    acc_sc = cnt_sc / cnt
    acc_sa = cnt_sa / cnt
    acc_wn = cnt_wn / cnt
    acc_wr = cnt_wr / cnt
    acc_wc = cnt_wc / cnt
    acc_wo = cnt_wo / cnt
    acc_wvi = cnt_wvi / cnt
    acc_wv = cnt_wv / cnt
    acc_lx = cnt_lx / cnt
    acc_x = cnt_x / cnt

    acc = [
        ave_loss, acc_sn, acc_sc, acc_sa, acc_wn, acc_wr, acc_wc, acc_wo,
        acc_wvi, acc_wv, acc_lx, acc_x
    ]
    return acc, results, cnt_list
示例#18
0
def train(train_loader,
          train_table,
          model,
          model_bert,
          opt,
          bert_config,
          tokenizer,
          max_seq_length,
          num_target_layers,
          accumulate_gradients=1,
          check_grad=False,
          st_pos=0,
          opt_bert=None,
          path_db=None,
          dset_name='train',
          col_pool_type='start_tok',
          aug=False):
    model.train()
    model_bert.train()

    ave_loss = 0
    cnt = 0  # count the # of examples
    cnt_sc = 0  # count the # of correct predictions of select column
    cnt_sa = 0  # of selectd aggregation
    cnt_wn = 0  # of where number
    cnt_wc = 0  # of where column
    cnt_wo = 0  # of where operator
    cnt_wv = 0  # of where-value
    cnt_wvi = 0  # of where-value index (on question tokens)
    cnt_lx = 0  # of logical form acc
    cnt_x = 0  # of execution acc

    # Engine for SQL querying.
    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))

    for iB, t in enumerate(train_loader):
        cnt += len(t)

        if cnt < st_pos:
            continue
        # Get fields
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(
            t, train_table, no_hs_t=True, no_sql_t=True)
        g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
        # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset.
        g_wvi_corenlp = get_g_wvi_corenlp(t)

        all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, \
        l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_bert_output(model_bert, tokenizer, nlu_t, hds, max_seq_length)

        try:
            #
            g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx,
                                                      g_wvi_corenlp)
        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            # e.g. train: 32.
            continue

        wemb_n = get_wemb_n(i_nlu, l_n, bert_config.hidden_size,
                            bert_config.num_hidden_layers, all_encoder_layer,
                            1)
        wemb_h = get_wemb_h_FT_Scalar_1(i_hds,
                                        l_hs,
                                        bert_config.hidden_size,
                                        all_encoder_layer,
                                        col_pool_type=col_pool_type)
        # wemb_h = [B, max_header_number, hS]
        cls_vec = pooled_output

        # model specific part
        # get g_wvi (it is idex for word-piece tok)
        # score
        s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n,
                                                   l_n,
                                                   wemb_h,
                                                   l_hs,
                                                   cls_vec,
                                                   g_sc=g_sc,
                                                   g_sa=g_sa,
                                                   g_wn=g_wn,
                                                   g_wc=g_wc,
                                                   g_wo=g_wo,
                                                   g_wvi=g_wvi)

        # Calculate loss & step
        loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa, g_wn,
                          g_wc, g_wo, g_wvi)

        # Calculate gradient
        if iB % accumulate_gradients == 0:  # mode
            # at start, perform zero_grad
            opt.zero_grad()
            if opt_bert:
                opt_bert.zero_grad()
            loss.backward()
            if accumulate_gradients == 1:
                opt.step()
                if opt_bert:
                    opt_bert.step()
        elif iB % accumulate_gradients == (accumulate_gradients - 1):
            # at the final, take step with accumulated graident
            loss.backward()
            opt.step()
            if opt_bert:
                opt_bert.step()
        else:
            # at intermediate stage, just accumulates the gradients
            loss.backward()

        if check_grad:
            named_parameters = model.named_parameters()

            mu_list, sig_list = get_mean_grad(named_parameters)

            grad_abs_mean_mean = mean(mu_list)
            grad_abs_mean_sig = std(mu_list)
            grad_abs_sig_mean = mean(sig_list)
        else:
            grad_abs_mean_mean = 1
            grad_abs_mean_sig = 1
            grad_abs_sig_mean = 1

        # Prediction
        pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(
            s_sc,
            s_sa,
            s_wn,
            s_wc,
            s_wo,
            s_wv,
        )
        pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(
            pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)
        pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str,
                                  nlu)

        # Cacluate accuracy
        cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \
        cnt_wc1_list, cnt_wo1_list, \
        cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi,
                                                      pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi,
                                                      sql_i, pr_sql_i,
                                                      mode='train')

        cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list,
                                       cnt_wn1_list, cnt_wc1_list,
                                       cnt_wo1_list, cnt_wv1_list)
        # lx stands for logical form accuracy

        # Execution accuracy test.
        if not aug:
            cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(
                engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i)
        else:
            cnt_x1_list = [0] * len(t)
            g_ans = ['N/A (data augmented'] * len(t)
            pr_ans = ['N/A (data augmented'] * len(t)
        # statistics
        ave_loss += loss.item()

        # count
        cnt_sc += sum(cnt_sc1_list)
        cnt_sa += sum(cnt_sa1_list)
        cnt_wn += sum(cnt_wn1_list)
        cnt_wc += sum(cnt_wc1_list)
        cnt_wo += sum(cnt_wo1_list)
        cnt_wvi += sum(cnt_wvi1_list)
        cnt_wv += sum(cnt_wv1_list)
        cnt_lx += sum(cnt_lx1_list)
        cnt_x += sum(cnt_x1_list)

    ave_loss /= cnt
    acc_sc = cnt_sc / cnt
    acc_sa = cnt_sa / cnt
    acc_wn = cnt_wn / cnt
    acc_wc = cnt_wc / cnt
    acc_wo = cnt_wo / cnt
    acc_wvi = cnt_wv / cnt
    acc_wv = cnt_wv / cnt
    acc_lx = cnt_lx / cnt
    acc_x = cnt_x / cnt

    acc = [
        ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv,
        acc_lx, acc_x
    ]
    aux_out = [grad_abs_mean_mean, grad_abs_mean_sig, grad_abs_sig_mean]

    return acc, aux_out
示例#19
0
def infer(nlu1,
          table_name,
          data_table,
          path_db,
          db_name,
          model,
          model_bert,
          bert_config,
          max_seq_length,
          num_target_layers,
          beam_size=4,
          show_table=False,
          show_answer_only=False):
    # I know it is of against the DRY principle but to minimize the risk of introducing bug w, the infer function introuced.
    model.eval()
    model_bert.eval()
    engine = DBEngine(os.path.join(path_db, f"{db_name}.db"))

    # Get inputs
    nlu = [nlu1]
    # nlu_t1 = tokenize_corenlp(client, nlu1)
    nlu_t1 = tokenize_corenlp_direct_version(client, nlu1)
    nlu_t = [nlu_t1]

    tb1 = data_table[0]
    hds1 = tb1['header']
    tb = [tb1]
    hds = [hds1]
    hs_t = [[]]

    wemb_n, wemb_h, l_n, l_hpu, l_hs, \
    nlu_tt, t_to_tt_idx, tt_to_t_idx \
        = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                        num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)

    prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(
        wemb_n,
        l_n,
        wemb_h,
        l_hpu,
        l_hs,
        engine,
        tb,
        nlu_t,
        nlu_tt,
        tt_to_t_idx,
        nlu,
        beam_size=beam_size)

    # sort and generate
    pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)
    if len(pr_sql_i) != 1:
        raise EnvironmentError
    pr_sql_q1 = generate_sql_q(pr_sql_i, [tb1])
    pr_sql_q = [pr_sql_q1]

    try:
        pr_ans, _ = engine.execute_return_query(tb[0]['id'], pr_sc[0],
                                                pr_sa[0], pr_sql_i[0]['conds'])
    except:
        pr_ans = ['Answer not found.']
        pr_sql_q = ['Answer not found.']

    if show_answer_only:
        print(f'Q: {nlu[0]}')
        print(f'A: {pr_ans[0]}')
        print(f'SQL: {pr_sql_q}')

    else:
        print(
            f'START ============================================================= '
        )
        print(f'{hds}')
        if show_table:
            print(engine.show_table(table_name))
        print(f'nlu: {nlu}')
        print(f'pr_sql_i : {pr_sql_i}')
        print(f'pr_sql_q : {pr_sql_q}')
        print(f'pr_ans: {pr_ans}')
        print(
            f'---------------------------------------------------------------------'
        )

    return pr_sql_i, pr_ans
示例#20
0
        print(idx)
    return gen_sqls


if __name__ == "__main__":
    root = '/mnt/sda/qhz/sqlova'
    query_path = os.path.join(root, 'data_and_model', 'train_tok_origin.jsonl')
    table_path = os.path.join(root, 'data_and_model', 'train.tables.jsonl')
    p_sqls_path = os.path.join(root, 'data/distant_data',
                               'train_distant.jsonl')
    queries = extract.read_queries(query_path)
    p_sqlss = extract.read_potential_sqls(p_sqls_path)
    answer_path = './syn.txt'
    g_answers = extract.read_gold_answers(answer_path)
    print(len(g_answers))
    engine = DBEngine_s('./data_and_model/train.db')
    rr_p_sqlss = []
    for i, p_sqls in enumerate(tqdm(p_sqlss)):
        rr_p_sqls = []
        if (len(p_sqls)) < 3:
            rr_p_sqls = [query['query'] for query in p_sqls]
        else:
            for p_sql in p_sqls:
                qg = Query.from_dict(p_sql['query'], ordered=True)
                res = engine.execute_query(queries[i]['table_id'],
                                           qg,
                                           lower=True)
                if res == g_answers[i]:
                    rr_p_sqls.append(p_sql['query'])
        if len(rr_p_sqls) == 0:
            print(f"{i}\n")
示例#21
0
def train(train_loader,
          train_table,
          model,
          model_bert,
          opt,
          bert_config,
          tokenizer,
          max_seq_length,
          num_target_layers,
          accumulate_gradients=1,
          check_grad=True,
          st_pos=0,
          opt_bert=None,
          path_db=None,
          dset_name='train'):
    model.train()
    model_bert.train()

    ave_loss = 0
    count = 0  # count the # of examples
    count_sc = 0  # count the # of correct predictions of select column
    count_sa = 0  # of selectd aggregation
    count_wn = 0  # of where number
    count_wc = 0  # of where column
    count_wo = 0  # of where operator
    count_wv = 0  # of where-value
    count_wvi = 0  # of where-value index (on question tokens)
    count_logic_form_acc = 0  # of logical form acc
    count_execute_acc = 0  # of execution acc

    # Engine for SQL querying.
    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))

    for batch_index, batch_data in enumerate(train_loader):
        count += len(batch_data)

        if count < st_pos:
            continue
        # Get fields
        question, question_token, sql, sql_text, sql_t, table, header_token, header \
            = get_fields(batch_data, train_table, no_hs_t=True, no_sql_t=True)

        gt_select_column, gt_select_agg, gt_wherenumber, gt_wherecolumn, g_wo, g_wv = get_gt(
            sql)
        # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset.
        gt_wherevalueindex_corenlp = get_gt_wherevalueindex_corenlp(batch_data)

        emb_question, emb_header, len_question, len_header_token, number_header, \
        question_token_bert, token_to_berttoken_index, berttoken_to_token_index \
            = get_wemb_bert(bert_config, model_bert, tokenizer, question_token, header, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)

        try:
            #
            gt_wherevalueindex = get_gt_wherevalueindex_bert_from_gt_wherevalueindex_corenlp(
                token_to_berttoken_index, gt_wherevalueindex_corenlp)
        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            # e.g. train: 32.
            continue

        # score
        # gt_wherevalueindex = start_index, end_index
        # score_where_value: [batch,4,question_len,2]
        score_select_column, score_select_agg, score_where_number, score_where_column, score_where_op, score_where_value, \
            score_select_column_softmax, score_select_agg_softmax, score_where_number_softmax,\
            score_where_column_softmax, score_whereop_softmax, score_where_value_softmax \
                    = model(emb_question, len_question, emb_header, len_header_token, number_header,
                           g_sc=gt_select_column, g_sa=gt_select_agg, g_wn=gt_wherenumber,
                           g_wc=gt_wherecolumn, g_wvi=gt_wherevalueindex)

        # Calculate loss & step
        loss = Loss_selectwhere_startend_v2(
            score_select_column, score_select_agg, score_where_number,
            score_where_column, score_where_op, score_where_value,
            gt_select_column, gt_select_agg, gt_wherenumber, gt_wherecolumn,
            g_wo, gt_wherevalueindex)

        # RL

        # Random explore
        pred_selectcolumn_random, pred_selectagg_random, pred_wherenumber_random, pred_wherecolumn_random, pred_whereop_random, pred_wherevalueindex_random = \
                  pred_selectwhere_startend_random(score_select_column_softmax, score_select_agg_softmax, score_where_number_softmax,
                                                               score_where_column_softmax, score_whereop_softmax, score_where_value_softmax,)
        """
        pred_wherevalue_str_random, pred_wherevalue_str_bert_random = convert_pred_wvi_to_string(pred_wherevalueindex_random, question_token,
                                                                                  question_token_bert,
                                                                                  berttoken_to_token_index, question)
        
        # Sort pr_wc:
        #   Sort pr_wc when training the model as pr_wo and pr_wvi are predicted using ground-truth where-column (g_wc)
        #   In case of 'dev' or 'test', it is not necessary as the ground-truth is not used during inference.
        pred_wherecolumn_sorted_random = sort_pred_wherecolumn(pred_wherecolumn_random, gt_wherecolumn)
        
        random_sql_int = generate_sql_i_v2(pred_selectcolumn_random, pred_selectagg_random, pred_wherenumber_random,
                                        pred_wherecolumn_sorted_random, pred_whereop_random, pred_wherevalue_str_random, question)
        """

        # Prediction
        pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_selectwhere_startend(
            score_select_column,
            score_select_agg,
            score_where_number,
            score_where_column,
            score_where_op,
            score_where_value,
        )
        pred_wherevalue_str, pred_wherevalue_str_bert_random = convert_pred_wvi_to_string(
            pr_wvi, question_token, question_token_bert,
            berttoken_to_token_index, question)

        # Sort pr_wc:
        #   Sort pr_wc when training the model as pr_wo and pr_wvi are predicted using ground-truth where-column (g_wc)
        #   In case of 'dev' or 'test', it is not necessary as the ground-truth is not used during inference.
        pr_wc_sorted = sort_pred_wherecolumn(pr_wc, gt_wherecolumn)
        pred_sql_int = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc_sorted, pr_wo,
                                      pred_wherevalue_str, question)

        # select column
        def equal_in(cell_, pred_answer_column):
            for cell in pred_answer_column:
                if cell == cell_:
                    return True
            return False

        # select_column_random_int = torch.squeeze(torch.softmax(score_select_column,dim=-1).multinomial(1),dim=-1)
        #select_column_int = torch.squeeze(torch.argmax(score_select_column,dim=-1),dim=-1)

        # if batch_index%100==0:
        #     print("RL_loss_select_column",RL_loss_select_column.data.cpu().numpy())

        # RL
        # where number

        def list_in_list(small, big):
            for cell in big:
                try:
                    cell = str(int(cell))
                except:
                    cell = str(cell)
                if cell.lower() in small:
                    return True

        batch_size = score_where_value_softmax.shape[0]
        batch_have_good_reward = False
        while batch_have_good_reward == False:

            selectcolumn_random_v2 = []
            for i in range(batch_size):
                col = random.randint(1, len(header[i])) - 1
                selectcolumn_random_v2.append(col)

            wherenumber_random_v2 = []
            for i in range(batch_size):
                num = random.randint(1, 4)
                wherenumber_random_v2.append(num)

            selectagg_random_v2 = []
            for i in range(batch_size):
                agg = random.randint(1, 6) - 1
                selectagg_random_v2.append(agg)

            whereop_random_v2 = []
            for i in range(batch_size):
                cond = []
                for j in range(4):
                    col = random.randint(1, 3) - 1
                    cond.append(col)
                whereop_random_v2.append(cond)

            wherecolumn_random_v2 = []
            for i in range(batch_size):
                cond = []
                for j in range(4):
                    col = random.randint(1, len(header[i])) - 1
                    cond.append(col)
                wherecolumn_random_v2.append(cond)

            pred_wherevalueindex_random_v2 = []  # [8,4,2]
            for i in range(batch_size):
                cond = []
                for j in range(4):
                    start = random.randint(1, len_question[i]) - 1
                    end = random.randint(1, len_question[i]) - 1
                    while start >= end:
                        start = random.randint(1, len_question[i]) - 1
                        end = random.randint(1, len_question[i]) - 1
                    cond.append([start, end])
                pred_wherevalueindex_random_v2.append(cond)

            pred_wherevalue_str_random_v2, pred_wherevalue_str_bert_random_v2 = convert_pred_wvi_to_string(
                pred_wherevalueindex_random_v2, question_token,
                question_token_bert, berttoken_to_token_index, question)

            # pred_wherecolumn_sorted_random_v2 = sort_pred_wherecolumn(wherecolumn_random_v2, gt_wherecolumn)

            random_sql_int_v2 = generate_sql_i_v2(
                pred_selectcolumn_random, pred_selectagg_random,
                wherenumber_random_v2, wherecolumn_random_v2,
                whereop_random_v2, pred_wherevalue_str_random_v2, question)

            batch_reward_all = []
            for i in range(len(batch_data)):
                gt_answer_list = batch_data[i]["answer"]
                reward = 0
                tmp_conds = []
                for j in range(wherenumber_random_v2[i]):
                    cond = []
                    cond.append(wherecolumn_random_v2[i][j])
                    cond.append(whereop_random_v2[i][j])
                    cond.append(random_sql_int_v2[i]["conds"][j][2])
                    tmp_conds.append(cond)

                pred_answer_column = engine.execute(table[i]['id'],
                                                    selectcolumn_random_v2[i],
                                                    selectagg_random_v2[i],
                                                    tmp_conds)
                # answer in
                for cell in gt_answer_list:
                    if cell in pred_answer_column or equal_in(
                            cell, pred_answer_column):
                        reward = 1
                if reward > 0:
                    batch_have_good_reward = True
                batch_reward_all.append(reward)
            """
            batch_reward_where = []

            for i in range(len(batch_data)):
                gt_answer_list = batch_data[i]["answer"]
                reward = 0
                if len(random_sql_int_v2[i]["conds"])>0 and len(pred_sql_int[i]["conds"])>0:
                    tmp_conds = []
                    tmp_cond = pred_sql_int[i]["conds"][0]
                    tmp_cond[0] = wherecolumn_random_v2[i][0]
                    tmp_cond[1] = whereop_random_v2[i][0]
                    tmp_cond[2] = random_sql_int_v2[i]["conds"][0][2]
                    tmp_conds.append(tmp_cond)
                    if tmp_conds[0][2]!="":
                        pred_answer_column = engine.execute(table[i]['id'], "*", 0, tmp_conds)
                        # answer in
                        for cell in gt_answer_list:
                            if cell in pred_answer_column or equal_in(cell, pred_answer_column):
                                reward = 1

                        pred_answer_column4 = engine.execute(table[i]['id'], pred_sql_int[i]["sel"], 0, tmp_conds)
                        # answer absolute in
                        for cell in gt_answer_list:
                            if cell in pred_answer_column4 or equal_in(cell, pred_answer_column4):
                                reward = 1

                    if len(tmp_conds)>=1:
                        # same column: word in question and where column
                        pred_answer_column2 = engine.execute(table[i]['id'], tmp_conds[0][0], 0, [])
                        for cell in pred_answer_column2:
                            try:
                                cell = str(int(cell))
                            except:
                                cell = str(cell)
                            if cell in question[i].lower():
                                reward = 1
                        # same column: where value and where column
                            if cell == tmp_conds[0][2]:
                                reward = 1


                    tmp_conds2 = []
                    tmp_cond2 = pred_sql_int[i]["conds"][0]
                    tmp_cond2[0] = wherecolumn_random_v2[i][0]
                    tmp_cond2[2] = random_sql_int_v2[i]["conds"][0][2]
                    tmp_cond2[1] = 0 # EQUAL
                    tmp_conds2.append(tmp_cond2)
                    if len(tmp_conds2) >= 1:
                        pred_answer_column3 = engine.execute(table[i]['id'], tmp_conds2[0][0], 0, tmp_conds2)
                        # same row: the answer and this cell
                        for row in table[i]["rows"]:
                            if list_in_list(pred_answer_column3,row) and list_in_list(gt_answer_list,row):
                                reward = 1
                if reward>0:
                    batch_have_good_reward=True
                batch_reward_where.append(reward)
            """

        onehot_action_batch_wherenumber = []
        for action_int in wherenumber_random_v2:
            tmp = [0] * score_where_number.shape[1]
            tmp[action_int - 1] = 1
            onehot_action_batch_wherenumber.append(tmp)

        RL_loss_where_number = torch.mean(-torch.log(
            torch.sum(torch.softmax(score_where_number, dim=-1) * torch.tensor(
                onehot_action_batch_wherenumber, dtype=torch.float).to(device),
                      dim=-1)) * torch.tensor(batch_reward_all,
                                              dtype=torch.float).to(device))
        """
        batch_reward_select_column = []
        for i in range(len(batch_data)):
            gt_answer_list = batch_data[i]["answer"]
            pred_answer_column = engine.execute(table[i]['id'], random_sql_int[i]["sel"], 0, [])
            reward = -1
        for cell in gt_answer_list:
            if cell in pred_answer_column or equal_in(cell,pred_answer_column):
                reward = 1
        batch_reward_select_column.append(reward)
        """

        onehot_action_batch_selectcolumn = []
        for sel in selectcolumn_random_v2:
            tmp = [0] * score_select_column.shape[1]
            tmp[sel] = 1
            onehot_action_batch_selectcolumn.append(tmp)

        RL_loss_select_column = torch.mean(
            -torch.log(
                torch.sum(torch.softmax(score_select_column, dim=-1) *
                          torch.tensor(onehot_action_batch_selectcolumn,
                                       dtype=torch.float).to(device),
                          dim=-1)) *
            torch.tensor(batch_reward_all, dtype=torch.float).to(device))

        onehot_action_batch_selectagg = []
        for agg in selectagg_random_v2:
            tmp = [0] * score_select_agg.shape[1]
            tmp[agg] = 1
            onehot_action_batch_selectagg.append(tmp)

        RL_loss_select_agg = torch.mean(-torch.log(
            torch.sum(torch.softmax(score_select_agg, dim=-1) * torch.tensor(
                onehot_action_batch_selectagg, dtype=torch.float).to(device),
                      dim=-1)) * torch.tensor(batch_reward_all,
                                              dtype=torch.float).to(device))

        # RL
        # where column
        # where_column_int = torch.squeeze(torch.argmax(score_where_column, dim=-1), dim=-1)
        # where_column_int = torch.squeeze(torch.softmax(score_where_column, dim=-1).multinomial(1), dim=-1)

        where_column_int = []
        for tmp_sql in random_sql_int_v2:
            if len(tmp_sql["conds"]) == 0:
                where_column_int.append(-1)
            else:
                where_column_int.append(tmp_sql["conds"][0][0])
        onehot_action_batch_wherecolumn = []
        for action_int in wherecolumn_random_v2:
            tmp = [0] * score_where_column.shape[1]
            if action_int[0] != -1:  # all 0 will runtime error
                tmp[action_int[0]] = 1
            else:  # this is must
                tmp = [0.01] * score_where_column.shape[1]
            onehot_action_batch_wherecolumn.append(tmp)

        RL_loss_where_column = torch.mean(-torch.log(
            torch.sum(torch.softmax(score_where_column, dim=-1) * torch.tensor(
                onehot_action_batch_wherecolumn, dtype=torch.float).to(device),
                      dim=-1)) * torch.tensor(batch_reward_all,
                                              dtype=torch.float).to(device))

        # RL
        # where value
        # pred_wherevalueindex_random [8,4,2]
        action_wherevalue = []
        for one in pred_wherevalueindex_random_v2:
            if len(one) > 0:
                action = []
                for i in range(4):
                    if len(one) > i:
                        start = one[i][0]
                        tmp_start = [0] * score_where_value.shape[2]
                        tmp_start[start] = 1
                        end = one[i][1]
                        tmp_end = [0] * score_where_value.shape[2]
                        tmp_end[end] = 1
                        action.append([tmp_start, tmp_end])
                    else:
                        action.append([[0.01] * score_where_value.shape[2],
                                       [0.01] * score_where_value.shape[2]])
                action_wherevalue.append(action)
            else:
                action = []
                for i in range(4):
                    action.append([[0.01] * score_where_value.shape[2],
                                   [0.01] * score_where_value.shape[2]])
                action_wherevalue.append(action)

        tmp_action = torch.tensor(action_wherevalue,
                                  dtype=torch.float).to(device)
        tmp_score_where_value = torch.transpose(score_where_value, 2, 3)
        RL_loss_where_value = torch.mean(-torch.log(
            torch.sum(
                torch.softmax(tmp_score_where_value, dim=-1) * tmp_action,
                dim=-1)) * torch.unsqueeze(torch.unsqueeze(
                    torch.tensor(batch_reward_all, dtype=torch.float), dim=-1),
                                           dim=-1).to(device))

        # RL
        # where op
        action_whereop = []
        for one in whereop_random_v2:
            if len(one) > 0:
                action = []
                for i in range(4):
                    if len(one) > i:
                        op = one[0]
                        tmp_start = [0] * score_where_op.shape[2]
                        tmp_start[op] = 1
                        action.append(tmp_start)
                    else:
                        action.append([0.01] * score_where_op.shape[2])
                action_whereop.append(action)
            else:
                action = []
                for i in range(4):
                    action.append([0.01] * score_where_op.shape[2])
                action_whereop.append(action)

        tmp_action_op = torch.tensor(action_whereop,
                                     dtype=torch.float).to(device)
        RL_loss_where_op = torch.mean(-torch.log(
            torch.sum(torch.softmax(score_where_op, dim=-1) * tmp_action_op,
                      dim=-1)) * torch.unsqueeze(torch.tensor(
                          batch_reward_all, dtype=torch.float),
                                                 dim=-1).to(device))

        loss += RL_loss_where_number
        loss += RL_loss_select_agg
        loss += RL_loss_select_column
        loss += RL_loss_where_op
        loss += RL_loss_where_column
        loss += RL_loss_where_value
        # Calculate gradient
        if batch_index % accumulate_gradients == 0:  # mode
            # at start, perform zero_grad
            opt.zero_grad()
            if opt_bert:
                opt_bert.zero_grad()
            loss.backward()
            if accumulate_gradients == 1:
                opt.step()
                if opt_bert:
                    opt_bert.step()
        elif batch_index % accumulate_gradients == (accumulate_gradients - 1):
            # at the final, take step with accumulated graident
            loss.backward()
            opt.step()
            if opt_bert:
                opt_bert.step()
        else:
            # at intermediate stage, just accumulates the gradients
            loss.backward()

        # Cacluate accuracy
        count_sc1_list, count_sa1_list, count_wn1_list, \
        count_wc1_list, count_wo1_list, \
        count_wvi1_list, count_wv1_list = get_count_sw_list(gt_select_column, gt_select_agg, gt_wherenumber, gt_wherecolumn, g_wo, gt_wherevalueindex,
                                                                   pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi,
                                                                   sql, pred_sql_int,
                                                                   mode='train')

        count_lx1_list = get_count_lx_list(count_sc1_list, count_sa1_list,
                                           count_wn1_list, count_wc1_list,
                                           count_wo1_list, count_wv1_list)
        # lx stands for logical form accuracy

        # Execution accuracy test.
        count_x1_list, g_ans, pr_ans = get_count_x_list(
            engine, table, gt_select_column, gt_select_agg, sql, pr_sc, pr_sa,
            pred_sql_int)

        # statistics
        ave_loss += loss.item()

        # count
        count_sc += sum(count_sc1_list)
        count_sa += sum(count_sa1_list)
        count_wn += sum(count_wn1_list)
        count_wc += sum(count_wc1_list)
        count_wo += sum(count_wo1_list)
        count_wvi += sum(count_wvi1_list)
        count_wv += sum(count_wv1_list)
        count_logic_form_acc += sum(count_lx1_list)
        count_execute_acc += sum(count_x1_list)

    ave_loss /= count
    acc_sc = count_sc / count
    acc_sa = count_sa / count
    acc_wn = count_wn / count
    acc_wc = count_wc / count
    acc_wo = count_wo / count
    acc_wvi = count_wv / count
    acc_wv = count_wv / count
    acc_lx = count_logic_form_acc / count
    acc_x = count_execute_acc / count

    acc = [
        ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv,
        acc_lx, acc_x
    ]

    aux_out = 1

    return acc, aux_out
示例#22
0
def test(data_loader,
         data_table,
         model,
         model_bert,
         bert_config,
         tokenizer,
         max_seq_length,
         num_target_layers,
         detail=False,
         st_pos=0,
         cnt_tot=1,
         EG=False,
         beam_size=4,
         path_db=None,
         dset_name='test',
         mvl=2):
    model.eval()
    model_bert.eval()

    engine = DBEngine(os.path.join(path_db, dset_name, f"{dset_name}.db"))
    results = []
    for iB, t in enumerate(data_loader):
        #print('iB: ', iB)#to locate the error

        #print(iB)

        # Get fields
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(
            t, data_table, no_hs_t=True, no_sql_t=True, generate_mode=True)

        wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx, wemb_v, l_npu, l_token \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers, num_out_layers_v=num_target_layers)

        # model specific part
        # score
        if not EG:
            # No Execution guided decoding
            s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4 = model(
                mvl, wemb_n, l_n, wemb_h, l_hpu, l_hs, wemb_v, l_npu, l_token)

            # get loss & step
            #loss = Loss_sw_se(s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wrpc, s_nrpc, s_wc, s_wo, s_wv1, s_wv2, g_sn, g_sc, g_sa, g_wn, g_dwn, g_wr, g_wc, g_wo, g_wvi, g_wrcn)
            #unable for loss
            loss = torch.tensor([0])
            # prediction
            pr_sn, pr_sc, pr_sa, pr_wn, pr_wr, pr_hrpc, pr_wc, pr_wo, pr_wvi = pred_sw_se(
                s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2,
                s_wv3, s_wv4, mvl)
            pr_wvi_decode = g_wvi_decoder_stidx_length_jian_yi(pr_wvi)
            pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(
                pr_wvi_decode, nlu_t, nlu_tt, tt_to_t_idx)
            # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu)
            pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wr, pr_wc, pr_wo,
                                      pr_wv_str, nlu)
        else:
            # Execution guided decoding
            s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4 = model(
                mvl, wemb_n, l_n, wemb_h, l_hpu, l_hs, wemb_v, l_npu, l_token)
            pr_sn1, pr_sc1, pr_sa1, pr_wn1, pr_wr1, pr_hrpc1, pr_wc1, pr_wo1, pr_wvi1 = pred_sw_se(
                s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2,
                s_wv3, s_wv4, mvl)
            pr_wvi_decode = g_wvi_decoder_stidx_length_jian_yi(pr_wvi1)
            pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(
                pr_wvi_decode, nlu_t, nlu_tt, tt_to_t_idx)
            pr_sql_i1 = generate_sql_i(pr_sc1, pr_sa1, pr_wn1, pr_wr1, pr_wc1,
                                       pr_wo1, pr_wv_str, nlu)

            # Execution guided decoding
            pr_sql_i, exe_error1, still_error1 = model.beam_forward(
                pr_sql_i1,
                mvl,
                wemb_n,
                l_n,
                wemb_h,
                l_hpu,
                l_hs,
                wemb_v,
                l_npu,
                l_token,
                engine,
                tb,
                nlu_t,
                beam_size=beam_size)
            # sort and generate
            #print(pr_sql_i)
            #pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)

            #pr_sn, pr_sc, pr_sa, pr_wn, pr_wr, pr_wc, pr_wo, pr_wv = generate_pr(pr_sql_i)

            # Follosing variables are just for the consistency with no-EG case.
            pr_wvi = None  # not used
            pr_wv_str = None
            pr_wv_str_wp = None
            loss = torch.tensor([0])

        #pr_sql_q = generate_sql_q(pr_sql_i, tb)

        # Saving for the official evaluation later.
        for b, pr_sql_i1 in enumerate(pr_sql_i):

            results1 = pr_sql_i1
            # print(results1)
            results.append(results1)

    return results
示例#23
0
def test(data_loader,
         data_table,
         model,
         model_bert,
         bert_config,
         tokenizer,
         max_seq_length,
         num_target_layers,
         detail=False,
         st_pos=0,
         cnt_tot=1,
         EG=False,
         beam_size=4,
         path_db=None,
         dset_name='test',
         col_pool_type='start_tok',
         aug=False):
    model.eval()
    model_bert.eval()

    ave_loss = 0
    cnt = 0
    cnt_sc = 0
    cnt_sa = 0
    cnt_wn = 0
    cnt_wc = 0
    cnt_wo = 0
    cnt_wv = 0
    cnt_wvi = 0
    cnt_lx = 0
    cnt_x = 0

    cnt_list = []
    p_list = []  # List of prediction probabilities.
    data_list = [
    ]  # Miscellanerous data. Save it for later convenience of analysis.

    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))
    results = []
    for iB, t in enumerate(data_loader):

        cnt += len(t)
        if cnt < st_pos:
            continue
        # Get fields
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(
            t, data_table, no_hs_t=True, no_sql_t=True)

        g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
        g_wvi_corenlp = get_g_wvi_corenlp(t)

        all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, \
        l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_bert_output(model_bert, tokenizer, nlu_t, hds, max_seq_length)

        try:
            g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx,
                                                      g_wvi_corenlp)
            g_wv_str, g_wv_str_wp = convert_pr_wvi_to_string(
                g_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)

        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            for b in range(len(nlu)):
                results1 = {}
                results1["error"] = "Skip happened"
                results1["nlu"] = nlu[b]
                results1["table_id"] = tb[b]["id"]
                results.append(results1)
            continue

        # model specific part
        # score
        wemb_n = get_wemb_n(i_nlu, l_n, bert_config.hidden_size,
                            bert_config.num_hidden_layers, all_encoder_layer,
                            1)
        wemb_h = get_wemb_h_FT_Scalar_1(i_hds,
                                        l_hs,
                                        bert_config.hidden_size,
                                        all_encoder_layer,
                                        col_pool_type=col_pool_type)
        # wemb_h = [B, max_header_number, hS]
        cls_vec = pooled_output
        # No Execution guided decoding
        if not EG:

            s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h,
                                                       l_hs, cls_vec)

            # get loss & step
            loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa,
                              g_wn, g_wc, g_wo, g_wvi)

            # prediction
            pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(
                s_sc,
                s_sa,
                s_wn,
                s_wc,
                s_wo,
                s_wv,
            )
            pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(
                pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)
            # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu)
            pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo,
                                      pr_wv_str, nlu)

            # calculate probability
            p_tot, p_select, p_where, p_sc, p_sa, p_wn, p_wc, p_wo, p_wvi \
                = cal_prob(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv,
                           pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi)

        else:
            # Execution guided decoding
            pr_sc_best, pr_sa_best, pr_wn_based_on_prob, pr_wvi_best, \
            pr_sql_i, p_tot, p_select, p_where, p_sc_best, p_sa_best, \
            p_wn_best, p_wc_best, p_wo_best, p_wvi_best \
                = model.forward_EG(wemb_n, l_n, wemb_h, l_hs, cls_vec, engine, tb,
                                   nlu_t, nlu_tt, tt_to_t_idx, nlu,
                                   beam_size=beam_size)

            pr_sc = pr_sc_best
            pr_sa = pr_sa_best
            pr_wn = pr_wn_based_on_prob

            p_sc = p_sc_best
            p_sa = p_sa_best
            p_wn = p_wn_best

            # sort and generate: prob-based-sort (descending) -> wc-idx-based-sort (ascending)
            pr_wc, pr_wo, pr_wv_str, pr_wvi, pr_sql_i, \
            p_wc, p_wo, p_wvi = sort_and_generate_pr_w(pr_sql_i, pr_wvi_best, p_wc_best, p_wo_best, p_wvi_best)

            # Follosing variables are just for the consistency with no-EG case.
            pr_wv_str_wp = None
            loss = torch.tensor([0])

        p_list_batch = [
            p_tot, p_select, p_where, p_sc, p_sa, p_wn, p_wc, p_wo, p_wvi
        ]
        p_list.append(p_list_batch)

        g_sql_q = generate_sql_q(sql_i, tb)
        pr_sql_q = generate_sql_q(pr_sql_i, tb)

        # Saving for the official evaluation later.
        for b, pr_sql_i1 in enumerate(pr_sql_i):
            results1 = {}
            results1["query"] = pr_sql_i1
            results1["table_id"] = tb[b]["id"]
            results1["nlu"] = nlu[b]
            results.append(results1)

        cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \
        cnt_wc1_list, cnt_wo1_list, \
        cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi,
                                                      pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi,
                                                      sql_i, pr_sql_i,
                                                      mode='test')

        cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list,
                                       cnt_wn1_list, cnt_wc1_list,
                                       cnt_wo1_list, cnt_wv1_list)

        # Execution accura y test
        cnt_x1_list = []
        # lx stands for logical form accuracy

        # Execution accuracy test.
        if not aug:
            cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(
                engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i)
        else:
            cnt_x1_list = [0] * len(t)
            g_ans = ['N/A (data augmented'] * len(t)
            pr_ans = ['N/A (data augmented'] * len(t)
        # stat
        ave_loss += loss.item()

        # count
        cnt_sc += sum(cnt_sc1_list)
        cnt_sa += sum(cnt_sa1_list)
        cnt_wn += sum(cnt_wn1_list)
        cnt_wc += sum(cnt_wc1_list)
        cnt_wo += sum(cnt_wo1_list)
        cnt_wv += sum(cnt_wv1_list)
        cnt_wvi += sum(cnt_wvi1_list)
        cnt_lx += sum(cnt_lx1_list)
        cnt_x += sum(cnt_x1_list)

        current_cnt = [
            cnt_tot, cnt, cnt_sc, cnt_sa, cnt_wn, cnt_wc, cnt_wo, cnt_wv,
            cnt_wvi, cnt_lx, cnt_x
        ]
        cnt_list_batch = [
            cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list,
            cnt_wo1_list, cnt_wv1_list, cnt_lx1_list, cnt_x1_list
        ]
        cnt_list.append(cnt_list_batch)
        # report
        if detail:
            report_detail(hds, nlu, g_sc, g_sa, g_wn, g_wc, g_wo, g_wv,
                          g_wv_str, g_sql_q, g_ans, pr_sc, pr_sa, pr_wn, pr_wc,
                          pr_wo, pr_wv_str, pr_sql_q, pr_ans, cnt_list_batch,
                          current_cnt)
        data_batch = []
        for b, nlu1 in enumerate(nlu):
            data1 = [
                nlu[b], nlu_t[b], sql_i[b], g_sql_q[b], g_ans[b], pr_sql_i[b],
                pr_sql_q[b], pr_ans[b], tb[b]
            ]
            data_batch.append(data1)

        data_list.append(data_batch)

    ave_loss /= cnt
    acc_sc = cnt_sc / cnt
    acc_sa = cnt_sa / cnt
    acc_wn = cnt_wn / cnt
    acc_wc = cnt_wc / cnt
    acc_wo = cnt_wo / cnt
    acc_wvi = cnt_wvi / cnt
    acc_wv = cnt_wv / cnt
    acc_lx = cnt_lx / cnt
    acc_x = cnt_x / cnt

    acc = [
        ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv,
        acc_lx, acc_x
    ]
    return acc, results, cnt_list, p_list, data_list
示例#24
0
def train(train_loader,
          train_table,
          model,
          model_bert,
          opt,
          bert_config,
          tokenizer,
          max_seq_length,
          num_target_layers,
          accumulate_gradients=1,
          check_grad=True,
          st_pos=0,
          opt_bert=None,
          path_db=None,
          dset_name='train'):

    ave_loss = 0
    count = 0  # count the # of examples
    count_sc = 0  # count the # of correct predictions of select column
    count_sa = 0  # of selectd aggregation
    count_wn = 0  # of where number
    count_wc = 0  # of where column
    count_wo = 0  # of where operator
    count_wv = 0  # of where-value
    count_wvi = 0  # of where-value index (on question tokens)
    count_logic_form_acc = 0  # of logical form acc
    count_execute_acc = 0  # of execution acc

    # Engine for SQL querying.
    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))

    explored_data_list = []

    for batch_index, batch_data in enumerate(train_loader):

        count += len(batch_data)

        if count < st_pos:
            continue
        # Get fields
        question, question_token, sql, sql_text, sql_t, table, header_token, header \
            = get_fields(batch_data, train_table, no_hs_t=True, no_sql_t=True)

        len_question_bert, len_header_token, number_header, \
        question_token_bert, token_to_berttoken_index, berttoken_to_token_index \
            = get_wemb_bert_v2(bert_config, model_bert, tokenizer, question_token, header, max_seq_length,
                               num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)

        # select column
        def equal_in(cell_, pred_answer_column):
            for cell in pred_answer_column:
                if cell == cell_:
                    return True
            return False

        # RL
        # where number

        def list_in_list(small, big):
            for cell in big:
                try:
                    cell_ = int(cell)
                    if cell_ in small:
                        return True
                    cell_ = float(cell)
                    if cell_ in small:
                        return True
                except:
                    cell_ = str(cell)
                    if cell_.lower() in small:
                        return True

            for cell in small:
                try:
                    cell_ = int(cell)
                    if cell_ in big:
                        return True
                    cell_ = float(cell)
                    if cell_ in big:
                        return True
                except:
                    cell_ = str(cell)
                    if cell_.lower() in big:
                        return True

            return False

        def list_exact_match(input1, input2):
            tmp1 = [str(item) for item in input1]
            tmp2 = [str(item) for item in input2]
            if sorted(tmp1) == sorted(tmp2):
                return True
            return False

        def contains(big_list, small_list):
            return set(small_list).issubset(set(big_list))

        for i in range(len(batch_data)):
            print(sql[i])
            explored_data = {}
            explore_count = 0
            breakall = False

            reward_where = False
            reward_where_cond1 = 0
            reward_where_cond2 = 0
            reward_where_cond3 = 0
            reward_where_cond4 = 0

            len_question = len(question_token[i])

            gt_answer_list = batch_data[i]["answer"]

            where_number_random = 0
            select_column_random = -1
            select_agg_random = 0
            col = 0
            op = 0
            start = 0
            end = 0

            col2 = 0
            op2 = 0
            start2 = 0
            end2 = 0

            col3 = 0
            op3 = 0
            start3 = 0
            end3 = 0

            col4 = 0
            op4 = 0
            start4 = 0
            end4 = 0

            saved_where_number = -1
            saved_col1 = -1
            saved_op1 = -1
            saved_start1 = -1
            saved_end1 = -1
            saved = False

            # select_column_random = 3
            # where_number_random = 1
            while True:

                final_select_agg = None
                final_select_column = None
                final_conds = []

                tmp_conds = []
                if where_number_random == 0:
                    select_column_random += 1
                    if select_column_random == len(header[i]):
                        select_agg_random += 1
                        select_column_random = 0
                    if select_agg_random == 6:
                        where_number_random += 1
                        select_agg_random = 0

                if where_number_random == 1:

                    end += 1
                    if end >= len_question + 1:
                        start += 1
                        end = start + 1
                    if start >= len_question:
                        op += 1
                        start = 0
                    if op == 3:
                        col += 1
                        op = 0
                    if col == len(header[i]):
                        select_column_random += 1
                        col = 0
                    if select_column_random == len(header[i]):
                        select_agg_random += 1
                        select_column_random = 0
                    if select_agg_random == 6:
                        where_number_random += 1
                        select_agg_random = 0

                if saved == True and where_number_random == 2:
                    col = saved_col1
                    start = saved_start1
                    end = saved_end1
                    op = saved_op1
                    end2 += 1
                    if end2 >= len_question + 1:
                        start2 += 1
                        end2 = start2 + 1
                    if start2 >= len_question:
                        op2 += 1
                        start2 = 0
                    if op2 == 3:
                        col2 += 1
                        op2 = 0
                    if col2 == len(header[i]):
                        select_column_random += 1
                        col2 = 0
                    if select_column_random == len(header[i]):
                        select_agg_random += 1
                        select_column_random = 0
                    if select_agg_random == 6:
                        where_number_random += 1
                        select_agg_random = 0

                if saved == False and where_number_random == 2:
                    end += 1
                    if end >= len_question + 1:
                        start += 1
                        end = start + 1
                    if start >= len_question:
                        op += 1
                        start = 0
                    if op == 3:
                        col += 1
                        op = 0
                    if col == len(header[i]):
                        end2 += 1
                        col = 0
                    if end2 >= len_question + 1:
                        start2 += 1
                        end2 = start2 + 1
                    if start2 >= len_question:
                        op2 += 1
                        start2 = 0
                    if op2 == 3:
                        col2 += 1
                        op2 = 0
                    if col2 == len(header[i]):
                        select_column_random += 1
                        col2 = 0
                    if select_column_random == len(header[i]):
                        select_agg_random += 1
                        select_column_random = 0
                    if select_agg_random == 6:
                        where_number_random += 1
                        select_agg_random = 0

                if where_number_random == 3:
                    # break #TODO
                    end += 1
                    if end >= len_question + 1:
                        start += 1
                        end = start + 1
                    if start >= len_question:
                        op += 1
                        start = 0
                    if op == 3:
                        col += 1
                        op = 0
                    if col == len(header[i]):
                        end2 += 1
                        col = 0
                    if end2 >= len_question + 1:
                        start2 += 1
                        end2 = start2 + 1
                    if start2 >= len_question:
                        op2 += 1
                        start2 = 0
                    if op2 == 3:
                        col2 += 1
                        op2 = 0
                    if col2 == len(header[i]):
                        end3 += 1
                        col2 = 0
                    if end3 >= len_question + 1:
                        start3 += 1
                        end3 = start3 + 1
                    if start3 >= len_question:
                        op3 += 1
                        start3 = 0
                    if op3 == 3:
                        col3 += 1
                        op3 = 0
                    if col3 == len(header[i]):
                        select_column_random += 1
                        col3 = 0
                    if select_column_random == len(header[i]):
                        select_agg_random += 1
                        select_column_random = 0
                    if select_agg_random == 6:
                        where_number_random += 1
                        select_agg_random = 0

                if where_number_random == 4:
                    end += 1
                    if end >= len_question + 1:
                        start += 1
                        end = start + 1
                    if start >= len_question:
                        op += 1
                        start = 0
                    if op == 3:
                        col += 1
                        op = 0
                    if col == len(header[i]):
                        end2 += 1
                        col = 0
                    if end2 >= len_question + 1:
                        start2 += 1
                        end2 = start2 + 1
                    if start2 >= len_question:
                        op2 += 1
                        start2 = 0
                    if op2 == 3:
                        col2 += 1
                        op2 = 0
                    if col2 == len(header[i]):
                        end3 += 1
                        col2 = 0
                    if end3 >= len_question + 1:
                        start3 += 1
                        end3 = start3 + 1
                    if start3 >= len_question:
                        op3 += 1
                        start3 = 0
                    if op3 == 3:
                        col3 += 1
                        op3 = 0
                    if col3 == len(header[i]):
                        end4 += 1
                        col3 = 0
                    if end4 >= len_question + 1:
                        start4 += 1
                        end4 = start4 + 1
                    if start4 >= len_question:
                        op4 += 1
                        start4 = 0
                    if op4 == 3:
                        col4 += 1
                        op4 = 0
                    if col4 == len(header[i]):
                        select_column_random += 1
                        col4 = 0
                    if select_column_random == len(header[i]):
                        select_agg_random += 1
                        select_column_random = 0
                    if select_agg_random == 6:
                        where_number_random += 1
                        select_agg_random = 0

                if where_number_random == 1:
                    cond = []
                    cond.append(col)
                    cond.append(op)

                    pr_wv_str = question_token[i][start:end]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    # if type(cond_value_) == str:  # and random.randint(1,2)==1:
                    #     op = 0
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                if where_number_random == 2:
                    cond = []
                    cond.append(col)
                    cond.append(op)
                    pr_wv_str = question_token[i][start:end]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                    cond = []
                    cond.append(col2)
                    cond.append(op2)
                    pr_wv_str = question_token[i][start2:end2]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                if where_number_random == 3:
                    cond = []
                    cond.append(col)
                    cond.append(op)
                    pr_wv_str = question_token[i][start:end]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                    cond = []
                    cond.append(col2)
                    cond.append(op2)
                    pr_wv_str = question_token[i][start2:end2]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                    cond = []
                    cond.append(col3)
                    cond.append(op3)
                    pr_wv_str = question_token[i][start3:end3]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                if where_number_random == 4:
                    cond = []
                    cond.append(col)
                    cond.append(op)
                    pr_wv_str = question_token[i][start:end]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                    cond = []
                    cond.append(col2)
                    cond.append(op2)
                    pr_wv_str = question_token[i][start2:end2]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                    cond = []
                    cond.append(col3)
                    cond.append(op3)
                    pr_wv_str = question_token[i][start3:end3]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                    cond = []
                    cond.append(col4)
                    cond.append(op4)
                    pr_wv_str = question_token[i][start4:end4]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                # print(select_column_random, select_agg_random, tmp_conds)
                pred_answer_column = engine.execute(table[i]['id'],
                                                    select_column_random,
                                                    select_agg_random,
                                                    tmp_conds)

                explore_count += 1
                if explore_count % 100000 == 0:
                    print(explore_count)
                if explore_count > 500000:
                    break

                exact_match = list_exact_match(gt_answer_list,
                                               pred_answer_column)
                if where_number_random==1 and not exact_match and contains(pred_answer_column,gt_answer_list)\
                        and saved_where_number==-1 and op==0: # 可能会导致1condition的错过
                    # where_number_random = 2 # 可能会导致1condition的错过
                    saved_start1 = start
                    saved_end1 = end
                    saved_col1 = col
                    saved_op1 = op
                    saved = True

                # answer in
                if exact_match:
                    if pred_answer_column == [None]:
                        break

                    print("explore sql", select_column_random,
                          select_agg_random, tmp_conds)

                    if type(gt_answer_list[0]
                            ) == str and select_agg_random != 0:
                        print("fake sql")
                    elif where_number_random == 1 and type(tmp_conds[0][2])==str and tmp_conds[0][1]!=0 or\
                        where_number_random == 2 and type(tmp_conds[1][2]) == str and tmp_conds[1][1] != 0 or\
                        where_number_random == 3 and type(tmp_conds[2][2]) == str and tmp_conds[2][1] != 0 or \
                        where_number_random == 4 and type(tmp_conds[3][2]) == str and tmp_conds[3][1] != 0:
                        print("fake sql")
                    else:
                        # print("explore answer", pred_answer_column)
                        if type(pred_answer_column[0]) == int or type(
                                pred_answer_column[0]) == float:
                            final_select_agg = select_agg_random
                        else:
                            final_select_agg = 0

                        if final_select_agg == 0:
                            pred_answer_column2 = engine.execute(
                                table[i]['id'], select_column_random, 0, [])
                            for cell in gt_answer_list:
                                if cell in pred_answer_column2 or equal_in(
                                        cell, pred_answer_column2):
                                    final_select_column = select_column_random
                                    break
                        else:
                            final_select_column = select_column_random

                        if final_select_agg == 0:
                            pred_answer_column3 = engine.execute(
                                table[i]['id'], "*", 0, tmp_conds)
                            # answer in
                            for cell in gt_answer_list:
                                if cell in pred_answer_column3 or equal_in(
                                        cell, pred_answer_column3):
                                    reward_where = True
                                    break
                        else:
                            reward_where = True

                        # same column: word in question and where column
                        if where_number_random >= 1:
                            pred_answer_column4 = engine.execute(
                                table[i]['id'], tmp_conds[0][0], 0, [])
                            for cell in pred_answer_column4:
                                try:
                                    cell_ = str(float(cell))
                                    if cell_ in question[i].lower():
                                        reward_where_cond1 += 0.1
                                        break
                                    cell_ = str(int(cell))
                                    if cell_ in question[i].lower():
                                        reward_where_cond1 += 0.1
                                        break
                                except:
                                    cell = str(cell)
                                    if cell in question[i].lower():
                                        reward_where_cond1 += 0.1
                                        break
                                # same column: where value and where column
                            value = tmp_conds[0][2]
                            if value in pred_answer_column4:
                                reward_where_cond1 += 0.1
                            try:
                                value = float(tmp_conds[0][2])
                                if value in pred_answer_column4:
                                    reward_where_cond1 += 0.1
                            except:
                                pass
                            try:
                                value = int(tmp_conds[0][2])
                                if value in pred_answer_column4:
                                    reward_where_cond1 += 0.1
                            except:
                                pass
                            try:
                                value = str(int(tmp_conds[0][2]))
                                if value in pred_answer_column4:
                                    reward_where_cond1 += 0.1
                            except:
                                pass
                            try:
                                value = str(float(tmp_conds[0][2]))
                                if value in pred_answer_column4:
                                    reward_where_cond1 += 0.1
                            except:
                                pass

                        # same column: word in question and where column
                        if where_number_random >= 2:
                            pred_answer_column4 = engine.execute(
                                table[i]['id'], tmp_conds[1][0], 0, [])
                            for cell in pred_answer_column4:
                                try:
                                    cell_ = str(float(cell))
                                    if cell_ in question[i].lower():
                                        reward_where_cond2 += 0.1
                                        break
                                    cell_ = str(int(cell))
                                    if cell_ in question[i].lower():
                                        reward_where_cond2 += 0.1
                                        break
                                except:
                                    cell = str(cell)
                                    if cell in question[i].lower():
                                        reward_where_cond2 += 0.1
                                        break
                                # same column: where value and where column
                            value = tmp_conds[1][2]
                            if value in pred_answer_column4:
                                reward_where_cond2 += 0.1
                            try:
                                value = float(tmp_conds[1][2])
                                if value in pred_answer_column4:
                                    reward_where_cond2 += 0.1
                            except:
                                pass
                            try:
                                value = int(tmp_conds[1][2])
                                if value in pred_answer_column4:
                                    reward_where_cond2 += 0.1
                            except:
                                pass
                            try:
                                value = str(int(tmp_conds[1][2]))
                                if value in pred_answer_column4:
                                    reward_where_cond2 += 0.1
                            except:
                                pass
                            try:
                                value = str(float(tmp_conds[1][2]))
                                if value in pred_answer_column4:
                                    reward_where_cond2 += 0.1
                            except:
                                pass

                        # same column: word in question and where column
                        if where_number_random >= 3:
                            pred_answer_column4 = engine.execute(
                                table[i]['id'], tmp_conds[2][0], 0, [])
                            for cell in pred_answer_column4:
                                try:
                                    cell_ = str(float(cell))
                                    if cell_ in question[i].lower():
                                        reward_where_cond3 += 0.1
                                        break
                                    cell_ = str(int(cell))
                                    if cell_ in question[i].lower():
                                        reward_where_cond3 += 0.1
                                        break
                                except:
                                    cell = str(cell)
                                    if cell in question[i].lower():
                                        reward_where_cond3 += 0.1
                                        break
                                # same column: where value and where column
                            value = tmp_conds[2][2]
                            if value in pred_answer_column4:
                                reward_where_cond3 += 0.1
                            try:
                                value = float(tmp_conds[2][2])
                                if value in pred_answer_column4:
                                    reward_where_cond3 += 0.1
                            except:
                                pass
                            try:
                                value = int(tmp_conds[2][2])
                                if value in pred_answer_column4:
                                    reward_where_cond3 += 0.1
                            except:
                                pass
                            try:
                                value = str(int(tmp_conds[2][2]))
                                if value in pred_answer_column4:
                                    reward_where_cond3 += 0.1
                            except:
                                pass
                            try:
                                value = str(float(tmp_conds[2][2]))
                                if value in pred_answer_column4:
                                    reward_where_cond3 += 0.1
                            except:
                                pass

                        # same column: word in question and where column
                        if where_number_random >= 4:
                            pred_answer_column4 = engine.execute(
                                table[i]['id'], tmp_conds[3][0], 0, [])
                            for cell in pred_answer_column4:
                                try:
                                    cell_ = str(float(cell))
                                    if cell_ in question[i].lower():
                                        reward_where_cond4 += 0.1
                                        break
                                    cell_ = str(int(cell))
                                    if cell_ in question[i].lower():
                                        reward_where_cond4 += 0.1
                                        break
                                except:
                                    cell = str(cell)
                                    if cell in question[i].lower():
                                        reward_where_cond4 += 0.1
                                        break
                                # same column: where value and where column
                            value = tmp_conds[3][2]
                            if value in pred_answer_column4:
                                reward_where_cond4 += 0.1
                            try:
                                value = float(tmp_conds[3][2])
                                if value in pred_answer_column4:
                                    reward_where_cond4 += 0.1
                            except:
                                pass
                            try:
                                value = int(tmp_conds[3][2])
                                if value in pred_answer_column4:
                                    reward_where_cond4 += 0.1
                            except:
                                pass
                            try:
                                value = str(int(tmp_conds[3][2]))
                                if value in pred_answer_column4:
                                    reward_where_cond4 += 0.1
                            except:
                                pass
                            try:
                                value = str(float(tmp_conds[3][2]))
                                if value in pred_answer_column4:
                                    reward_where_cond4 += 0.1
                            except:
                                pass
                        """ 有问题,cond op 只能强制为 = 因为 > 或 < 不在一行
                        if where_number_random >= 1 and final_select_agg==0:
                            tmp_conds2 = tmp_conds
                            tmp_conds2[0][1] = 0  # EQUAL
                            pred_answer_column5 = engine.execute(table[i]['id'], tmp_conds2[0][0], 0, tmp_conds2)
                            # same row: the answer and this cell
                            for row in table[i]["rows"]:
                                if list_in_list(pred_answer_column5, row) and list_in_list(gt_answer_list, row):
                                    reward_where_cond1 += 0.1
                                    break
    
                        if where_number_random >= 2 and final_select_agg==0:
                            tmp_conds2 = tmp_conds
                            tmp_conds2[0][1] = 0  # EQUAL
                            tmp_conds2[1][1] = 0  # EQUAL
                            pred_answer_column5 = engine.execute(table[i]['id'], tmp_conds2[1][0], 0, tmp_conds2)
                            # same row: the answer and this cell
                            for row in table[i]["rows"]:
                                if list_in_list(pred_answer_column5, row) and list_in_list(gt_answer_list, row):
                                    reward_where_cond2 += 0.1
                                    break
    
                        if where_number_random >= 3 and final_select_agg==0:
                            tmp_conds2 = tmp_conds
                            tmp_conds2[0][1] = 0  # EQUAL
                            tmp_conds2[1][1] = 0  # EQUAL
                            tmp_conds2[2][1] = 0  # EQUAL
                            pred_answer_column5 = engine.execute(table[i]['id'], tmp_conds2[2][0], 0, tmp_conds2)
                            # same row: the answer and this cell
                            for row in table[i]["rows"]:
                                if list_in_list(pred_answer_column5, row) and list_in_list(gt_answer_list, row):
                                    reward_where_cond3 += 0.1
                                    break
    
                        if where_number_random >= 4 and final_select_agg==0:
                            tmp_conds2 = tmp_conds
                            tmp_conds2[0][1] = 0  # EQUAL
                            tmp_conds2[1][1] = 0  # EQUAL
                            tmp_conds2[2][1] = 0  # EQUAL
                            tmp_conds2[3][1] = 0  # EQUAL
                            pred_answer_column5 = engine.execute(table[i]['id'], tmp_conds2[3][0], 0, tmp_conds2)
                            # same row: the answer and this cell
                            for row in table[i]["rows"]:
                                if list_in_list(pred_answer_column5, row) and list_in_list(gt_answer_list, row):
                                    reward_where_cond4 += 0.1
                                    break
                        """

                        if reward_where_cond1 >= 0.2 and reward_where == True and where_number_random >= 1:
                            final_conds.append(tmp_conds[0])
                        if reward_where_cond2 >= 0.2 and reward_where == True and where_number_random >= 2:
                            final_conds.append(tmp_conds[1])
                        if reward_where_cond3 >= 0.2 and reward_where == True and where_number_random >= 3:
                            final_conds.append(tmp_conds[2])
                        if reward_where_cond4 >= 0.2 and reward_where == True and where_number_random >= 4:
                            final_conds.append(tmp_conds[3])
                        if final_select_agg != None and final_select_column != None and (
                                where_number_random == 1 and len(final_conds)
                                == 1 or where_number_random == 2
                                and len(final_conds) == 2
                                or where_number_random == 3
                                and len(final_conds) == 3
                                or where_number_random == 4
                                and len(final_conds) == 4):
                            break
                        if final_select_agg != None and final_select_column != None and where_number_random == 0:
                            break

            if final_select_column != None:
                explored_data["sel"] = final_select_column
                explored_data["agg"] = final_select_agg
                explored_data["conds"] = final_conds
                explored_data_list.append(explored_data)
                print(len(explored_data_list))
                one_data = batch_data[i]
                one_data["sql"] = explored_data
                one_data["query"] = explored_data
                f = open("gen_data.jsonl", mode="a", encoding="utf-8")
                json.dump(one_data, f)
                f.write('\n')
                f.close()

    print("Done")
    ave_loss /= count
    acc_sc = count_sc / count
    acc_sa = count_sa / count
    acc_wn = count_wn / count
    acc_wc = count_wc / count
    acc_wo = count_wo / count
    acc_wvi = count_wv / count
    acc_wv = count_wv / count
    acc_lx = count_logic_form_acc / count
    acc_x = count_execute_acc / count

    acc = [
        ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv,
        acc_lx, acc_x
    ]

    aux_out = 1

    return acc, aux_out
示例#25
0
from sqlnet.dbengine import DBEngine
from rl.train_rl import config
from train import *
import json
engine_train = DBEngine("train.db")
engine_dev = DBEngine("dev.db")

train_data, train_table, dev_data, dev_table, _, _ = load_wikisql(
    "./", False, -1, no_w2i=True, no_hs_tok=True)
train_loader, dev_loader = get_loader_wikisql(train_data,
                                              dev_data,
                                              32,
                                              shuffle_train=False)


def process(train_data_, name, engine_):
    for i, item in enumerate(train_data_):
        if i % 100 == 0:
            print(i)
        # if i==15988:
        #     print()

        # sql = {'sel': 5, 'conds': [[3, 0, "26"], [6, 1, "8"]], 'agg': 1}
        # table_id = '2-10240125-1'
        # t = train_table[table_id]
        # a = engine_.execute_query_v2(table_id, sql)

        answer = engine_.execute_query_v2(item["table_id"], item["sql"])
        if answer == [None]:
            print(None)
        train_data_[i]["answer"] = answer
示例#26
0
import os

from sqlnet.dbengine import DBEngine
import json
import time

path_db = './wikisql/data/tianchi/'
dset_name = 'val'

# The engine for seaching results
engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))

# Return the results queried
query = lambda sql_tmp: engine.execute(sql_tmp['table_id'], sql_tmp['sql']['sel'], sql_tmp['sql']['agg'], sql_tmp['sql']['conds'], sql_tmp['sql']['cond_conn_op'])

fname = os.path.join(path_db, f"{dset_name}.json")
with open(fname, encoding='utf-8') as fs:
    total_count = 0
    start = time.time()
    for line in fs:
        record = json.loads(line)
        break
    while total_count < 10:
        res = query(record)
        total_count += 1
    print('%d times of invoking cost %.2fs.' % (total_count, time.time() - start))
    # print('The number of empty results is %d, the number of result is 0 %d' % (count, tmp))
示例#27
0
文件: train.py 项目: Mars-Wei/MISP
def test(data_loader,
         data_table,
         model,
         model_bert,
         bert_config,
         tokenizer,
         max_seq_length,
         num_target_layers,
         detail=False,
         st_pos=0,
         cnt_tot=1,
         EG=False,
         beam_size=4,
         path_db=None,
         dset_name='test'):
    model.eval()
    model_bert.eval()

    ave_loss = 0
    cnt = 0
    cnt_sc = 0
    cnt_sa = 0
    cnt_wn = 0
    cnt_wc = 0
    cnt_wo = 0
    cnt_wv = 0
    cnt_wvi = 0
    cnt_lx = 0
    cnt_x = 0

    cnt_list = []

    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))
    results = []
    for iB, t in enumerate(data_loader):

        cnt += len(t)
        if cnt < st_pos:
            continue
        # Get fields
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(
            t, data_table, no_hs_t=True, no_sql_t=True)

        g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
        g_wvi_corenlp = get_g_wvi_corenlp(t)

        wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        try:
            g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx,
                                                      g_wvi_corenlp)
            g_wv_str, g_wv_str_wp = convert_pr_wvi_to_string(
                g_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)

        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            for b in range(len(nlu)):
                results1 = {}
                results1["error"] = "Skip happened"
                results1["nlu"] = nlu[b]
                results1["table_id"] = tb[b]["id"]
                results.append(results1)
            continue

        # model specific part
        # score
        if not EG:
            # No Execution guided decoding
            s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h,
                                                       l_hpu, l_hs)

            # get loss & step
            loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa,
                              g_wn, g_wc, g_wo, g_wvi)

            # prediction
            pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(
                s_sc,
                s_sa,
                s_wn,
                s_wc,
                s_wo,
                s_wv,
            )
            pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(
                pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)
            # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu)
            pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo,
                                      pr_wv_str, nlu)
        else:
            # Execution guided decoding
            prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(
                wemb_n,
                l_n,
                wemb_h,
                l_hpu,
                l_hs,
                engine,
                tb,
                nlu_t,
                nlu_tt,
                tt_to_t_idx,
                nlu,
                beam_size=beam_size)
            # sort and generate
            pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)

            # Follosing variables are just for the consistency with no-EG case.
            pr_wvi = None  # not used
            pr_wv_str = None
            pr_wv_str_wp = None
            loss = torch.tensor([0])

        g_sql_q = generate_sql_q(sql_i, tb)
        pr_sql_q = generate_sql_q(pr_sql_i, tb)

        # Saving for the official evaluation later.
        for b, pr_sql_i1 in enumerate(pr_sql_i):
            results1 = {}
            results1["query"] = pr_sql_i1
            results1["table_id"] = tb[b]["id"]
            results1["nlu"] = nlu[b]
            results.append(results1)

        cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \
        cnt_wc1_list, cnt_wo1_list, \
        cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sc, g_sa,g_wn, g_wc,g_wo, g_wvi,
                                                                   pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi,
                                                                   sql_i, pr_sql_i,
                                                                   mode='test')

        cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list,
                                       cnt_wn1_list, cnt_wc1_list,
                                       cnt_wo1_list, cnt_wv1_list)

        # Execution accura y test
        cnt_x1_list = []
        # lx stands for logical form accuracy

        # Execution accuracy test.
        cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa,
                                                    sql_i, pr_sc, pr_sa,
                                                    pr_sql_i)

        # stat
        ave_loss += loss.item()

        # count
        cnt_sc += sum(cnt_sc1_list)
        cnt_sa += sum(cnt_sa1_list)
        cnt_wn += sum(cnt_wn1_list)
        cnt_wc += sum(cnt_wc1_list)
        cnt_wo += sum(cnt_wo1_list)
        cnt_wv += sum(cnt_wv1_list)
        cnt_wvi += sum(cnt_wvi1_list)
        cnt_lx += sum(cnt_lx1_list)
        cnt_x += sum(cnt_x1_list)

        current_cnt = [
            cnt_tot, cnt, cnt_sc, cnt_sa, cnt_wn, cnt_wc, cnt_wo, cnt_wv,
            cnt_wvi, cnt_lx, cnt_x
        ]
        cnt_list1 = [
            cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list,
            cnt_wo1_list, cnt_wv1_list, cnt_lx1_list, cnt_x1_list
        ]
        cnt_list.append(cnt_list1)
        # report
        if detail:
            report_detail(hds, nlu, g_sc, g_sa, g_wn, g_wc, g_wo, g_wv,
                          g_wv_str, g_sql_q, g_ans, pr_sc, pr_sa, pr_wn, pr_wc,
                          pr_wo, pr_wv_str, pr_sql_q, pr_ans, cnt_list1,
                          current_cnt)

    ave_loss /= cnt
    acc_sc = cnt_sc / cnt
    acc_sa = cnt_sa / cnt
    acc_wn = cnt_wn / cnt
    acc_wc = cnt_wc / cnt
    acc_wo = cnt_wo / cnt
    acc_wvi = cnt_wvi / cnt
    acc_wv = cnt_wv / cnt
    acc_lx = cnt_lx / cnt
    acc_x = cnt_x / cnt

    acc = [
        ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv,
        acc_lx, acc_x
    ]
    return acc, results, cnt_list
示例#28
0
def train(train_loader,
          train_table,
          model,
          model_bert,
          opt,
          bert_config,
          tokenizer,
          max_seq_length,
          num_target_layers,
          accumulate_gradients=1,
          check_grad=True,
          st_pos=0,
          opt_bert=None,
          path_db=None,
          dset_name='train'):
    model.train()
    model_bert.train()
    ave_loss, one_acc_num, tot_acc_num, ex_acc_num = 0, 0.0, 0.0, 0.0
    cnt = 0  # count the # of examples
    # Engine for SQL querying.
    # 这里别忘了改,引擎要变成新的
    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))
    pbar = tqdm(range(len(train_loader.dataset) // 16))

    for iB, t in enumerate(train_loader):
        # t 是一个完整的tok文件
        cnt += len(t)
        if cnt < st_pos:
            continue
        # Get fields
        nlu, nlu_t, sql_i, sql_t, tb, hs_t, hds = get_fields(t,
                                                             train_table,
                                                             no_hs_t=True,
                                                             no_sql_t=True)
        # nlu  : natural language utterance 源自然语言
        # nlu_t: tokenized nlu  分词的问题
        # sql_i: canonical form of SQL query 查询sql
        # sql_q: full SQL query text. Not used.已删除
        # sql_t: tokenized SQL query 分词的问题 = nlu_t
        # tb   : table
        # hs_t : tokenized headers. Not used.
        # hds :   header

        g_sc, g_sa, g_sop, g_wn, g_wc, g_wo, g_wv, g_sel_num_seq, g_sel_ag_seq, conds = get_g(
            sql_i)
        # g_sel_num_seq真实sel的个数
        # g_sel_ag_seq 包含一个元组,agg个数,sel实际值,agg实际值(list)
        # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset.
        '''
        去除baseline情况的,需要把g_sc /g_sa换成一位数组
        '''
        g_sc1 = []
        for i in range(len(g_sc)):
            g_sc1.append(g_sc[i][0])
        g_sc = g_sc1

        g_sa1 = []
        for i in range(len(g_sa)):
            g_sa1.append(g_sa[i][0])
        g_sa = g_sa1

        # 这里提取了语义索引
        g_wvi_corenlp = get_g_wvi_corenlp(t)

        wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)

        # wemb_n: natural language embedding
        # wemb_h: header embedding
        # l_n: token lengths of each question
        # l_hpu: header token lengths
        # l_hs: the number of columns (headers) of the tables.
        try:
            g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx,
                                                      g_wvi_corenlp)
        except:
            print('索引转值出错')
            continue

        # score
        s_scn, s_sc, s_sa, s_sop, s_wn, s_wc, s_wo, s_wv = model(
            wemb_n,
            l_n,
            wemb_h,
            l_hpu,
            l_hs,
            g_scn=g_sel_num_seq,
            g_sc=g_sc,
            g_sa=g_sa,
            g_wn=g_wn,
            g_wc=g_wc,
            g_sop=g_sop,
            g_wo=g_wo,
            g_wvi=g_wvi)

        # start = time.time()
        # results = []
        # lenth = len(t)
        # g_wvi_corenlp = []

        # 多进程部分
        '''        
        manager = mp.Manager()
        dict = manager.dict()
        pool = mp.Pool(32)
        for x in range(lenth):

            pool.apply_async(gwvi, (dict, x, conds[x], nlu_t[x]))

        pool.close()
        pool.join()

        for idx in range(lenth):
            g_wvi_corenlp.append(dict[idx])

        end = time.time()
        print('runs %0.2f seconds.' % (end - start))

        '''
        # 单进程部分
        # for x in range(len(conds)):
        #     wv_ann1 = []
        #     cond1 = conds[x]
        #     nlu_1 = nlu_t[x]
        #     for conds11 in cond1:
        #         _wv_ann1 = annotate_ws.annotate(str(conds11[2]))
        #         wv_ann11 = _wv_ann1['gloss']
        #         wv_ann1.append(wv_ann11)
        #
        #     try:
        #         wvi1_corenlp = annotate_ws.check_wv_tok_in_nlu_tok(wv_ann1, nlu_1)
        #         g_wvi_corenlp.append(wvi1_corenlp)
        #     except:
        #         print("gwvi构建失败")
        #         print(nlu_1)
        #         exit()

        loss = Loss_sw_se(s_scn, s_sc, s_sa, s_sop, s_wn, s_wc, s_wo, s_wv,
                          g_sel_num_seq, g_sc, g_sa, g_sop, g_wn, g_wc, g_wo,
                          g_wvi)

        # Calculate gradient
        if iB % accumulate_gradients == 0:  # mode
            # at start, perform zero_grad
            opt.zero_grad()
            if opt_bert:
                opt_bert.zero_grad()
            loss.backward()
            if accumulate_gradients == 1:
                opt.step()
                if opt_bert:
                    opt_bert.step()
        elif iB % accumulate_gradients == (accumulate_gradients - 1):
            # at the final, take step with accumulated graident
            loss.backward()
            opt.step()
            if opt_bert:
                opt_bert.step()
        else:
            # at intermediate stage, just accumulates the gradients
            loss.backward()

        # L = loss.item()
        ave_loss += loss.item()

        pbar.update(len(t))
    return ave_loss / cnt