Esempio n. 1
0
def process_examples(fexample, ftable, fout):
    print('annotating {}'.format(fexample))
    with open(fexample) as fs, open(ftable) as ft, open(fout, 'wt') as fo:
        print('loading tables')
        tables = {}
        for line in tqdm(ft, total=count_lines(ftable)):
            d = json.loads(line)
            tables[d['id']] = d
        print('loading examples')
        n_written = 0
        for line in tqdm(fs, total=count_lines(fexample)):
            d = json.loads(line)
            a = annotate_example(d, tables[d['table_id']])
            if not is_valid_example(a):
                # raise Exception(str(a))
                print('Invalid example: {}'.format(str(a)))
                continue

            gold = Query.from_tokenized_dict(a['query'])
            reconstruct = Query.from_sequence(a['seq_output'], a['table'], lowercase=True)
            if gold.lower() != reconstruct.lower():
                raise Exception ('Expected:\n{}\nGot:\n{}'.format(gold, reconstruct))
            fo.write(json.dumps(a) + '\n')
            n_written += 1
        print('wrote {} examples'.format(n_written))
Esempio n. 2
0
def build_sql_train(filename):
    agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
    cond_ops = ['=', '>', '<', 'OP']
    syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']
    train_file = open(filename,"w")
    db = MySQLdb.connect(host="localhost",  # your host
                     user="******",       # username
                     passwd="cfhoCPkr",     # password
                     db="wikisql")   # name of the database
    data = []
    cur = db.cursor()
    table_header={}
    
    with open(OPTIONS.train_table_file) as tb:
        grades = []
        for ls in tqdm(tb, total=count_lines(OPTIONS.train_table_file)):
            eg = json.loads(ls)
            table_header[eg['id']] = [re.split(b'[\/|,|;|\s]\s*',h.encode('utf8')) for h in eg['header']]
    with open(OPTIONS.train_source_file) as fs:
        grades = []
        for ls in tqdm(fs, total=count_lines(OPTIONS.train_source_file)):
            eg = json.loads(ls)
            table_id = eg['table_id']
            cq =''
            i=0
            for cond in eg['sql']['conds']:
                col = cond[0]
                op = cond[1] 
                if isinstance(cond[2],str):
                    cval = cond[2].encode('utf8')
                elif isinstance(cond[2],unicode):
                    cval=unicodedata.normalize('NFKD', cond[2]).encode('ascii','ignore')
                else:
                    cval = cond[2]
                cq= cq+('col'+str(col))+cond_ops[op]+'\''+str(cval)+'\' '
                if i< len(eg['sql']['conds'])-1:
                    i = i+1
                    cq = cq + 'AND '

            if(eg['sql']['agg']>0):
                agg = agg_ops[eg['sql']['agg']]
                query = 'SELECT '+ agg+'(col'+str(eg['sql']['sel'])+')'+ ' FROM table_{}'.format((table_id.replace('-','_'))) + ' WHERE '+cq
            else:
                query = 'SELECT '+ ('col'+str(eg['sql']['sel']))+ ' FROM table_{}'.format((table_id.replace('-','_'))) + ' WHERE '+cq
            question = eg['question']
            q = ''
            for t in re.split(r'(,|;|/|/\|!|@|#|$|\"|\(|\)|`|=|\s)\s*',question):
                if isinstance(t,str):
                    q = q+''+t.encode('utf8')
                elif isinstance(t,unicode):
                    q =q+ ''+unicodedata.normalize('NFKD', t).encode('ascii','ignore').replace(u'\u0101','a')
                else:
                    q = q
            question = re.sub('\t+','\s',q)
            
            train_file.write(question+'\t'+query+'\n')
            #gold = cur.execute(query)
    train_file.close()
Esempio n. 3
0
def process_tables(ftable, fout=None):
    # join words together
    def _join_words(entry):
        result = ""
        for i in range(len(entry["words"])):
            result += entry["words"][i] + (entry["after"][i] if entry["after"][i] is not " " else "^")
        return result

    tables = []
    with open(ftable) as ft:
        for line in tqdm(ft, total=count_lines(ftable)):
            raw_table = json.loads(line)
            try:
                table = {
                    "id": raw_table["id"],
                    "header": [ _join_words(annotate(h)) for h in raw_table["header"]],
                    "rows": [[ _join_words(annotate(str(tok))) for tok in l] for l in raw_table["rows"]]
                }
                tables.append(table)
            except:
                print(line)
                break

    if fout is not None:
        with open(fout, "w+") as fo:
            for table in tables:
                fo.write(json.dumps(table) + "\n")
