示例#1
0
文件: view.py 项目: viperasi/eftol
 def POST(self):
     data =web.input(email=None, passwd=None, auto=False)
     if data.email and data.passwd:
         db = DBEngine().getInstance()
         user = db.select(
             ['users as u'], 
             vars={'email':data.email}, 
             what='u.id, u.name, u.passwd', 
             where='u.email=$email'
             )
         if user == None :
             return simplejson.dumps('{succ:false, msg:"邮箱地址不存在"}')
         else:
             currUser = user[0]
             if data.passwd == currUser.passwd:
                 if data.auto:
                     web.setcookie("email", data.email)
                 return simplejson.dumps('{succ:true, user:{id:' + str(currUser.id) + ',name:"' + currUser.name + '"}}')
             else:
                 return simplejson.dumps('{succ:false, msg:"邮箱与密码不匹配,请重新输入"}')
     else:
         return simplejson.dumps('{succ:false, msg:"请输入完整后登录"}')
     if data.auto:
         print 'auto'
     else:
         print 'no auto'
示例#2
0
def epoch_reinforce_train(model, optimizer, batch_size, sql_data, table_data,
                          db_path):
    engine = DBEngine(db_path)

    #XXX print ("=== call to train")
    model.train()
    perm = np.random.permutation(len(sql_data))
    cum_reward = 0.0
    st = 0
    while st < len(sql_data):
        ed = st + batch_size if st + batch_size < len(perm) else len(perm)

        q_seq, col_seq, col_num, ans_seq, query_seq, gt_whr_seq, raw_data =\
            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
        #XXX print ("=== generate g_s")
        g_s = model.generate_g_s(q_seq, col_seq, query_seq)
        raw_q_seq = [x[0] for x in raw_data]
        raw_col_seq = [x[1] for x in raw_data]
        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        #XXX print ("=== fwd")
        score = model.forward(q_seq,
                              col_seq,
                              col_num, (True, True, True),
                              reinforce=True)
        clasif_queries = model.gen_query(score,
                                         q_seq,
                                         col_seq,
                                         raw_q_seq,
                                         raw_col_seq, (True, True, True),
                                         reinforce=True)

        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        rewards = []
        for idx, (sql_gt, sql_pred,
                  tid) in enumerate(zip(query_gt, clasif_queries, table_ids)):
            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                    sql_gt['conds'])
            try:
                ret_pred = engine.execute(tid, sql_pred['sel'],
                                          sql_pred['agg'], sql_pred['conds'])
            except:
                ret_pred = None

            if ret_pred is None:
                rewards.append(-2)
            elif ret_pred != ret_gt:
                rewards.append(-1)
            else:
                rewards.append(1)

        cum_reward += (sum(rewards))
        optimizer.zero_grad()
        model.reinforce_backward(score, rewards)
        optimizer.step()

        st = ed

    return cum_reward / len(sql_data)
示例#3
0
 def getInstance(self):
     db = DBEngine().getInstance()
     results = db.select(self.table, vars=self.var, what=self.what,where=self.where,order=self.order)
     count = results[0].count
     mod = int(count) % self.limit
     if mod == 0:
     	self.allPage = int(count) / self.limit
     else:
     	self.allPage = int(count) / self.limit + 1
     return self
示例#4
0
文件: view.py 项目: viperasi/eftol
 def POST(self):
     data = web.input(id=None,type=None)
     db = DBEngine().getInstance()
     lists = db.select(['invtypes as t','trntranslations as tt','trntranslations as ttt'],vars={'gid':data.id},what='t.typeid,t.capacity,tt.text AS typename,ttt.text AS racename',where='tt.keyid=t.typeid and tt.tcid=8 and tt.languageid="ZH" AND ttt.keyid=t.raceid AND ttt.tcid=9 AND ttt.languageid="ZH" and t.groupid=$gid',order='t.raceid,t.typeid')
     stJSON = '['
     for st in lists:
         stJSON = stJSON + '{id:' + str(st.typeid) + ',capacity:' + str(st.capacity) + ',name:"' + st.typename + '",race:"' + st.racename + '"},'
     stJSON = stJSON[:len(stJSON)-1]
     stJSON = stJSON + ']'
     stJSON = '{succ:true,type:\'' + data.type +'\',options:' + stJSON +'}'
     return simplejson.dumps(stJSON)
