예제 #1
0
def eval_one_qelos(db_file, pred_file, source_file):
    engine = DBEngine(db_file)
    exact_match = []
    with open(source_file) as fs, open(pred_file) as fp:
        grades = []
        for ls, lp in tqdm(zip(fs, fp), total=count_lines(source_file)):
            eg = json.loads(ls)
            ep = json.loads(lp)
            qg = Query.from_dict(eg['sql'])
            gold = engine.execute_query(eg['table_id'], qg, lower=True)
            qp = None
            try:
                qp = Query.from_dict(ep)
                pred = engine.execute_query(eg['table_id'], qp, lower=True)
            except Exception as e:
                pred = repr(e)
            correct = pred == gold
            match = qp == qg
            grades.append(correct)
            exact_match.append(match)
        result = {
            'ex_accuracy': sum(grades) / len(grades),
            'lf_accuracy': sum(exact_match) / len(exact_match),
            }
        return result
예제 #2
0
def hello():
    pred_json = json.loads(request.args.get('pred'))
    gold_json = json.loads(request.args.get('gold'))
    dataset = request.args.get('dataset')

    engine = DBEngine(os.path.join(DATABASE_PATH, "{}.db".format(dataset)))

    exact_match = []
    grades = []

    for lp, ls in tqdm(zip(pred_json, gold_json), total=len(gold_json)):
        eg = ls
        ep = lp
        qg = Query.from_dict(eg['sql'])
        gold = engine.execute_query(eg['table_id'], qg, lower=True)
        pred = ep['error']
        qp = None
        if not ep['error']:
            try:
                qp = Query.from_dict(ep['query'])
                pred = engine.execute_query(eg['table_id'], qp, lower=True)
            except Exception as e:
                pred = repr(e)
        correct = pred == gold
        match = qp == qg
        grades.append(correct)
        exact_match.append(match)

    ex_accuracy = sum(grades) / len(grades)
    lf_accuracy = sum(exact_match) / len(exact_match)
    return json.dumps({"ex_accuracy": ex_accuracy, "lf_accuracy": lf_accuracy})
예제 #3
0
    def eval(self, gold, sql_gold, engine):
        self.exception_raised = False # reset this flag for a new evaluation
        if self.agg == gold['query']['agg']:
            self.correct['agg'] = 1

        if self.sel == gold['query']['sel']:
            self.correct['sel'] = 1

        op_list_pred = [op for col, op, span in self.cond]
        op_list_gold = [op for col, op, span in gold['query']['conds']]

        col_list_pred = [col for col, op, span in self.cond]
        col_list_gold = [col for col, op, span in gold['query']['conds']]

        q = gold['question']['words']
        span_list_pred = [' '.join(q[span[0]:span[1] + 1])
                          for col, op, span in self.cond]
        span_list_gold = [' '.join(span['words'])
                          for col, op, span in gold['query']['conds']]

        where_pred = list(zip(col_list_pred, op_list_pred, span_list_pred))
        where_gold = list(zip(col_list_gold, op_list_gold, span_list_gold))
        where_pred.sort()
        where_gold.sort()
        if where_pred == where_gold and (len(col_list_pred) == len(col_list_gold)) and (len(op_list_pred) == len(op_list_gold)) and (len(span_list_pred) == len(span_list_gold)):
            self.correct['where'] = 1

        if (len(col_list_pred) == len(col_list_gold)) and ([it[0] for it in where_pred] == [it[0] for it in where_gold]):
            self.correct['col'] = 1

        if (len(op_list_pred) == len(op_list_gold)) and ([it[1] for it in where_pred] == [it[1] for it in where_gold]):
            self.correct['lay'] = 1

        if (len(span_list_pred) == len(span_list_gold)) and ([it[2] for it in where_pred] == [it[2] for it in where_gold]):
            self.correct['span'] = 1

        if all((self.correct[it] == 1 for it in ('agg', 'sel', 'where'))):
            self.correct['all'] = 1

        # execution
        table_id = gold['table_id']
        ans_gold = engine.execute_query(
            table_id, Query.from_dict(sql_gold), lower=True)
        try:
            sql_pred = {'agg':self.agg, 'sel':self.sel, 'conds': self.recover_cond_to_gloss(gold)}
            ans_pred = engine.execute_query(
                table_id, Query.from_dict(sql_pred), lower=True)
        except Exception as e:

            #pdb.set_trace()

            self.exception_raised = True
            ans_pred = repr(e)
        if set(ans_gold) == set(ans_pred):
            self.correct['exe'] = 1