Esempio n. 4
0
def eval_one_qelos(db_file, pred_file, source_file):
    engine = DBEngine(db_file)
    exact_match = []
    with open(source_file) as fs, open(pred_file) as fp:
        grades = []
        for ls, lp in tqdm(zip(fs, fp), total=count_lines(source_file)):
            eg = json.loads(ls)
            ep = json.loads(lp)
            qg = Query.from_dict(eg['sql'])
            gold = engine.execute_query(eg['table_id'], qg, lower=True)
            qp = None
            try:
                qp = Query.from_dict(ep)
                pred = engine.execute_query(eg['table_id'], qp, lower=True)
            except Exception as e:
                pred = repr(e)
            correct = pred == gold
            match = qp == qg
            grades.append(correct)
            exact_match.append(match)
        result = {
            'ex_accuracy': sum(grades) / len(grades),
            'lf_accuracy': sum(exact_match) / len(exact_match),
            }
        return result
Esempio n. 5
0
def load_seq2sql_file(train_file):
    i = 0
    data = []
    with open(train_file) as lf:
        for ls in tqdm(lf, total=count_lines(train_file)):
            data.append(ast.literal_eval(ls.replace(' ', '')))
            i = i + 1
    return data
Esempio n. 6
0
def load_word2vec(file_loc):
    with open(file_loc) as w2v_file:
        grades = []
        for ls in tqdm(w2v_file, total=count_lines(file_loc)):
            lv = ls.split(' ')
            word = lv[0]
            vec = []
            for dim in range(len(lv) - 1):
                vec.append(lv[dim + 1])
            w2v[word] = vec
Esempio n. 7
0
def main(argv):
    del argv  # Unused.

    db_file = join(FLAGS.data_root, FLAGS.db_file)
    parsed_std_sql_file = join(FLAGS.data_root, FLAGS.parsed_std_sql_file)
    parsed_pred_sql_file = join(FLAGS.data_root, FLAGS.parsed_pred_sql_file)

    engine = DBEngine(db_file)
    exact_match = []

    with open(parsed_std_sql_file) as fs, open(parsed_pred_sql_file) as fp:
        grades = []
        for ls, lp in tqdm(zip(fs, fp),
                           total=count_lines(parsed_std_sql_file)):
            eg = json.loads(ls)
            ep = json.loads(lp)

            try:
                qg = Query.from_dict(eg['sql'])
                gold = engine.execute_query(eg['table_id'], qg, lower=True)
            except Exception as e:
                gold = repr(e)

            #pred = ep['error']
            qp = None
            #if not ep['error']:
            if True:
                try:
                    qp = Query.from_dict(ep['sql'])
                    pred = engine.execute_query(eg['table_id'], qp, lower=True)
                except Exception as e:
                    pred = repr(e)
            correct = pred == gold
            match = qp == qg
            if pred == gold and qp != qg:
                print(qp)
                print(qg)
            grades.append(correct)
            exact_match.append(match)
        print(
            json.dumps(
                {
                    'ex_accuracy': sum(grades) / len(grades),
                    'lf_accuracy': sum(exact_match) / len(exact_match),
                },
                indent=2))
