コード例 #1
0
ファイル: evaluate.py プロジェクト: dwtcourses/NLQ_to_SQL
def main():
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    # engine = DBEngine(opt.db_file)

    with codecs.open(opt.source_file, "r", "utf-8") as corpus_file:
        sql_list = [json.loads(line)['sql'] for line in corpus_file]

    js_list = table.IO.read_anno_json(opt.anno)

    prev_best = (None, None)
    for fn_model in glob.glob(opt.model_path):
        print(fn_model)
        print(opt.anno)
        opt.model = fn_model

        translator = table.Translator(opt, dummy_opt.__dict__)
        data = table.IO.TableDataset(js_list, translator.fields, None, False)
        test_data = table.IO.OrderedIterator(dataset=data,
                                             device=opt.gpu,
                                             batch_size=opt.batch_size,
                                             train=False,
                                             sort=True,
                                             sort_within_batch=False)

        # inference
        r_list = []
        for batch in test_data:
            r_list += translator.translate(batch)
        r_list.sort(key=lambda x: x.idx)
コード例 #2
0
ファイル: test_single.py プロジェクト: code-gen/coarse2fine
def do_test(example_list):
    metric_name_list = ['lay-token', 'lay', 'tgt-token', 'tgt']

    args.model = args.model_path  # TODO??
    translator = table.Translator(args)
    data = table.IO.TableDataset(example_list, translator.fields, 0, None,
                                 False)

    test_data = table.IO.OrderedIterator(
        dataset=data,
        device=args.gpu_id[0] if args.cuda else -1,
        batch_size=args.batch_size,
        train=False,
        sort=True,
        sort_within_batch=False)

    out_list = []
    for i, batch in enumerate(test_data):
        r = translator.translate(batch)
        logger.info(r[0])
        out_list += r

    out_list.sort(key=lambda x: x.idx)
    assert len(out_list) == len(
        example_list), 'len(out_list) != len(js_list): {} != {}'.format(
            len(out_list), len(example_list))

    # evaluation
    for pred, gold in zip(out_list, example_list):
        pred.eval(gold)

    for metric_name in metric_name_list:
        if metric_name.endswith("-token"):
            c_correct = sum([
                len(
                    set(x.get_by_name(metric_name)) -
                    set(y[metric_name.split("-")[0]])) == 0
                for x, y in zip(out_list, example_list)
            ])
            acc = c_correct / len(out_list)

            out_str = '{}: {} / {} = {:.2%}'.format(metric_name.upper(),
                                                    c_correct, len(out_list),
                                                    acc)
            logger.info(out_str)

        else:
            c_correct = sum((x.correct[metric_name] for x in out_list))
            acc = c_correct / len(out_list)

            out_str = '{}: {} / {} = {:.2%}'.format(metric_name.upper(),
                                                    c_correct, len(out_list),
                                                    acc)
            logger.info(out_str)

            for x in out_list:
                for prd, tgt in x.incorrect[metric_name]:
                    logger.warning("\nprd: %s\ntgt: %s" %
                                   (" ".join(prd), " ".join(tgt)))
コード例 #3
0
ファイル: evaluate.py プロジェクト: zenghanfu/coarse2fine
def main():
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    engine = DBEngine(opt.db_file)

    with codecs.open(opt.source_file, "r", "utf-8") as corpus_file:
        sql_list = [json.loads(line)['sql'] for line in corpus_file]

    js_list = table.IO.read_anno_json(opt.anno)

    prev_best = (None, None)
    for fn_model in glob.glob(opt.model_path):
        print(fn_model)
        print(opt.anno)
        opt.model = fn_model

        translator = table.Translator(opt, dummy_opt.__dict__)
        data = table.IO.TableDataset(js_list, translator.fields, None, False)
        test_data = table.IO.OrderedIterator(dataset=data,
                                             device=opt.gpu,
                                             batch_size=opt.batch_size,
                                             train=False,
                                             sort=True,
                                             sort_within_batch=False)

        # inference
        r_list = []
        for batch in test_data:
            r_list += translator.translate(batch)
        r_list.sort(key=lambda x: x.idx)
        assert len(r_list) == len(
            js_list), 'len(r_list) != len(js_list): {} != {}'.format(
                len(r_list), len(js_list))

        # evaluation
        for pred, gold, sql_gold in zip(r_list, js_list, sql_list):
            pred.eval(gold, sql_gold, engine)
        print('Results:')
        for metric_name in ('all', 'exe'):
            c_correct = sum((x.correct[metric_name] for x in r_list))
            print('{}: {} / {} = {:.2%}'.format(metric_name, c_correct,
                                                len(r_list),
                                                c_correct / len(r_list)))
            if metric_name == 'all' and (prev_best[0] is None
                                         or c_correct > prev_best[1]):
                prev_best = (fn_model, c_correct)

    if (opt.split == 'dev') and (prev_best[0] is not None):
        with codecs.open(os.path.join(opt.data_path, 'dev_best.txt'),
                         'w',
                         encoding='utf-8') as f_out:
            f_out.write('{}\n'.format(prev_best[0]))