示例#5
0
文件: view.py 项目: viperasi/eftol
 def GET(self):
     db = DBEngine().getInstance()
     shipTypes = db.select(['invgroups','trntranslations'],where='trntranslations.keyid=invgroups.groupid and trntranslations.tcid=7 and trntranslations.languageid="ZH" and invgroups.categoryid=6')
     web.header('Content-Type', 'application/json')
     stJSON = '['
     for st in shipTypes:
         stJSON = stJSON + '{id:' + str(st.groupID) + ', name:"' + st.text + '"},'
     stJSON = stJSON[:len(stJSON)-1]
     stJSON = stJSON + ']'
     stJSON = '{succ:true, type:\'group\',options:' + stJSON + '}'
     return simplejson.dumps(stJSON) 
示例#6
0
def epoch_acc(model, batch_size, sql_data, table_data, db_path):
    engine = DBEngine(db_path)
    model.eval()
    perm = list(range(len(sql_data)))
    badcase = 0
    one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0
    for st in tqdm(range(len(sql_data) // batch_size + 1)):
        ed = (st + 1) * batch_size if (st + 1) * \
            batch_size < len(perm) else len(perm)
        st = st * batch_size
        q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, raw_data = \
            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
        # q_seq: char-based sequence of question
        # gt_sel_num: number of selected columns and aggregation functions, new added field
        # col_seq: char-based column name
        # col_num: number of headers in one table
        # ans_seq: (sel, number of conds, sel list in conds, op list in conds)
        # gt_cond_seq: ground truth of conditions
        # raw_data: ori question, headers, sql
        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        # query_gt: ground truth of sql, data['sql'], containing sel, agg,
        # conds:{sel, op, value}
        raw_q_seq = [x[0] for x in raw_data]  # original question
        try:
            score = model.forward(q_seq, col_seq, col_num)
            pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq)
            # generate predicted format
            one_err, tot_err = model.check_acc(raw_data, pred_queries,
                                               query_gt)
        except BaseException:
            badcase += 1
            print('badcase', badcase)
            continue
        one_acc_num += (ed - st - one_err)
        tot_acc_num += (ed - st - tot_err)

        # Execution Accuracy
        for sql_gt, sql_pred, tid in zip(query_gt, pred_queries, table_ids):
            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                    sql_gt['conds'], sql_gt['cond_conn_op'])
            try:
                ret_pred = engine.execute(tid, sql_pred['sel'],
                                          sql_pred['agg'], sql_pred['conds'],
                                          sql_pred['cond_conn_op'])
            except BaseException:
                ret_pred = None
            ex_acc_num += (ret_gt == ret_pred)
    return one_acc_num / len(sql_data), tot_acc_num / \
        len(sql_data), ex_acc_num / len(sql_data)
示例#7
0
文件: view.py 项目: viperasi/eftol
 def GET(self, shipId):
     db =DBEngine().getInstance()
     ship = db.select(
                     ['invtypes as i', 'trntranslationcolumns as t1', 'trntranslations as t11',
                     'invgroups as g', 'trntranslationcolumns as t2', 'trntranslations as t22',
                     'chrraces as c', 'trntranslationcolumns as t3', 'trntranslations as t33',
                     'trntranslationcolumns as t4', 'trntranslations as t44'],
                     vars = {'sid' : shipId},
                     what = 'i.typeid, t11.text as name, t22.text as `group`, t33.text as race, t44.text as `desc`',
                     where = 't11.keyid=i.typeid and t1.tablename="dbo.invtypes" and t1.columnName="typename" and t1.tcid=t11.tcid and t11.languageid="ZH"'+
                             ' and t22.keyid=g.groupid and t2.tablename="dbo.invgroups" and t2.columnName="groupname" and t2.tcid=t22.tcid and t22.languageid="ZH"'+
                             ' and t33.keyid=c.raceid and t3.tablename="dbo.chrraces" and t3.columnName="racename" and t3.tcid=t33.tcid and t33.languageid="ZH"'+
                             ' and t44.keyid=i.typeid and t4.tablename="dbo.invtypes" and t4.columnName="description" and t4.tcid=t44.tcid and t44.languageid="ZH"'+
                             ' and i.typeid=$sid and i.groupid=g.groupid and c.raceid=i.raceid')
     return render.ceft(ship[0])