Esempio n. 8
0
        def __init__(self):
            self.data_dir = ''

    opts = Options()
    fill_from_args(opts)

    for split in ['train', 'dev', 'test']:
        orig = os.path.join(opts.data_dir, f'{split}.jsonl')
        db_file = os.path.join(opts.data_dir, f'{split}.db')
        ans_file = os.path.join(opts.data_dir, f"{split}_ans.jsonl.gz")
        tbl_file = os.path.join(opts.data_dir, f"{split}.tables.jsonl")
        engine = DBEngine(db_file)
        exact_match = []
        with open(orig) as fs, write_open(ans_file) as fo:
            grades = []
            for ls in tqdm(fs, total=count_lines(orig)):
                eg = json.loads(ls)
                sql = eg['sql']
                qg = Query.from_dict(sql, ordered=False)
                gold = engine.execute_query(eg['table_id'], qg, lower=True)
                assert isinstance(gold, list)
                #if len(gold) != 1:
                #    print(f'for {sql} : {gold}')
                eg['answer'] = gold
                eg['rowids'] = engine.execute_query_rowid(eg['table_id'],
                                                          qg,
                                                          lower=True)
                # CONSIDER: if it is not an agg query, somehow identify the particular cell
                fo.write(json.dumps(eg) + '\n')

        convert(jsonl_lines(ans_file),
Esempio n. 9
0
def main():
    for split in ['test', 'dev']:

        fsplit = os.path.join(data_path, split) + '.jsonl'
        ftable = os.path.join(data_path, split) + '.tables.jsonl'
        with open(fsplit) as fs, open(ftable) as ft:
            print('loading tables')
            tables = {}

            for line in tqdm(ft, total=count_lines(ftable)):
                d = json.loads(line)
                tables[d['id']] = d
            print('loading tables done.')

        with open(data_path+'%s_infer.txt'%split, 'r') as in_file, \
            open(save_path+'%s_infer.txt'%split, 'w') as out_file,\
            open(save_path+'%s_ground_truth.txt'%split,'r') as true0,\
            open(save_path+'%s_ground_truth_mark.txt'%split,'r') as true,\
            open(save_path+'%s_SQL2tableid.txt'%split,'r') as id_file,\
            open(save_path+'%s_sym_pairs.txt'%split,'r') as pairs_file:

            lines = in_file.readlines()
            gt_lines = true0.readlines()
            gt_mark_lines = true.readlines()
            ids = id_file.readlines()
            pairs = pairs_file.readlines()

            idx = 1
            for line, gt, gt_mark, t_id, pair in zip(lines, gt_lines,
                                                     gt_mark_lines, ids,
                                                     pairs):

                t_id = t_id.replace('\n', '')

                new = ''

                inf = line

                rows = tables[t_id]['rows']
                rows = np.asarray(rows)
                fs = tables[t_id]['header']

                all_f = []
                for f in fs:
                    f = _preclean(str(f))
                    all_f.append(f)

                for row in rows:
                    for i in range(len(fs)):
                        f = _preclean(str(row[i]))
                        all_f.append(f)

                #all fields (including values) are sorted in dscending order
                all_f.sort(key=len, reverse=True)

                added = [
                ]  # collect phrases with SYMBOL word (e.g.,'MAX') in it

                for f in all_f:
                    PASS = False
                    ADD = False  #SYMBOL in it or not a valid sub-phrase

                    for a in added:
                        if f in a:
                            PASS = True
                            break

                    if len(f.split()) == 1 and f not in line.split():
                        PASS = True

                    # whether SYMBOL is in it
                    for s in symbols:
                        if s in f.split():
                            ADD = True
                            break

                    if f in line and PASS == False and ADD and f != 'where':
                        added.append(f)

                # added phrases are replaced away
                for iden, phrase in enumerate(added):
                    line = line.replace(phrase, '<p' + str(iden) + '>')

                tokens = line.split()
                for i in range(len(tokens)):
                    token = tokens[i]

                    if i == 0:
                        if token not in symbols:
                            new += '\"'

                    elif tokens[i - 1] in symbols and token not in symbols:
                        new += '\"'

                    new += token

                    if i == len(tokens) - 1:
                        if token not in symbols:
                            new += '\"'
                    elif tokens[i + 1] in symbols and token not in symbols:
                        new += '\"'

                    new += ' '

                new = new.strip()

                # replace back
                for iden, phrase in enumerate(added):
                    new = new.replace('<p' + str(iden) + '>', phrase)

                gt_mark = gt_mark.replace('\n', '')
                if gt == inf and gt_mark != (new):

                    print(added)
                    print('inf:' + line)
                    print('gt:' + gt)
                    print('gt_mark:' + gt_mark + '.')
                    print('inf_mark:' + new + '.')

                out_file.write(new + '\n')
                idx += 1
Esempio n. 10
0
from lib.dbengine import DBEngine
from lib.query import Query
from lib.common import count_lines

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('source_file', help='source file for the prediction')
    parser.add_argument('db_file', help='source database for the prediction')
    parser.add_argument('pred_file', help='predictions by the model')
    args = parser.parse_args()

    engine = DBEngine(args.db_file)
    exact_match = []
    with open(args.source_file) as fs, open(args.pred_file) as fp:
        grades = []
        for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.pred_file)):
            eg = json.loads(ls)
            ep = json.loads(lp)
            qg = Query.from_dict(eg['sql'])
            #print qg
            gold = engine.execute_query(eg['table_id'], qg, lower=True)
            pred = ep['error']
            qp = None
            if not ep['error']:
                try:
                    qp = Query.from_dict(ep['query'])
                    pred = engine.execute_query(eg['table_id'], qp, lower=True)
                except Exception as e:
                    pred = repr(e)
            correct = pred == gold
            match = qp == qg
