コード例 #1
0
def hello():
    pred_json = json.loads(request.args.get('pred'))
    gold_json = json.loads(request.args.get('gold'))
    dataset = request.args.get('dataset')

    engine = DBEngine(os.path.join(DATABASE_PATH, "{}.db".format(dataset)))

    exact_match = []
    grades = []

    for lp, ls in tqdm(zip(pred_json, gold_json), total=len(gold_json)):
        eg = ls
        ep = lp
        qg = Query.from_dict(eg['sql'])
        gold = engine.execute_query(eg['table_id'], qg, lower=True)
        pred = ep['error']
        qp = None
        if not ep['error']:
            try:
                qp = Query.from_dict(ep['query'])
                pred = engine.execute_query(eg['table_id'], qp, lower=True)
            except Exception as e:
                pred = repr(e)
        correct = pred == gold
        match = qp == qg
        grades.append(correct)
        exact_match.append(match)

    ex_accuracy = sum(grades) / len(grades)
    lf_accuracy = sum(exact_match) / len(exact_match)
    return json.dumps({"ex_accuracy": ex_accuracy, "lf_accuracy": lf_accuracy})
コード例 #2
0
ファイル: evaluate.py プロジェクト: lukovnikov/WikiSQL
def eval_one_qelos(db_file, pred_file, source_file):
    engine = DBEngine(db_file)
    exact_match = []
    with open(source_file) as fs, open(pred_file) as fp:
        grades = []
        for ls, lp in tqdm(zip(fs, fp), total=count_lines(source_file)):
            eg = json.loads(ls)
            ep = json.loads(lp)
            qg = Query.from_dict(eg['sql'])
            gold = engine.execute_query(eg['table_id'], qg, lower=True)
            qp = None
            try:
                qp = Query.from_dict(ep)
                pred = engine.execute_query(eg['table_id'], qp, lower=True)
            except Exception as e:
                pred = repr(e)
            correct = pred == gold
            match = qp == qg
            grades.append(correct)
            exact_match.append(match)
        result = {
            'ex_accuracy': sum(grades) / len(grades),
            'lf_accuracy': sum(exact_match) / len(exact_match),
            }
        return result
コード例 #3
0
ファイル: evaluate_ours.py プロジェクト: jzl0166/NLIDB
def main(argv):
    del argv  # Unused.

    db_file = join(FLAGS.data_root, FLAGS.db_file)
    parsed_std_sql_file = join(FLAGS.data_root, FLAGS.parsed_std_sql_file)
    parsed_pred_sql_file = join(FLAGS.data_root, FLAGS.parsed_pred_sql_file)

    engine = DBEngine(db_file)
    exact_match = []

    with open(parsed_std_sql_file) as fs, open(parsed_pred_sql_file) as fp:
        grades = []
        for ls, lp in tqdm(zip(fs, fp),
                           total=count_lines(parsed_std_sql_file)):
            eg = json.loads(ls)
            ep = json.loads(lp)

            try:
                qg = Query.from_dict(eg['sql'])
                gold = engine.execute_query(eg['table_id'], qg, lower=True)
            except Exception as e:
                gold = repr(e)

            #pred = ep['error']
            qp = None
            #if not ep['error']:
            if True:
                try:
                    qp = Query.from_dict(ep['sql'])
                    pred = engine.execute_query(eg['table_id'], qp, lower=True)
                except Exception as e:
                    pred = repr(e)
            correct = pred == gold
            match = qp == qg
            if pred == gold and qp != qg:
                print(qp)
                print(qg)
            grades.append(correct)
            exact_match.append(match)
        print(
            json.dumps(
                {
                    'ex_accuracy': sum(grades) / len(grades),
                    'lf_accuracy': sum(exact_match) / len(exact_match),
                },
                indent=2))