示例#8
0
文件: view.py 项目: viperasi/eftol
 def GET(self,currPage):
     data = web.input(gid='-1', rid='-1', tname='')
     web.header('Content-Type','text/html; charset=utf-8', unique=True)
     db = DBEngine().getInstance()
     if currPage == '':
         currPage = '1'
     shipTypes = db.select(
                     ['invgroups', 'trntranslationcolumns', 'trntranslations'],
                     what = 'invgroups.groupid as groupid,trntranslations.text as groupname',
                     where = 'trntranslations.keyid=invgroups.groupid and trntranslationcolumns.tablename="dbo.invgroups" AND trntranslationcolumns.columnName="groupname" AND trntranslations.tcid=trntranslationcolumns.tcid and trntranslations.languageid="ZH" and invgroups.categoryid=6')
     chrraces = db.select(
                     ['chrraces', 'trntranslationcolumns', 'trntranslations'],
                     what='chrraces.raceid as raceid,trntranslations.text as racename',
                     where='trntranslations.keyid=chrraces.raceid and trntranslationcolumns.tablename="dbo.chrraces" AND trntranslationcolumns.columnName="racename" AND trntranslations.tcid=trntranslationcolumns.tcid and trntranslations.languageid="ZH"')
     var = {}
     where = ''
     countWhere = ''
     if data.gid != '-1':
         var['gid'] = data.gid
         where = where + ' and i.groupid=$gid '
     if data.rid != '-1':
         var['rid'] = data.rid
         where = where + ' and i.raceid=$rid'
     if data.tname != '':
         var['tname'] = '%' + data.tname + '%'
         where = where + ' and t11.text like $tname'
     offset = (int(currPage) - 1) * 10
     ships = db.select(
                     ['invtypes as i', 'trntranslationcolumns as t1', 'trntranslations as t11',
                     'invgroups as g', 'trntranslationcolumns as t2', 'trntranslations as t22',
                     'chrraces as c', 'trntranslationcolumns as t3', 'trntranslations as t33'],
                     vars = var,
                     what = 'i.typeid,t11.text as typename,t22.text as groupname,t33.text as racename',
                     where = 't11.keyid=i.typeid and t1.tablename="dbo.invtypes" and t1.columnName="typename" and t1.tcid=t11.tcid and t11.languageid="ZH"'+
                             ' and t22.keyid=g.groupid and t2.tablename="dbo.invgroups" and t2.columnName="groupname" and t2.tcid=t22.tcid and t22.languageid="ZH"'+
                             ' and t33.keyid=c.raceid and t3.tablename="dbo.chrraces" and t3.columnName="racename" and t3.tcid=t33.tcid and t33.languageid="ZH"'+
                             ' and g.categoryid=6 and i.groupid=g.groupid and c.raceid=i.raceid' + where,
                     order = 'c.raceid,g.groupid,i.typeid',
                     limit = 10,
                     offset = offset)
     table = ['invtypes as i', 'invgroups as g', 'chrraces as c', 'trntranslationcolumns as t1', 'trntranslations as t11']
     what = 'count(1) as count'
     where = 't11.keyid=i.typeid and t1.tablename="dbo.invtypes" and t1.columnname="typename" and t1.tcid=t11.tcid and t11.languageid="ZH" and g.categoryid=6 and i.groupid=g.groupid and c.raceid=i.raceid ' + where
     order = 'c.raceid,g.groupid,i.typeid'
     page = Pagination(table, var, what, where, order, 10, currPage).getInstance()
     return render.eft(shipTypes, chrraces, ships, page, data)