예제 #4
0
 def get_query_from_json(self, json_line):
     """Returns a Query object for the json input and returns the table object as well"""
     q = Query.from_dict(json_line["sql"])
     t_id = json_line["table_id"]
     table = self.table_map[t_id]
     t = Table("", table["header"], table["types"], table["rows"])
     return t, q
예제 #5
0
def main(argv):
    del argv  # Unused.

    db_file = join(FLAGS.data_root, FLAGS.db_file)
    parsed_std_sql_file = join(FLAGS.data_root, FLAGS.parsed_std_sql_file)
    parsed_pred_sql_file = join(FLAGS.data_root, FLAGS.parsed_pred_sql_file)

    engine = DBEngine(db_file)
    exact_match = []

    with open(parsed_std_sql_file) as fs, open(parsed_pred_sql_file) as fp:
        grades = []
        for ls, lp in tqdm(zip(fs, fp),
                           total=count_lines(parsed_std_sql_file)):
            eg = json.loads(ls)
            ep = json.loads(lp)

            try:
                qg = Query.from_dict(eg['sql'])
                gold = engine.execute_query(eg['table_id'], qg, lower=True)
            except Exception as e:
                gold = repr(e)

            #pred = ep['error']
            qp = None
            #if not ep['error']:
            if True:
                try:
                    qp = Query.from_dict(ep['sql'])
                    pred = engine.execute_query(eg['table_id'], qp, lower=True)
                except Exception as e:
                    pred = repr(e)
            correct = pred == gold
            match = qp == qg
            if pred == gold and qp != qg:
                print(qp)
                print(qg)
            grades.append(correct)
            exact_match.append(match)
        print(
            json.dumps(
                {
                    'ex_accuracy': sum(grades) / len(grades),
                    'lf_accuracy': sum(exact_match) / len(exact_match),
                },
                indent=2))
예제 #6
0
def main(anno_file_name, col_headers, raw_args=None, verbose=True):
    parser = argparse.ArgumentParser(description='evaluate.py')
    opts.translate_opts(parser)
    opt = parser.parse_args(raw_args)
    torch.cuda.set_device(opt.gpu)
    opt.db_file = os.path.join(opt.data_path, '{}.db'.format(opt.split))
    opt.pre_word_vecs = os.path.join(opt.data_path, 'embedding')
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]
    opt.anno = anno_file_name

    engine = DBEngine(opt.db_file)

    js_list = table.IO.read_anno_json(opt.anno)

    prev_best = (None, None)
    sql_query = []
    for fn_model in glob.glob(opt.model_path):

        opt.model = fn_model

        translator = Translator(opt, dummy_opt.__dict__)
        data = table.IO.TableDataset(js_list, translator.fields, None, False)
        test_data = table.IO.OrderedIterator(dataset=data,
                                             device=opt.gpu,
                                             batch_size=opt.batch_size,
                                             train=False,
                                             sort=True,
                                             sort_within_batch=False)

        # inference
        r_list = []
        for batch in test_data:
            r_list += translator.translate(batch)
        r_list.sort(key=lambda x: x.idx)
        pred = r_list[-1]
        sql_pred = {
            'agg': pred.agg,
            'sel': pred.sel,
            'conds': pred.recover_cond_to_gloss(js_list[-1])
        }
        if verbose:
            print('\n sql_pred: ', sql_pred, '\n')
            print('\n col_headers: ', col_headers, '\n')
        sql_query = Query(sql_pred['sel'], sql_pred['agg'], sql_pred['conds'])
        try:
            ans_pred = engine.execute_query(js_list[-1]['table_id'],
                                            Query.from_dict(sql_pred),
                                            lower=True,
                                            verbose=verbose)
        except Exception as e:
            ans_pred = None
    return sql_query.get_complete_query(col_headers), ans_pred