コード例 #4
0
def main():
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    js_list = table.IO.read_anno_json(opt.anno, opt)

    metric_name_list = ['tgt']
    prev_best = (None, None)
    for fn_model in glob.glob(opt.model_path):
        opt.model = fn_model
        print(fn_model)
        print(opt.anno)

        translator = table.Translator(opt, dummy_opt.__dict__)
        data = table.IO.TableDataset(js_list, translator.fields, 0, None,
                                     False)
        test_data = table.IO.OrderedIterator(dataset=data,
                                             device=opt.gpu,
                                             batch_size=opt.batch_size,
                                             train=False,
                                             sort=True,
                                             sort_within_batch=False)

        # inference
        r_list = []
        for batch in test_data:
            r = translator.translate(batch)
            r_list += r
        r_list.sort(key=lambda x: x.idx)
        assert len(r_list) == len(
            js_list), 'len(r_list) != len(js_list): {} != {}'.format(
                len(r_list), len(js_list))

        # evaluation
        for pred, gold in zip(r_list, js_list):
            pred.eval(gold)
        print('Results:')
        for metric_name in metric_name_list:
            c_correct = sum((x.correct[metric_name] for x in r_list))
            acc = c_correct / len(r_list)
            print('{}: {} / {} = {:.2%}'.format(metric_name, c_correct,
                                                len(r_list), acc))
            if metric_name == 'tgt' and (prev_best[0] is None
                                         or acc > prev_best[1]):
                prev_best = (fn_model, acc)

    if (opt.split == 'dev') and (prev_best[0] is not None):
        with codecs.open(os.path.join(opt.root_dir, opt.dataset,
                                      'dev_best.txt'),
                         'w',
                         encoding='utf-8') as f_out:
            f_out.write('{}\n'.format(prev_best[0]))
コード例 #5
0
def main():
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    js_list = table.IO.read_anno_json(opt.anno, opt)
    # metric_name_list = ['tgt']
    prev_best = (None, None)
    # print(opt.model_path)
    for fn_model in glob.glob(opt.model_path):
        opt.model = fn_model
        print(fn_model)
        with torch.no_grad():
            translator = table.Translator(opt, dummy_opt.__dict__)
            data = table.IO.TableDataset(js_list, translator.fields, 0, None,
                                         False)
            test_data = table.IO.OrderedIterator(dataset=data,
                                                 device=opt.gpu,
                                                 batch_size=opt.batch_size,
                                                 train=False,
                                                 sort=True,
                                                 sort_within_batch=False)
            # inference
            r_list = []
            for batch in test_data:
                r = translator.translate(batch)
                r_list += r

        r_list.sort(key=lambda x: x.idx)
        assert len(r_list) == len(
            js_list), 'len(r_list) != len(js_list): {} != {}'.format(
                len(r_list), len(js_list))

        metric, _ = com_metric(js_list, r_list)
    if opt.split == 'test':
        ref_dic, pre_dict = effect_len(js_list, r_list)
        for i in range(len(ref_dic)):
            js_list = ref_dic[i]
            r_list = pre_dict[i]
            print("the effect of length {}".format(i))
            metric, _ = com_metric(js_list, r_list)

        if prev_best[0] is None or float(metric['Bleu_1']) > prev_best[1]:
            prev_best = (fn_model, metric['Bleu_1'])

    if (opt.split == 'dev') and (prev_best[0] is not None):
        with codecs.open(os.path.join(opt.root_dir, opt.dataset,
                                      'dev_best.txt'),
                         'w',
                         encoding='utf-8') as f_out:
            f_out.write('{}\n'.format(prev_best[0]))