コード例 #4
0
def main(anno_file_name, col_headers, raw_args=None, verbose=True):
    parser = argparse.ArgumentParser(description='evaluate.py')
    opts.translate_opts(parser)
    opt = parser.parse_args(raw_args)
    torch.cuda.set_device(opt.gpu)
    opt.db_file = os.path.join(opt.data_path, '{}.db'.format(opt.split))
    opt.pre_word_vecs = os.path.join(opt.data_path, 'embedding')
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]
    opt.anno = anno_file_name

    engine = DBEngine(opt.db_file)

    js_list = table.IO.read_anno_json(opt.anno)

    prev_best = (None, None)
    sql_query = []
    for fn_model in glob.glob(opt.model_path):

        opt.model = fn_model

        translator = Translator(opt, dummy_opt.__dict__)
        data = table.IO.TableDataset(js_list, translator.fields, None, False)
        test_data = table.IO.OrderedIterator(dataset=data,
                                             device=opt.gpu,
                                             batch_size=opt.batch_size,
                                             train=False,
                                             sort=True,
                                             sort_within_batch=False)

        # inference
        r_list = []
        for batch in test_data:
            r_list += translator.translate(batch)
        r_list.sort(key=lambda x: x.idx)
        pred = r_list[-1]
        sql_pred = {
            'agg': pred.agg,
            'sel': pred.sel,
            'conds': pred.recover_cond_to_gloss(js_list[-1])
        }
        if verbose:
            print('\n sql_pred: ', sql_pred, '\n')
            print('\n col_headers: ', col_headers, '\n')
        sql_query = Query(sql_pred['sel'], sql_pred['agg'], sql_pred['conds'])
        try:
            ans_pred = engine.execute_query(js_list[-1]['table_id'],
                                            Query.from_dict(sql_pred),
                                            lower=True,
                                            verbose=verbose)
        except Exception as e:
            ans_pred = None
    return sql_query.get_complete_query(col_headers), ans_pred
コード例 #5
0
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from lib.query import Query
from lib.dbengine import DBEngine
import pytorch

if __name__ == '__main__':
    for split in ['train', 'dev', 'test']:
        print('checking {}'.format(split))
        engine = DBEngine('data/{}.db'.format(split))
        n_lines = 0
        with open('data/{}.jsonl'.format(split)) as f:
            for l in f:
                n_lines += 1
        with open('data/{}.jsonl'.format(split)) as f:
            for l in tqdm(f, total=n_lines):
                d = json.loads(l)
                query = Query.from_dict(d['sql'])

                # make sure it's executable
                result = engine.execute_query(d['table_id'], query)
                if result:
                    for a, b, c in d['sql']['conds']:
                        if str(c).lower() not in d['question'].lower():
                            raise Exception(
                                'Could not find condition {} in question {} for query {}'
                                .format(c, d['question'], query))
                else:
                    raise Exception(
                        'Query {} did not execute to a valid result'.format(
                            query))
コード例 #6
0
ファイル: evaluate.py プロジェクト: TooTouch/SPARTA
    parser.add_argument('--topk', type=int, default=3, help='k of top_k')
    
    args = parser.parse_args()

    engine = DBEngine(args.db_file)
   
    temp = []
    
    with open(args.source_file) as fs, open(args.pred_file) as fp:
        grades = []
        exact_match = []
        
        for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)):            
            eg = json.loads(ls)
            qg = Query.from_dict(eg['sql'], ordered=args.ordered)
            gold = engine.execute_query(eg['table_id'], qg, lower=True)
            
            pred_topk = []
            qp_topk = []
            
            ep = json.loads(lp)
            pred = ep.get('error', None)
            qp = None
            for i in range(args.topk):
                if not ep.get('error', None):
                    try:
                        
                        if ep['query'][str(i)]['conds'] == [[]]:
                            ep['query'][str(i)]['conds'] = []
            
                        qp = Query.from_dict(ep['query'][str(i)], ordered=args.ordered)
コード例 #7
0
    parser.add_argument("--pred_file")
    parser.add_argument("--ordered", action='store_true')
    args = parser.parse_args()

    engine = DBEngine(args.db_file)
    ex_acc_list = []
    lf_acc_list = []
    with open(args.source_file) as sf, open(args.pred_file) as pf:
        for source_line, pred_line in tqdm(zip(sf, pf), total=count_lines(args.source_file)):
            # line별 정답과 예측 샘플 가져오기
            gold_example = json.loads(source_line)
            pred_example = json.loads(pred_line)

            # 정답 샘플 lf, ex 구하기
            lf_gold_query = Query.from_dict(gold_example['sql'], ordered=args.ordered)
            ex_gold = engine.execute_query(gold_example['table_id'], lf_gold_query, lower=True)

            # error가 아닌 경우 예측 샘플 lf, ex 구하기
            lf_pred_query = None
            ex_pred = pred_example.get('error', None)
            if not ex_pred:
                try:
                    lf_pred_query = Query.from_dict(pred_example['query'], ordered=args.ordered)
                    ex_pred = engine.execute_query(gold_example['table_id'], lf_pred_query, lower=True)
                except Exception as e:
                    ex_pred = repr(e)

            # lf, ex의 gold, pred 매칭결과 구하기
            ex_acc_list.append(ex_pred == ex_gold)
            lf_acc_list.append(lf_pred_query == lf_gold_query) # query의 __eq__를 호출
        print('ex_accuracy {}\n lf_accuracy {}'.format(