예제 #7
0
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from lib.query import Query
from lib.dbengine import DBEngine
import pytorch

if __name__ == '__main__':
    for split in ['train', 'dev', 'test']:
        print('checking {}'.format(split))
        engine = DBEngine('data/{}.db'.format(split))
        n_lines = 0
        with open('data/{}.jsonl'.format(split)) as f:
            for l in f:
                n_lines += 1
        with open('data/{}.jsonl'.format(split)) as f:
            for l in tqdm(f, total=n_lines):
                d = json.loads(l)
                query = Query.from_dict(d['sql'])

                # make sure it's executable
                result = engine.execute_query(d['table_id'], query)
                if result:
                    for a, b, c in d['sql']['conds']:
                        if str(c).lower() not in d['question'].lower():
                            raise Exception(
                                'Could not find condition {} in question {} for query {}'
                                .format(c, d['question'], query))
                else:
                    raise Exception(
                        'Query {} did not execute to a valid result'.format(
                            query))
예제 #8
0
파일: evaluate.py 프로젝트: TooTouch/SPARTA
    parser.add_argument('--topk', type=int, default=3, help='k of top_k')
    
    args = parser.parse_args()

    engine = DBEngine(args.db_file)
   
    temp = []
    
    with open(args.source_file) as fs, open(args.pred_file) as fp:
        grades = []
        exact_match = []
        
        for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)):            
            eg = json.loads(ls)
            qg = Query.from_dict(eg['sql'], ordered=args.ordered)
            gold = engine.execute_query(eg['table_id'], qg, lower=True)
            
            pred_topk = []
            qp_topk = []
            
            ep = json.loads(lp)
            pred = ep.get('error', None)
            qp = None
            for i in range(args.topk):
                if not ep.get('error', None):
                    try:
                        
                        if ep['query'][str(i)]['conds'] == [[]]:
                            ep['query'][str(i)]['conds'] = []
            
    def eval(self, split, gold, sql_gold, engine=None):
        if self.agg == gold['query']['agg']:
            self.correct['agg'] = 1

        if self.sel == gold['query']['sel']:
            self.correct['sel'] = 1

        #gold['BIO_label'].data, gold['BIO_column_label'].data

        op_list_pred = [op for col, op, span in self.cond]
        op_list_gold = [op for col, op, span in gold['query']['conds']]

        col_list_pred = [col for col, op, span in self.cond]
        col_list_gold = [col for col, op, span in gold['query']['conds']]

        q = gold['question']['words']
        span_list_pred = [
            ' '.join(q[span[0]:span[1] + 1]) for col, op, span in self.cond
        ]
        span_list_gold = [
            ' '.join(span['words']) for col, op, span in gold['query']['conds']
        ]

        where_pred = list(zip(col_list_pred, op_list_pred, span_list_pred))
        where_gold = list(zip(col_list_gold, op_list_gold, span_list_gold))
        where_pred.sort()
        where_gold.sort()
        if where_pred == where_gold and (
                len(col_list_pred) == len(col_list_gold)) and (
                    len(op_list_pred)
                    == len(op_list_gold)) and (len(span_list_pred)
                                               == len(span_list_gold)):
            self.correct['where'] = 1

        if (len(col_list_pred) == len(col_list_gold)) and ([
                it[0] for it in where_pred
        ] == [it[0] for it in where_gold]):
            self.correct['col'] = 1

        if (len(op_list_pred) == len(op_list_gold)) and ([
                it[1] for it in where_pred
        ] == [it[1] for it in where_gold]):
            self.correct['lay'] = 1

        if (len(span_list_pred) == len(span_list_gold)) and ([
                it[2] for it in where_pred
        ] == [it[2] for it in where_gold]):
            self.correct['span'] = 1

        if all((self.correct[it] == 1 for it in ('agg', 'sel', 'where'))):
            self.correct['all'] = 1

        # execution
        table_id = gold['table_id']
        ans_gold = '0'
        ans_pred = '1'
        if engine is not None:
            ans_gold = engine.execute_query(table_id,
                                            Query.from_dict(sql_gold),
                                            lower=True)

            try:
                sql_pred = {
                    'agg': self.agg,
                    'sel': self.sel,
                    'conds': self.recover_cond_to_gloss(gold)
                }
                ans_pred = engine.execute_query(table_id,
                                                Query.from_dict(sql_pred),
                                                lower=True)
            except Exception as e:
                ans_pred = repr(e)
        else:
            ans_gold = '0'
            ans_pred = '1'
        if set(ans_gold) == set(ans_pred):
            self.correct['exe'] = 1

        error_case = {}
        if split == 'finaltest':
            #if self.correct['where'] != 1:
            if True:
                error_case['sel'] = self.correct['sel']
                error_case['where'] = self.correct['where']
                error_case['all'] = self.correct['all']
                error_case['table_id'] = gold['table_id']
                error_case['question_id'] = gold['id']
                error_case['question'] = gold['question']['words']
                error_case['table_head'] = [
                    head['words'] for head in gold['table']['header']
                ]
                #error_case['table_content'] = gold['tbl_content']

                #                for i in range(len(sql_gold['conds'])):
                #                    sql_gold['conds'][i][0] = (
                #                    sql_gold['conds'][i][0], gold['table']['header'][sql_gold['conds'][i][0]]['words'],
                #                    gold['tbl_content'][sql_gold['conds'][i][0]])
                error_case['gold'] = sql_gold

                error_case['predict'] = {
                    'agg': self.agg.item(),
                    'sel': self.sel.item(),
                    'conds': self.print_recover_cond_to_gloss(gold)
                }
                # error_case['exe result']=self.correct['exe']

                error_case['BIO'] = [(x.item(), y) for x, y in zip(
                    list(self.BIO), list(gold['question']['words']))]
                error_case['BIO_col'] = self.BIO_col.tolist()

        return error_case