コード例 #6
0
def main():
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    engine = DBEngine(opt.db_file)

    with codecs.open(opt.source_file, "r", "utf-8") as corpus_file:
        sql_list = [json.loads(line)['sql'] for line in corpus_file]

    js_list = table.IO.read_anno_json(opt.anno)

    prev_best = (None, None)
    print(opt.split, opt.model_path)

    num_models = 0

    f_out = open('Two-stream-' + opt.unseen_table + '-out-case', 'w')

    for fn_model in glob.glob(opt.model_path):
        num_models += 1
        sys.stdout.flush()
        print(fn_model)
        print(opt.anno)
        opt.model = fn_model

        translator = table.Translator(opt, dummy_opt.__dict__)
        data = table.IO.TableDataset(js_list, translator.fields, None, False)
        #torch.save(data, open( 'data.pt', 'wb'))
        test_data = table.IO.OrderedIterator(dataset=data,
                                             device=opt.gpu,
                                             batch_size=opt.batch_size,
                                             train=False,
                                             sort=True,
                                             sort_within_batch=False)

        # inference
        r_list = []
        for batch in test_data:
            r_list += translator.translate(batch)
        r_list.sort(key=lambda x: x.idx)
        assert len(r_list) == len(
            js_list), 'len(r_list) != len(js_list): {} != {}'.format(
                len(r_list), len(js_list))
        # evaluation
        error_cases = []
        for pred, gold, sql_gold in zip(r_list, js_list, sql_list):
            error_cases.append(pred.eval(opt.split, gold, sql_gold, engine))
#            error_cases.append(pred.eval(opt.split, gold, sql_gold))
        print('Results:')
        for metric_name in ('all', 'exe', 'agg', 'sel', 'where', 'col', 'span',
                            'lay', 'BIO', 'BIO_col'):
            c_correct = sum((x.correct[metric_name] for x in r_list))
            print('{}: {} / {} = {:.2%}'.format(metric_name, c_correct,
                                                len(r_list),
                                                c_correct / len(r_list)))
            if metric_name == 'all':
                all_acc = c_correct
            if metric_name == 'exe':
                exe_acc = c_correct
        if prev_best[
                0] is None or all_acc + exe_acc > prev_best[1] + prev_best[2]:
            prev_best = (fn_model, all_acc, exe_acc)

#        random.shuffle(error_cases)
        for error_case in error_cases:
            if len(error_case) == 0:
                continue
            json.dump(error_case, f_out)
            f_out.write('\n')


#            print('table_id:\t', error_case['table_id'])
#            print('question_id:\t',error_case['question_id'])
#            print('question:\t', error_case['question'])
#            print('table_head:\t', error_case['table_head'])
#            print('table_content:\t', error_case['table_content'])
#            print()

#            print(error_case['BIO'])
#            print(error_case['BIO_col'])
#            print()

#            print('gold:','agg:',error_case['gold']['agg'],'sel:',error_case['predict']['sel'])
#            for i in range(len(error_case['gold']['conds'])):
#                print(error_case['gold']['conds'][i])

#           print('predict:','agg:',error_case['predict']['agg'],'sel:',error_case['predict']['sel'])
#           for i in range(len(error_case['predict']['conds'])):
#               print(error_case['predict']['conds'][i])
#           print('\n\n')

    print(prev_best)
    if (opt.split == 'dev') and (prev_best[0] is not None) and num_models != 1:
        if opt.unseen_table == 'full':
            with codecs.open(os.path.join(opt.save_path, 'dev_best.txt'),
                             'w',
                             encoding='utf-8') as f_out:
                f_out.write('{}\n'.format(prev_best[0]))
        else:
            with codecs.open(os.path.join(
                    opt.save_path, 'dev_best_' + opt.unseen_table + '.txt'),
                             'w',
                             encoding='utf-8') as f_out:
                f_out.write('{}\n'.format(prev_best[0]))