Esempio n. 11
0
def data_(inv_, args):
    agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
    cond_ops = ['=', '>', '<', 'OP']
    syms = [
        'SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION',
        'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS'
    ]
    db = MySQLdb.connect(
        host="localhost",  # your host
        user="******",  # username
        passwd="cfhoCPkr",  # password
        db="wikisql")  # name of the database
    data = []
    cur = db.cursor()
    table_file = args.source_file.split(
        '.')[0] + '.tables.' + args.source_file.split('.')[1]
    table_header = {}

    with open(table_file) as tb:
        grades = []
        for ls in tqdm(tb, total=count_lines(table_file)):
            eg = json.loads(ls)
            table_header[eg['id']] = [
                re.split(b'[\/|,|;|\s]\s*', h.encode('utf8'))
                for h in eg['header']
            ]
    with open(args.source_file) as fs:
        grades = []
        for ls in tqdm(fs, total=count_lines(args.source_file)):
            eg = json.loads(ls)
            table_id = eg['table_id']
            cq = ''
            i = 0
            for cond in eg['sql']['conds']:
                col = cond[0]
                op = cond[1]
                if isinstance(cond[2], str):
                    cval = cond[2].encode('utf8')
                else:
                    cval = str(cond[2]).encode('utf8')
                cq = cq + ('col' +
                           str(col)) + cond_ops[op] + '\'' + str(cval) + '\' '
                if i < len(eg['sql']['conds']) - 1:
                    i = i + 1
                    cq = cq + 'AND '

            if (eg['sql']['agg'] > 0):
                agg = agg_ops[eg['sql']['agg']]
                query = 'SELECT ' + agg + '(col' + str(
                    eg['sql']['sel']) + ')' + ' FROM table_{}'.format(
                        (table_id.replace('-', '_'))) + ' WHERE ' + cq
            else:
                query = 'SELECT ' + (
                    'col' + str(eg['sql']['sel'])) + ' FROM table_{}'.format(
                        (table_id.replace('-', '_'))) + ' WHERE ' + cq
            query = query.replace(str(u'\u2013'.encode('utf8')), '-').replace(
                '\xc4\x81',
                'a').replace('\xe1\xb9\x83', 'm').replace('[[', '').replace(
                    ',', '').replace('||', '').replace(']]',
                                                       '').replace('|', '')
            #gold = cur.execute(query)
            ## TO DO .. compare pred with gold
            hl = []  # header list
            v = [
                r for r in table_header[table_id.replace('table_', '').replace(
                    '_', '-')]
            ]
            for a in v:
                for ap in a:
                    t = re.sub(r'(\(|\)|,|;|!)', '', str.lower(str(ap)))
                    if t in inv_:
                        hl.append(inv_[t])

            qv = re.split(r'[\(|\)|!|,|?|\-.|\\|\/|\{|\}|\[|\]|#|$|\&|\s]\s*',
                          str.lower(str(eg['question'].encode('UTF8'))))
            qvu = []
            for w in qv:
                if w in inv_ and w not in qvu:
                    qvu.append(w)
            qs = [q for q in re.split(r'(\(|\)|\'|\=|\/|\s)\s*', query)]
            for q in qs:
                q = str.lower(q)
                if q not in inv_:
                    inv_[q] = len(inv_)
                    id_[len(inv_)] = q
            input_ = hl + [inv_[w]
                           for w in qvu] + [inv_[str.lower(s)] for s in syms]
            data.append((input_, [
                inv_[str.lower(q)]
                for q in re.split(r'[\(|\)|\'|=|/|\s]\s*', query)
            ]))
    return data