예제 #10
0
    opts = Options()
    fill_from_args(opts)

    for split in ['train', 'dev', 'test']:
        orig = os.path.join(opts.data_dir, f'{split}.jsonl')
        db_file = os.path.join(opts.data_dir, f'{split}.db')
        ans_file = os.path.join(opts.data_dir, f"{split}_ans.jsonl.gz")
        tbl_file = os.path.join(opts.data_dir, f"{split}.tables.jsonl")
        engine = DBEngine(db_file)
        exact_match = []
        with open(orig) as fs, write_open(ans_file) as fo:
            grades = []
            for ls in tqdm(fs, total=count_lines(orig)):
                eg = json.loads(ls)
                sql = eg['sql']
                qg = Query.from_dict(sql, ordered=False)
                gold = engine.execute_query(eg['table_id'], qg, lower=True)
                assert isinstance(gold, list)
                #if len(gold) != 1:
                #    print(f'for {sql} : {gold}')
                eg['answer'] = gold
                eg['rowids'] = engine.execute_query_rowid(eg['table_id'],
                                                          qg,
                                                          lower=True)
                # CONSIDER: if it is not an agg query, somehow identify the particular cell
                fo.write(json.dumps(eg) + '\n')

        convert(jsonl_lines(ans_file),
                jsonl_lines(tbl_file),
                os.path.join(opts.data_dir, f"{split}_agg.jsonl.gz"),
                skip_aggregation=False)