示例#9
0
文件: view.py 项目: viperasi/eftol
 def GET(self,shipId):
     db =DBEngine().getInstance()
     web.header('Content-Type','application/json; charset=utf-8', unique=True)
     ship = db.select(
                     ['invtypes as i', 'trntranslationcolumns as t1', 'trntranslations as t11',
                     'invgroups as g', 'trntranslationcolumns as t2', 'trntranslations as t22',
                     'chrraces as c', 'trntranslationcolumns as t3', 'trntranslations as t33',
                     'trntranslationcolumns as t4', 'trntranslations as t44'],
                     vars = {'sid' : shipId},
                     what = 'i.typeid, t11.text as name, t22.text as `group`, t33.text as race, t44.text as `desc`, i.radius, i.mass, i.volume, i.capacity ',
                     where = 't11.keyid=i.typeid and t1.tablename="dbo.invtypes" and t1.columnName="typename" and t1.tcid=t11.tcid and t11.languageid="ZH"'+
                             ' and t22.keyid=g.groupid and t2.tablename="dbo.invgroups" and t2.columnName="groupname" and t2.tcid=t22.tcid and t22.languageid="ZH"'+
                             ' and t33.keyid=c.raceid and t3.tablename="dbo.chrraces" and t3.columnName="racename" and t3.tcid=t33.tcid and t33.languageid="ZH"'+
                             ' and t44.keyid=i.typeid and t4.tablename="dbo.invtypes" and t4.columnName="description" and t4.tcid=t44.tcid and t44.languageid="ZH"'+
                             ' and i.typeid=$sid and i.groupid=g.groupid and c.raceid=i.raceid')
     attr = db.select(
                     ['dgmtypeattributes AS dt', 'dgmattributetypes as d',
                     'trntranslationcolumns AS t1', 'trntranslations AS t11'],
                     vars = {'sid' : shipId},
                     what = 't11.text AS displayname,dt.attributeid, coalesce(dt.valuefloat,dt.valueint) as value , d.categoryid',
                     where = 'dt.typeid=$sid AND d.attributeid=dt.attributeid'+
                             ' and (dt.attributeID <> 182 AND dt.attributeID <> 277) AND (dt.attributeID <> 183 AND dt.attributeID <> 278) '+ 
                             ' and (dt.attributeID <> 184 AND dt.attributeID <> 279) AND (dt.attributeID <> 1285 AND dt.attributeID <> 1286) '+ 
                             ' and (dt.attributeID <> 1289 AND dt.attributeID <> 1287) AND (dt.attributeID <> 1290 AND dt.attributeID <> 1288)'+
                             ' AND t11.keyid=dt.attributeid AND t1.tablename="dbo.dgmattributetypes" AND t1.columnname="displayname" AND t1.tcid=t11.tcid AND t11.languageid="ZH"')
     skillssql = '''select tt.text as skill, COALESCE(skillLevel.valueFloat, skillLevel.valueInt) AS requiredLevel, attr.attributeid,output.*
              from ( select prereqs(dgmtypeattributes.typeid) as id, @level as treelevel, @parent as parent, substr(@path,2) as path
                  from ( select @start_with:=$typeid, @id:=@start_with,@level:=0,@parent:=0,@path:="" ) vars, dgmtypeattributes
                  where @id is not null ) output inner join trntranslations as tt on tt.tcid=8 and tt.languageid="zh" and tt.keyid=output.id 
                  INNER JOIN dgmtypeattributes AS attr ON attr.typeID = output.parent AND attr.attributeID IN (182,183,184,1285,1289,1290) AND COALESCE(attr.valueFloat, attr.valueInt) = output.id 
                  INNER JOIN dgmtypeattributes AS skillLevel ON skillLevel.typeID = output.parent AND skillLevel.attributeID IN (277,278,279,1286,1287,1288) 
                  where ( (attr.attributeID = 182 AND skillLevel.attributeID = 277) OR (attr.attributeID = 183 AND skillLevel.attributeID = 278) OR 
                      (attr.attributeID = 184 AND skillLevel.attributeID = 279) OR (attr.attributeID = 1285 AND skillLevel.attributeID = 1286) OR 
                      (attr.attributeID = 1289 AND skillLevel.attributeID = 1287) OR (attr.attributeID = 1290 AND skillLevel.attributeID = 1288)) order by treelevel,attr.attributeid'''
     skills = db.query(skillssql,vars={'typeid':shipId})
     shipJson = simplejson.dumps(ship[0])
     attrJson = simplejson.dumps(attr.list())
     skillJson = simplejson.dumps(skills.list())
     shipJson = simplejson.loads(shipJson)
     shipJson['prop'] = simplejson.loads(attrJson)
     shipJson['skills'] = simplejson.loads(skillJson)
     return simplejson.dumps(shipJson)