Esempio n. 12
0
def main():
    parser = ArgumentParser()
    parser.add_argument('--din', default=data_path, help='data directory')
    parser.add_argument('--dout', default='annotated', help='output directory')
    args = parser.parse_args()

    if not os.path.isdir(args.dout):
        os.makedirs(args.dout)

    for split in ['dev']:
        #for split in ['train', 'test', 'dev']:
        with open(save_path+'%s.qu'%split, 'w') as qu_file, open(save_path+'%s.lon'%split, 'w') as lon_file, \
            open(save_path+'%s.out'%split, 'w') as out, open(save_path+'%s_sym_pairs.txt'%split, 'w') as sym_file, \
            open(save_path+'%s_ground_truth.txt'%split, 'w') as S_file:

            fsplit = os.path.join(args.din, split) + '.jsonl'
            ftable = os.path.join(args.din, split) + '.tables.jsonl'

            with open(fsplit) as fs, open(ftable) as ft:
                print('loading tables')
                tables = {}
                for line in tqdm(ft, total=count_lines(ftable)):
                    d = json.loads(line)
                    tables[d['id']] = d
                print('loading tables done.')

                print('loading examples')
                n, acc, acc_pair, acc_all, error = 0, 0, 0, 0, 0
                target = -1

                ADD_FIELDS = False
                step2 = True

                for line in tqdm(fs, total=count_lines(fsplit)):
                    ADD_TO_FILE = True
                    d = json.loads(line)
                    Q = d['question']

                    rows = tables[d['table_id']]['rows']
                    rows = np.asarray(rows)
                    fs = tables[d['table_id']]['header']

                    all_fields = []
                    for f in fs:
                        all_fields.append(_preclean(f))
                    # all fields are sorted by length in descending order
                    # for string match purpose
                    all_fields = sorted(all_fields, key=len, reverse=True)

                    smap = defaultdict(list)  #f2v
                    reverse_map = defaultdict(list)  #v2f
                    for row in rows:
                        for i in range(len(fs)):
                            cur_f = _preclean(str(fs[i]))
                            cur_row = _preclean(str(row[i]))
                            smap[cur_f].append(cur_row)
                            if cur_f not in reverse_map[cur_row]:
                                reverse_map[cur_row].append(cur_f)

                    #----------------------------------------------------------
                    # all values are sorted by length in descending order
                    # for string match purpose
                    keys = sorted(reverse_map.keys(), key=len, reverse=True)

                    Q = _preclean(Q)
                    Q_ori = Q

                    #####################################
                    ########## Annotate question ########
                    #####################################
                    candidates, cond_fields = _match_pairs(
                        Q, Q_ori, keys, reverse_map)
                    Q_head, head2partial = _match_head(Q, Q_ori, smap,
                                                       all_fields, cond_fields)

                    Q, Qpairs = annotate_Q(Q, Q_ori, Q_head, candidates,
                                           all_fields, head2partial, n, target)
                    qu_file.write(Q + '\n')

                    validation_pairs = copy.copy(Qpairs)
                    validation_pairs.append((Q_head, '<f0>', 'head'))
                    for i, f in enumerate(all_fields):
                        validation_pairs.append((f, '<c' + str(i) + '>', 'c'))

                    #####################################
                    ########## Annotate SQL #############
                    #####################################
                    q_sent = Query.from_dict(d['sql'])
                    S, col_names, val_names = q_sent.to_sentence(
                        tables[d['table_id']]['header'], rows,
                        tables[d['table_id']]['types'])
                    S = _preclean(S)
                    S_ori = S

                    S_noparen = q_sent.to_sentence_noparenthesis(
                        tables[d['table_id']]['header'], rows,
                        tables[d['table_id']]['types'])
                    S_noparen = _preclean(S_noparen)

                    col_names = [_preclean(col_name) for col_name in col_names]
                    val_names = [_preclean(val_name) for val_name in val_names]

                    HEAD = col_names[-1]
                    S_head = _preclean(HEAD)

                    #annotate for SQL
                    name_pairs = []
                    for col_name, val_name in zip(col_names, val_names):
                        if col_name == val_name:
                            name_pairs.append([_preclean(col_name), 'true'])
                        else:
                            name_pairs.append(
                                [_preclean(col_name),
                                 _preclean(val_name)])

                    # sort to compare with candidates
                    name_pairs.sort(key=lambda x: x[1])
                    new_name_pairs = [[
                        '<f' + str(i + 1) + '>', '<v' + str(i + 1) + '>'
                    ] for i, (field, value) in enumerate(name_pairs)]

                    # only annotate S while identified (f,v) pairs are right
                    if _equal(name_pairs, candidates):
                        pairs = []
                        for (f, v), (new_f,
                                     new_v) in zip(name_pairs, new_name_pairs):
                            pairs.append((f, new_f, 'f'))
                            pairs.append((v, new_v, 'v'))
                        # sort (word,symbol) pairs by length in descending order
                        pairs.sort(key=lambda x: 100 - len(x[0]))

                        for p, new_p, t in pairs:
                            cp = _backslash(p)

                            if new_p in Q:
                                if t == 'v':
                                    S = S.replace(p + ' )', new_p + ' )')
                                if t == 'f':
                                    S = re.sub(
                                        '\( ' + cp + ' (equal|less|greater)',
                                        '( ' + new_p + r' \1', S)

                    # only annotate S while identified HEAD is right
                    if S_head == Q_head and '<f0>' in Q:
                        S = S.replace(S_head, '<f0>')

                    # annote unseen fields
                    if ADD_FIELDS:
                        for i, f in enumerate(all_fields):
                            cf = _backslash(f)
                            S = re.sub('(\s|^)' + cf + '(\s|$|s)',
                                       ' <c' + str(i) + '> ', S)

                    S = _clean(S)
                    lon_file.write(S + '\n')

                    ###############################
                    ######### VALIDATION ##########
                    ###############################
                    recover_S = S
                    for word, sym, t in validation_pairs:
                        recover_S = recover_S.replace(sym, word)
                        sym_file.write(sym + '=>' + word + '<>')
                    sym_file.write('\n')

                    S_file.write(S_noparen + '\n')
                    #------------------------------------------------------------------------
                    if _equal(name_pairs, candidates):
                        acc_pair += 1

                    if Q_head == S_head:
                        acc += 1

                    if _equal(name_pairs, candidates) and Q_head == S_head:
                        acc_all += 1

                    full_anno = True
                    for s in S.split():
                        if s[0] != '<' and s not in [
                                '(', ')', 'where', 'less', 'greater', 'equal',
                                'max', 'min', 'count', 'sum', 'avg', 'and',
                                'true'
                        ]:
                            error += 1
                            full_anno = False
                            break

                    if False and not (_equal(name_pairs, candidates)
                                      and head == col_names[-1]):
                        print('--------' + str(n) + '-----------')
                        print(Q_ori)
                        print(Q)
                        print(S_ori)
                        print(S)
                        print('head:')
                        print(head)
                        print(head2partial)
                        print('true HEAD')
                        print(Q_head)
                        print('fields:')
                        print(candidates)
                        print(p2partial)
                        print('true fields')
                        print(name_pairs)

                    n += 1

                print('total number of examples:' + str(n))
                print('fully snnotated:' + str(1 - error * 1.0 / n))
                print('accurate all percent:' + str(acc_all * 1.0 / n))
                print('accurate HEAD match percent:' + str(acc * 1.0 / n))
                print('accurate fields pair match percent:' +
                      str(acc_pair * 1.0 / n))