예제 #11
0
if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('source_file', help='source file for the prediction')
    parser.add_argument('db_file', help='source database for the prediction')
    parser.add_argument('pred_file', help='predictions by the model')
    args = parser.parse_args()

    engine = DBEngine(args.db_file)
    exact_match = []
    with open(args.source_file) as fs, open(args.pred_file) as fp:
        grades = []
        for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.pred_file)):
            eg = json.loads(ls)
            ep = json.loads(lp)
            qg = Query.from_dict(eg['sql'])
            #print qg
            gold = engine.execute_query(eg['table_id'], qg, lower=True)
            pred = ep['error']
            qp = None
            if not ep['error']:
                try:
                    qp = Query.from_dict(ep['query'])
                    pred = engine.execute_query(eg['table_id'], qp, lower=True)
                except Exception as e:
                    pred = repr(e)
            correct = pred == gold
            match = qp == qg
            #if not match:
            #    print qp.to_dict()
            #    print qg.to_dict()
예제 #12
0
def main():
    parser = ArgumentParser()
    parser.add_argument('--din', default=data_path, help='data directory')
    parser.add_argument('--dout', default='annotated', help='output directory')
    args = parser.parse_args()

    if not os.path.isdir(args.dout):
        os.makedirs(args.dout)

    for split in ['dev']:
        #for split in ['train', 'test', 'dev']:
        with open(save_path+'%s.qu'%split, 'w') as qu_file, open(save_path+'%s.lon'%split, 'w') as lon_file, \
            open(save_path+'%s.out'%split, 'w') as out, open(save_path+'%s_sym_pairs.txt'%split, 'w') as sym_file, \
            open(save_path+'%s_ground_truth.txt'%split, 'w') as S_file:

            fsplit = os.path.join(args.din, split) + '.jsonl'
            ftable = os.path.join(args.din, split) + '.tables.jsonl'

            with open(fsplit) as fs, open(ftable) as ft:
                print('loading tables')
                tables = {}
                for line in tqdm(ft, total=count_lines(ftable)):
                    d = json.loads(line)
                    tables[d['id']] = d
                print('loading tables done.')

                print('loading examples')
                n, acc, acc_pair, acc_all, error = 0, 0, 0, 0, 0
                target = -1

                ADD_FIELDS = False
                step2 = True

                for line in tqdm(fs, total=count_lines(fsplit)):
                    ADD_TO_FILE = True
                    d = json.loads(line)
                    Q = d['question']

                    rows = tables[d['table_id']]['rows']
                    rows = np.asarray(rows)
                    fs = tables[d['table_id']]['header']

                    all_fields = []
                    for f in fs:
                        all_fields.append(_preclean(f))
                    # all fields are sorted by length in descending order
                    # for string match purpose
                    all_fields = sorted(all_fields, key=len, reverse=True)

                    smap = defaultdict(list)  #f2v
                    reverse_map = defaultdict(list)  #v2f
                    for row in rows:
                        for i in range(len(fs)):
                            cur_f = _preclean(str(fs[i]))
                            cur_row = _preclean(str(row[i]))
                            smap[cur_f].append(cur_row)
                            if cur_f not in reverse_map[cur_row]:
                                reverse_map[cur_row].append(cur_f)

                    #----------------------------------------------------------
                    # all values are sorted by length in descending order
                    # for string match purpose
                    keys = sorted(reverse_map.keys(), key=len, reverse=True)

                    Q = _preclean(Q)
                    Q_ori = Q

                    #####################################
                    ########## Annotate question ########
                    #####################################
                    candidates, cond_fields = _match_pairs(
                        Q, Q_ori, keys, reverse_map)
                    Q_head, head2partial = _match_head(Q, Q_ori, smap,
                                                       all_fields, cond_fields)

                    Q, Qpairs = annotate_Q(Q, Q_ori, Q_head, candidates,
                                           all_fields, head2partial, n, target)
                    qu_file.write(Q + '\n')

                    validation_pairs = copy.copy(Qpairs)
                    validation_pairs.append((Q_head, '<f0>', 'head'))
                    for i, f in enumerate(all_fields):
                        validation_pairs.append((f, '<c' + str(i) + '>', 'c'))

                    #####################################
                    ########## Annotate SQL #############
                    #####################################
                    q_sent = Query.from_dict(d['sql'])
                    S, col_names, val_names = q_sent.to_sentence(
                        tables[d['table_id']]['header'], rows,
                        tables[d['table_id']]['types'])
                    S = _preclean(S)
                    S_ori = S

                    S_noparen = q_sent.to_sentence_noparenthesis(
                        tables[d['table_id']]['header'], rows,
                        tables[d['table_id']]['types'])
                    S_noparen = _preclean(S_noparen)

                    col_names = [_preclean(col_name) for col_name in col_names]
                    val_names = [_preclean(val_name) for val_name in val_names]

                    HEAD = col_names[-1]
                    S_head = _preclean(HEAD)

                    #annotate for SQL
                    name_pairs = []
                    for col_name, val_name in zip(col_names, val_names):
                        if col_name == val_name:
                            name_pairs.append([_preclean(col_name), 'true'])
                        else:
                            name_pairs.append(
                                [_preclean(col_name),
                                 _preclean(val_name)])

                    # sort to compare with candidates
                    name_pairs.sort(key=lambda x: x[1])
                    new_name_pairs = [[
                        '<f' + str(i + 1) + '>', '<v' + str(i + 1) + '>'
                    ] for i, (field, value) in enumerate(name_pairs)]

                    # only annotate S while identified (f,v) pairs are right
                    if _equal(name_pairs, candidates):
                        pairs = []
                        for (f, v), (new_f,
                                     new_v) in zip(name_pairs, new_name_pairs):
                            pairs.append((f, new_f, 'f'))
                            pairs.append((v, new_v, 'v'))
                        # sort (word,symbol) pairs by length in descending order
                        pairs.sort(key=lambda x: 100 - len(x[0]))

                        for p, new_p, t in pairs:
                            cp = _backslash(p)

                            if new_p in Q:
                                if t == 'v':
                                    S = S.replace(p + ' )', new_p + ' )')
                                if t == 'f':
                                    S = re.sub(
                                        '\( ' + cp + ' (equal|less|greater)',
                                        '( ' + new_p + r' \1', S)

                    # only annotate S while identified HEAD is right
                    if S_head == Q_head and '<f0>' in Q:
                        S = S.replace(S_head, '<f0>')

                    # annote unseen fields
                    if ADD_FIELDS:
                        for i, f in enumerate(all_fields):
                            cf = _backslash(f)
                            S = re.sub('(\s|^)' + cf + '(\s|$|s)',
                                       ' <c' + str(i) + '> ', S)

                    S = _clean(S)
                    lon_file.write(S + '\n')

                    ###############################
                    ######### VALIDATION ##########
                    ###############################
                    recover_S = S
                    for word, sym, t in validation_pairs:
                        recover_S = recover_S.replace(sym, word)
                        sym_file.write(sym + '=>' + word + '<>')
                    sym_file.write('\n')

                    S_file.write(S_noparen + '\n')
                    #------------------------------------------------------------------------
                    if _equal(name_pairs, candidates):
                        acc_pair += 1

                    if Q_head == S_head:
                        acc += 1

                    if _equal(name_pairs, candidates) and Q_head == S_head:
                        acc_all += 1

                    full_anno = True
                    for s in S.split():
                        if s[0] != '<' and s not in [
                                '(', ')', 'where', 'less', 'greater', 'equal',
                                'max', 'min', 'count', 'sum', 'avg', 'and',
                                'true'
                        ]:
                            error += 1
                            full_anno = False
                            break

                    if False and not (_equal(name_pairs, candidates)
                                      and head == col_names[-1]):
                        print('--------' + str(n) + '-----------')
                        print(Q_ori)
                        print(Q)
                        print(S_ori)
                        print(S)
                        print('head:')
                        print(head)
                        print(head2partial)
                        print('true HEAD')
                        print(Q_head)
                        print('fields:')
                        print(candidates)
                        print(p2partial)
                        print('true fields')
                        print(name_pairs)

                    n += 1

                print('total number of examples:' + str(n))
                print('fully snnotated:' + str(1 - error * 1.0 / n))
                print('accurate all percent:' + str(acc_all * 1.0 / n))
                print('accurate HEAD match percent:' + str(acc * 1.0 / n))
                print('accurate fields pair match percent:' +
                      str(acc_pair * 1.0 / n))
