Exemple #1
0
    def syn_dev(self, model, leng=0):
        syntree_pred = []

        assert 'dev_synconst' in self.task_list

        if leng == 0:
            str_leng = ""
        else:
            str_leng = str(leng)
        dev_pred_head = []
        dev_pred_type = []
        for step, batch in enumerate(
                tqdm(self.task_dataloader['dev_synconst' + str_leng],
                     desc="Syntax Dev")):
            input_ids, input_mask, word_start_mask, word_end_mask, segment_ids, lm_label_ids, is_next, sent = batch
            dis_idx = [i for i in range(len(input_ids))]
            dis_idx = torch.tensor(dis_idx)
            batch = dis_idx, input_ids, input_mask, word_start_mask, word_end_mask, segment_ids, lm_label_ids, is_next
            bert_data = tuple(t.to(self.device) for t in batch)
            sentences = [json.loads(sent_str) for sent_str in sent]
            # linz, head, type, _, _, _ = model(sentences=sentences, bert_data=bert_data)
            # dev_pred_head.extend([json.loads(head_str) for head_str in head])
            # dev_pred_type.extend([json.loads(type_str) for type_str in type])
            # syntree_pred.extend(linz)
            syntree, _, _ = model(sentences=sentences, bert_data=bert_data)
            syntree_pred.extend(syntree)

        # const parsing:
        self.summary_dict['dev_synconst' + str_leng] = evaluate.evalb(
            self.evalb_dir, self.ptb_dataset['dev_synconst_tree' + str_leng],
            syntree_pred)

        # dep parsing:

        dev_pred_head = [[leaf.father for leaf in tree.leaves()]
                         for tree in syntree_pred]
        dev_pred_type = [[leaf.type for leaf in tree.leaves()]
                         for tree in syntree_pred]
        syndep_dev_pos = [[leaf.tag for leaf in tree.leaves()]
                          for tree in self.ptb_dataset['dev_synconst_tree' +
                                                       str_leng]]
        assert len(dev_pred_head) == len(dev_pred_type)
        assert len(dev_pred_type) == len(self.ptb_dataset['dev_syndep_type' +
                                                          str_leng])

        self.summary_dict['dev_syndep_uas'+str_leng], self.summary_dict['dev_syndep_las'+str_leng] = \
            dep_eval.eval(len(dev_pred_head), self.ptb_dataset['dev_syndep_sent'+str_leng], syndep_dev_pos,
                          dev_pred_head, dev_pred_type, self.ptb_dataset['dev_syndep_head'+str_leng], self.ptb_dataset['dev_syndep_type'+str_leng],
                          punct_set=self.hparams.punctuation, symbolic_root=False)
