Beispiel #1
0
def run_test(args):

    const_test_path = args.consttest_ptb_path

    dep_test_path = args.deptest_ptb_path

    if args.dataset == 'ctb':
        const_test_path = args.consttest_ctb_path
        dep_test_path = args.deptest_ctb_path

    print("Loading model from {}...".format(args.model_path_base))
    assert args.model_path_base.endswith(".pt"), "Only pytorch savefiles supported"

    info = torch_load(args.model_path_base)
    assert 'hparams' in info['spec'], "Older savefiles not supported"
    parser = Zparser.ChartParser.from_spec(info['spec'], info['state_dict'])
    parser.eval()

    dep_test_reader = CoNLLXReader(dep_test_path, parser.type_vocab)
    print('Reading dependency parsing data from %s' % dep_test_path)

    dep_test_data = []
    test_inst = dep_test_reader.getNext()
    dep_test_headid = np.zeros([40000, 300], dtype=int)
    dep_test_type = []
    dep_test_word = []
    dep_test_pos = []
    dep_test_lengs = np.zeros(40000, dtype=int)
    cun = 0
    while test_inst is not None:
        inst_size = test_inst.length()
        dep_test_lengs[cun] = inst_size
        sent = test_inst.sentence
        dep_test_data.append((sent.words, test_inst.postags, test_inst.heads, test_inst.types))
        for i in range(inst_size):
            dep_test_headid[cun][i] = test_inst.heads[i]
        dep_test_type.append(test_inst.types)
        dep_test_word.append(sent.words)
        dep_test_pos.append(sent.postags)
        # dep_sentences.append([(tag, word) for i, (word, tag) in enumerate(zip(sent.words, sent.postags))])
        test_inst = dep_test_reader.getNext()
        cun = cun + 1

    dep_test_reader.close()

    print("Loading test trees from {}...".format(const_test_path))
    test_treebank = trees.load_trees(const_test_path, dep_test_headid, dep_test_type, dep_test_word)
    print("Loaded {:,} test examples.".format(len(test_treebank)))

    print("Parsing test sentences...")
    start_time = time.time()

    punct_set = '.' '``' "''" ':' ','

    parser.eval()
    test_predicted = []
    for start_index in range(0, len(test_treebank), args.eval_batch_size):
        subbatch_trees = test_treebank[start_index:start_index + args.eval_batch_size]

        subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees]

        predicted, _, = parser.parse_batch(subbatch_sentences)
        del _
        test_predicted.extend([p.convert() for p in predicted])

    test_fscore = evaluate.evalb(args.evalb_dir, test_treebank, test_predicted)
    print(
        "test-fscore {} "
        "test-elapsed {}".format(
            test_fscore,
            format_elapsed(start_time),
        )
    )

    test_pred_head = [[leaf.father for leaf in tree.leaves()] for tree in test_predicted]
    test_pred_type = [[leaf.type for leaf in tree.leaves()] for tree in test_predicted]
    assert len(test_pred_head) == len(test_pred_type)
    assert len(test_pred_type) == len(dep_test_type)
    stats, stats_nopunc, stats_root, test_total_inst = dep_eval.eval(len(test_pred_head), dep_test_word, dep_test_pos,
                                                                     test_pred_head,
                                                                     test_pred_type, dep_test_headid, dep_test_type,
                                                                     dep_test_lengs, punct_set=punct_set,
                                                                     symbolic_root=False)

    test_ucorrect, test_lcorrect, test_total, test_ucomlpete_match, test_lcomplete_match = stats
    test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc, test_ucomlpete_match_nopunc, test_lcomplete_match_nopunc = stats_nopunc
    test_root_correct, test_total_root = stats_root

    print(
        'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
            test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 / test_total,
            test_lcorrect * 100 / test_total,
            test_ucomlpete_match * 100 / test_total_inst, test_lcomplete_match * 100 / test_total_inst
            ))
    print(
        'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% ' % (
            test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
            test_ucorrect_nopunc * 100 / test_total_nopunc,
            test_lcorrect_nopunc * 100 / test_total_nopunc,
            test_ucomlpete_match_nopunc * 100 / test_total_inst,
            test_lcomplete_match_nopunc * 100 / test_total_inst))
    print('best test Root: corr: %d, total: %d, acc: %.2f%%' % (
        test_root_correct, test_total_root, test_root_correct * 100 / test_total_root))
    print(
        '============================================================================================================================')
Beispiel #2
0
    def check_dev(epoch_num):
        nonlocal best_dev_score
        nonlocal best_model_path

        dev_start_time = time.time()

        parser.eval()

        dev_predicted = []

        for dev_start_index in range(0, len(dev_treebank), args.eval_batch_size):
            subbatch_trees = dev_treebank[dev_start_index:dev_start_index+args.eval_batch_size]
            subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees]

            predicted,  _,= parser.parse_batch(subbatch_sentences)
            del _

            dev_predicted.extend([p.convert() for p in predicted])

        dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank, dev_predicted)

        print(
            "dev-fscore {} "
            "dev-elapsed {} "
            "total-elapsed {}".format(
                dev_fscore,
                format_elapsed(dev_start_time),
                format_elapsed(start_time),
            )
        )

        dev_pred_head = [[leaf.father for leaf in tree.leaves()] for tree in dev_predicted]
        dev_pred_type = [[leaf.type for leaf in tree.leaves()] for tree in dev_predicted]
        assert len(dev_pred_head) == len(dev_pred_type)
        assert len(dev_pred_type) == len(dep_dev_type)
        stats, stats_nopunc, stats_root, num_inst = dep_eval.eval(len(dev_pred_head), dep_dev_word, dep_dev_pos,
                                                                  dev_pred_head, dev_pred_type,
                                                                  dep_dev_headid, dep_dev_type,
                                                                  dep_dev_lengs, punct_set=punct_set,
                                                                  symbolic_root=False)
        dev_ucorr, dev_lcorr, dev_total, dev_ucomlpete, dev_lcomplete = stats
        dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucomlpete_nopunc, dev_lcomplete_nopunc = stats_nopunc
        dev_root_corr, dev_total_root = stats_root
        dev_total_inst = num_inst
        print(
            'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
                dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total, dev_lcorr * 100 / dev_total,
                dev_ucomlpete * 100 / dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
        print(
            'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
                dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc,
                dev_ucorr_nopunc * 100 / dev_total_nopunc,
                dev_lcorr_nopunc * 100 / dev_total_nopunc,
                dev_ucomlpete_nopunc * 100 / dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
        print('Root: corr: %d, total: %d, acc: %.2f%%' % (
            dev_root_corr, dev_total_root, dev_root_corr * 100 / dev_total_root))

        dev_uas = dev_ucorr_nopunc * 100 / dev_total_nopunc
        dev_las = dev_lcorr_nopunc * 100 / dev_total_nopunc

        if dev_fscore.fscore + dev_las > best_dev_score :
            if best_model_path is not None:
                extensions = [".pt"]
                for ext in extensions:
                    path = best_model_path + ext
                    if os.path.exists(path):
                        print("Removing previous model file {}...".format(path))
                        os.remove(path)

            best_dev_score = dev_fscore.fscore + dev_las
            best_model_path = "{}_best_dev={:.2f}_devuas={:.2f}_devlas={:.2f}".format(
                args.model_path_base, dev_fscore.fscore, dev_uas,dev_las)
            print("Saving new best model to {}...".format(best_model_path))
            torch.save({
                'spec': parser.spec,
                'state_dict': parser.state_dict(),
                'trainer' : trainer.state_dict(),
                }, besthh_model_path + ".pt")