Esempio n. 13
0
def _gen_files(save_path=save_path):
    """
    Prepare files used for Machine Comprehension Binary Classifier
    """
    for split in ['train', 'test', 'dev']:
        print('------%s-------' % split)
        n = 0
        fsplit = os.path.join(data_path, split) + '.jsonl'
        ftable = os.path.join(data_path, split) + '.tables.jsonl'
        """
	test.txt original column content w/o truncate
	test_model.txt column name truncate or pad to length 3
	"""
        with open(fsplit) as fs, open(ftable) as ft, \
            open(os.path.join(save_path, split+'.txt'), mode='w') as fw, \
            open(os.path.join(save_path, '%s_model.txt'%split), mode='w') as fw_model, \
   open(os.path.join(save_path, split+'.lon'), mode='w') as fsql, \
            open(os.path.join(save_path, '%s.ori.qu'%split), 'w') as qu_file, \
            open(os.path.join(save_path, '%s.ori.lon'%split), 'w') as lon_file:
            print('loading tables...')
            tables = {}
            for line in tqdm(ft, total=count_lines(ftable)):
                d = json.loads(line)
                tables[d['id']] = d
            print('loading tables done.')

            print('loading examples')
            f2v_all, v2f_all = [], []
            for line in tqdm(fs, total=count_lines(fsplit)):

                d = json.loads(line)
                Q = d['question']
                Q = _preclean(Q).replace('\t', '')

                qu_file.write(Q + '\n')

                q_sent = Query.from_dict(d['sql'])
                rows = tables[d['table_id']]['rows']
                S, col_names, val_names = q_sent.to_sentence(
                    tables[d['table_id']]['header'], rows,
                    tables[d['table_id']]['types'])
                S = _preclean(S)

                lon_file.write(S + '\n')

                rows = np.asarray(rows)
                fs = tables[d['table_id']]['header']
                all_fields = [_preclean(f) for f in fs]
                # all fields are sorted by length in descending order
                # for string match purpose
                headers = sorted(all_fields, key=len, reverse=True)

                f2v = defaultdict(list)  #f2v
                v2f = defaultdict(list)  #v2f
                for row in rows:
                    for i in range(len(fs)):
                        cur_f = _preclean(str(fs[i]))
                        cur_row = _preclean(str(row[i]))
                        #cur_f = cur_f.replace('\u2003',' ')
                        f2v[cur_f].append(cur_row)
                        if cur_f not in v2f[cur_row]:
                            v2f[cur_row].append(cur_f)
                f2v_all.append(f2v)
                v2f_all.append(v2f)

                #####################################
                ########## Annotate SQL #############
                #####################################
                q_sent = Query.from_dict(d['sql'])
                S, col_names, val_names = q_sent.to_sentence(
                    tables[d['table_id']]['header'], rows,
                    tables[d['table_id']]['types'])
                S = _preclean(S)

                S_noparen = q_sent.to_sentence_noparenthesis(
                    tables[d['table_id']]['header'], rows,
                    tables[d['table_id']]['types'])
                S_noparen = _preclean(S_noparen)

                col_names = [_preclean(col_name) for col_name in col_names]
                val_names = [_preclean(val_name) for val_name in val_names]

                HEAD = col_names[-1]
                S_head = _preclean(HEAD)

                #annotate for SQL
                name_pairs = []
                for col_name, val_name in zip(col_names, val_names):
                    if col_name == val_name:
                        name_pairs.append([_preclean(col_name), 'true'])
                    else:
                        name_pairs.append(
                            [_preclean(col_name),
                             _preclean(val_name)])

                # sort to compare with candidates
                name_pairs.sort(key=lambda x: x[1])
                fsql.write('#%d\n' % n)
                fw.write('#%d\n' % n)

                for f in col_names:
                    fsql.write(S.replace(f, '[' + f + ']') + '\n')
                    f = _truncate(f, END='<bos>', PAD='<pad>', max_len=-1)
                    s = (Q + '\t' + f + '\t 1')
                    assert len(s.split('\t')) == 3
                    fw.write(s + '\n')

                for f in [f for f in headers if f not in col_names]:
                    f = _truncate(f, END='<bos>', PAD='<pad>', max_len=-1)
                    s = (Q + '\t' + f + '\t 0')
                    assert len(s.split('\t')) == 3
                    fw.write(s + '\n')
                    fsql.write(S + '\n')

                #if '\u2003' in Q:
                #    print('u2003: '+Q)
                #if '\xa0' in Q:
                #    print(n)
                #    print('xa0: '+Q)
                #    print(S)
                for f in col_names:
                    f = f.replace(u'\xa0', u' ').replace('\t', '')
                    Q = Q.replace(u'\xa0', u' ').replace('\t', '')
                    f = _truncate(f, END='bos', PAD='pad', max_len=3)
                    s = (Q + '\t' + f + '\t 1')
                    assert len(s.split('\t')) == 3
                    fw_model.write(s + '\n')

                for f in [f for f in headers if f not in col_names]:
                    f = f.replace(u'\xa0', u' ').replace('\t', '')
                    Q = Q.replace(u'\xa0', u' ').replace('\t', '')
                    f = _truncate(f, END='bos', PAD='pad', max_len=3)
                    s = (Q + '\t' + f + '\t 0')
                    assert len(s.split('\t')) == 3
                    fw_model.write(s + '\n')

                n += 1
            fsql.write('#%d\n' % n)
            fw.write('#%d\n' % n)

        scipy.savez(os.path.join(save_path, '%s_dict.npz' % split),
                    f2v_all=f2v_all,
                    v2f_all=v2f_all)
        print('num of records:%d' % n)
