def process_examples(fexample, ftable, fout): print('annotating {}'.format(fexample)) with open(fexample) as fs, open(ftable) as ft, open(fout, 'wt') as fo: print('loading tables') tables = {} for line in tqdm(ft, total=count_lines(ftable)): d = json.loads(line) tables[d['id']] = d print('loading examples') n_written = 0 for line in tqdm(fs, total=count_lines(fexample)): d = json.loads(line) a = annotate_example(d, tables[d['table_id']]) if not is_valid_example(a): # raise Exception(str(a)) print('Invalid example: {}'.format(str(a))) continue gold = Query.from_tokenized_dict(a['query']) reconstruct = Query.from_sequence(a['seq_output'], a['table'], lowercase=True) if gold.lower() != reconstruct.lower(): raise Exception ('Expected:\n{}\nGot:\n{}'.format(gold, reconstruct)) fo.write(json.dumps(a) + '\n') n_written += 1 print('wrote {} examples'.format(n_written))
def generate_query(self, db, max_cond=4): max_cond = min(len(self.header), max_cond) # sample a select column sel_index = random.choice(list(range(len(self.header)))) # sample where conditions query = Query(-1, Query.agg_ops.index('')) results = self.execute_query(db, query) condition_options = list(range(len(self.header))) condition_options.remove(sel_index) for i in range(max_cond): if not results: break cond_index = random.choice(condition_options) if self.types[cond_index] == 'text': cond_op = Query.cond_ops.index('=') else: cond_op = random.choice(list(range(len(Query.cond_ops)))) cond_val = random.choice([r[cond_index] for r in results]) query.conditions.append((cond_index, cond_op, cond_val)) new_results = self.execute_query(db, query) if [r[sel_index] for r in new_results] != [r[sel_index] for r in results]: condition_options.remove(cond_index) results = new_results else: query.conditions.pop() # sample an aggregation operation if self.types[sel_index] == 'text': query.agg_index = Query.agg_ops.index('') else: query.agg_index = random.choice(list(range(len(Query.agg_ops)))) query.sel_index = sel_index results = self.execute_query(db, query) return query, results
def eval_one_qelos(db_file, pred_file, source_file): engine = DBEngine(db_file) exact_match = [] with open(source_file) as fs, open(pred_file) as fp: grades = [] for ls, lp in tqdm(zip(fs, fp), total=count_lines(source_file)): eg = json.loads(ls) ep = json.loads(lp) qg = Query.from_dict(eg['sql']) gold = engine.execute_query(eg['table_id'], qg, lower=True) qp = None try: qp = Query.from_dict(ep) pred = engine.execute_query(eg['table_id'], qp, lower=True) except Exception as e: pred = repr(e) correct = pred == gold match = qp == qg grades.append(correct) exact_match.append(match) result = { 'ex_accuracy': sum(grades) / len(grades), 'lf_accuracy': sum(exact_match) / len(exact_match), } return result
def hello(): pred_json = json.loads(request.args.get('pred')) gold_json = json.loads(request.args.get('gold')) dataset = request.args.get('dataset') engine = DBEngine(os.path.join(DATABASE_PATH, "{}.db".format(dataset))) exact_match = [] grades = [] for lp, ls in tqdm(zip(pred_json, gold_json), total=len(gold_json)): eg = ls ep = lp qg = Query.from_dict(eg['sql']) gold = engine.execute_query(eg['table_id'], qg, lower=True) pred = ep['error'] qp = None if not ep['error']: try: qp = Query.from_dict(ep['query']) pred = engine.execute_query(eg['table_id'], qp, lower=True) except Exception as e: pred = repr(e) correct = pred == gold match = qp == qg grades.append(correct) exact_match.append(match) ex_accuracy = sum(grades) / len(grades) lf_accuracy = sum(exact_match) / len(exact_match) return json.dumps({"ex_accuracy": ex_accuracy, "lf_accuracy": lf_accuracy})
def eval(self, gold, sql_gold, engine): self.exception_raised = False # reset this flag for a new evaluation if self.agg == gold['query']['agg']: self.correct['agg'] = 1 if self.sel == gold['query']['sel']: self.correct['sel'] = 1 op_list_pred = [op for col, op, span in self.cond] op_list_gold = [op for col, op, span in gold['query']['conds']] col_list_pred = [col for col, op, span in self.cond] col_list_gold = [col for col, op, span in gold['query']['conds']] q = gold['question']['words'] span_list_pred = [' '.join(q[span[0]:span[1] + 1]) for col, op, span in self.cond] span_list_gold = [' '.join(span['words']) for col, op, span in gold['query']['conds']] where_pred = list(zip(col_list_pred, op_list_pred, span_list_pred)) where_gold = list(zip(col_list_gold, op_list_gold, span_list_gold)) where_pred.sort() where_gold.sort() if where_pred == where_gold and (len(col_list_pred) == len(col_list_gold)) and (len(op_list_pred) == len(op_list_gold)) and (len(span_list_pred) == len(span_list_gold)): self.correct['where'] = 1 if (len(col_list_pred) == len(col_list_gold)) and ([it[0] for it in where_pred] == [it[0] for it in where_gold]): self.correct['col'] = 1 if (len(op_list_pred) == len(op_list_gold)) and ([it[1] for it in where_pred] == [it[1] for it in where_gold]): self.correct['lay'] = 1 if (len(span_list_pred) == len(span_list_gold)) and ([it[2] for it in where_pred] == [it[2] for it in where_gold]): self.correct['span'] = 1 if all((self.correct[it] == 1 for it in ('agg', 'sel', 'where'))): self.correct['all'] = 1 # execution table_id = gold['table_id'] ans_gold = engine.execute_query( table_id, Query.from_dict(sql_gold), lower=True) try: sql_pred = {'agg':self.agg, 'sel':self.sel, 'conds': self.recover_cond_to_gloss(gold)} ans_pred = engine.execute_query( table_id, Query.from_dict(sql_pred), lower=True) except Exception as e: #pdb.set_trace() self.exception_raised = True ans_pred = repr(e) if set(ans_gold) == set(ans_pred): self.correct['exe'] = 1
def main(anno_file_name, col_headers, raw_args=None, verbose=True): parser = argparse.ArgumentParser(description='evaluate.py') opts.translate_opts(parser) opt = parser.parse_args(raw_args) torch.cuda.set_device(opt.gpu) opt.db_file = os.path.join(opt.data_path, '{}.db'.format(opt.split)) opt.pre_word_vecs = os.path.join(opt.data_path, 'embedding') dummy_parser = argparse.ArgumentParser(description='train.py') opts.model_opts(dummy_parser) opts.train_opts(dummy_parser) dummy_opt = dummy_parser.parse_known_args([])[0] opt.anno = anno_file_name engine = DBEngine(opt.db_file) js_list = table.IO.read_anno_json(opt.anno) prev_best = (None, None) sql_query = [] for fn_model in glob.glob(opt.model_path): opt.model = fn_model translator = Translator(opt, dummy_opt.__dict__) data = table.IO.TableDataset(js_list, translator.fields, None, False) test_data = table.IO.OrderedIterator(dataset=data, device=opt.gpu, batch_size=opt.batch_size, train=False, sort=True, sort_within_batch=False) # inference r_list = [] for batch in test_data: r_list += translator.translate(batch) r_list.sort(key=lambda x: x.idx) pred = r_list[-1] sql_pred = { 'agg': pred.agg, 'sel': pred.sel, 'conds': pred.recover_cond_to_gloss(js_list[-1]) } if verbose: print('\n sql_pred: ', sql_pred, '\n') print('\n col_headers: ', col_headers, '\n') sql_query = Query(sql_pred['sel'], sql_pred['agg'], sql_pred['conds']) try: ans_pred = engine.execute_query(js_list[-1]['table_id'], Query.from_dict(sql_pred), lower=True, verbose=verbose) except Exception as e: ans_pred = None return sql_query.get_complete_query(col_headers), ans_pred
def test_query_all(self): self._prepare_query_test() files_2_3 = self.tie_core.query( Query([QUERY_TAG_2, QUERY_TAG_3], MatchType.all)) self.assertEqual( [os.path.abspath(QUERY_FILE_2)], files_2_3, "Query all with 2 tags did not find the expected files") files_3 = self.tie_core.query(Query([QUERY_TAG_3], MatchType.all)) self.assertEqual( [os.path.abspath(QUERY_FILE_2), os.path.abspath(QUERY_FILE_3)], files_3, "Query all with 1 tag did not find the expected files") _clean_after_query_test()
def _query(core, tags, match_type, frontend: Frontend): query = Query(tags, match_type) result_files = core.query(query) if len(result_files) == 0: frontend.show_message("No files found for your query!") else: print_out_list(result_files)
def test_query_all_no_match(self): self._prepare_query_test() result = self.tie_core.query( Query([QUERY_TAG_2, "some random non existant tag"], MatchType.all)) self.assertEqual([], result, "Query with no result") _clean_after_query_test()
def test(): # convert query dict to text (without correct column references) details = {"sel": 5, "conds": [[3, 0, "SOUTH AUSTRALIA"]], "agg": 0} test_str = Query(details["sel"], details["agg"], details["conds"]) print(test_str) db = records.Database('sqlite:///data/train.db') conn = db.get_connection() # convert query dict to text with table reference (still does not give the correct columns) # because header is not supplied table = Table.from_db(conn, "1-1000181-1") print(table.query_str(test_str)) # convert query dict to text with table reference after supplying headers table_data = { "id": "1-1000181-1", "header": [ "State/territory", "Text/background colour", "Format", "Current slogan", "Current series", "Notes" ], "types": [], "rows": [] } t = Table(table_data["id"], table_data["header"], table_data["types"], table_data["rows"]) print(t.query_str(test_str))
def get_query_from_json(self, json_line): """Returns a Query object for the json input and returns the table object as well""" q = Query.from_dict(json_line["sql"]) t_id = json_line["table_id"] table = self.table_map[t_id] t = Table("", table["header"], table["types"], table["rows"]) return t, q
def toQueryStr(file_name, table_arr, type=0, test_batch_size=1000): path = os.path.join(DATA_DIR, '{}.jsonl'.format(file_name)) print(path) with open(path, 'r') as pf: data = pf.readlines() idxs = np.arange(len(data)) data = np.array(data, dtype=np.object) np.random.seed(0) # set random seed so that random things are reproducible np.random.shuffle(idxs) data = data[idxs] batched_data = chunked(data, test_batch_size) print("start processing") examples = [] for batch_idx, batch_data in enumerate(batched_data): if len(batch_data) < test_batch_size: break # the last batch is smaller than the others, exclude. for d_idx, d in enumerate(batch_data): line = json.loads(str(d), encoding='utf-8') doc_token = line['question'] code_arr = line['sql'] query = Query(code_arr['sel'], code_arr['agg'], code_arr['conds']) id = line['table_id'] table = Table("table_id", "header", "types", "rows") code_str = '' for table in table_arr: if table.table_id == id: table = table code_str = table.query_str(query) break else: continue isNegative = np.random.randint(2) if isNegative == 0: random_line_num = np.random.randint(len(data)) line = json.loads(str(data[random_line_num]), encoding='utf-8') doc_token = line['question'] code_token = code_str else: code_token = code_str example = (str(isNegative), "nothing", "nothing", doc_token, code_token) example = '<CODESPLIT>'.join(example) examples.append(example) data_path = os.path.join(DATA_DIR, 'train_valid/wiki_sql') if not os.path.exists(data_path): os.makedirs(data_path) output_file_name = "1.txt" if type == 0: output_file_name = 'train.txt' else: output_file_name = 'valid.txt' file_path = os.path.join(data_path, output_file_name) print(file_path) with open(file_path, 'w', encoding='utf-8') as f: f.writelines('\n'.join(examples))
def main(argv): del argv # Unused. db_file = join(FLAGS.data_root, FLAGS.db_file) parsed_std_sql_file = join(FLAGS.data_root, FLAGS.parsed_std_sql_file) parsed_pred_sql_file = join(FLAGS.data_root, FLAGS.parsed_pred_sql_file) engine = DBEngine(db_file) exact_match = [] with open(parsed_std_sql_file) as fs, open(parsed_pred_sql_file) as fp: grades = [] for ls, lp in tqdm(zip(fs, fp), total=count_lines(parsed_std_sql_file)): eg = json.loads(ls) ep = json.loads(lp) try: qg = Query.from_dict(eg['sql']) gold = engine.execute_query(eg['table_id'], qg, lower=True) except Exception as e: gold = repr(e) #pred = ep['error'] qp = None #if not ep['error']: if True: try: qp = Query.from_dict(ep['sql']) pred = engine.execute_query(eg['table_id'], qp, lower=True) except Exception as e: pred = repr(e) correct = pred == gold match = qp == qg if pred == gold and qp != qg: print(qp) print(qg) grades.append(correct) exact_match.append(match) print( json.dumps( { 'ex_accuracy': sum(grades) / len(grades), 'lf_accuracy': sum(exact_match) / len(exact_match), }, indent=2))
def test_query_any(self): self._prepare_query_test() files = self.tie_core.query( Query([QUERY_TAG_2, QUERY_TAG_3], MatchType.any)) self.assertEqual([ os.path.abspath(QUERY_FILE_1), os.path.abspath(QUERY_FILE_2), os.path.abspath(QUERY_FILE_3) ], files, "Query any did not find the expected files") _clean_after_query_test()
def main(): app = QtGui.QApplication(sys.argv) path = find_data_file('/data.db') engine = sql.create_engine('sqlite:///' + path) q = Query(engine) form = App(q) form.show() app.exec_()
def toQueryStr(file_name, table_dic, type=0, test_batch_size=1000): path = os.path.join(DATA_DIR, '{}.jsonl'.format(file_name)) print(path) with open(path, 'r') as pf: data = pf.readlines() idxs = np.arange(len(data)) data = np.array(data, dtype=np.object) np.random.seed(0) # set random seed so that random things are reproducible np.random.shuffle(idxs) data = data[idxs] batched_data = chunked(data, test_batch_size) print("start processing") # tokenizer = RobertaTokenizer.from_pretrained("roberta-base") # tokenizer.add_special_tokens(False) examples = [] for batch_idx, batch_data in enumerate(batched_data): if len(batch_data) < test_batch_size: break # the last batch is smaller than the others, exclude. for d_idx, d in enumerate(batch_data): line = json.loads(str(d), encoding='utf-8') doc_str = line['question'] code_arr = line['sql'] query = Query(code_arr['sel'], code_arr['agg'], code_arr['conds']) id = line['table_id'] if id in table_dic: table = table_dic[id] code_str = table.query_str(query) example = dict() example['docstring_tokens'] = tokenize_docstring_from_string( doc_str) example['code_tokens'] = tokenize_docstring_from_string( code_str) examples.append(json.dumps(example)) data_path = os.path.join(DATA_DIR, 'wiki_sql') if not os.path.exists(data_path): os.makedirs(data_path) output_file_name = "1.txt" if type == 0: output_file_name = 'train.jsonl' elif type == 2: output_file_name = 'test.jsonl' else: output_file_name = 'valid.jsonl' file_path = os.path.join(data_path, output_file_name) print(file_path) with open(file_path, 'w', encoding='utf-8') as f: f.writelines('\n'.join(examples))
def _gen_files(save_path=save_path): """ Prepare files used for Machine Comprehension Binary Classifier """ for split in ['train', 'test', 'dev']: print('------%s-------' % split) n = 0 fsplit = os.path.join(data_path, split) + '.jsonl' ftable = os.path.join(data_path, split) + '.tables.jsonl' """ test.txt original column content w/o truncate test_model.txt column name truncate or pad to length 3 """ with open(fsplit) as fs, open(ftable) as ft, \ open(os.path.join(save_path, split+'.txt'), mode='w') as fw, \ open(os.path.join(save_path, '%s_model.txt'%split), mode='w') as fw_model, \ open(os.path.join(save_path, split+'.lon'), mode='w') as fsql, \ open(os.path.join(save_path, '%s.ori.qu'%split), 'w') as qu_file, \ open(os.path.join(save_path, '%s.ori.lon'%split), 'w') as lon_file: print('loading tables...') tables = {} for line in tqdm(ft, total=count_lines(ftable)): d = json.loads(line) tables[d['id']] = d print('loading tables done.') print('loading examples') f2v_all, v2f_all = [], [] for line in tqdm(fs, total=count_lines(fsplit)): d = json.loads(line) Q = d['question'] Q = _preclean(Q).replace('\t', '') qu_file.write(Q + '\n') q_sent = Query.from_dict(d['sql']) rows = tables[d['table_id']]['rows'] S, col_names, val_names = q_sent.to_sentence( tables[d['table_id']]['header'], rows, tables[d['table_id']]['types']) S = _preclean(S) lon_file.write(S + '\n') rows = np.asarray(rows) fs = tables[d['table_id']]['header'] all_fields = [_preclean(f) for f in fs] # all fields are sorted by length in descending order # for string match purpose headers = sorted(all_fields, key=len, reverse=True) f2v = defaultdict(list) #f2v v2f = defaultdict(list) #v2f for row in rows: for i in range(len(fs)): cur_f = _preclean(str(fs[i])) cur_row = _preclean(str(row[i])) #cur_f = cur_f.replace('\u2003',' ') f2v[cur_f].append(cur_row) if cur_f not in v2f[cur_row]: v2f[cur_row].append(cur_f) f2v_all.append(f2v) v2f_all.append(v2f) ##################################### ########## Annotate SQL ############# ##################################### q_sent = Query.from_dict(d['sql']) S, col_names, val_names = q_sent.to_sentence( tables[d['table_id']]['header'], rows, tables[d['table_id']]['types']) S = _preclean(S) S_noparen = q_sent.to_sentence_noparenthesis( tables[d['table_id']]['header'], rows, tables[d['table_id']]['types']) S_noparen = _preclean(S_noparen) col_names = [_preclean(col_name) for col_name in col_names] val_names = [_preclean(val_name) for val_name in val_names] HEAD = col_names[-1] S_head = _preclean(HEAD) #annotate for SQL name_pairs = [] for col_name, val_name in zip(col_names, val_names): if col_name == val_name: name_pairs.append([_preclean(col_name), 'true']) else: name_pairs.append( [_preclean(col_name), _preclean(val_name)]) # sort to compare with candidates name_pairs.sort(key=lambda x: x[1]) fsql.write('#%d\n' % n) fw.write('#%d\n' % n) for f in col_names: fsql.write(S.replace(f, '[' + f + ']') + '\n') f = _truncate(f, END='<bos>', PAD='<pad>', max_len=-1) s = (Q + '\t' + f + '\t 1') assert len(s.split('\t')) == 3 fw.write(s + '\n') for f in [f for f in headers if f not in col_names]: f = _truncate(f, END='<bos>', PAD='<pad>', max_len=-1) s = (Q + '\t' + f + '\t 0') assert len(s.split('\t')) == 3 fw.write(s + '\n') fsql.write(S + '\n') #if '\u2003' in Q: # print('u2003: '+Q) #if '\xa0' in Q: # print(n) # print('xa0: '+Q) # print(S) for f in col_names: f = f.replace(u'\xa0', u' ').replace('\t', '') Q = Q.replace(u'\xa0', u' ').replace('\t', '') f = _truncate(f, END='bos', PAD='pad', max_len=3) s = (Q + '\t' + f + '\t 1') assert len(s.split('\t')) == 3 fw_model.write(s + '\n') for f in [f for f in headers if f not in col_names]: f = f.replace(u'\xa0', u' ').replace('\t', '') Q = Q.replace(u'\xa0', u' ').replace('\t', '') f = _truncate(f, END='bos', PAD='pad', max_len=3) s = (Q + '\t' + f + '\t 0') assert len(s.split('\t')) == 3 fw_model.write(s + '\n') n += 1 fsql.write('#%d\n' % n) fw.write('#%d\n' % n) scipy.savez(os.path.join(save_path, '%s_dict.npz' % split), f2v_all=f2v_all, v2f_all=v2f_all) print('num of records:%d' % n)
def test_query_all_empty_tags_list(self): self._prepare_query_test() result = self.tie_core.query(Query([], MatchType.all)) self.assertEqual([], result, "Query with no tags") _clean_after_query_test()
def main(): parser = ArgumentParser() parser.add_argument('--din', default=data_path, help='data directory') parser.add_argument('--dout', default='annotated', help='output directory') args = parser.parse_args() if not os.path.isdir(args.dout): os.makedirs(args.dout) for split in ['dev']: #for split in ['train', 'test', 'dev']: with open(save_path+'%s.qu'%split, 'w') as qu_file, open(save_path+'%s.lon'%split, 'w') as lon_file, \ open(save_path+'%s.out'%split, 'w') as out, open(save_path+'%s_sym_pairs.txt'%split, 'w') as sym_file, \ open(save_path+'%s_ground_truth.txt'%split, 'w') as S_file: fsplit = os.path.join(args.din, split) + '.jsonl' ftable = os.path.join(args.din, split) + '.tables.jsonl' with open(fsplit) as fs, open(ftable) as ft: print('loading tables') tables = {} for line in tqdm(ft, total=count_lines(ftable)): d = json.loads(line) tables[d['id']] = d print('loading tables done.') print('loading examples') n, acc, acc_pair, acc_all, error = 0, 0, 0, 0, 0 target = -1 ADD_FIELDS = False step2 = True for line in tqdm(fs, total=count_lines(fsplit)): ADD_TO_FILE = True d = json.loads(line) Q = d['question'] rows = tables[d['table_id']]['rows'] rows = np.asarray(rows) fs = tables[d['table_id']]['header'] all_fields = [] for f in fs: all_fields.append(_preclean(f)) # all fields are sorted by length in descending order # for string match purpose all_fields = sorted(all_fields, key=len, reverse=True) smap = defaultdict(list) #f2v reverse_map = defaultdict(list) #v2f for row in rows: for i in range(len(fs)): cur_f = _preclean(str(fs[i])) cur_row = _preclean(str(row[i])) smap[cur_f].append(cur_row) if cur_f not in reverse_map[cur_row]: reverse_map[cur_row].append(cur_f) #---------------------------------------------------------- # all values are sorted by length in descending order # for string match purpose keys = sorted(reverse_map.keys(), key=len, reverse=True) Q = _preclean(Q) Q_ori = Q ##################################### ########## Annotate question ######## ##################################### candidates, cond_fields = _match_pairs( Q, Q_ori, keys, reverse_map) Q_head, head2partial = _match_head(Q, Q_ori, smap, all_fields, cond_fields) Q, Qpairs = annotate_Q(Q, Q_ori, Q_head, candidates, all_fields, head2partial, n, target) qu_file.write(Q + '\n') validation_pairs = copy.copy(Qpairs) validation_pairs.append((Q_head, '<f0>', 'head')) for i, f in enumerate(all_fields): validation_pairs.append((f, '<c' + str(i) + '>', 'c')) ##################################### ########## Annotate SQL ############# ##################################### q_sent = Query.from_dict(d['sql']) S, col_names, val_names = q_sent.to_sentence( tables[d['table_id']]['header'], rows, tables[d['table_id']]['types']) S = _preclean(S) S_ori = S S_noparen = q_sent.to_sentence_noparenthesis( tables[d['table_id']]['header'], rows, tables[d['table_id']]['types']) S_noparen = _preclean(S_noparen) col_names = [_preclean(col_name) for col_name in col_names] val_names = [_preclean(val_name) for val_name in val_names] HEAD = col_names[-1] S_head = _preclean(HEAD) #annotate for SQL name_pairs = [] for col_name, val_name in zip(col_names, val_names): if col_name == val_name: name_pairs.append([_preclean(col_name), 'true']) else: name_pairs.append( [_preclean(col_name), _preclean(val_name)]) # sort to compare with candidates name_pairs.sort(key=lambda x: x[1]) new_name_pairs = [[ '<f' + str(i + 1) + '>', '<v' + str(i + 1) + '>' ] for i, (field, value) in enumerate(name_pairs)] # only annotate S while identified (f,v) pairs are right if _equal(name_pairs, candidates): pairs = [] for (f, v), (new_f, new_v) in zip(name_pairs, new_name_pairs): pairs.append((f, new_f, 'f')) pairs.append((v, new_v, 'v')) # sort (word,symbol) pairs by length in descending order pairs.sort(key=lambda x: 100 - len(x[0])) for p, new_p, t in pairs: cp = _backslash(p) if new_p in Q: if t == 'v': S = S.replace(p + ' )', new_p + ' )') if t == 'f': S = re.sub( '\( ' + cp + ' (equal|less|greater)', '( ' + new_p + r' \1', S) # only annotate S while identified HEAD is right if S_head == Q_head and '<f0>' in Q: S = S.replace(S_head, '<f0>') # annote unseen fields if ADD_FIELDS: for i, f in enumerate(all_fields): cf = _backslash(f) S = re.sub('(\s|^)' + cf + '(\s|$|s)', ' <c' + str(i) + '> ', S) S = _clean(S) lon_file.write(S + '\n') ############################### ######### VALIDATION ########## ############################### recover_S = S for word, sym, t in validation_pairs: recover_S = recover_S.replace(sym, word) sym_file.write(sym + '=>' + word + '<>') sym_file.write('\n') S_file.write(S_noparen + '\n') #------------------------------------------------------------------------ if _equal(name_pairs, candidates): acc_pair += 1 if Q_head == S_head: acc += 1 if _equal(name_pairs, candidates) and Q_head == S_head: acc_all += 1 full_anno = True for s in S.split(): if s[0] != '<' and s not in [ '(', ')', 'where', 'less', 'greater', 'equal', 'max', 'min', 'count', 'sum', 'avg', 'and', 'true' ]: error += 1 full_anno = False break if False and not (_equal(name_pairs, candidates) and head == col_names[-1]): print('--------' + str(n) + '-----------') print(Q_ori) print(Q) print(S_ori) print(S) print('head:') print(head) print(head2partial) print('true HEAD') print(Q_head) print('fields:') print(candidates) print(p2partial) print('true fields') print(name_pairs) n += 1 print('total number of examples:' + str(n)) print('fully snnotated:' + str(1 - error * 1.0 / n)) print('accurate all percent:' + str(acc_all * 1.0 / n)) print('accurate HEAD match percent:' + str(acc * 1.0 / n)) print('accurate fields pair match percent:' + str(acc_pair * 1.0 / n))
for split in ['train', 'dev', 'test']: fsplit = os.path.join(args.din, split) + '.jsonl' ftable = os.path.join(args.din, split) + '.tables.jsonl' fout = os.path.join(args.dout, split) + '.jsonl' print('annotating {}'.format(fsplit)) with open(fsplit) as fs, open(ftable) as ft, open(fout, 'wt') as fo: print('loading tables') tables = {} for line in tqdm(ft, total=count_lines(ftable)): d = json.loads(line) tables[d['id']] = d print('loading examples') n_written = 0 for line in tqdm(fs, total=count_lines(fsplit)): d = json.loads(line) a = annotate_example(d, tables[d['table_id']]) if not is_valid_example(a): raise Exception(str(a)) gold = Query.from_tokenized_dict(a['query']) reconstruct = Query.from_sequence(a['seq_output'], a['table'], lowercase=True) if gold.lower() != reconstruct.lower(): raise Exception('Expected:\n{}\nGot:\n{}'.format( gold, reconstruct)) fo.write(json.dumps(a) + '\n') n_written += 1 print('wrote {} examples'.format(n_written))
parser.add_argument('--topk', type=int, default=3, help='k of top_k') args = parser.parse_args() engine = DBEngine(args.db_file) temp = [] with open(args.source_file) as fs, open(args.pred_file) as fp: grades = [] exact_match = [] for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)): eg = json.loads(ls) qg = Query.from_dict(eg['sql'], ordered=args.ordered) gold = engine.execute_query(eg['table_id'], qg, lower=True) pred_topk = [] qp_topk = [] ep = json.loads(lp) pred = ep.get('error', None) qp = None for i in range(args.topk): if not ep.get('error', None): try: if ep['query'][str(i)]['conds'] == [[]]: ep['query'][str(i)]['conds'] = []
if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('source_file', help='source file for the prediction') parser.add_argument('db_file', help='source database for the prediction') parser.add_argument('pred_file', help='predictions by the model') args = parser.parse_args() engine = DBEngine(args.db_file) exact_match = [] with open(args.source_file) as fs, open(args.pred_file) as fp: grades = [] for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.pred_file)): eg = json.loads(ls) ep = json.loads(lp) qg = Query.from_dict(eg['sql']) #print qg gold = engine.execute_query(eg['table_id'], qg, lower=True) pred = ep['error'] qp = None if not ep['error']: try: qp = Query.from_dict(ep['query']) pred = engine.execute_query(eg['table_id'], qp, lower=True) except Exception as e: pred = repr(e) correct = pred == gold match = qp == qg #if not match: # print qp.to_dict() # print qg.to_dict()
def eval(self, split, gold, sql_gold, engine=None): if self.agg == gold['query']['agg']: self.correct['agg'] = 1 if self.sel == gold['query']['sel']: self.correct['sel'] = 1 #gold['BIO_label'].data, gold['BIO_column_label'].data op_list_pred = [op for col, op, span in self.cond] op_list_gold = [op for col, op, span in gold['query']['conds']] col_list_pred = [col for col, op, span in self.cond] col_list_gold = [col for col, op, span in gold['query']['conds']] q = gold['question']['words'] span_list_pred = [ ' '.join(q[span[0]:span[1] + 1]) for col, op, span in self.cond ] span_list_gold = [ ' '.join(span['words']) for col, op, span in gold['query']['conds'] ] where_pred = list(zip(col_list_pred, op_list_pred, span_list_pred)) where_gold = list(zip(col_list_gold, op_list_gold, span_list_gold)) where_pred.sort() where_gold.sort() if where_pred == where_gold and ( len(col_list_pred) == len(col_list_gold)) and ( len(op_list_pred) == len(op_list_gold)) and (len(span_list_pred) == len(span_list_gold)): self.correct['where'] = 1 if (len(col_list_pred) == len(col_list_gold)) and ([ it[0] for it in where_pred ] == [it[0] for it in where_gold]): self.correct['col'] = 1 if (len(op_list_pred) == len(op_list_gold)) and ([ it[1] for it in where_pred ] == [it[1] for it in where_gold]): self.correct['lay'] = 1 if (len(span_list_pred) == len(span_list_gold)) and ([ it[2] for it in where_pred ] == [it[2] for it in where_gold]): self.correct['span'] = 1 if all((self.correct[it] == 1 for it in ('agg', 'sel', 'where'))): self.correct['all'] = 1 # execution table_id = gold['table_id'] ans_gold = '0' ans_pred = '1' if engine is not None: ans_gold = engine.execute_query(table_id, Query.from_dict(sql_gold), lower=True) try: sql_pred = { 'agg': self.agg, 'sel': self.sel, 'conds': self.recover_cond_to_gloss(gold) } ans_pred = engine.execute_query(table_id, Query.from_dict(sql_pred), lower=True) except Exception as e: ans_pred = repr(e) else: ans_gold = '0' ans_pred = '1' if set(ans_gold) == set(ans_pred): self.correct['exe'] = 1 error_case = {} if split == 'finaltest': #if self.correct['where'] != 1: if True: error_case['sel'] = self.correct['sel'] error_case['where'] = self.correct['where'] error_case['all'] = self.correct['all'] error_case['table_id'] = gold['table_id'] error_case['question_id'] = gold['id'] error_case['question'] = gold['question']['words'] error_case['table_head'] = [ head['words'] for head in gold['table']['header'] ] #error_case['table_content'] = gold['tbl_content'] # for i in range(len(sql_gold['conds'])): # sql_gold['conds'][i][0] = ( # sql_gold['conds'][i][0], gold['table']['header'][sql_gold['conds'][i][0]]['words'], # gold['tbl_content'][sql_gold['conds'][i][0]]) error_case['gold'] = sql_gold error_case['predict'] = { 'agg': self.agg.item(), 'sel': self.sel.item(), 'conds': self.print_recover_cond_to_gloss(gold) } # error_case['exe result']=self.correct['exe'] error_case['BIO'] = [(x.item(), y) for x, y in zip( list(self.BIO), list(gold['question']['words']))] error_case['BIO_col'] = self.BIO_col.tolist() return error_case
if __name__ == "__main__": lines = 0 file_type = sys.argv[1] header_mapping_dictionary = defaultdict(lambda: []) for idx, line in enumerate(open(file_type + ".tables.jsonl")): tableJson = json.loads(line) if len(header_mapping_dictionary[tableJson["id"]]) == 0: header_mapping_dictionary[tableJson["id"]] = "{\"header\": " + str( tableJson["header"]) + "}" for line in open(file_type + "_tok.jsonl"): linesJson = json.loads(line) jsonQuery = linesJson["sql"] text = Query(jsonQuery["sel"], jsonQuery["agg"], conditions=jsonQuery["conds"]).__repr__() sql_pattern = '''SELECT\s*(?P<agg>^(\s+).*)?\s*(?P<select_col_1>^(\s+).*)|(?P<select_col>^(\s+).*)\s*FROM\s*(?P<table_name>^(\s+).*)\s*WHERE\s*(?P<where_col>[\w]*)\s*(?P<where_op>[\S><&|+\-\%*!=]{1,2})\s*(?P<where_cond>[\w\s]*)''' #text = "SELECT col1 FROM tblName WHERE col2 op val" sql_pattern = '''SELECT\s+(?P<agg>COUNT|SUM|AVG|MAX|MIN)?\s+(?P<select_col>[^\s]+)\s+FROM\s+(?P<table_name>[^\s]+)\s*WHERE\s+(?P<where_col>[\w]*)\s+(?P<where_op>[\S><&|+\-\%*!=]{1,2})\s+(?P<where_cond>.*)''' matches = re.search(sql_pattern, text, re.IGNORECASE) index = text.find("table") first_part = text[0:index] second_part = text[index + 5:] if matches: agg = matches.group("agg").strip() if matches.group( "agg") else None
opts = Options() fill_from_args(opts) for split in ['train', 'dev', 'test']: orig = os.path.join(opts.data_dir, f'{split}.jsonl') db_file = os.path.join(opts.data_dir, f'{split}.db') ans_file = os.path.join(opts.data_dir, f"{split}_ans.jsonl.gz") tbl_file = os.path.join(opts.data_dir, f"{split}.tables.jsonl") engine = DBEngine(db_file) exact_match = [] with open(orig) as fs, write_open(ans_file) as fo: grades = [] for ls in tqdm(fs, total=count_lines(orig)): eg = json.loads(ls) sql = eg['sql'] qg = Query.from_dict(sql, ordered=False) gold = engine.execute_query(eg['table_id'], qg, lower=True) assert isinstance(gold, list) #if len(gold) != 1: # print(f'for {sql} : {gold}') eg['answer'] = gold eg['rowids'] = engine.execute_query_rowid(eg['table_id'], qg, lower=True) # CONSIDER: if it is not an agg query, somehow identify the particular cell fo.write(json.dumps(eg) + '\n') convert(jsonl_lines(ans_file), jsonl_lines(tbl_file), os.path.join(opts.data_dir, f"{split}_agg.jsonl.gz"), skip_aggregation=False)
parser.add_argument("--db_file") parser.add_argument("--pred_file") parser.add_argument("--ordered", action='store_true') args = parser.parse_args() engine = DBEngine(args.db_file) ex_acc_list = [] lf_acc_list = [] with open(args.source_file) as sf, open(args.pred_file) as pf: for source_line, pred_line in tqdm(zip(sf, pf), total=count_lines(args.source_file)): # line별 정답과 예측 샘플 가져오기 gold_example = json.loads(source_line) pred_example = json.loads(pred_line) # 정답 샘플 lf, ex 구하기 lf_gold_query = Query.from_dict(gold_example['sql'], ordered=args.ordered) ex_gold = engine.execute_query(gold_example['table_id'], lf_gold_query, lower=True) # error가 아닌 경우 예측 샘플 lf, ex 구하기 lf_pred_query = None ex_pred = pred_example.get('error', None) if not ex_pred: try: lf_pred_query = Query.from_dict(pred_example['query'], ordered=args.ordered) ex_pred = engine.execute_query(gold_example['table_id'], lf_pred_query, lower=True) except Exception as e: ex_pred = repr(e) # lf, ex의 gold, pred 매칭결과 구하기 ex_acc_list.append(ex_pred == ex_gold) lf_acc_list.append(lf_pred_query == lf_gold_query) # query의 __eq__를 호출
import pytest from lib.query import Query, LogicalOperator from tests.test_lib.conftest import does_not_raise @pytest.mark.parametrize('queries, expected_result, expected_exception', ( ( [ Query( field_name='semester', value='1', operator=LogicalOperator.OR, ), ], {1, 2, 3}, does_not_raise(), ), ( [ Query( field_name='last_name', value='one', operator=LogicalOperator.OR, ), Query( field_name='last_name', value='two', operator=LogicalOperator.OR, ), ],
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from lib.query import Query from lib.dbengine import DBEngine import pytorch if __name__ == '__main__': for split in ['train', 'dev', 'test']: print('checking {}'.format(split)) engine = DBEngine('data/{}.db'.format(split)) n_lines = 0 with open('data/{}.jsonl'.format(split)) as f: for l in f: n_lines += 1 with open('data/{}.jsonl'.format(split)) as f: for l in tqdm(f, total=n_lines): d = json.loads(l) query = Query.from_dict(d['sql']) # make sure it's executable result = engine.execute_query(d['table_id'], query) if result: for a, b, c in d['sql']['conds']: if str(c).lower() not in d['question'].lower(): raise Exception( 'Could not find condition {} in question {} for query {}' .format(c, d['question'], query)) else: raise Exception( 'Query {} did not execute to a valid result'.format( query))
line_split = line.split('|') source_idx = int(line_split[-1]) if len(line_split[:-1]) > 1: seq_in = '|'.join(line_split[:-1]) else: seq_in = line_split[0] json_line = {} json_line['sample_id'] = source_idx json_line['query'] = "" json_line['error'] = "" seq_in = seq_in.decode('utf-8') seq_in = seq_in.split(' ') if '_EOS' in seq_in: seq_in = seq_in[:seq_in.index('_EOS')] json_line['seq'] = ' '.join(seq_in) sj = json.loads(data_source[source_idx]) try: output_with_gloss = get_output_with_gloss( seq_in, sj['seq_input']) q = Query.from_sequence(output_with_gloss, sj['table']).to_dict() except Exception as e: json_line['error'] = repr(e) else: json_line['query'] = q orig_list.append(json_line) sortedlist = sorted(orig_list, key=lambda k: k['sample_id']) with open(args.dout, 'wt') as fo: for item in sortedlist: fo.write(json.dumps(item) + '\n')