예제 #13
0
def _gen_files(save_path=save_path):
    """
    Prepare files used for Machine Comprehension Binary Classifier
    """
    for split in ['train', 'test', 'dev']:
        print('------%s-------' % split)
        n = 0
        fsplit = os.path.join(data_path, split) + '.jsonl'
        ftable = os.path.join(data_path, split) + '.tables.jsonl'
        """
	test.txt original column content w/o truncate
	test_model.txt column name truncate or pad to length 3
	"""
        with open(fsplit) as fs, open(ftable) as ft, \
            open(os.path.join(save_path, split+'.txt'), mode='w') as fw, \
            open(os.path.join(save_path, '%s_model.txt'%split), mode='w') as fw_model, \
   open(os.path.join(save_path, split+'.lon'), mode='w') as fsql, \
            open(os.path.join(save_path, '%s.ori.qu'%split), 'w') as qu_file, \
            open(os.path.join(save_path, '%s.ori.lon'%split), 'w') as lon_file:
            print('loading tables...')
            tables = {}
            for line in tqdm(ft, total=count_lines(ftable)):
                d = json.loads(line)
                tables[d['id']] = d
            print('loading tables done.')

            print('loading examples')
            f2v_all, v2f_all = [], []
            for line in tqdm(fs, total=count_lines(fsplit)):

                d = json.loads(line)
                Q = d['question']
                Q = _preclean(Q).replace('\t', '')

                qu_file.write(Q + '\n')

                q_sent = Query.from_dict(d['sql'])
                rows = tables[d['table_id']]['rows']
                S, col_names, val_names = q_sent.to_sentence(
                    tables[d['table_id']]['header'], rows,
                    tables[d['table_id']]['types'])
                S = _preclean(S)

                lon_file.write(S + '\n')

                rows = np.asarray(rows)
                fs = tables[d['table_id']]['header']
                all_fields = [_preclean(f) for f in fs]
                # all fields are sorted by length in descending order
                # for string match purpose
                headers = sorted(all_fields, key=len, reverse=True)

                f2v = defaultdict(list)  #f2v
                v2f = defaultdict(list)  #v2f
                for row in rows:
                    for i in range(len(fs)):
                        cur_f = _preclean(str(fs[i]))
                        cur_row = _preclean(str(row[i]))
                        #cur_f = cur_f.replace('\u2003',' ')
                        f2v[cur_f].append(cur_row)
                        if cur_f not in v2f[cur_row]:
                            v2f[cur_row].append(cur_f)
                f2v_all.append(f2v)
                v2f_all.append(v2f)

                #####################################
                ########## Annotate SQL #############
                #####################################
                q_sent = Query.from_dict(d['sql'])
                S, col_names, val_names = q_sent.to_sentence(
                    tables[d['table_id']]['header'], rows,
                    tables[d['table_id']]['types'])
                S = _preclean(S)

                S_noparen = q_sent.to_sentence_noparenthesis(
                    tables[d['table_id']]['header'], rows,
                    tables[d['table_id']]['types'])
                S_noparen = _preclean(S_noparen)

                col_names = [_preclean(col_name) for col_name in col_names]
                val_names = [_preclean(val_name) for val_name in val_names]

                HEAD = col_names[-1]
                S_head = _preclean(HEAD)

                #annotate for SQL
                name_pairs = []
                for col_name, val_name in zip(col_names, val_names):
                    if col_name == val_name:
                        name_pairs.append([_preclean(col_name), 'true'])
                    else:
                        name_pairs.append(
                            [_preclean(col_name),
                             _preclean(val_name)])

                # sort to compare with candidates
                name_pairs.sort(key=lambda x: x[1])
                fsql.write('#%d\n' % n)
                fw.write('#%d\n' % n)

                for f in col_names:
                    fsql.write(S.replace(f, '[' + f + ']') + '\n')
                    f = _truncate(f, END='<bos>', PAD='<pad>', max_len=-1)
                    s = (Q + '\t' + f + '\t 1')
                    assert len(s.split('\t')) == 3
                    fw.write(s + '\n')

                for f in [f for f in headers if f not in col_names]:
                    f = _truncate(f, END='<bos>', PAD='<pad>', max_len=-1)
                    s = (Q + '\t' + f + '\t 0')
                    assert len(s.split('\t')) == 3
                    fw.write(s + '\n')
                    fsql.write(S + '\n')

                #if '\u2003' in Q:
                #    print('u2003: '+Q)
                #if '\xa0' in Q:
                #    print(n)
                #    print('xa0: '+Q)
                #    print(S)
                for f in col_names:
                    f = f.replace(u'\xa0', u' ').replace('\t', '')
                    Q = Q.replace(u'\xa0', u' ').replace('\t', '')
                    f = _truncate(f, END='bos', PAD='pad', max_len=3)
                    s = (Q + '\t' + f + '\t 1')
                    assert len(s.split('\t')) == 3
                    fw_model.write(s + '\n')

                for f in [f for f in headers if f not in col_names]:
                    f = f.replace(u'\xa0', u' ').replace('\t', '')
                    Q = Q.replace(u'\xa0', u' ').replace('\t', '')
                    f = _truncate(f, END='bos', PAD='pad', max_len=3)
                    s = (Q + '\t' + f + '\t 0')
                    assert len(s.split('\t')) == 3
                    fw_model.write(s + '\n')

                n += 1
            fsql.write('#%d\n' % n)
            fw.write('#%d\n' % n)

        scipy.savez(os.path.join(save_path, '%s_dict.npz' % split),
                    f2v_all=f2v_all,
                    v2f_all=v2f_all)
        print('num of records:%d' % n)