Esempio n. 14
0
def index_data(data_file):
    id_ = {
        0: '',
        1: '<EOS>',
        2: 'max',
        3: 'min',
        4: 'count',
        5: 'sum',
        6: 'avg',
        7: '=',
        8: '>',
        9: '<',
        10: 'op',
        11: 'select',
        12: 'where',
        13: 'and',
        14: 'col',
        15: 'table',
        16: 'caption',
        17: 'page',
        18: 'section',
        19: 'cond',
        20: 'question',
        21: 'agg',
        22: 'aggops',
        23: 'condops',
        24: 'col0',
        25: 'col1',
        26: 'col2',
        27: 'col3',
        28: 'col4',
        29: 'col5',
        30: 'col6',
        31: 'col7',
        32: 'col8',
        33: 'col9',
        34: 'col10',
        35: 'col11',
        36: 'col12',
        37: 'col13',
        38: 'col14',
        39: 'col15',
        40: 'svaha'
    }
    inv_ = {v: k for k, v in id_.items()}
    i = 41
    for df in data_file:
        if 'tables' in df:
            with open(df) as ind:
                for ls in tqdm(ind, total=count_lines(df)):
                    eg = json.loads(ls)
                    if (eg.get('id')):
                        table_id = 'table_{}'.format(eg['id'].replace(
                            '-', '_'))
                        if table_id not in inv_:
                            id_[i] = table_id
                            inv_[table_id] = i
                            i = i + 1

        else:
            with open(df) as ind:
                for ls in tqdm(ind, total=count_lines(df)):
                    for t in re.split(r'(,|;|/|/\|!|@|#|$|\"|\(|\)|`|=|\s)\s*',
                                      ls):
                        w = str.lower(t)
                        if w not in inv_:
                            id_[i] = w
                            inv_[w] = i
                            i = i + 1
    return id_, inv_
