def eval_one_qelos(db_file, pred_file, source_file): engine = DBEngine(db_file) exact_match = [] with open(source_file) as fs, open(pred_file) as fp: grades = [] for ls, lp in tqdm(zip(fs, fp), total=count_lines(source_file)): eg = json.loads(ls) ep = json.loads(lp) qg = Query.from_dict(eg['sql']) gold = engine.execute_query(eg['table_id'], qg, lower=True) qp = None try: qp = Query.from_dict(ep) pred = engine.execute_query(eg['table_id'], qp, lower=True) except Exception as e: pred = repr(e) correct = pred == gold match = qp == qg grades.append(correct) exact_match.append(match) result = { 'ex_accuracy': sum(grades) / len(grades), 'lf_accuracy': sum(exact_match) / len(exact_match), } return result
def hello(): pred_json = json.loads(request.args.get('pred')) gold_json = json.loads(request.args.get('gold')) dataset = request.args.get('dataset') engine = DBEngine(os.path.join(DATABASE_PATH, "{}.db".format(dataset))) exact_match = [] grades = [] for lp, ls in tqdm(zip(pred_json, gold_json), total=len(gold_json)): eg = ls ep = lp qg = Query.from_dict(eg['sql']) gold = engine.execute_query(eg['table_id'], qg, lower=True) pred = ep['error'] qp = None if not ep['error']: try: qp = Query.from_dict(ep['query']) pred = engine.execute_query(eg['table_id'], qp, lower=True) except Exception as e: pred = repr(e) correct = pred == gold match = qp == qg grades.append(correct) exact_match.append(match) ex_accuracy = sum(grades) / len(grades) lf_accuracy = sum(exact_match) / len(exact_match) return json.dumps({"ex_accuracy": ex_accuracy, "lf_accuracy": lf_accuracy})
def eval(self, gold, sql_gold, engine): self.exception_raised = False # reset this flag for a new evaluation if self.agg == gold['query']['agg']: self.correct['agg'] = 1 if self.sel == gold['query']['sel']: self.correct['sel'] = 1 op_list_pred = [op for col, op, span in self.cond] op_list_gold = [op for col, op, span in gold['query']['conds']] col_list_pred = [col for col, op, span in self.cond] col_list_gold = [col for col, op, span in gold['query']['conds']] q = gold['question']['words'] span_list_pred = [' '.join(q[span[0]:span[1] + 1]) for col, op, span in self.cond] span_list_gold = [' '.join(span['words']) for col, op, span in gold['query']['conds']] where_pred = list(zip(col_list_pred, op_list_pred, span_list_pred)) where_gold = list(zip(col_list_gold, op_list_gold, span_list_gold)) where_pred.sort() where_gold.sort() if where_pred == where_gold and (len(col_list_pred) == len(col_list_gold)) and (len(op_list_pred) == len(op_list_gold)) and (len(span_list_pred) == len(span_list_gold)): self.correct['where'] = 1 if (len(col_list_pred) == len(col_list_gold)) and ([it[0] for it in where_pred] == [it[0] for it in where_gold]): self.correct['col'] = 1 if (len(op_list_pred) == len(op_list_gold)) and ([it[1] for it in where_pred] == [it[1] for it in where_gold]): self.correct['lay'] = 1 if (len(span_list_pred) == len(span_list_gold)) and ([it[2] for it in where_pred] == [it[2] for it in where_gold]): self.correct['span'] = 1 if all((self.correct[it] == 1 for it in ('agg', 'sel', 'where'))): self.correct['all'] = 1 # execution table_id = gold['table_id'] ans_gold = engine.execute_query( table_id, Query.from_dict(sql_gold), lower=True) try: sql_pred = {'agg':self.agg, 'sel':self.sel, 'conds': self.recover_cond_to_gloss(gold)} ans_pred = engine.execute_query( table_id, Query.from_dict(sql_pred), lower=True) except Exception as e: #pdb.set_trace() self.exception_raised = True ans_pred = repr(e) if set(ans_gold) == set(ans_pred): self.correct['exe'] = 1
def get_query_from_json(self, json_line): """Returns a Query object for the json input and returns the table object as well""" q = Query.from_dict(json_line["sql"]) t_id = json_line["table_id"] table = self.table_map[t_id] t = Table("", table["header"], table["types"], table["rows"]) return t, q
def main(argv): del argv # Unused. db_file = join(FLAGS.data_root, FLAGS.db_file) parsed_std_sql_file = join(FLAGS.data_root, FLAGS.parsed_std_sql_file) parsed_pred_sql_file = join(FLAGS.data_root, FLAGS.parsed_pred_sql_file) engine = DBEngine(db_file) exact_match = [] with open(parsed_std_sql_file) as fs, open(parsed_pred_sql_file) as fp: grades = [] for ls, lp in tqdm(zip(fs, fp), total=count_lines(parsed_std_sql_file)): eg = json.loads(ls) ep = json.loads(lp) try: qg = Query.from_dict(eg['sql']) gold = engine.execute_query(eg['table_id'], qg, lower=True) except Exception as e: gold = repr(e) #pred = ep['error'] qp = None #if not ep['error']: if True: try: qp = Query.from_dict(ep['sql']) pred = engine.execute_query(eg['table_id'], qp, lower=True) except Exception as e: pred = repr(e) correct = pred == gold match = qp == qg if pred == gold and qp != qg: print(qp) print(qg) grades.append(correct) exact_match.append(match) print( json.dumps( { 'ex_accuracy': sum(grades) / len(grades), 'lf_accuracy': sum(exact_match) / len(exact_match), }, indent=2))
def main(anno_file_name, col_headers, raw_args=None, verbose=True): parser = argparse.ArgumentParser(description='evaluate.py') opts.translate_opts(parser) opt = parser.parse_args(raw_args) torch.cuda.set_device(opt.gpu) opt.db_file = os.path.join(opt.data_path, '{}.db'.format(opt.split)) opt.pre_word_vecs = os.path.join(opt.data_path, 'embedding') dummy_parser = argparse.ArgumentParser(description='train.py') opts.model_opts(dummy_parser) opts.train_opts(dummy_parser) dummy_opt = dummy_parser.parse_known_args([])[0] opt.anno = anno_file_name engine = DBEngine(opt.db_file) js_list = table.IO.read_anno_json(opt.anno) prev_best = (None, None) sql_query = [] for fn_model in glob.glob(opt.model_path): opt.model = fn_model translator = Translator(opt, dummy_opt.__dict__) data = table.IO.TableDataset(js_list, translator.fields, None, False) test_data = table.IO.OrderedIterator(dataset=data, device=opt.gpu, batch_size=opt.batch_size, train=False, sort=True, sort_within_batch=False) # inference r_list = [] for batch in test_data: r_list += translator.translate(batch) r_list.sort(key=lambda x: x.idx) pred = r_list[-1] sql_pred = { 'agg': pred.agg, 'sel': pred.sel, 'conds': pred.recover_cond_to_gloss(js_list[-1]) } if verbose: print('\n sql_pred: ', sql_pred, '\n') print('\n col_headers: ', col_headers, '\n') sql_query = Query(sql_pred['sel'], sql_pred['agg'], sql_pred['conds']) try: ans_pred = engine.execute_query(js_list[-1]['table_id'], Query.from_dict(sql_pred), lower=True, verbose=verbose) except Exception as e: ans_pred = None return sql_query.get_complete_query(col_headers), ans_pred
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from lib.query import Query from lib.dbengine import DBEngine import pytorch if __name__ == '__main__': for split in ['train', 'dev', 'test']: print('checking {}'.format(split)) engine = DBEngine('data/{}.db'.format(split)) n_lines = 0 with open('data/{}.jsonl'.format(split)) as f: for l in f: n_lines += 1 with open('data/{}.jsonl'.format(split)) as f: for l in tqdm(f, total=n_lines): d = json.loads(l) query = Query.from_dict(d['sql']) # make sure it's executable result = engine.execute_query(d['table_id'], query) if result: for a, b, c in d['sql']['conds']: if str(c).lower() not in d['question'].lower(): raise Exception( 'Could not find condition {} in question {} for query {}' .format(c, d['question'], query)) else: raise Exception( 'Query {} did not execute to a valid result'.format( query))
parser.add_argument('--topk', type=int, default=3, help='k of top_k') args = parser.parse_args() engine = DBEngine(args.db_file) temp = [] with open(args.source_file) as fs, open(args.pred_file) as fp: grades = [] exact_match = [] for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)): eg = json.loads(ls) qg = Query.from_dict(eg['sql'], ordered=args.ordered) gold = engine.execute_query(eg['table_id'], qg, lower=True) pred_topk = [] qp_topk = [] ep = json.loads(lp) pred = ep.get('error', None) qp = None for i in range(args.topk): if not ep.get('error', None): try: if ep['query'][str(i)]['conds'] == [[]]: ep['query'][str(i)]['conds'] = []
def eval(self, split, gold, sql_gold, engine=None): if self.agg == gold['query']['agg']: self.correct['agg'] = 1 if self.sel == gold['query']['sel']: self.correct['sel'] = 1 #gold['BIO_label'].data, gold['BIO_column_label'].data op_list_pred = [op for col, op, span in self.cond] op_list_gold = [op for col, op, span in gold['query']['conds']] col_list_pred = [col for col, op, span in self.cond] col_list_gold = [col for col, op, span in gold['query']['conds']] q = gold['question']['words'] span_list_pred = [ ' '.join(q[span[0]:span[1] + 1]) for col, op, span in self.cond ] span_list_gold = [ ' '.join(span['words']) for col, op, span in gold['query']['conds'] ] where_pred = list(zip(col_list_pred, op_list_pred, span_list_pred)) where_gold = list(zip(col_list_gold, op_list_gold, span_list_gold)) where_pred.sort() where_gold.sort() if where_pred == where_gold and ( len(col_list_pred) == len(col_list_gold)) and ( len(op_list_pred) == len(op_list_gold)) and (len(span_list_pred) == len(span_list_gold)): self.correct['where'] = 1 if (len(col_list_pred) == len(col_list_gold)) and ([ it[0] for it in where_pred ] == [it[0] for it in where_gold]): self.correct['col'] = 1 if (len(op_list_pred) == len(op_list_gold)) and ([ it[1] for it in where_pred ] == [it[1] for it in where_gold]): self.correct['lay'] = 1 if (len(span_list_pred) == len(span_list_gold)) and ([ it[2] for it in where_pred ] == [it[2] for it in where_gold]): self.correct['span'] = 1 if all((self.correct[it] == 1 for it in ('agg', 'sel', 'where'))): self.correct['all'] = 1 # execution table_id = gold['table_id'] ans_gold = '0' ans_pred = '1' if engine is not None: ans_gold = engine.execute_query(table_id, Query.from_dict(sql_gold), lower=True) try: sql_pred = { 'agg': self.agg, 'sel': self.sel, 'conds': self.recover_cond_to_gloss(gold) } ans_pred = engine.execute_query(table_id, Query.from_dict(sql_pred), lower=True) except Exception as e: ans_pred = repr(e) else: ans_gold = '0' ans_pred = '1' if set(ans_gold) == set(ans_pred): self.correct['exe'] = 1 error_case = {} if split == 'finaltest': #if self.correct['where'] != 1: if True: error_case['sel'] = self.correct['sel'] error_case['where'] = self.correct['where'] error_case['all'] = self.correct['all'] error_case['table_id'] = gold['table_id'] error_case['question_id'] = gold['id'] error_case['question'] = gold['question']['words'] error_case['table_head'] = [ head['words'] for head in gold['table']['header'] ] #error_case['table_content'] = gold['tbl_content'] # for i in range(len(sql_gold['conds'])): # sql_gold['conds'][i][0] = ( # sql_gold['conds'][i][0], gold['table']['header'][sql_gold['conds'][i][0]]['words'], # gold['tbl_content'][sql_gold['conds'][i][0]]) error_case['gold'] = sql_gold error_case['predict'] = { 'agg': self.agg.item(), 'sel': self.sel.item(), 'conds': self.print_recover_cond_to_gloss(gold) } # error_case['exe result']=self.correct['exe'] error_case['BIO'] = [(x.item(), y) for x, y in zip( list(self.BIO), list(gold['question']['words']))] error_case['BIO_col'] = self.BIO_col.tolist() return error_case
opts = Options() fill_from_args(opts) for split in ['train', 'dev', 'test']: orig = os.path.join(opts.data_dir, f'{split}.jsonl') db_file = os.path.join(opts.data_dir, f'{split}.db') ans_file = os.path.join(opts.data_dir, f"{split}_ans.jsonl.gz") tbl_file = os.path.join(opts.data_dir, f"{split}.tables.jsonl") engine = DBEngine(db_file) exact_match = [] with open(orig) as fs, write_open(ans_file) as fo: grades = [] for ls in tqdm(fs, total=count_lines(orig)): eg = json.loads(ls) sql = eg['sql'] qg = Query.from_dict(sql, ordered=False) gold = engine.execute_query(eg['table_id'], qg, lower=True) assert isinstance(gold, list) #if len(gold) != 1: # print(f'for {sql} : {gold}') eg['answer'] = gold eg['rowids'] = engine.execute_query_rowid(eg['table_id'], qg, lower=True) # CONSIDER: if it is not an agg query, somehow identify the particular cell fo.write(json.dumps(eg) + '\n') convert(jsonl_lines(ans_file), jsonl_lines(tbl_file), os.path.join(opts.data_dir, f"{split}_agg.jsonl.gz"), skip_aggregation=False)
if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('source_file', help='source file for the prediction') parser.add_argument('db_file', help='source database for the prediction') parser.add_argument('pred_file', help='predictions by the model') args = parser.parse_args() engine = DBEngine(args.db_file) exact_match = [] with open(args.source_file) as fs, open(args.pred_file) as fp: grades = [] for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.pred_file)): eg = json.loads(ls) ep = json.loads(lp) qg = Query.from_dict(eg['sql']) #print qg gold = engine.execute_query(eg['table_id'], qg, lower=True) pred = ep['error'] qp = None if not ep['error']: try: qp = Query.from_dict(ep['query']) pred = engine.execute_query(eg['table_id'], qp, lower=True) except Exception as e: pred = repr(e) correct = pred == gold match = qp == qg #if not match: # print qp.to_dict() # print qg.to_dict()
def main(): parser = ArgumentParser() parser.add_argument('--din', default=data_path, help='data directory') parser.add_argument('--dout', default='annotated', help='output directory') args = parser.parse_args() if not os.path.isdir(args.dout): os.makedirs(args.dout) for split in ['dev']: #for split in ['train', 'test', 'dev']: with open(save_path+'%s.qu'%split, 'w') as qu_file, open(save_path+'%s.lon'%split, 'w') as lon_file, \ open(save_path+'%s.out'%split, 'w') as out, open(save_path+'%s_sym_pairs.txt'%split, 'w') as sym_file, \ open(save_path+'%s_ground_truth.txt'%split, 'w') as S_file: fsplit = os.path.join(args.din, split) + '.jsonl' ftable = os.path.join(args.din, split) + '.tables.jsonl' with open(fsplit) as fs, open(ftable) as ft: print('loading tables') tables = {} for line in tqdm(ft, total=count_lines(ftable)): d = json.loads(line) tables[d['id']] = d print('loading tables done.') print('loading examples') n, acc, acc_pair, acc_all, error = 0, 0, 0, 0, 0 target = -1 ADD_FIELDS = False step2 = True for line in tqdm(fs, total=count_lines(fsplit)): ADD_TO_FILE = True d = json.loads(line) Q = d['question'] rows = tables[d['table_id']]['rows'] rows = np.asarray(rows) fs = tables[d['table_id']]['header'] all_fields = [] for f in fs: all_fields.append(_preclean(f)) # all fields are sorted by length in descending order # for string match purpose all_fields = sorted(all_fields, key=len, reverse=True) smap = defaultdict(list) #f2v reverse_map = defaultdict(list) #v2f for row in rows: for i in range(len(fs)): cur_f = _preclean(str(fs[i])) cur_row = _preclean(str(row[i])) smap[cur_f].append(cur_row) if cur_f not in reverse_map[cur_row]: reverse_map[cur_row].append(cur_f) #---------------------------------------------------------- # all values are sorted by length in descending order # for string match purpose keys = sorted(reverse_map.keys(), key=len, reverse=True) Q = _preclean(Q) Q_ori = Q ##################################### ########## Annotate question ######## ##################################### candidates, cond_fields = _match_pairs( Q, Q_ori, keys, reverse_map) Q_head, head2partial = _match_head(Q, Q_ori, smap, all_fields, cond_fields) Q, Qpairs = annotate_Q(Q, Q_ori, Q_head, candidates, all_fields, head2partial, n, target) qu_file.write(Q + '\n') validation_pairs = copy.copy(Qpairs) validation_pairs.append((Q_head, '<f0>', 'head')) for i, f in enumerate(all_fields): validation_pairs.append((f, '<c' + str(i) + '>', 'c')) ##################################### ########## Annotate SQL ############# ##################################### q_sent = Query.from_dict(d['sql']) S, col_names, val_names = q_sent.to_sentence( tables[d['table_id']]['header'], rows, tables[d['table_id']]['types']) S = _preclean(S) S_ori = S S_noparen = q_sent.to_sentence_noparenthesis( tables[d['table_id']]['header'], rows, tables[d['table_id']]['types']) S_noparen = _preclean(S_noparen) col_names = [_preclean(col_name) for col_name in col_names] val_names = [_preclean(val_name) for val_name in val_names] HEAD = col_names[-1] S_head = _preclean(HEAD) #annotate for SQL name_pairs = [] for col_name, val_name in zip(col_names, val_names): if col_name == val_name: name_pairs.append([_preclean(col_name), 'true']) else: name_pairs.append( [_preclean(col_name), _preclean(val_name)]) # sort to compare with candidates name_pairs.sort(key=lambda x: x[1]) new_name_pairs = [[ '<f' + str(i + 1) + '>', '<v' + str(i + 1) + '>' ] for i, (field, value) in enumerate(name_pairs)] # only annotate S while identified (f,v) pairs are right if _equal(name_pairs, candidates): pairs = [] for (f, v), (new_f, new_v) in zip(name_pairs, new_name_pairs): pairs.append((f, new_f, 'f')) pairs.append((v, new_v, 'v')) # sort (word,symbol) pairs by length in descending order pairs.sort(key=lambda x: 100 - len(x[0])) for p, new_p, t in pairs: cp = _backslash(p) if new_p in Q: if t == 'v': S = S.replace(p + ' )', new_p + ' )') if t == 'f': S = re.sub( '\( ' + cp + ' (equal|less|greater)', '( ' + new_p + r' \1', S) # only annotate S while identified HEAD is right if S_head == Q_head and '<f0>' in Q: S = S.replace(S_head, '<f0>') # annote unseen fields if ADD_FIELDS: for i, f in enumerate(all_fields): cf = _backslash(f) S = re.sub('(\s|^)' + cf + '(\s|$|s)', ' <c' + str(i) + '> ', S) S = _clean(S) lon_file.write(S + '\n') ############################### ######### VALIDATION ########## ############################### recover_S = S for word, sym, t in validation_pairs: recover_S = recover_S.replace(sym, word) sym_file.write(sym + '=>' + word + '<>') sym_file.write('\n') S_file.write(S_noparen + '\n') #------------------------------------------------------------------------ if _equal(name_pairs, candidates): acc_pair += 1 if Q_head == S_head: acc += 1 if _equal(name_pairs, candidates) and Q_head == S_head: acc_all += 1 full_anno = True for s in S.split(): if s[0] != '<' and s not in [ '(', ')', 'where', 'less', 'greater', 'equal', 'max', 'min', 'count', 'sum', 'avg', 'and', 'true' ]: error += 1 full_anno = False break if False and not (_equal(name_pairs, candidates) and head == col_names[-1]): print('--------' + str(n) + '-----------') print(Q_ori) print(Q) print(S_ori) print(S) print('head:') print(head) print(head2partial) print('true HEAD') print(Q_head) print('fields:') print(candidates) print(p2partial) print('true fields') print(name_pairs) n += 1 print('total number of examples:' + str(n)) print('fully snnotated:' + str(1 - error * 1.0 / n)) print('accurate all percent:' + str(acc_all * 1.0 / n)) print('accurate HEAD match percent:' + str(acc * 1.0 / n)) print('accurate fields pair match percent:' + str(acc_pair * 1.0 / n))
def _gen_files(save_path=save_path): """ Prepare files used for Machine Comprehension Binary Classifier """ for split in ['train', 'test', 'dev']: print('------%s-------' % split) n = 0 fsplit = os.path.join(data_path, split) + '.jsonl' ftable = os.path.join(data_path, split) + '.tables.jsonl' """ test.txt original column content w/o truncate test_model.txt column name truncate or pad to length 3 """ with open(fsplit) as fs, open(ftable) as ft, \ open(os.path.join(save_path, split+'.txt'), mode='w') as fw, \ open(os.path.join(save_path, '%s_model.txt'%split), mode='w') as fw_model, \ open(os.path.join(save_path, split+'.lon'), mode='w') as fsql, \ open(os.path.join(save_path, '%s.ori.qu'%split), 'w') as qu_file, \ open(os.path.join(save_path, '%s.ori.lon'%split), 'w') as lon_file: print('loading tables...') tables = {} for line in tqdm(ft, total=count_lines(ftable)): d = json.loads(line) tables[d['id']] = d print('loading tables done.') print('loading examples') f2v_all, v2f_all = [], [] for line in tqdm(fs, total=count_lines(fsplit)): d = json.loads(line) Q = d['question'] Q = _preclean(Q).replace('\t', '') qu_file.write(Q + '\n') q_sent = Query.from_dict(d['sql']) rows = tables[d['table_id']]['rows'] S, col_names, val_names = q_sent.to_sentence( tables[d['table_id']]['header'], rows, tables[d['table_id']]['types']) S = _preclean(S) lon_file.write(S + '\n') rows = np.asarray(rows) fs = tables[d['table_id']]['header'] all_fields = [_preclean(f) for f in fs] # all fields are sorted by length in descending order # for string match purpose headers = sorted(all_fields, key=len, reverse=True) f2v = defaultdict(list) #f2v v2f = defaultdict(list) #v2f for row in rows: for i in range(len(fs)): cur_f = _preclean(str(fs[i])) cur_row = _preclean(str(row[i])) #cur_f = cur_f.replace('\u2003',' ') f2v[cur_f].append(cur_row) if cur_f not in v2f[cur_row]: v2f[cur_row].append(cur_f) f2v_all.append(f2v) v2f_all.append(v2f) ##################################### ########## Annotate SQL ############# ##################################### q_sent = Query.from_dict(d['sql']) S, col_names, val_names = q_sent.to_sentence( tables[d['table_id']]['header'], rows, tables[d['table_id']]['types']) S = _preclean(S) S_noparen = q_sent.to_sentence_noparenthesis( tables[d['table_id']]['header'], rows, tables[d['table_id']]['types']) S_noparen = _preclean(S_noparen) col_names = [_preclean(col_name) for col_name in col_names] val_names = [_preclean(val_name) for val_name in val_names] HEAD = col_names[-1] S_head = _preclean(HEAD) #annotate for SQL name_pairs = [] for col_name, val_name in zip(col_names, val_names): if col_name == val_name: name_pairs.append([_preclean(col_name), 'true']) else: name_pairs.append( [_preclean(col_name), _preclean(val_name)]) # sort to compare with candidates name_pairs.sort(key=lambda x: x[1]) fsql.write('#%d\n' % n) fw.write('#%d\n' % n) for f in col_names: fsql.write(S.replace(f, '[' + f + ']') + '\n') f = _truncate(f, END='<bos>', PAD='<pad>', max_len=-1) s = (Q + '\t' + f + '\t 1') assert len(s.split('\t')) == 3 fw.write(s + '\n') for f in [f for f in headers if f not in col_names]: f = _truncate(f, END='<bos>', PAD='<pad>', max_len=-1) s = (Q + '\t' + f + '\t 0') assert len(s.split('\t')) == 3 fw.write(s + '\n') fsql.write(S + '\n') #if '\u2003' in Q: # print('u2003: '+Q) #if '\xa0' in Q: # print(n) # print('xa0: '+Q) # print(S) for f in col_names: f = f.replace(u'\xa0', u' ').replace('\t', '') Q = Q.replace(u'\xa0', u' ').replace('\t', '') f = _truncate(f, END='bos', PAD='pad', max_len=3) s = (Q + '\t' + f + '\t 1') assert len(s.split('\t')) == 3 fw_model.write(s + '\n') for f in [f for f in headers if f not in col_names]: f = f.replace(u'\xa0', u' ').replace('\t', '') Q = Q.replace(u'\xa0', u' ').replace('\t', '') f = _truncate(f, END='bos', PAD='pad', max_len=3) s = (Q + '\t' + f + '\t 0') assert len(s.split('\t')) == 3 fw_model.write(s + '\n') n += 1 fsql.write('#%d\n' % n) fw.write('#%d\n' % n) scipy.savez(os.path.join(save_path, '%s_dict.npz' % split), f2v_all=f2v_all, v2f_all=v2f_all) print('num of records:%d' % n)
parser.add_argument("--db_file") parser.add_argument("--pred_file") parser.add_argument("--ordered", action='store_true') args = parser.parse_args() engine = DBEngine(args.db_file) ex_acc_list = [] lf_acc_list = [] with open(args.source_file) as sf, open(args.pred_file) as pf: for source_line, pred_line in tqdm(zip(sf, pf), total=count_lines(args.source_file)): # line별 정답과 예측 샘플 가져오기 gold_example = json.loads(source_line) pred_example = json.loads(pred_line) # 정답 샘플 lf, ex 구하기 lf_gold_query = Query.from_dict(gold_example['sql'], ordered=args.ordered) ex_gold = engine.execute_query(gold_example['table_id'], lf_gold_query, lower=True) # error가 아닌 경우 예측 샘플 lf, ex 구하기 lf_pred_query = None ex_pred = pred_example.get('error', None) if not ex_pred: try: lf_pred_query = Query.from_dict(pred_example['query'], ordered=args.ordered) ex_pred = engine.execute_query(gold_example['table_id'], lf_pred_query, lower=True) except Exception as e: ex_pred = repr(e) # lf, ex의 gold, pred 매칭결과 구하기 ex_acc_list.append(ex_pred == ex_gold) lf_acc_list.append(lf_pred_query == lf_gold_query) # query의 __eq__를 호출