コード例 #7
0
def main():
    js_list = table.IO.read_anno_json(args.anno)

    metric_name_list = ['tgt-token', 'lay-token', 'tgt', 'lay']

    prev_best = (None, None)

    model_range = range(10, 101, 5)

    if os.path.isfile(args.model_path):
        model_list = [args.model_path]
    elif os.path.isdir(args.model_path):
        model_list = sorted(
            glob.glob('%s/**/*.pt' % args.model_path, recursive=True))
    else:
        raise RuntimeError('Incorrect model path')

    for i, cur_model in enumerate(model_list):
        assert cur_model.endswith(".pt")

        # TODO: make better
        # if int(os.path.basename(cur_model)[2:4]) not in model_range:
        #     continue

        exp_name = get_exp_name(cur_model)

        args.model = cur_model
        logger.info(" * evaluating model [%s]" % cur_model)

        checkpoint = torch.load(args.model,
                                map_location=lambda storage, loc: storage)
        model_args = checkpoint['opt']

        fp = open(
            "./experiments/%s/%s-%s-eval.txt" %
            (exp_name, args.model.split("/")[-1], args.split), "wt")

        # translator model
        translator = table.Translator(args, checkpoint)
        test_data = table.IO.OrderedIterator(
            dataset=table.IO.TableDataset(js_list, translator.fields, 0, None,
                                          False),
            device=args.gpu_id[0] if args.cuda else -1,  # -1 is CPU
            batch_size=args.batch_size,
            train=False,
            sort=True,
            sort_within_batch=False)

        r_list = []
        for batch in tqdm(test_data, desc="Inference"):
            r = translator.translate(batch)
            r_list += r

        r_list.sort(key=lambda x: x.idx)
        assert len(r_list) == len(
            js_list), 'len(r_list) != len(js_list): {} != {}'.format(
                len(r_list), len(js_list))

        for pred, gold in tqdm(zip(r_list, js_list),
                               total=len(r_list),
                               desc="Evaluation"):
            pred.eval(gold)

        for metric_name in tqdm(metric_name_list,
                                desc="Dump results by metric"):

            if metric_name.endswith("-token"):
                c_correct = sum([
                    len(
                        set(x.get_by_name(metric_name)) -
                        set(y[metric_name.split("-")[0]])) == 0
                    for x, y in zip(r_list, js_list)
                ])
                acc = c_correct / len(r_list)

                out_str = 'result: {}: {} / {} = {:.2%}'.format(
                    metric_name, c_correct, len(r_list), acc)
                fp.write(out_str + "\n")
                print(out_str)

            else:
                c_correct = sum((x.correct[metric_name] for x in r_list))
                acc = c_correct / len(r_list)

                out_str = 'result: {}: {} / {} = {:.2%}'.format(
                    metric_name, c_correct, len(r_list), acc)
                fp.write(out_str + "\n")
                print(out_str)

                # dump incorrect examples
                for x in r_list:
                    for prd, tgt in x.incorrect[metric_name]:
                        fp.write("\tprd: %s\n\ttgt: %s\n\n" %
                                 (" ".join(prd), " ".join(tgt)))

            if metric_name == 'tgt' and (prev_best[0] is None
                                         or acc > prev_best[1]):
                prev_best = (cur_model, acc)
        # ---

        # save model args
        fp.write("\n\n")
        dump_cfg(fp, cfg=dict_update(args.__dict__, model_args.__dict__))
        fp.close()
コード例 #8
0
def main():
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    js_list = table.IO.read_anno_json(opt.anno, opt)

    metric_name_list = ['tgt']
    prev_best = (None, None)
    for fn_model in glob.glob(opt.model_path):
        opt.model = fn_model
        print(fn_model)
        print(opt.anno)

        translator = table.Translator(opt, dummy_opt.__dict__)
        data = table.IO.TableDataset(js_list, translator.fields, 0, None,
                                     False)
        test_data = table.IO.OrderedIterator(dataset=data,
                                             device=opt.gpu,
                                             batch_size=opt.batch_size,
                                             train=False,
                                             sort=True,
                                             sort_within_batch=False)

        # inference
        r_list = []
        for batch in test_data:
            r = translator.translate(batch)
            r_list += r
        r_list.sort(key=lambda x: x.idx)
        assert len(r_list) == len(
            js_list), 'len(r_list) != len(js_list): {} != {}'.format(
                len(r_list), len(js_list))

        # evaluation
        for pred, gold in zip(r_list, js_list):
            print("pred tgt: ", pred.tgt)
            print("pred lay: ", pred.lay)
            print("gold:", gold)

            pred.eval(gold)
        print('Results:')
        for metric_name in metric_name_list:
            c_correct = sum((x.correct[metric_name] for x in r_list))
            acc = c_correct / len(r_list)
            print('{}: {} / {} = {:.2%}'.format(metric_name, c_correct,
                                                len(r_list), acc))
            if metric_name == 'tgt' and (prev_best[0] is None
                                         or acc > prev_best[1]):
                prev_best = (fn_model, acc)

        # calcualte bleu score
        pred_tgt_tokens = [pred.tgt for pred in r_list]
        gold_tgt_tokens = [gold['tgt'] for gold in js_list]
        # print('pred_tgt_tokens[0]', pred_tgt_tokens[0])
        # print('gold_tgt_tokens[0]', gold_tgt_tokens[0])
        bleu_score = table.modules.bleu_score.compute_bleu(gold_tgt_tokens,
                                                           pred_tgt_tokens,
                                                           smooth=False)
        bleu_score = bleu_score[0]

        bleu_score_nltk = corpus_bleu(
            gold_tgt_tokens,
            pred_tgt_tokens,
            smoothing_function=SmoothingFunction().method3)

        print('{}: = {:.4}'.format('tgt blue score', bleu_score))

        print('{}: = {:.4}'.format('tgt nltk blue score', bleu_score_nltk))

    if (opt.split == 'dev') and (prev_best[0] is not None):
        with codecs.open(os.path.join(opt.root_dir, opt.dataset,
                                      'dev_best.txt'),
                         'w',
                         encoding='utf-8') as f_out:
            f_out.write('{}\n'.format(prev_best[0]))