Esempio n. 15
0
    parser.add_argument('pred_file', help='predictions by the model')
    parser.add_argument('--ordered', action='store_true', help='whether the exact match should consider the order of conditions')

    parser.add_argument('--topk', type=int, default=3, help='k of top_k')
    
    args = parser.parse_args()

    engine = DBEngine(args.db_file)
   
    temp = []
    
    with open(args.source_file) as fs, open(args.pred_file) as fp:
        grades = []
        exact_match = []
        
        for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)):            
            eg = json.loads(ls)
            qg = Query.from_dict(eg['sql'], ordered=args.ordered)
            gold = engine.execute_query(eg['table_id'], qg, lower=True)
            
            pred_topk = []
            qp_topk = []
            
            ep = json.loads(lp)
            pred = ep.get('error', None)
            qp = None
            for i in range(args.topk):
                if not ep.get('error', None):
                    try:
                        
                        if ep['query'][str(i)]['conds'] == [[]]:
Esempio n. 16
0
    parser.add_argument('--dout', default='annotated', help='output directory')
    args = parser.parse_args()

    if not os.path.isdir(args.dout):
        os.makedirs(args.dout)

    for split in ['train', 'dev', 'test']:
        fsplit = os.path.join(args.din, split) + '.jsonl'
        ftable = os.path.join(args.din, split) + '.tables.jsonl'
        fout = os.path.join(args.dout, split) + '.jsonl'

        print('annotating {}'.format(fsplit))
        with open(fsplit) as fs, open(ftable) as ft, open(fout, 'wt') as fo:
            print('loading tables')
            tables = {}
            for line in tqdm(ft, total=count_lines(ftable)):
                d = json.loads(line)
                tables[d['id']] = d
            print('loading examples')
            n_written = 0
            for line in tqdm(fs, total=count_lines(fsplit)):
                d = json.loads(line)
                a = annotate_example(d, tables[d['table_id']])
                if not is_valid_example(a):
                    raise Exception(str(a))

                gold = Query.from_tokenized_dict(a['query'])
                reconstruct = Query.from_sequence(a['seq_output'],
                                                  a['table'],
                                                  lowercase=True)
                if gold.lower() != reconstruct.lower():
Esempio n. 17
0
from lib.dbengine import DBEngine
from lib.common import count_lines

if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--source_file")
    parser.add_argument("--db_file")
    parser.add_argument("--pred_file")
    parser.add_argument("--ordered", action='store_true')
    args = parser.parse_args()

    engine = DBEngine(args.db_file)
    ex_acc_list = []
    lf_acc_list = []
    with open(args.source_file) as sf, open(args.pred_file) as pf:
        for source_line, pred_line in tqdm(zip(sf, pf), total=count_lines(args.source_file)):
            # line별 정답과 예측 샘플 가져오기
            gold_example = json.loads(source_line)
            pred_example = json.loads(pred_line)

            # 정답 샘플 lf, ex 구하기
            lf_gold_query = Query.from_dict(gold_example['sql'], ordered=args.ordered)
            ex_gold = engine.execute_query(gold_example['table_id'], lf_gold_query, lower=True)

            # error가 아닌 경우 예측 샘플 lf, ex 구하기
            lf_pred_query = None
            ex_pred = pred_example.get('error', None)
            if not ex_pred:
                try:
                    lf_pred_query = Query.from_dict(pred_example['query'], ordered=args.ordered)
                    ex_pred = engine.execute_query(gold_example['table_id'], lf_pred_query, lower=True)