def process_examples(fexample, ftable, fout): print('annotating {}'.format(fexample)) with open(fexample) as fs, open(ftable) as ft, open(fout, 'wt') as fo: print('loading tables') tables = {} for line in tqdm(ft, total=count_lines(ftable)): d = json.loads(line) tables[d['id']] = d print('loading examples') n_written = 0 for line in tqdm(fs, total=count_lines(fexample)): d = json.loads(line) a = annotate_example(d, tables[d['table_id']]) if not is_valid_example(a): # raise Exception(str(a)) print('Invalid example: {}'.format(str(a))) continue gold = Query.from_tokenized_dict(a['query']) reconstruct = Query.from_sequence(a['seq_output'], a['table'], lowercase=True) if gold.lower() != reconstruct.lower(): raise Exception ('Expected:\n{}\nGot:\n{}'.format(gold, reconstruct)) fo.write(json.dumps(a) + '\n') n_written += 1 print('wrote {} examples'.format(n_written))
def build_sql_train(filename): agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] cond_ops = ['=', '>', '<', 'OP'] syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS'] train_file = open(filename,"w") db = MySQLdb.connect(host="localhost", # your host user="******", # username passwd="cfhoCPkr", # password db="wikisql") # name of the database data = [] cur = db.cursor() table_header={} with open(OPTIONS.train_table_file) as tb: grades = [] for ls in tqdm(tb, total=count_lines(OPTIONS.train_table_file)): eg = json.loads(ls) table_header[eg['id']] = [re.split(b'[\/|,|;|\s]\s*',h.encode('utf8')) for h in eg['header']] with open(OPTIONS.train_source_file) as fs: grades = [] for ls in tqdm(fs, total=count_lines(OPTIONS.train_source_file)): eg = json.loads(ls) table_id = eg['table_id'] cq ='' i=0 for cond in eg['sql']['conds']: col = cond[0] op = cond[1] if isinstance(cond[2],str): cval = cond[2].encode('utf8') elif isinstance(cond[2],unicode): cval=unicodedata.normalize('NFKD', cond[2]).encode('ascii','ignore') else: cval = cond[2] cq= cq+('col'+str(col))+cond_ops[op]+'\''+str(cval)+'\' ' if i< len(eg['sql']['conds'])-1: i = i+1 cq = cq + 'AND ' if(eg['sql']['agg']>0): agg = agg_ops[eg['sql']['agg']] query = 'SELECT '+ agg+'(col'+str(eg['sql']['sel'])+')'+ ' FROM table_{}'.format((table_id.replace('-','_'))) + ' WHERE '+cq else: query = 'SELECT '+ ('col'+str(eg['sql']['sel']))+ ' FROM table_{}'.format((table_id.replace('-','_'))) + ' WHERE '+cq question = eg['question'] q = '' for t in re.split(r'(,|;|/|/\|!|@|#|$|\"|\(|\)|`|=|\s)\s*',question): if isinstance(t,str): q = q+''+t.encode('utf8') elif isinstance(t,unicode): q =q+ ''+unicodedata.normalize('NFKD', t).encode('ascii','ignore').replace(u'\u0101','a') else: q = q question = re.sub('\t+','\s',q) train_file.write(question+'\t'+query+'\n') #gold = cur.execute(query) train_file.close()
def process_tables(ftable, fout=None): # join words together def _join_words(entry): result = "" for i in range(len(entry["words"])): result += entry["words"][i] + (entry["after"][i] if entry["after"][i] is not " " else "^") return result tables = [] with open(ftable) as ft: for line in tqdm(ft, total=count_lines(ftable)): raw_table = json.loads(line) try: table = { "id": raw_table["id"], "header": [ _join_words(annotate(h)) for h in raw_table["header"]], "rows": [[ _join_words(annotate(str(tok))) for tok in l] for l in raw_table["rows"]] } tables.append(table) except: print(line) break if fout is not None: with open(fout, "w+") as fo: for table in tables: fo.write(json.dumps(table) + "\n")
def eval_one_qelos(db_file, pred_file, source_file): engine = DBEngine(db_file) exact_match = [] with open(source_file) as fs, open(pred_file) as fp: grades = [] for ls, lp in tqdm(zip(fs, fp), total=count_lines(source_file)): eg = json.loads(ls) ep = json.loads(lp) qg = Query.from_dict(eg['sql']) gold = engine.execute_query(eg['table_id'], qg, lower=True) qp = None try: qp = Query.from_dict(ep) pred = engine.execute_query(eg['table_id'], qp, lower=True) except Exception as e: pred = repr(e) correct = pred == gold match = qp == qg grades.append(correct) exact_match.append(match) result = { 'ex_accuracy': sum(grades) / len(grades), 'lf_accuracy': sum(exact_match) / len(exact_match), } return result
def load_seq2sql_file(train_file): i = 0 data = [] with open(train_file) as lf: for ls in tqdm(lf, total=count_lines(train_file)): data.append(ast.literal_eval(ls.replace(' ', ''))) i = i + 1 return data
def load_word2vec(file_loc): with open(file_loc) as w2v_file: grades = [] for ls in tqdm(w2v_file, total=count_lines(file_loc)): lv = ls.split(' ') word = lv[0] vec = [] for dim in range(len(lv) - 1): vec.append(lv[dim + 1]) w2v[word] = vec
def main(argv): del argv # Unused. db_file = join(FLAGS.data_root, FLAGS.db_file) parsed_std_sql_file = join(FLAGS.data_root, FLAGS.parsed_std_sql_file) parsed_pred_sql_file = join(FLAGS.data_root, FLAGS.parsed_pred_sql_file) engine = DBEngine(db_file) exact_match = [] with open(parsed_std_sql_file) as fs, open(parsed_pred_sql_file) as fp: grades = [] for ls, lp in tqdm(zip(fs, fp), total=count_lines(parsed_std_sql_file)): eg = json.loads(ls) ep = json.loads(lp) try: qg = Query.from_dict(eg['sql']) gold = engine.execute_query(eg['table_id'], qg, lower=True) except Exception as e: gold = repr(e) #pred = ep['error'] qp = None #if not ep['error']: if True: try: qp = Query.from_dict(ep['sql']) pred = engine.execute_query(eg['table_id'], qp, lower=True) except Exception as e: pred = repr(e) correct = pred == gold match = qp == qg if pred == gold and qp != qg: print(qp) print(qg) grades.append(correct) exact_match.append(match) print( json.dumps( { 'ex_accuracy': sum(grades) / len(grades), 'lf_accuracy': sum(exact_match) / len(exact_match), }, indent=2))
def __init__(self): self.data_dir = '' opts = Options() fill_from_args(opts) for split in ['train', 'dev', 'test']: orig = os.path.join(opts.data_dir, f'{split}.jsonl') db_file = os.path.join(opts.data_dir, f'{split}.db') ans_file = os.path.join(opts.data_dir, f"{split}_ans.jsonl.gz") tbl_file = os.path.join(opts.data_dir, f"{split}.tables.jsonl") engine = DBEngine(db_file) exact_match = [] with open(orig) as fs, write_open(ans_file) as fo: grades = [] for ls in tqdm(fs, total=count_lines(orig)): eg = json.loads(ls) sql = eg['sql'] qg = Query.from_dict(sql, ordered=False) gold = engine.execute_query(eg['table_id'], qg, lower=True) assert isinstance(gold, list) #if len(gold) != 1: # print(f'for {sql} : {gold}') eg['answer'] = gold eg['rowids'] = engine.execute_query_rowid(eg['table_id'], qg, lower=True) # CONSIDER: if it is not an agg query, somehow identify the particular cell fo.write(json.dumps(eg) + '\n') convert(jsonl_lines(ans_file),
def main(): for split in ['test', 'dev']: fsplit = os.path.join(data_path, split) + '.jsonl' ftable = os.path.join(data_path, split) + '.tables.jsonl' with open(fsplit) as fs, open(ftable) as ft: print('loading tables') tables = {} for line in tqdm(ft, total=count_lines(ftable)): d = json.loads(line) tables[d['id']] = d print('loading tables done.') with open(data_path+'%s_infer.txt'%split, 'r') as in_file, \ open(save_path+'%s_infer.txt'%split, 'w') as out_file,\ open(save_path+'%s_ground_truth.txt'%split,'r') as true0,\ open(save_path+'%s_ground_truth_mark.txt'%split,'r') as true,\ open(save_path+'%s_SQL2tableid.txt'%split,'r') as id_file,\ open(save_path+'%s_sym_pairs.txt'%split,'r') as pairs_file: lines = in_file.readlines() gt_lines = true0.readlines() gt_mark_lines = true.readlines() ids = id_file.readlines() pairs = pairs_file.readlines() idx = 1 for line, gt, gt_mark, t_id, pair in zip(lines, gt_lines, gt_mark_lines, ids, pairs): t_id = t_id.replace('\n', '') new = '' inf = line rows = tables[t_id]['rows'] rows = np.asarray(rows) fs = tables[t_id]['header'] all_f = [] for f in fs: f = _preclean(str(f)) all_f.append(f) for row in rows: for i in range(len(fs)): f = _preclean(str(row[i])) all_f.append(f) #all fields (including values) are sorted in dscending order all_f.sort(key=len, reverse=True) added = [ ] # collect phrases with SYMBOL word (e.g.,'MAX') in it for f in all_f: PASS = False ADD = False #SYMBOL in it or not a valid sub-phrase for a in added: if f in a: PASS = True break if len(f.split()) == 1 and f not in line.split(): PASS = True # whether SYMBOL is in it for s in symbols: if s in f.split(): ADD = True break if f in line and PASS == False and ADD and f != 'where': added.append(f) # added phrases are replaced away for iden, phrase in enumerate(added): line = line.replace(phrase, '<p' + str(iden) + '>') tokens = line.split() for i in range(len(tokens)): token = tokens[i] if i == 0: if token not in symbols: new += '\"' elif tokens[i - 1] in symbols and token not in symbols: new += '\"' new += token if i == len(tokens) - 1: if token not in symbols: new += '\"' elif tokens[i + 1] in symbols and token not in symbols: new += '\"' new += ' ' new = new.strip() # replace back for iden, phrase in enumerate(added): new = new.replace('<p' + str(iden) + '>', phrase) gt_mark = gt_mark.replace('\n', '') if gt == inf and gt_mark != (new): print(added) print('inf:' + line) print('gt:' + gt) print('gt_mark:' + gt_mark + '.') print('inf_mark:' + new + '.') out_file.write(new + '\n') idx += 1
from lib.dbengine import DBEngine from lib.query import Query from lib.common import count_lines if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('source_file', help='source file for the prediction') parser.add_argument('db_file', help='source database for the prediction') parser.add_argument('pred_file', help='predictions by the model') args = parser.parse_args() engine = DBEngine(args.db_file) exact_match = [] with open(args.source_file) as fs, open(args.pred_file) as fp: grades = [] for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.pred_file)): eg = json.loads(ls) ep = json.loads(lp) qg = Query.from_dict(eg['sql']) #print qg gold = engine.execute_query(eg['table_id'], qg, lower=True) pred = ep['error'] qp = None if not ep['error']: try: qp = Query.from_dict(ep['query']) pred = engine.execute_query(eg['table_id'], qp, lower=True) except Exception as e: pred = repr(e) correct = pred == gold match = qp == qg
def data_(inv_, args): agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] cond_ops = ['=', '>', '<', 'OP'] syms = [ 'SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS' ] db = MySQLdb.connect( host="localhost", # your host user="******", # username passwd="cfhoCPkr", # password db="wikisql") # name of the database data = [] cur = db.cursor() table_file = args.source_file.split( '.')[0] + '.tables.' + args.source_file.split('.')[1] table_header = {} with open(table_file) as tb: grades = [] for ls in tqdm(tb, total=count_lines(table_file)): eg = json.loads(ls) table_header[eg['id']] = [ re.split(b'[\/|,|;|\s]\s*', h.encode('utf8')) for h in eg['header'] ] with open(args.source_file) as fs: grades = [] for ls in tqdm(fs, total=count_lines(args.source_file)): eg = json.loads(ls) table_id = eg['table_id'] cq = '' i = 0 for cond in eg['sql']['conds']: col = cond[0] op = cond[1] if isinstance(cond[2], str): cval = cond[2].encode('utf8') else: cval = str(cond[2]).encode('utf8') cq = cq + ('col' + str(col)) + cond_ops[op] + '\'' + str(cval) + '\' ' if i < len(eg['sql']['conds']) - 1: i = i + 1 cq = cq + 'AND ' if (eg['sql']['agg'] > 0): agg = agg_ops[eg['sql']['agg']] query = 'SELECT ' + agg + '(col' + str( eg['sql']['sel']) + ')' + ' FROM table_{}'.format( (table_id.replace('-', '_'))) + ' WHERE ' + cq else: query = 'SELECT ' + ( 'col' + str(eg['sql']['sel'])) + ' FROM table_{}'.format( (table_id.replace('-', '_'))) + ' WHERE ' + cq query = query.replace(str(u'\u2013'.encode('utf8')), '-').replace( '\xc4\x81', 'a').replace('\xe1\xb9\x83', 'm').replace('[[', '').replace( ',', '').replace('||', '').replace(']]', '').replace('|', '') #gold = cur.execute(query) ## TO DO .. compare pred with gold hl = [] # header list v = [ r for r in table_header[table_id.replace('table_', '').replace( '_', '-')] ] for a in v: for ap in a: t = re.sub(r'(\(|\)|,|;|!)', '', str.lower(str(ap))) if t in inv_: hl.append(inv_[t]) qv = re.split(r'[\(|\)|!|,|?|\-.|\\|\/|\{|\}|\[|\]|#|$|\&|\s]\s*', str.lower(str(eg['question'].encode('UTF8')))) qvu = [] for w in qv: if w in inv_ and w not in qvu: qvu.append(w) qs = [q for q in re.split(r'(\(|\)|\'|\=|\/|\s)\s*', query)] for q in qs: q = str.lower(q) if q not in inv_: inv_[q] = len(inv_) id_[len(inv_)] = q input_ = hl + [inv_[w] for w in qvu] + [inv_[str.lower(s)] for s in syms] data.append((input_, [ inv_[str.lower(q)] for q in re.split(r'[\(|\)|\'|=|/|\s]\s*', query) ])) return data
def main(): parser = ArgumentParser() parser.add_argument('--din', default=data_path, help='data directory') parser.add_argument('--dout', default='annotated', help='output directory') args = parser.parse_args() if not os.path.isdir(args.dout): os.makedirs(args.dout) for split in ['dev']: #for split in ['train', 'test', 'dev']: with open(save_path+'%s.qu'%split, 'w') as qu_file, open(save_path+'%s.lon'%split, 'w') as lon_file, \ open(save_path+'%s.out'%split, 'w') as out, open(save_path+'%s_sym_pairs.txt'%split, 'w') as sym_file, \ open(save_path+'%s_ground_truth.txt'%split, 'w') as S_file: fsplit = os.path.join(args.din, split) + '.jsonl' ftable = os.path.join(args.din, split) + '.tables.jsonl' with open(fsplit) as fs, open(ftable) as ft: print('loading tables') tables = {} for line in tqdm(ft, total=count_lines(ftable)): d = json.loads(line) tables[d['id']] = d print('loading tables done.') print('loading examples') n, acc, acc_pair, acc_all, error = 0, 0, 0, 0, 0 target = -1 ADD_FIELDS = False step2 = True for line in tqdm(fs, total=count_lines(fsplit)): ADD_TO_FILE = True d = json.loads(line) Q = d['question'] rows = tables[d['table_id']]['rows'] rows = np.asarray(rows) fs = tables[d['table_id']]['header'] all_fields = [] for f in fs: all_fields.append(_preclean(f)) # all fields are sorted by length in descending order # for string match purpose all_fields = sorted(all_fields, key=len, reverse=True) smap = defaultdict(list) #f2v reverse_map = defaultdict(list) #v2f for row in rows: for i in range(len(fs)): cur_f = _preclean(str(fs[i])) cur_row = _preclean(str(row[i])) smap[cur_f].append(cur_row) if cur_f not in reverse_map[cur_row]: reverse_map[cur_row].append(cur_f) #---------------------------------------------------------- # all values are sorted by length in descending order # for string match purpose keys = sorted(reverse_map.keys(), key=len, reverse=True) Q = _preclean(Q) Q_ori = Q ##################################### ########## Annotate question ######## ##################################### candidates, cond_fields = _match_pairs( Q, Q_ori, keys, reverse_map) Q_head, head2partial = _match_head(Q, Q_ori, smap, all_fields, cond_fields) Q, Qpairs = annotate_Q(Q, Q_ori, Q_head, candidates, all_fields, head2partial, n, target) qu_file.write(Q + '\n') validation_pairs = copy.copy(Qpairs) validation_pairs.append((Q_head, '<f0>', 'head')) for i, f in enumerate(all_fields): validation_pairs.append((f, '<c' + str(i) + '>', 'c')) ##################################### ########## Annotate SQL ############# ##################################### q_sent = Query.from_dict(d['sql']) S, col_names, val_names = q_sent.to_sentence( tables[d['table_id']]['header'], rows, tables[d['table_id']]['types']) S = _preclean(S) S_ori = S S_noparen = q_sent.to_sentence_noparenthesis( tables[d['table_id']]['header'], rows, tables[d['table_id']]['types']) S_noparen = _preclean(S_noparen) col_names = [_preclean(col_name) for col_name in col_names] val_names = [_preclean(val_name) for val_name in val_names] HEAD = col_names[-1] S_head = _preclean(HEAD) #annotate for SQL name_pairs = [] for col_name, val_name in zip(col_names, val_names): if col_name == val_name: name_pairs.append([_preclean(col_name), 'true']) else: name_pairs.append( [_preclean(col_name), _preclean(val_name)]) # sort to compare with candidates name_pairs.sort(key=lambda x: x[1]) new_name_pairs = [[ '<f' + str(i + 1) + '>', '<v' + str(i + 1) + '>' ] for i, (field, value) in enumerate(name_pairs)] # only annotate S while identified (f,v) pairs are right if _equal(name_pairs, candidates): pairs = [] for (f, v), (new_f, new_v) in zip(name_pairs, new_name_pairs): pairs.append((f, new_f, 'f')) pairs.append((v, new_v, 'v')) # sort (word,symbol) pairs by length in descending order pairs.sort(key=lambda x: 100 - len(x[0])) for p, new_p, t in pairs: cp = _backslash(p) if new_p in Q: if t == 'v': S = S.replace(p + ' )', new_p + ' )') if t == 'f': S = re.sub( '\( ' + cp + ' (equal|less|greater)', '( ' + new_p + r' \1', S) # only annotate S while identified HEAD is right if S_head == Q_head and '<f0>' in Q: S = S.replace(S_head, '<f0>') # annote unseen fields if ADD_FIELDS: for i, f in enumerate(all_fields): cf = _backslash(f) S = re.sub('(\s|^)' + cf + '(\s|$|s)', ' <c' + str(i) + '> ', S) S = _clean(S) lon_file.write(S + '\n') ############################### ######### VALIDATION ########## ############################### recover_S = S for word, sym, t in validation_pairs: recover_S = recover_S.replace(sym, word) sym_file.write(sym + '=>' + word + '<>') sym_file.write('\n') S_file.write(S_noparen + '\n') #------------------------------------------------------------------------ if _equal(name_pairs, candidates): acc_pair += 1 if Q_head == S_head: acc += 1 if _equal(name_pairs, candidates) and Q_head == S_head: acc_all += 1 full_anno = True for s in S.split(): if s[0] != '<' and s not in [ '(', ')', 'where', 'less', 'greater', 'equal', 'max', 'min', 'count', 'sum', 'avg', 'and', 'true' ]: error += 1 full_anno = False break if False and not (_equal(name_pairs, candidates) and head == col_names[-1]): print('--------' + str(n) + '-----------') print(Q_ori) print(Q) print(S_ori) print(S) print('head:') print(head) print(head2partial) print('true HEAD') print(Q_head) print('fields:') print(candidates) print(p2partial) print('true fields') print(name_pairs) n += 1 print('total number of examples:' + str(n)) print('fully snnotated:' + str(1 - error * 1.0 / n)) print('accurate all percent:' + str(acc_all * 1.0 / n)) print('accurate HEAD match percent:' + str(acc * 1.0 / n)) print('accurate fields pair match percent:' + str(acc_pair * 1.0 / n))
def _gen_files(save_path=save_path): """ Prepare files used for Machine Comprehension Binary Classifier """ for split in ['train', 'test', 'dev']: print('------%s-------' % split) n = 0 fsplit = os.path.join(data_path, split) + '.jsonl' ftable = os.path.join(data_path, split) + '.tables.jsonl' """ test.txt original column content w/o truncate test_model.txt column name truncate or pad to length 3 """ with open(fsplit) as fs, open(ftable) as ft, \ open(os.path.join(save_path, split+'.txt'), mode='w') as fw, \ open(os.path.join(save_path, '%s_model.txt'%split), mode='w') as fw_model, \ open(os.path.join(save_path, split+'.lon'), mode='w') as fsql, \ open(os.path.join(save_path, '%s.ori.qu'%split), 'w') as qu_file, \ open(os.path.join(save_path, '%s.ori.lon'%split), 'w') as lon_file: print('loading tables...') tables = {} for line in tqdm(ft, total=count_lines(ftable)): d = json.loads(line) tables[d['id']] = d print('loading tables done.') print('loading examples') f2v_all, v2f_all = [], [] for line in tqdm(fs, total=count_lines(fsplit)): d = json.loads(line) Q = d['question'] Q = _preclean(Q).replace('\t', '') qu_file.write(Q + '\n') q_sent = Query.from_dict(d['sql']) rows = tables[d['table_id']]['rows'] S, col_names, val_names = q_sent.to_sentence( tables[d['table_id']]['header'], rows, tables[d['table_id']]['types']) S = _preclean(S) lon_file.write(S + '\n') rows = np.asarray(rows) fs = tables[d['table_id']]['header'] all_fields = [_preclean(f) for f in fs] # all fields are sorted by length in descending order # for string match purpose headers = sorted(all_fields, key=len, reverse=True) f2v = defaultdict(list) #f2v v2f = defaultdict(list) #v2f for row in rows: for i in range(len(fs)): cur_f = _preclean(str(fs[i])) cur_row = _preclean(str(row[i])) #cur_f = cur_f.replace('\u2003',' ') f2v[cur_f].append(cur_row) if cur_f not in v2f[cur_row]: v2f[cur_row].append(cur_f) f2v_all.append(f2v) v2f_all.append(v2f) ##################################### ########## Annotate SQL ############# ##################################### q_sent = Query.from_dict(d['sql']) S, col_names, val_names = q_sent.to_sentence( tables[d['table_id']]['header'], rows, tables[d['table_id']]['types']) S = _preclean(S) S_noparen = q_sent.to_sentence_noparenthesis( tables[d['table_id']]['header'], rows, tables[d['table_id']]['types']) S_noparen = _preclean(S_noparen) col_names = [_preclean(col_name) for col_name in col_names] val_names = [_preclean(val_name) for val_name in val_names] HEAD = col_names[-1] S_head = _preclean(HEAD) #annotate for SQL name_pairs = [] for col_name, val_name in zip(col_names, val_names): if col_name == val_name: name_pairs.append([_preclean(col_name), 'true']) else: name_pairs.append( [_preclean(col_name), _preclean(val_name)]) # sort to compare with candidates name_pairs.sort(key=lambda x: x[1]) fsql.write('#%d\n' % n) fw.write('#%d\n' % n) for f in col_names: fsql.write(S.replace(f, '[' + f + ']') + '\n') f = _truncate(f, END='<bos>', PAD='<pad>', max_len=-1) s = (Q + '\t' + f + '\t 1') assert len(s.split('\t')) == 3 fw.write(s + '\n') for f in [f for f in headers if f not in col_names]: f = _truncate(f, END='<bos>', PAD='<pad>', max_len=-1) s = (Q + '\t' + f + '\t 0') assert len(s.split('\t')) == 3 fw.write(s + '\n') fsql.write(S + '\n') #if '\u2003' in Q: # print('u2003: '+Q) #if '\xa0' in Q: # print(n) # print('xa0: '+Q) # print(S) for f in col_names: f = f.replace(u'\xa0', u' ').replace('\t', '') Q = Q.replace(u'\xa0', u' ').replace('\t', '') f = _truncate(f, END='bos', PAD='pad', max_len=3) s = (Q + '\t' + f + '\t 1') assert len(s.split('\t')) == 3 fw_model.write(s + '\n') for f in [f for f in headers if f not in col_names]: f = f.replace(u'\xa0', u' ').replace('\t', '') Q = Q.replace(u'\xa0', u' ').replace('\t', '') f = _truncate(f, END='bos', PAD='pad', max_len=3) s = (Q + '\t' + f + '\t 0') assert len(s.split('\t')) == 3 fw_model.write(s + '\n') n += 1 fsql.write('#%d\n' % n) fw.write('#%d\n' % n) scipy.savez(os.path.join(save_path, '%s_dict.npz' % split), f2v_all=f2v_all, v2f_all=v2f_all) print('num of records:%d' % n)
def index_data(data_file): id_ = { 0: '', 1: '<EOS>', 2: 'max', 3: 'min', 4: 'count', 5: 'sum', 6: 'avg', 7: '=', 8: '>', 9: '<', 10: 'op', 11: 'select', 12: 'where', 13: 'and', 14: 'col', 15: 'table', 16: 'caption', 17: 'page', 18: 'section', 19: 'cond', 20: 'question', 21: 'agg', 22: 'aggops', 23: 'condops', 24: 'col0', 25: 'col1', 26: 'col2', 27: 'col3', 28: 'col4', 29: 'col5', 30: 'col6', 31: 'col7', 32: 'col8', 33: 'col9', 34: 'col10', 35: 'col11', 36: 'col12', 37: 'col13', 38: 'col14', 39: 'col15', 40: 'svaha' } inv_ = {v: k for k, v in id_.items()} i = 41 for df in data_file: if 'tables' in df: with open(df) as ind: for ls in tqdm(ind, total=count_lines(df)): eg = json.loads(ls) if (eg.get('id')): table_id = 'table_{}'.format(eg['id'].replace( '-', '_')) if table_id not in inv_: id_[i] = table_id inv_[table_id] = i i = i + 1 else: with open(df) as ind: for ls in tqdm(ind, total=count_lines(df)): for t in re.split(r'(,|;|/|/\|!|@|#|$|\"|\(|\)|`|=|\s)\s*', ls): w = str.lower(t) if w not in inv_: id_[i] = w inv_[w] = i i = i + 1 return id_, inv_
parser.add_argument('pred_file', help='predictions by the model') parser.add_argument('--ordered', action='store_true', help='whether the exact match should consider the order of conditions') parser.add_argument('--topk', type=int, default=3, help='k of top_k') args = parser.parse_args() engine = DBEngine(args.db_file) temp = [] with open(args.source_file) as fs, open(args.pred_file) as fp: grades = [] exact_match = [] for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)): eg = json.loads(ls) qg = Query.from_dict(eg['sql'], ordered=args.ordered) gold = engine.execute_query(eg['table_id'], qg, lower=True) pred_topk = [] qp_topk = [] ep = json.loads(lp) pred = ep.get('error', None) qp = None for i in range(args.topk): if not ep.get('error', None): try: if ep['query'][str(i)]['conds'] == [[]]:
parser.add_argument('--dout', default='annotated', help='output directory') args = parser.parse_args() if not os.path.isdir(args.dout): os.makedirs(args.dout) for split in ['train', 'dev', 'test']: fsplit = os.path.join(args.din, split) + '.jsonl' ftable = os.path.join(args.din, split) + '.tables.jsonl' fout = os.path.join(args.dout, split) + '.jsonl' print('annotating {}'.format(fsplit)) with open(fsplit) as fs, open(ftable) as ft, open(fout, 'wt') as fo: print('loading tables') tables = {} for line in tqdm(ft, total=count_lines(ftable)): d = json.loads(line) tables[d['id']] = d print('loading examples') n_written = 0 for line in tqdm(fs, total=count_lines(fsplit)): d = json.loads(line) a = annotate_example(d, tables[d['table_id']]) if not is_valid_example(a): raise Exception(str(a)) gold = Query.from_tokenized_dict(a['query']) reconstruct = Query.from_sequence(a['seq_output'], a['table'], lowercase=True) if gold.lower() != reconstruct.lower():
from lib.dbengine import DBEngine from lib.common import count_lines if __name__=="__main__": parser = argparse.ArgumentParser() parser.add_argument("--source_file") parser.add_argument("--db_file") parser.add_argument("--pred_file") parser.add_argument("--ordered", action='store_true') args = parser.parse_args() engine = DBEngine(args.db_file) ex_acc_list = [] lf_acc_list = [] with open(args.source_file) as sf, open(args.pred_file) as pf: for source_line, pred_line in tqdm(zip(sf, pf), total=count_lines(args.source_file)): # line별 정답과 예측 샘플 가져오기 gold_example = json.loads(source_line) pred_example = json.loads(pred_line) # 정답 샘플 lf, ex 구하기 lf_gold_query = Query.from_dict(gold_example['sql'], ordered=args.ordered) ex_gold = engine.execute_query(gold_example['table_id'], lf_gold_query, lower=True) # error가 아닌 경우 예측 샘플 lf, ex 구하기 lf_pred_query = None ex_pred = pred_example.get('error', None) if not ex_pred: try: lf_pred_query = Query.from_dict(pred_example['query'], ordered=args.ordered) ex_pred = engine.execute_query(gold_example['table_id'], lf_pred_query, lower=True)