示例#10
0
def epoch_exec_acc(model, batch_size, sql_data, table_data, db_path):
    engine = DBEngine(db_path)

    model.eval()
    perm = list(range(len(sql_data)))
    tot_acc_num = 0.0
    acc_of_log = 0.0
    st = 0
    while st < len(sql_data):
        ed = st + batch_size if st + batch_size < len(perm) else len(perm)
        q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, raw_data = \
            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
        raw_q_seq = [x[0] for x in raw_data]
        raw_col_seq = [x[1] for x in raw_data]
        gt_where_seq = model.generate_gt_where_seq(q_seq, col_seq, query_seq)
        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        gt_sel_seq = [x[1] for x in ans_seq]
        score = model.forward(q_seq,
                              col_seq,
                              col_num, (True, True, True),
                              gt_sel=gt_sel_seq)
        pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq,
                                       raw_col_seq, (True, True, True))

        for idx, (sql_gt, sql_pred,
                  tid) in enumerate(zip(query_gt, pred_queries, table_ids)):
            try:
                ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                        sql_gt['conds'])
            except:
                ret_gt = None
            try:
                ret_pred = engine.execute(tid, sql_pred['sel'],
                                          sql_pred['agg'], sql_pred['conds'])
            except:
                ret_pred = None
            tot_acc_num += (ret_gt == ret_pred)

        st = ed

    return tot_acc_num / len(sql_data)
示例#11
0
def execute_accuracy(query_gt, pred_queries, table_ids, db_path, sql_data):
    """
        Execution Accuracy 执行精确性,只要sql的执行结果一致就行

    """
    engine = DBEngine(db_path)
    ex_acc_num = 0
    for sql_gt, sql_pred, tid in zip(query_gt, pred_queries, table_ids):
        ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                sql_gt['conds'], sql_gt['cond_conn_op'])

        try:
            ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'],
                                      sql_pred['conds'],
                                      sql_pred['cond_conn_op'])
        except Exception as e:
            print(e)
            ret_pred = None
        ex_acc_num += (ret_gt == ret_pred)
    print('\nexecute acc is {}'.format(ex_acc_num / len(sql_data)))
    return ex_acc_num / len(sql_data)
示例#12
0
def parse_sql():
    table_id = request.json['table_id']
    question = request.json['question']

    flag_childfind = 0
    if table_id == 'device':
        matchObj = re.search(r'最(.*)手机是', question, re.M | re.I)
        if matchObj:
            str_match = matchObj.group()
            str_mat_tmp = str_match.replace('手机', '') + '多少'
            question = question.replace(str_match, str_mat_tmp)
            flag_childfind = 1
            key_col_index = 5

    if table_id == 'telbill':
        matchObj = re.search(r'最(.*)账期是', question, re.M | re.I)
        if matchObj:
            str_match = matchObj.group()
            str_mat_tmp = str_match.replace('账期', '') + '多少'
            question = question.replace(str_match, str_mat_tmp)
            flag_childfind = 1
            key_col_index = 0

    test_json_line = '{\"question\": \"' + question + '\",\"table_id\": \"' + table_id + '\"}'
    test_data = read_line(test_json_line, test_tables)
    print(test_json_line)
    test_dataseq = DataSequence(data=test_data,
                                tokenizer=query_tokenizer,
                                label_encoder=label_encoder,
                                is_train=False,
                                shuffle_header=False,
                                max_len=160,
                                shuffle=False,
                                batch_size=1)

    header_lens = np.sum(test_dataseq[0]['input_header_mask'], axis=-1)
    model = models['stage1']
    model2 = models['stage2']
    with graph.as_default():
        preds_cond_conn_op, preds_sel_agg, preds_cond_op = model.predict_on_batch(
            test_dataseq[0])
        sql = outputs_to_sqls(preds_cond_conn_op, preds_sel_agg, preds_cond_op,
                              header_lens, test_dataseq.label_encoder)
        te_qc_pairs = QuestionCondPairsDataset(
            test_data,
            candidate_extractor=CandidateCondsExtractor(share_candidates=True),
            has_label=False,
            model_1_outputs=sql)

        te_qc_pairs_seq = QuestionCondPairsDataseq(te_qc_pairs,
                                                   tokenizer,
                                                   sampler=FullSampler(),
                                                   shuffle=False,
                                                   batch_size=1)
        te_result = model2.predict_generator(te_qc_pairs_seq, verbose=1)

    task2_result = merge_result(te_qc_pairs, te_result, threshold=0.995)
    cond = list(task2_result.get(0, []))
    sql[0]['conds'] = cond

    engine = DBEngine()
    #table_id = json.loads(test_json_line)['table_id']
    header = test_tables.__getitem__(table_id)._df.columns.values.tolist()
    sql_gen = engine.execute(table_id, sql[0]['sel'], sql[0]['agg'],
                             sql[0]['conds'], sql[0]['cond_conn_op'], header)

    if flag_childfind == 1 and sql[0]['agg'][0] > 0:
        #print(sql[0]['sel'])
        header_index = int(sql[0]['sel'][0])

        childcol = header[header_index]
        key_col = header[key_col_index]
        sql_gen = 'select ' + key_col + ' from Table_' + table_id + ' where ' + childcol + '=( ' + sql_gen + ' )'

    return jsonify({'task': sql_gen})