Exemple #2
0
def run_test(args):

    synconst_test_path = args.synconst_test_ptb_path

    syndep_test_path = args.syndep_test_ptb_path

    srlspan_test_path = args.srlspan_test_ptb_path
    srlspan_brown_path = args.srlspan_test_brown_path

    srldep_test_path = args.srldep_test_ptb_path
    srldep_brown_path = args.srldep_test_brown_path

    print("Loading model from {}...".format(args.model_path_base))
    assert args.model_path_base.endswith(
        ".pt"), "Only pytorch savefiles supported"

    info = torch_load(args.model_path_base)
    assert 'hparams' in info['spec'], "Older savefiles not supported"
    parser = Zparser.ChartParser.from_spec(info['spec'], info['state_dict'])

    syndep_test_sent, syndep_test_pos, syndep_test_heads, syndep_test_types = syndep_reader.read_syndep(
        syndep_test_path)

    srlspan_test_sent, srlspan_test_verb, srlspan_test_dict, srlspan_test_predpos, srlspan_test_goldpos, \
    srlspan_test_label, srlspan_test_label_start, srlspan_test_heads = srlspan_reader.read_srlspan(srlspan_test_path)

    srlspan_brown_sent, srlspan_brown_verb, srlspan_brown_dict, srlspan_brown_predpos, srlspan_brown_goldpos, \
    srlspan_brown_label, srlspan_brown_label_start, srlspan_brown_heads = srlspan_reader.read_srlspan(srlspan_brown_path)

    srldep_test_sent, srldep_test_predpos, srldep_test_verb, srldep_test_dict, srldep_test_heads = srldep_reader.read_srldep(
        srldep_test_path)
    srldep_brown_sent, srldep_brown_predpos, srldep_brown_verb, srldep_brown_dict, srldep_brown_heads = srldep_reader.read_srldep(
        srldep_brown_path)

    print("Loading test trees from {}...".format(synconst_test_path))
    test_treebank = trees.load_trees(synconst_test_path, syndep_test_heads,
                                     syndep_test_types, srlspan_test_label,
                                     srlspan_test_label_start)

    print("Loaded {:,} test examples.".format(len(test_treebank)))

    print("Parsing test sentences...")
    start_time = time.time()

    punct_set = '.' '``' "''" ':' ','

    parser.eval()
    print("Start test eval:")
    test_start_time = time.time()

    syntree_pred = []
    srlspan_pred = []
    srldep_pred = []
    #span srl and syn have same test data
    for start_index in range(0, len(test_treebank), args.eval_batch_size):
        subbatch_trees = test_treebank[start_index:start_index +
                                       args.eval_batch_size]

        subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()]
                              for tree in subbatch_trees]

        if parser.hparams.use_gold_predicate:
            syntree, srlspan_dict, _ = parser.parse_batch(
                subbatch_sentences,
                gold_verbs=srlspan_test_verb[start_index:start_index +
                                             args.eval_batch_size])
        else:
            syntree, srlspan_dict, _ = parser.parse_batch(subbatch_sentences)

        syntree_pred.extend(syntree)
        srlspan_pred.extend(srlspan_dict)

    for start_index in range(0, len(srldep_test_sent), args.eval_batch_size):

        subbatch_words_srldep = srldep_test_sent[start_index:start_index +
                                                 args.eval_batch_size]
        subbatch_pos_srldep = srldep_test_predpos[start_index:start_index +
                                                  args.eval_batch_size]
        subbatch_sentences_srldep = [[
            (tag, word) for j, (tag, word) in enumerate(zip(tags, words))
        ] for i, (tags, words) in enumerate(
            zip(subbatch_pos_srldep, subbatch_words_srldep))]

        if parser.hparams.use_gold_predicate:
            _, _, srldep_dict = parser.parse_batch(
                subbatch_sentences_srldep,
                gold_verbs=srldep_test_verb[start_index:start_index +
                                            args.eval_batch_size])
        else:
            _, _, srldep_dict = parser.parse_batch(subbatch_sentences_srldep)

        srldep_pred.extend(srldep_dict)

    # const parsing:

    test_fscore = evaluate.evalb(args.evalb_dir, test_treebank, syntree_pred)

    # dep parsing:

    test_pred_head = [[leaf.father for leaf in tree.leaves()]
                      for tree in syntree_pred]
    test_pred_type = [[leaf.type for leaf in tree.leaves()]
                      for tree in syntree_pred]

    assert len(test_pred_head) == len(test_pred_type)
    assert len(test_pred_type) == len(syndep_test_types)
    test_uas, test_las = dep_eval.eval(len(test_pred_head),
                                       syndep_test_sent,
                                       syndep_test_pos,
                                       test_pred_head,
                                       test_pred_type,
                                       syndep_test_heads,
                                       syndep_test_types,
                                       punct_set=punct_set,
                                       symbolic_root=False)

    print("===============================================")
    print("wsj srl span test eval:")
    precision, recall, f1, ul_prec, ul_recall, ul_f1 = (
        srl_eval.compute_srl_f1(srlspan_test_sent,
                                srlspan_test_dict,
                                srlspan_pred,
                                srl_conll_eval_path=False))
    print("===============================================")
    print("wsj srl dep test eval:")
    precision, recall, f1 = (srl_eval.compute_dependency_f1(
        srldep_test_sent,
        srldep_test_dict,
        srldep_pred,
        srl_conll_eval_path=False))
    print("===============================================")

    print(
        '============================================================================================================================'
    )

    syntree_pred = []
    srlspan_pred = []
    srldep_pred = []
    for start_index in range(0, len(srlspan_brown_sent), args.eval_batch_size):
        subbatch_words = srlspan_brown_sent[start_index:start_index +
                                            args.eval_batch_size]
        subbatch_pos = srlspan_brown_predpos[start_index:start_index +
                                             args.eval_batch_size]
        subbatch_sentences = [[
            (tag, word) for j, (tag, word) in enumerate(zip(tags, words))
        ] for i, (tags, words) in enumerate(zip(subbatch_pos, subbatch_words))]

        if parser.hparams.use_gold_predicate:
            syntree, srlspan_dict, _ = parser.parse_batch(
                subbatch_sentences,
                gold_verbs=srlspan_brown_verb[start_index:start_index +
                                              args.eval_batch_size])
        else:
            syntree, srlspan_dict, _ = parser.parse_batch(subbatch_sentences)

        syntree_pred.extend(syntree)
        srlspan_pred.extend(srlspan_dict)

    for start_index in range(0, len(srldep_brown_sent), args.eval_batch_size):

        subbatch_words_srldep = srldep_brown_sent[start_index:start_index +
                                                  args.eval_batch_size]
        subbatch_pos_srldep = srldep_brown_predpos[start_index:start_index +
                                                   args.eval_batch_size]
        subbatch_sentences_srldep = [[
            (tag, word) for j, (tag, word) in enumerate(zip(tags, words))
        ] for i, (tags, words) in enumerate(
            zip(subbatch_pos_srldep, subbatch_words_srldep))]

        if parser.hparams.use_gold_predicate:
            _, _, srldep_dict = parser.parse_batch(
                subbatch_sentences_srldep,
                gold_verbs=srldep_brown_verb[start_index:start_index +
                                             args.eval_batch_size])
        else:
            _, _, srldep_dict = parser.parse_batch(subbatch_sentences_srldep)

        srldep_pred.extend(srldep_dict)

    print("===============================================")
    print("brown srl span test eval:")
    precision, recall, f1, ul_prec, ul_recall, ul_f1 = (
        srl_eval.compute_srl_f1(srlspan_brown_sent,
                                srlspan_brown_dict,
                                srlspan_pred,
                                srl_conll_eval_path=False))
    print("===============================================")
    print("brown srl dep test eval:")
    precision, recall, f1 = (srl_eval.compute_dependency_f1(
        srldep_brown_sent,
        srldep_brown_dict,
        srldep_pred,
        srl_conll_eval_path=False))
    print("===============================================")

    print("test-elapsed {} "
          "total-elapsed {}".format(
              format_elapsed(test_start_time),
              format_elapsed(start_time),
          ))

    print(
        '============================================================================================================================'
    )