예제 #14
0
    parser.add_argument("--db_file")
    parser.add_argument("--pred_file")
    parser.add_argument("--ordered", action='store_true')
    args = parser.parse_args()

    engine = DBEngine(args.db_file)
    ex_acc_list = []
    lf_acc_list = []
    with open(args.source_file) as sf, open(args.pred_file) as pf:
        for source_line, pred_line in tqdm(zip(sf, pf), total=count_lines(args.source_file)):
            # line별 정답과 예측 샘플 가져오기
            gold_example = json.loads(source_line)
            pred_example = json.loads(pred_line)

            # 정답 샘플 lf, ex 구하기
            lf_gold_query = Query.from_dict(gold_example['sql'], ordered=args.ordered)
            ex_gold = engine.execute_query(gold_example['table_id'], lf_gold_query, lower=True)

            # error가 아닌 경우 예측 샘플 lf, ex 구하기
            lf_pred_query = None
            ex_pred = pred_example.get('error', None)
            if not ex_pred:
                try:
                    lf_pred_query = Query.from_dict(pred_example['query'], ordered=args.ordered)
                    ex_pred = engine.execute_query(gold_example['table_id'], lf_pred_query, lower=True)
                except Exception as e:
                    ex_pred = repr(e)

            # lf, ex의 gold, pred 매칭결과 구하기
            ex_acc_list.append(ex_pred == ex_gold)
            lf_acc_list.append(lf_pred_query == lf_gold_query) # query의 __eq__를 호출