示例#13
0
header_lens = np.sum(test_dataseq[0]['input_header_mask'], axis=-1)
preds_cond_conn_op, preds_sel_agg, preds_cond_op = model.predict_on_batch(
    test_dataseq[0])
sql = outputs_to_sqls(preds_cond_conn_op, preds_sel_agg, preds_cond_op,
                      header_lens, test_dataseq.label_encoder)
te_qc_pairs = QuestionCondPairsDataset(
    test_data,
    candidate_extractor=CandidateCondsExtractor(share_candidates=True),
    has_label=False,
    model_1_outputs=sql)

te_qc_pairs_seq = QuestionCondPairsDataseq(te_qc_pairs,
                                           tokenizer,
                                           sampler=FullSampler(),
                                           shuffle=False,
                                           batch_size=1)
te_result = model2.predict_generator(te_qc_pairs_seq, verbose=1)

task2_result = merge_result(te_qc_pairs, te_result, threshold=0.995)
cond = list(task2_result.get(0, []))
sql[0]['conds'] = cond

engine = DBEngine()
table_id = json.loads(test_json_line)['table_id']
header = test_tables.__getitem__(table_id)._df.columns.values.tolist()
print(
    engine.execute(table_id, sql[0]['sel'], sql[0]['agg'], sql[0]['conds'],
                   sql[0]['cond_conn_op'], header))
#print(engine.execute(sql_json['table_id'], sql_json['sql']['sel'], sql_json['sql']['agg'], sql_json['sql']['conds'], sql_json['sql']['cond_conn_op']))
示例#14
0
    def is_same_execute(self, tid, sql_gt, sql_pred):
        """
        sql_gt & sql_pred is a dict and must contain sel, agg, conds, cond_conn_op 
        """
        ret_gt = self.engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                     sql_gt['conds'], sql_gt['cond_conn_op'])
        try:
            ret_pred = self.engine.execute(tid, sql_pred['sel'],
                                           sql_pred['agg'], sql_pred['conds'],
                                           sql_pred['cond_conn_op'])
        except Exception as e:
            return False
        return ret_gt == ret_pred


engine = DBEngine(os.path.join(valid_data_path, 'val.db'))
sqlite_oper = Sqlite3Oper(engine)


def check_part_acc(pred_queries, gt_queries, tables_list, valid_data):
    """
        判断各个组件的精确度
        param: 
                pred_queries: array of query
                gt_queries: array of query
                tables_list: 表列表
                valid_data: valid data 带有数据比较多
        ouput: xxx 
 
    """
    NEED_REWRITE_LOG = True