Exemple #3
0
    def make_check(self, model, optimizer, epoch_num):

        print("Start dev eval:")
        summary_dict = {}

        dev_start_time = time.time()

        summary_dict["synconst dev F1"] = evaluate.FScore(0, 0, 0)
        summary_dict["syndep dev uas"] = 0
        summary_dict["syndep dev las"] = 0
        summary_dict["pos dev"] = 0
        summary_dict["synconst test F1"] = evaluate.FScore(0, 0, 0)
        summary_dict["syndep test uas"] = 0
        summary_dict["syndep test las"] = 0
        summary_dict["pos test"] = 0
        summary_dict["srlspan dev F1" ]= 0
        summary_dict["srldep dev F1"] = 0
        summary_dict["srlspan test F1"] = 0
        summary_dict["srlspan brown F1"] = 0
        summary_dict["srldep test F1"] = 0
        summary_dict["srldep brown F1"] = 0

        model.eval()

        syntree_pred = []
        srlspan_pred = []
        srldep_pred = []
        pos_pred = []
        if self.hparams.joint_syn:
            for start_index in range(0, len(self.ptb_dataset['dev_synconst_tree']), self.eval_batch_size):
                subbatch_trees = self.ptb_dataset['dev_synconst_tree'][start_index:start_index +self.eval_batch_size]
                subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees]

                syntree, _, _= model.parse_batch(subbatch_sentences)

                syntree_pred.extend(syntree)

            # const parsing:

            summary_dict["synconst dev F1"] = evaluate.evalb(self.evalb_dir, self.ptb_dataset['dev_synconst_tree'], syntree_pred)

            # dep parsing:

            dev_pred_head = [[leaf.father for leaf in tree.leaves()] for tree in syntree_pred]
            dev_pred_type = [[leaf.type for leaf in tree.leaves()] for tree in syntree_pred]
            assert len(dev_pred_head) == len(dev_pred_type)
            assert len(dev_pred_type) == len(self.ptb_dataset['dev_syndep_type'])
            summary_dict["syndep dev uas"], summary_dict["syndep dev las"] = dep_eval.eval(len(dev_pred_head), self.ptb_dataset['dev_syndep_sent'], self.ptb_dataset['dev_syndep_pos'],
                                             dev_pred_head, dev_pred_type, self.ptb_dataset['dev_syndep_head'], self.ptb_dataset['dev_syndep_type'],
                                             punct_set=self.hparams.punct_set, symbolic_root=False)
        # for srl different dev set
        if self.hparams.joint_srl or self.hparams.joint_pos:
            for start_index in range(0, len(self.ptb_dataset['dev_srlspan_sent']), self.eval_batch_size):
                subbatch_words = self.ptb_dataset['dev_srlspan_sent'][start_index:start_index + self.eval_batch_size]
                subbatch_pos = self.ptb_dataset['dev_srlspan_pos'][start_index:start_index + self.eval_batch_size]
                subbatch_sentences = [[(tag, word) for j, (tag, word) in enumerate(zip(tags, words))] for i, (tags, words)
                                      in enumerate(zip(subbatch_pos, subbatch_words))]

                srlspan_tree, srlspan_dict, _ = \
                    model(subbatch_sentences, gold_verbs=self.ptb_dataset['dev_srlspan_verb'][start_index:start_index + self.eval_batch_size])

                srlspan_pred.extend(srlspan_dict)
                pos_pred.extend([leaf.goldtag for leaf in srlspan_tree.leaves()])

            for start_index in range(0, len(self.ptb_dataset['dev_srldep_sent']), self.eval_batch_size):
                subbatch_words = self.ptb_dataset['dev_srldep_sent'][start_index:start_index + self.eval_batch_size]
                subbatch_pos = self.ptb_dataset['dev_srldep_pos'][start_index:start_index + self.eval_batch_size]
                subbatch_sentences = [[(tag, word) for j, (tag, word) in enumerate(zip(tags, words))] for i, (tags, words)
                                      in enumerate(zip(subbatch_pos, subbatch_words))]

                _, srldep_dict, _ = \
                    model(subbatch_sentences, gold_verbs=self.ptb_dataset['dev_srldep_verb'][
                                                         start_index:start_index + self.eval_batch_size])

                srldep_pred.extend(srldep_dict)

            if self.hparams.joint_srl:
                # srl span:
                # predicate F1
                # pid_precision, pred_recall, pid_f1, _, _, _, _ = srl_eval.compute_span_f1(
                #     srlspan_dev_verb, dev_pred_verb, "Predicate ID")
                print("===============================================")
                print("srl span dev eval:")
                precision, recall, f1, ul_prec, ul_recall, ul_f1 = (
                    srl_eval.compute_srl_f1(self.ptb_dataset['dev_srlspan_sent'], self.ptb_dataset['dev_srlspan_dict'], srlspan_pred, srl_conll_eval_path=False))
                summary_dict["srlspan dev F1"] = f1
                summary_dict["srlspan dev precision"] = precision
                summary_dict["srlspan dev recall"] = precision
                print("===============================================")
                print("srl dep dev eval:")
                precision, recall, f1 = (
                    srl_eval.compute_dependency_f1(self.ptb_dataset['dev_srldep_sent'], self.ptb_dataset['dev_srldep_dict'], srldep_pred,
                                                   srl_conll_eval_path=False, use_gold=self.hparams.use_gold_predicate))
                summary_dict["srldep dev F1"] = f1
                summary_dict["srldep dev precision"] = precision
                summary_dict["srldep dev recall"] = precision
                print("===============================================")

            if self.hparams.joint_pos:
                summary_dict["pos dev"] = pos_eval.eval(self.ptb_dataset['dev_srlspan_goldpos'], pos_pred)

        print(
            "dev-elapsed {} ".format(
                format_elapsed(dev_start_time),
            )
        )

        print(
            '============================================================================================================================')

        print("Start test eval:")
        test_start_time = time.time()

        syntree_pred = []
        srlspan_pred = []
        srldep_pred = []
        pos_pred = []
        test_fscore = evaluate.FScore(0, 0, 0)
        test_uas = 0
        test_las = 0
        for start_index in range(0, len(self.ptb_dataset['test_synconst_tree']), self.eval_batch_size):
            subbatch_trees = self.ptb_dataset['test_synconst_tree'][start_index:start_index + self.eval_batch_size]

            subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees]

            syntree, srlspan_dict, _ = \
                model(subbatch_sentences, gold_verbs=self.ptb_dataset['test_srlspan_verb'][
                                                     start_index:start_index + self.eval_batch_size])

            syntree_pred.extend(syntree)
            srlspan_pred.extend(srlspan_dict)
            pos_pred.extend([leaf.goldtag for leaf in syntree.leaves()])

        if self.hparams.joint_srl:
            for start_index in range(0, len(self.ptb_dataset['test_srlspan_sent']), self.eval_batch_size):

                subbatch_words_srldep = self.ptb_dataset['test_srlspan_sent'][start_index:start_index + self.eval_batch_size]
                subbatch_pos_srldep = self.ptb_dataset['test_srlspan_pos'][start_index:start_index + self.eval_batch_size]
                subbatch_sentences = [[(tag, word) for j, (tag, word) in enumerate(zip(tags, words))] for
                                             i, (tags, words)
                                             in enumerate(zip(subbatch_pos_srldep, subbatch_words_srldep))]

                _, _, srldep_dict = \
                    model(subbatch_sentences, gold_verbs=self.ptb_dataset['test_srldep_verb'][
                                                         start_index:start_index + self.eval_batch_size])

                srldep_pred.extend(srldep_dict)

            # const parsing:
        if self.hparams.joint_syn:
            summary_dict["synconst test F1"] = evaluate.evalb(self.evalb_dir, self.ptb_dataset['test_synconst_tree'], syntree_pred)

            # dep parsing:

            test_pred_head = [[leaf.father for leaf in tree.leaves()] for tree in syntree_pred]
            test_pred_type = [[leaf.type for leaf in tree.leaves()] for tree in syntree_pred]

            assert len(test_pred_head) == len(test_pred_type)
            assert len(test_pred_type) == len(self.ptb_dataset['test_syndep_type'])
            summary_dict["syndep test uas"], summary_dict["syndep test las"] = dep_eval.eval(len(test_pred_head), self.ptb_dataset['test_syndep_sent'], self.ptb_dataset['test_syndep_pos'], test_pred_head,
                                               test_pred_type, self.ptb_dataset['test_syndep_head'], self.ptb_dataset['test_syndep_type'],
                                               punct_set=self.hparams.punct_set, symbolic_root=False)

        if self.hparams.joint_pos:
            summary_dict["pos test"] = pos_eval.eval(self.ptb_dataset['test_srlspan_goldpos'], pos_pred)

        # srl span:
        if self.hparams.joint_srl:
            # predicate F1
            # pid_precision, pred_recall, pid_f1, _, _, _, _ = srl_eval.compute_span_f1(
            #     srlspan_test_verb, test_pred_verb, "Predicate ID")

            print("===============================================")
            print("wsj srl span test eval:")
            precision, recall, f1, ul_prec, ul_recall, ul_f1 = (
                srl_eval.compute_srl_f1(self.ptb_dataset['test_srlspan_sent'], self.ptb_dataset['test_srlspan_dict'], srlspan_pred, srl_conll_eval_path=False))
            summary_dict["srlspan test F1"] = f1
            summary_dict["srlspan test precision"] = precision
            summary_dict["srlspan test recall"] = precision
            print("===============================================")
            print("wsj srl dep test eval:")
            precision, recall, f1 = (
                srl_eval.compute_dependency_f1(self.ptb_dataset['test_srldep_sent'], self.ptb_dataset['test_srldep_dict'], srldep_pred,
                                               srl_conll_eval_path=False, use_gold=self.hparams.use_gold_predicate))
            summary_dict["srldep test F1"] = f1
            summary_dict["srldep test precision"] = precision
            summary_dict["srldep test recall"] = precision
            print("===============================================")

            print(
                '============================================================================================================================')

            syntree_pred = []
            srlspan_pred = []
            srldep_pred = []
            for start_index in range(0, len(self.ptb_dataset['brown_srlspan_sent']), self.eval_batch_size):
                subbatch_words = self.ptb_dataset['brown_srlspan_sent'][start_index:start_index + self.eval_batch_size]
                subbatch_pos = self.ptb_dataset['brown_srlspan_pos'][start_index:start_index + self.eval_batch_size]
                subbatch_sentences = [[(tag, word) for j, (tag, word) in enumerate(zip(tags, words))] for i, (tags, words)
                                      in enumerate(zip(subbatch_pos, subbatch_words))]

                syntree, srlspan_dict, _ = \
                    model(subbatch_sentences, gold_verbs=self.ptb_dataset['brown_srlspan_verb'][
                                                         start_index:start_index + self.eval_batch_size])
                syntree_pred.extend(syntree)
                srlspan_pred.extend(srlspan_dict)

            for start_index in range(0, len(self.ptb_dataset['brown_srldep_sent']), self.eval_batch_size):

                subbatch_words_srldep = self.ptb_dataset['brown_srldep_sent'][start_index:start_index + self.eval_batch_size]
                subbatch_pos_srldep = self.ptb_dataset['brown_srldep_sent'][start_index:start_index + self.eval_batch_size]
                subbatch_sentences = [[(tag, word) for j, (tag, word) in enumerate(zip(tags, words))] for
                                             i, (tags, words)
                                             in enumerate(zip(subbatch_pos_srldep, subbatch_words_srldep))]

                _, _, srldep_dict = \
                    model(subbatch_sentences, gold_verbs=self.ptb_dataset['brown_srldep_verb'][
                                                         start_index:start_index + self.eval_batch_size])

                srldep_pred.extend(srldep_dict)

            # predicate F1
            # pid_precision, pred_recall, pid_f1, _, _, _, _ = srl_eval.compute_span_f1(
            #     srlspan_test_verb, test_pred_verb, "Predicate ID")

            print("===============================================")
            print("brown srl span test eval:")
            precision, recall, f1, ul_prec, ul_recall, ul_f1 = (
                srl_eval.compute_srl_f1(self.ptb_dataset['brown_srlspan_sent'], self.ptb_dataset['brown_srlspan_dict'], srlspan_pred, srl_conll_eval_path=False))
            summary_dict["srlspan brown F1"] = f1
            summary_dict["srlspan brown precision"] = precision
            summary_dict["srlspan brown recall"] = precision
            print("===============================================")
            print("brown srl dep test eval:")
            precision, recall, f1 = (
                srl_eval.compute_dependency_f1(self.ptb_dataset['brown_srldep_sent'], self.ptb_dataset['brown_srldep_dict'], srldep_pred, srl_conll_eval_path=False,
                                               use_gold=self.hparams.use_gold_predicate))
            summary_dict["srldep brown F1"] = f1
            summary_dict["srldep brown precision"] = precision
            summary_dict["srldep brown recall"] = precision
            print("===============================================")

        print(
            "test-elapsed {} ".format(
                format_elapsed(test_start_time)
            )
        )

        print(
            '============================================================================================================================')

        if summary_dict['synconst dev F1'].fscore + summary_dict['syndep dev las'] + summary_dict["srlspan dev F1"] + summary_dict[
            "srldep dev F1"] + summary_dict['pos dev'] > self.best_dev_score:
            if self.best_model_path is not None:
                extensions = [".pt"]
                for ext in extensions:
                    path = self.best_model_path + ext
                    if os.path.exists(path):
                        print("Removing previous model file {}...".format(path))
                        os.remove(path)

            self.best_dev_score = summary_dict['synconst dev F1'].fscore + summary_dict['syndep dev las'] + summary_dict["srlspan dev F1"] + summary_dict[
            "srldep dev F1"] + summary_dict['pos dev']
            best_model_path = "{}_best_dev={:.2f}_devuas={:.2f}_devlas={:.2f}_devsrlspan={:.2f}_devsrldep={:.2f}".format(
                self.model_path_base, summary_dict['synconst dev F1'], summary_dict['syndep dev uas'], summary_dict['syndep dev las'],
                summary_dict["srlspan dev F1"], summary_dict["srldep dev F1"])
            print("Saving new best model to {}...".format(best_model_path))
            torch.save({
                'spec': model.spec,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, best_model_path + ".pt")

        log_data = "{} epoch, dev-fscore {:},test-fscore {:}, dev-uas {:.2f}, dev-las {:.2f}," \
                   "test-uas {:.2f}, test-las {:.2f}, dev-srlspan {:.2f}, test-wsj-srlspan {:.2f}, test-brown-srlspan {:.2f}," \
                   " dev-srldep {:.2f},  test-wsj-srldep {:.2f}, test-brown-srldep {:.2f}, dev-pos {:.2f}, test-pos {:.2f}," \
                   "dev_score {:.2f}, best_dev_score {:.2f}" \
            .format(epoch_num, summary_dict["synconst dev F1"], summary_dict["synconst test F1"],
                    summary_dict["syndep dev uas"], summary_dict["syndep dev las"],
                    summary_dict["syndep test uas"], summary_dict["syndep test las"],
                    summary_dict["srlspan dev F1"], summary_dict["srlspan test F1"], summary_dict["srlspan brown F1"],
                    summary_dict["srldep dev F1"], summary_dict["srldep test F1"], summary_dict["srldep brown F1"],
                    summary_dict["pos dev"], summary_dict["pos test"],
                    summary_dict['synconst dev F1'].fscore + summary_dict['syndep dev las'] + summary_dict["srlspan dev F1"] +
                    summary_dict["srldep dev F1"] + summary_dict['pos dev'],
                    self.best_dev_score)

        if not os.path.exists(self.log_path):
            flog = open(self.log_path, 'w')
        flog = open(self.log_path, 'r+')
        content = flog.read()
        flog.seek(0, 0)
        flog.write(log_data + '\n' + content)
Exemple #4
0
    def check_dev(epoch_num):
        nonlocal best_dev_score
        nonlocal best_model_path
        nonlocal best_epoch

        print("Start dev eval:")

        dev_start_time = time.time()
        dev_fscore = evaluate.FScore(0, 0, 0)
        dev_uas = 0
        dev_las = 0
        pos_dev = 0
        summary_dict = {}
        summary_dict["srlspan dev F1"] = 0
        summary_dict["srldep dev F1"] = 0
        parser.eval()

        syntree_pred = []
        srlspan_pred = []
        srldep_pred = []
        pos_pred = []
        if hparams.joint_syn_dep or hparams.joint_syn_const:
            for dev_start_index in range(0, len(dev_treebank),
                                         args.eval_batch_size):
                subbatch_trees = dev_treebank[dev_start_index:dev_start_index +
                                              args.eval_batch_size]
                subbatch_sentences = [[(leaf.tag, leaf.word)
                                       for leaf in tree.leaves()]
                                      for tree in subbatch_trees]

                syntree, _, _ = parser.parse_batch(subbatch_sentences)

                syntree_pred.extend(syntree)

            #const parsing:

            dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank,
                                        syntree_pred)

            #dep parsing:

            dev_pred_head = [[leaf.father for leaf in tree.leaves()]
                             for tree in syntree_pred]
            dev_pred_type = [[leaf.type for leaf in tree.leaves()]
                             for tree in syntree_pred]
            assert len(dev_pred_head) == len(dev_pred_type)
            assert len(dev_pred_type) == len(syndep_dev_types)
            dev_uas, dev_las = dep_eval.eval(len(dev_pred_head),
                                             syndep_dev_sent,
                                             syndep_dev_pos,
                                             dev_pred_head,
                                             dev_pred_type,
                                             syndep_dev_heads,
                                             syndep_dev_types,
                                             punct_set=punct_set,
                                             symbolic_root=False)
        #for srl different dev set
        if hparams.joint_srl_span or hparams.joint_pos:
            for dev_start_index in range(0, len(srlspan_dev_sent),
                                         args.eval_batch_size):
                subbatch_words = srlspan_dev_sent[
                    dev_start_index:dev_start_index + args.eval_batch_size]
                subbatch_pos = srlspan_dev_predpos[
                    dev_start_index:dev_start_index + args.eval_batch_size]
                subbatch_sentences = [[
                    (tag, word)
                    for j, (tag, word) in enumerate(zip(tags, words))
                ] for i, (
                    tags,
                    words) in enumerate(zip(subbatch_pos, subbatch_words))]

                if hparams.use_gold_predicate:
                    srlspan_tree, srlspan_dict, _ = parser.parse_batch(
                        subbatch_sentences,
                        gold_verbs=srlspan_dev_verb[
                            dev_start_index:dev_start_index +
                            args.eval_batch_size],
                        syndep_heads=srlspan_dev_heads[
                            dev_start_index:dev_start_index +
                            args.eval_batch_size])
                else:
                    srlspan_tree, srlspan_dict, _ = parser.parse_batch(
                        subbatch_sentences,
                        syndep_heads=srlspan_dev_heads[
                            dev_start_index:dev_start_index +
                            args.eval_batch_size])

                srlspan_pred.extend(srlspan_dict)

                pos_pred.extend([[leaf.goldtag for leaf in tree.leaves()]
                                 for tree in srlspan_tree])

            if hparams.joint_srl_span:
                print("===============================================")
                print("srl span dev eval:")
                precision, recall, f1, ul_prec, ul_recall, ul_f1 = (
                    srl_eval.compute_srl_f1(srlspan_dev_sent,
                                            srlspan_dev_dict,
                                            srlspan_pred,
                                            srl_conll_eval_path=False))
                summary_dict["srlspan dev F1"] = f1
                summary_dict["srlspan dev precision"] = precision
                summary_dict["srlspan dev recall"] = precision

            if hparams.joint_pos:
                pos_dev = pos_eval.eval(srlspan_dev_goldpos, pos_pred)

        if hparams.joint_srl_dep:

            for dev_start_index in range(0, len(srldep_dev_sent),
                                         args.eval_batch_size):
                subbatch_words = srldep_dev_sent[
                    dev_start_index:dev_start_index + args.eval_batch_size]
                subbatch_pos = srldep_dev_predpos[
                    dev_start_index:dev_start_index + args.eval_batch_size]
                subbatch_sentences = [[
                    (tag, word)
                    for j, (tag, word) in enumerate(zip(tags, words))
                ] for i, (
                    tags,
                    words) in enumerate(zip(subbatch_pos, subbatch_words))]

                if hparams.use_gold_predicate:
                    _, _, srldep_dict = parser.parse_batch(
                        subbatch_sentences,
                        gold_verbs=srldep_dev_verb[
                            dev_start_index:dev_start_index +
                            args.eval_batch_size],
                        syndep_heads=srldep_dev_heads[
                            dev_start_index:dev_start_index +
                            args.eval_batch_size])
                else:
                    _, _, srldep_dict = parser.parse_batch(
                        subbatch_sentences,
                        syndep_heads=srldep_dev_heads[
                            dev_start_index:dev_start_index +
                            args.eval_batch_size])

                srldep_pred.extend(srldep_dict)

            print("===============================================")
            print("srl dep dev eval:")
            precision, recall, f1 = (srl_eval.compute_dependency_f1(
                srldep_dev_sent,
                srldep_dev_dict,
                srldep_pred,
                srl_conll_eval_path=False,
                use_gold=hparams.use_gold_predicate))
            summary_dict["srldep dev F1"] = f1
            summary_dict["srldep dev precision"] = precision
            summary_dict["srldep dev recall"] = precision
            print("===============================================")

        print("dev-elapsed {} "
              "total-elapsed {}".format(
                  format_elapsed(dev_start_time),
                  format_elapsed(start_time),
              ))

        print(
            '============================================================================================================================'
        )

        if dev_fscore.fscore + dev_las + summary_dict[
                "srlspan dev F1"] + summary_dict[
                    "srldep dev F1"] + pos_dev > best_dev_score:
            if best_model_path is not None:
                extensions = [".pt"]
                for ext in extensions:
                    path = best_model_path + ext
                    if os.path.exists(path):
                        print(
                            "Removing previous model file {}...".format(path))
                        os.remove(path)

            best_dev_score = dev_fscore.fscore + dev_las + summary_dict[
                "srlspan dev F1"] + summary_dict["srldep dev F1"] + pos_dev
            best_model_path = "{}_best_dev={:.2f}_devuas={:.2f}_devlas={:.2f}_devsrlspan={:.2f}_devsrldep={:.2f}".format(
                args.model_path_base, dev_fscore.fscore, dev_uas, dev_las,
                summary_dict["srlspan dev F1"], summary_dict["srldep dev F1"])
            print("Saving new best model to {}...".format(best_model_path))
            torch.save(
                {
                    'spec': parser.spec,
                    'state_dict': parser.state_dict(),
                    'trainer': trainer.state_dict(),
                }, best_model_path + ".pt")