示例#15
0
def epoch_reinforce_train(model, optimizer, batch_size, sql_dataloader):
    """
    :param model: (Seq2SQL class)
    :param optimizer: (optimizer object)
    :param batch_size: (int)
    :param sql_data: (list) each entry is a dict containing one training example.
                    Dict includes table_id for relevant table
    :param table_data (dict) table data dict with keys as table_id's
    :param db_path: (str) path to the table db file
    """

    # engine = DBEngine(db_path) #Init database
    engine = DBEngine(sql_dataloader.dataset.TRAIN_DB)
    model.train()  #Set model in training mode
    # perm = np.random.permutation(len(sql_data))
    cum_reward = 0.0
    st = 0

    for batch_idx, sql_data in enumerate(sql_dataloader):
        gt_where_batch = model.generate_gt_where_seq(
            sql_data['question_tokens'], sql_data['column_headers'],
            sql_data['query_tokens']
        )  #Get where clauses of examples with tokens replaced by their token_ids
        raw_q_batch = [x[0] for x in sql_data['question_raw']
                       ]  # Get questions for each training example
        raw_col_batch = [x[1] for x in sql_data['question_raw']
                         ]  # Get Column Headers for each training example
        gt_sel_batch = [x[1] for x in sql_data['sql_query']
                        ]  # Get selector_id's for each training example
        table_ids = sql_data['table_id']
        gt_sql_entry = sql_data['sql_entry']
        score = model.forward(q=sql_data['question_tokens'],
                              col=sql_data['column_headers'],
                              col_num=sql_data['column_num'],
                              pred_entry=(True, True, True),
                              reinforce=True,
                              gt_sel=gt_sel_batch)
        loss = model.loss(score, sql_data['sql_query'], (True, True, False),
                          gt_where_batch)
        pred_queries = model.gen_query(score,
                                       sql_data['question_tokens'],
                                       sql_data['column_headers'],
                                       raw_q_batch,
                                       raw_col_batch, (True, True, True),
                                       reinforce=True)

        rewards = []
        # import pdb; pdb.set_trace()
        for (sql_gt, sql_pred, tid) in zip(gt_sql_entry, pred_queries,
                                           table_ids):
            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                    sql_gt['conds'])
            try:
                ret_pred = engine.execute(tid, sql_pred['sel'],
                                          sql_pred['agg'], sql_pred['conds'])
            except:
                ret_pred = None

            if ret_pred is None:
                rewards.append(-2)
            elif ret_pred != ret_gt:
                rewards.append(-1)
            else:
                rewards.append(3)

        cum_reward += (sum(rewards))

        lossrl = model.reinforce_backward(score[2], rewards, optimizer)
        # print("Avg RL Loss for batch: {}".format(lossrl.mean()))
        loss_batch = loss + lossrl

        # Optimization step batch-wise
        optimizer.zero_grad()
        loss_batch.backward(torch.FloatTensor([1]))
        optimizer.step()

        # Optimization step example-wise
        # for l in loss_batch:
        #     optimizer.zero_grad()
        #     l.backward(retain_graph=True)
        #     optimizer.step()

    print("Avg RL Loss for Epoch's last batch: {}. Avg CE Loss: {}".format(
        loss_batch.data[0], loss.data[0]))
    return cum_reward / len(sql_dataloader.dataset)
    header_lens = np.sum(test_dataseq[0]['input_header_mask'], axis=-1)
    preds_cond_conn_op, preds_sel_agg, preds_cond_op = model.predict_on_batch(
        test_dataseq[0])
    sql = outputs_to_sqls(preds_cond_conn_op, preds_sel_agg, preds_cond_op,
                          header_lens, test_dataseq.label_encoder)
    te_qc_pairs = QuestionCondPairsDataset(
        test_data,
        candidate_extractor=CandidateCondsExtractor(share_candidates=True),
        has_label=False,
        model_1_outputs=sql)

    te_qc_pairs_seq = QuestionCondPairsDataseq(te_qc_pairs,
                                               tokenizer,
                                               sampler=FullSampler(),
                                               shuffle=False,
                                               batch_size=1)
    te_result = model2.predict_generator(te_qc_pairs_seq, verbose=1)

    task2_result = merge_result(te_qc_pairs, te_result, threshold=0.995)
    cond = list(task2_result.get(0, []))
    sql[0]['conds'] = cond
    engine = DBEngine()
    print(
        engine.execute(
            json.loads(test_json_line)['table_id'], sql[0]['sel'],
            sql[0]['agg'], sql[0]['conds'], sql[0]['cond_conn_op']))

    print(sql)

# In[ ]: