def syn_dev(self, model, leng=0): syntree_pred = [] assert 'dev_synconst' in self.task_list if leng == 0: str_leng = "" else: str_leng = str(leng) dev_pred_head = [] dev_pred_type = [] for step, batch in enumerate( tqdm(self.task_dataloader['dev_synconst' + str_leng], desc="Syntax Dev")): input_ids, input_mask, word_start_mask, word_end_mask, segment_ids, lm_label_ids, is_next, sent = batch dis_idx = [i for i in range(len(input_ids))] dis_idx = torch.tensor(dis_idx) batch = dis_idx, input_ids, input_mask, word_start_mask, word_end_mask, segment_ids, lm_label_ids, is_next bert_data = tuple(t.to(self.device) for t in batch) sentences = [json.loads(sent_str) for sent_str in sent] # linz, head, type, _, _, _ = model(sentences=sentences, bert_data=bert_data) # dev_pred_head.extend([json.loads(head_str) for head_str in head]) # dev_pred_type.extend([json.loads(type_str) for type_str in type]) # syntree_pred.extend(linz) syntree, _, _ = model(sentences=sentences, bert_data=bert_data) syntree_pred.extend(syntree) # const parsing: self.summary_dict['dev_synconst' + str_leng] = evaluate.evalb( self.evalb_dir, self.ptb_dataset['dev_synconst_tree' + str_leng], syntree_pred) # dep parsing: dev_pred_head = [[leaf.father for leaf in tree.leaves()] for tree in syntree_pred] dev_pred_type = [[leaf.type for leaf in tree.leaves()] for tree in syntree_pred] syndep_dev_pos = [[leaf.tag for leaf in tree.leaves()] for tree in self.ptb_dataset['dev_synconst_tree' + str_leng]] assert len(dev_pred_head) == len(dev_pred_type) assert len(dev_pred_type) == len(self.ptb_dataset['dev_syndep_type' + str_leng]) self.summary_dict['dev_syndep_uas'+str_leng], self.summary_dict['dev_syndep_las'+str_leng] = \ dep_eval.eval(len(dev_pred_head), self.ptb_dataset['dev_syndep_sent'+str_leng], syndep_dev_pos, dev_pred_head, dev_pred_type, self.ptb_dataset['dev_syndep_head'+str_leng], self.ptb_dataset['dev_syndep_type'+str_leng], punct_set=self.hparams.punctuation, symbolic_root=False)
def run_test(args): synconst_test_path = args.synconst_test_ptb_path syndep_test_path = args.syndep_test_ptb_path srlspan_test_path = args.srlspan_test_ptb_path srlspan_brown_path = args.srlspan_test_brown_path srldep_test_path = args.srldep_test_ptb_path srldep_brown_path = args.srldep_test_brown_path print("Loading model from {}...".format(args.model_path_base)) assert args.model_path_base.endswith( ".pt"), "Only pytorch savefiles supported" info = torch_load(args.model_path_base) assert 'hparams' in info['spec'], "Older savefiles not supported" parser = Zparser.ChartParser.from_spec(info['spec'], info['state_dict']) syndep_test_sent, syndep_test_pos, syndep_test_heads, syndep_test_types = syndep_reader.read_syndep( syndep_test_path) srlspan_test_sent, srlspan_test_verb, srlspan_test_dict, srlspan_test_predpos, srlspan_test_goldpos, \ srlspan_test_label, srlspan_test_label_start, srlspan_test_heads = srlspan_reader.read_srlspan(srlspan_test_path) srlspan_brown_sent, srlspan_brown_verb, srlspan_brown_dict, srlspan_brown_predpos, srlspan_brown_goldpos, \ srlspan_brown_label, srlspan_brown_label_start, srlspan_brown_heads = srlspan_reader.read_srlspan(srlspan_brown_path) srldep_test_sent, srldep_test_predpos, srldep_test_verb, srldep_test_dict, srldep_test_heads = srldep_reader.read_srldep( srldep_test_path) srldep_brown_sent, srldep_brown_predpos, srldep_brown_verb, srldep_brown_dict, srldep_brown_heads = srldep_reader.read_srldep( srldep_brown_path) print("Loading test trees from {}...".format(synconst_test_path)) test_treebank = trees.load_trees(synconst_test_path, syndep_test_heads, syndep_test_types, srlspan_test_label, srlspan_test_label_start) print("Loaded {:,} test examples.".format(len(test_treebank))) print("Parsing test sentences...") start_time = time.time() punct_set = '.' '``' "''" ':' ',' parser.eval() print("Start test eval:") test_start_time = time.time() syntree_pred = [] srlspan_pred = [] srldep_pred = [] #span srl and syn have same test data for start_index in range(0, len(test_treebank), args.eval_batch_size): subbatch_trees = test_treebank[start_index:start_index + args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] if parser.hparams.use_gold_predicate: syntree, srlspan_dict, _ = parser.parse_batch( subbatch_sentences, gold_verbs=srlspan_test_verb[start_index:start_index + args.eval_batch_size]) else: syntree, srlspan_dict, _ = parser.parse_batch(subbatch_sentences) syntree_pred.extend(syntree) srlspan_pred.extend(srlspan_dict) for start_index in range(0, len(srldep_test_sent), args.eval_batch_size): subbatch_words_srldep = srldep_test_sent[start_index:start_index + args.eval_batch_size] subbatch_pos_srldep = srldep_test_predpos[start_index:start_index + args.eval_batch_size] subbatch_sentences_srldep = [[ (tag, word) for j, (tag, word) in enumerate(zip(tags, words)) ] for i, (tags, words) in enumerate( zip(subbatch_pos_srldep, subbatch_words_srldep))] if parser.hparams.use_gold_predicate: _, _, srldep_dict = parser.parse_batch( subbatch_sentences_srldep, gold_verbs=srldep_test_verb[start_index:start_index + args.eval_batch_size]) else: _, _, srldep_dict = parser.parse_batch(subbatch_sentences_srldep) srldep_pred.extend(srldep_dict) # const parsing: test_fscore = evaluate.evalb(args.evalb_dir, test_treebank, syntree_pred) # dep parsing: test_pred_head = [[leaf.father for leaf in tree.leaves()] for tree in syntree_pred] test_pred_type = [[leaf.type for leaf in tree.leaves()] for tree in syntree_pred] assert len(test_pred_head) == len(test_pred_type) assert len(test_pred_type) == len(syndep_test_types) test_uas, test_las = dep_eval.eval(len(test_pred_head), syndep_test_sent, syndep_test_pos, test_pred_head, test_pred_type, syndep_test_heads, syndep_test_types, punct_set=punct_set, symbolic_root=False) print("===============================================") print("wsj srl span test eval:") precision, recall, f1, ul_prec, ul_recall, ul_f1 = ( srl_eval.compute_srl_f1(srlspan_test_sent, srlspan_test_dict, srlspan_pred, srl_conll_eval_path=False)) print("===============================================") print("wsj srl dep test eval:") precision, recall, f1 = (srl_eval.compute_dependency_f1( srldep_test_sent, srldep_test_dict, srldep_pred, srl_conll_eval_path=False)) print("===============================================") print( '============================================================================================================================' ) syntree_pred = [] srlspan_pred = [] srldep_pred = [] for start_index in range(0, len(srlspan_brown_sent), args.eval_batch_size): subbatch_words = srlspan_brown_sent[start_index:start_index + args.eval_batch_size] subbatch_pos = srlspan_brown_predpos[start_index:start_index + args.eval_batch_size] subbatch_sentences = [[ (tag, word) for j, (tag, word) in enumerate(zip(tags, words)) ] for i, (tags, words) in enumerate(zip(subbatch_pos, subbatch_words))] if parser.hparams.use_gold_predicate: syntree, srlspan_dict, _ = parser.parse_batch( subbatch_sentences, gold_verbs=srlspan_brown_verb[start_index:start_index + args.eval_batch_size]) else: syntree, srlspan_dict, _ = parser.parse_batch(subbatch_sentences) syntree_pred.extend(syntree) srlspan_pred.extend(srlspan_dict) for start_index in range(0, len(srldep_brown_sent), args.eval_batch_size): subbatch_words_srldep = srldep_brown_sent[start_index:start_index + args.eval_batch_size] subbatch_pos_srldep = srldep_brown_predpos[start_index:start_index + args.eval_batch_size] subbatch_sentences_srldep = [[ (tag, word) for j, (tag, word) in enumerate(zip(tags, words)) ] for i, (tags, words) in enumerate( zip(subbatch_pos_srldep, subbatch_words_srldep))] if parser.hparams.use_gold_predicate: _, _, srldep_dict = parser.parse_batch( subbatch_sentences_srldep, gold_verbs=srldep_brown_verb[start_index:start_index + args.eval_batch_size]) else: _, _, srldep_dict = parser.parse_batch(subbatch_sentences_srldep) srldep_pred.extend(srldep_dict) print("===============================================") print("brown srl span test eval:") precision, recall, f1, ul_prec, ul_recall, ul_f1 = ( srl_eval.compute_srl_f1(srlspan_brown_sent, srlspan_brown_dict, srlspan_pred, srl_conll_eval_path=False)) print("===============================================") print("brown srl dep test eval:") precision, recall, f1 = (srl_eval.compute_dependency_f1( srldep_brown_sent, srldep_brown_dict, srldep_pred, srl_conll_eval_path=False)) print("===============================================") print("test-elapsed {} " "total-elapsed {}".format( format_elapsed(test_start_time), format_elapsed(start_time), )) print( '============================================================================================================================' )
def make_check(self, model, optimizer, epoch_num): print("Start dev eval:") summary_dict = {} dev_start_time = time.time() summary_dict["synconst dev F1"] = evaluate.FScore(0, 0, 0) summary_dict["syndep dev uas"] = 0 summary_dict["syndep dev las"] = 0 summary_dict["pos dev"] = 0 summary_dict["synconst test F1"] = evaluate.FScore(0, 0, 0) summary_dict["syndep test uas"] = 0 summary_dict["syndep test las"] = 0 summary_dict["pos test"] = 0 summary_dict["srlspan dev F1" ]= 0 summary_dict["srldep dev F1"] = 0 summary_dict["srlspan test F1"] = 0 summary_dict["srlspan brown F1"] = 0 summary_dict["srldep test F1"] = 0 summary_dict["srldep brown F1"] = 0 model.eval() syntree_pred = [] srlspan_pred = [] srldep_pred = [] pos_pred = [] if self.hparams.joint_syn: for start_index in range(0, len(self.ptb_dataset['dev_synconst_tree']), self.eval_batch_size): subbatch_trees = self.ptb_dataset['dev_synconst_tree'][start_index:start_index +self.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] syntree, _, _= model.parse_batch(subbatch_sentences) syntree_pred.extend(syntree) # const parsing: summary_dict["synconst dev F1"] = evaluate.evalb(self.evalb_dir, self.ptb_dataset['dev_synconst_tree'], syntree_pred) # dep parsing: dev_pred_head = [[leaf.father for leaf in tree.leaves()] for tree in syntree_pred] dev_pred_type = [[leaf.type for leaf in tree.leaves()] for tree in syntree_pred] assert len(dev_pred_head) == len(dev_pred_type) assert len(dev_pred_type) == len(self.ptb_dataset['dev_syndep_type']) summary_dict["syndep dev uas"], summary_dict["syndep dev las"] = dep_eval.eval(len(dev_pred_head), self.ptb_dataset['dev_syndep_sent'], self.ptb_dataset['dev_syndep_pos'], dev_pred_head, dev_pred_type, self.ptb_dataset['dev_syndep_head'], self.ptb_dataset['dev_syndep_type'], punct_set=self.hparams.punct_set, symbolic_root=False) # for srl different dev set if self.hparams.joint_srl or self.hparams.joint_pos: for start_index in range(0, len(self.ptb_dataset['dev_srlspan_sent']), self.eval_batch_size): subbatch_words = self.ptb_dataset['dev_srlspan_sent'][start_index:start_index + self.eval_batch_size] subbatch_pos = self.ptb_dataset['dev_srlspan_pos'][start_index:start_index + self.eval_batch_size] subbatch_sentences = [[(tag, word) for j, (tag, word) in enumerate(zip(tags, words))] for i, (tags, words) in enumerate(zip(subbatch_pos, subbatch_words))] srlspan_tree, srlspan_dict, _ = \ model(subbatch_sentences, gold_verbs=self.ptb_dataset['dev_srlspan_verb'][start_index:start_index + self.eval_batch_size]) srlspan_pred.extend(srlspan_dict) pos_pred.extend([leaf.goldtag for leaf in srlspan_tree.leaves()]) for start_index in range(0, len(self.ptb_dataset['dev_srldep_sent']), self.eval_batch_size): subbatch_words = self.ptb_dataset['dev_srldep_sent'][start_index:start_index + self.eval_batch_size] subbatch_pos = self.ptb_dataset['dev_srldep_pos'][start_index:start_index + self.eval_batch_size] subbatch_sentences = [[(tag, word) for j, (tag, word) in enumerate(zip(tags, words))] for i, (tags, words) in enumerate(zip(subbatch_pos, subbatch_words))] _, srldep_dict, _ = \ model(subbatch_sentences, gold_verbs=self.ptb_dataset['dev_srldep_verb'][ start_index:start_index + self.eval_batch_size]) srldep_pred.extend(srldep_dict) if self.hparams.joint_srl: # srl span: # predicate F1 # pid_precision, pred_recall, pid_f1, _, _, _, _ = srl_eval.compute_span_f1( # srlspan_dev_verb, dev_pred_verb, "Predicate ID") print("===============================================") print("srl span dev eval:") precision, recall, f1, ul_prec, ul_recall, ul_f1 = ( srl_eval.compute_srl_f1(self.ptb_dataset['dev_srlspan_sent'], self.ptb_dataset['dev_srlspan_dict'], srlspan_pred, srl_conll_eval_path=False)) summary_dict["srlspan dev F1"] = f1 summary_dict["srlspan dev precision"] = precision summary_dict["srlspan dev recall"] = precision print("===============================================") print("srl dep dev eval:") precision, recall, f1 = ( srl_eval.compute_dependency_f1(self.ptb_dataset['dev_srldep_sent'], self.ptb_dataset['dev_srldep_dict'], srldep_pred, srl_conll_eval_path=False, use_gold=self.hparams.use_gold_predicate)) summary_dict["srldep dev F1"] = f1 summary_dict["srldep dev precision"] = precision summary_dict["srldep dev recall"] = precision print("===============================================") if self.hparams.joint_pos: summary_dict["pos dev"] = pos_eval.eval(self.ptb_dataset['dev_srlspan_goldpos'], pos_pred) print( "dev-elapsed {} ".format( format_elapsed(dev_start_time), ) ) print( '============================================================================================================================') print("Start test eval:") test_start_time = time.time() syntree_pred = [] srlspan_pred = [] srldep_pred = [] pos_pred = [] test_fscore = evaluate.FScore(0, 0, 0) test_uas = 0 test_las = 0 for start_index in range(0, len(self.ptb_dataset['test_synconst_tree']), self.eval_batch_size): subbatch_trees = self.ptb_dataset['test_synconst_tree'][start_index:start_index + self.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] syntree, srlspan_dict, _ = \ model(subbatch_sentences, gold_verbs=self.ptb_dataset['test_srlspan_verb'][ start_index:start_index + self.eval_batch_size]) syntree_pred.extend(syntree) srlspan_pred.extend(srlspan_dict) pos_pred.extend([leaf.goldtag for leaf in syntree.leaves()]) if self.hparams.joint_srl: for start_index in range(0, len(self.ptb_dataset['test_srlspan_sent']), self.eval_batch_size): subbatch_words_srldep = self.ptb_dataset['test_srlspan_sent'][start_index:start_index + self.eval_batch_size] subbatch_pos_srldep = self.ptb_dataset['test_srlspan_pos'][start_index:start_index + self.eval_batch_size] subbatch_sentences = [[(tag, word) for j, (tag, word) in enumerate(zip(tags, words))] for i, (tags, words) in enumerate(zip(subbatch_pos_srldep, subbatch_words_srldep))] _, _, srldep_dict = \ model(subbatch_sentences, gold_verbs=self.ptb_dataset['test_srldep_verb'][ start_index:start_index + self.eval_batch_size]) srldep_pred.extend(srldep_dict) # const parsing: if self.hparams.joint_syn: summary_dict["synconst test F1"] = evaluate.evalb(self.evalb_dir, self.ptb_dataset['test_synconst_tree'], syntree_pred) # dep parsing: test_pred_head = [[leaf.father for leaf in tree.leaves()] for tree in syntree_pred] test_pred_type = [[leaf.type for leaf in tree.leaves()] for tree in syntree_pred] assert len(test_pred_head) == len(test_pred_type) assert len(test_pred_type) == len(self.ptb_dataset['test_syndep_type']) summary_dict["syndep test uas"], summary_dict["syndep test las"] = dep_eval.eval(len(test_pred_head), self.ptb_dataset['test_syndep_sent'], self.ptb_dataset['test_syndep_pos'], test_pred_head, test_pred_type, self.ptb_dataset['test_syndep_head'], self.ptb_dataset['test_syndep_type'], punct_set=self.hparams.punct_set, symbolic_root=False) if self.hparams.joint_pos: summary_dict["pos test"] = pos_eval.eval(self.ptb_dataset['test_srlspan_goldpos'], pos_pred) # srl span: if self.hparams.joint_srl: # predicate F1 # pid_precision, pred_recall, pid_f1, _, _, _, _ = srl_eval.compute_span_f1( # srlspan_test_verb, test_pred_verb, "Predicate ID") print("===============================================") print("wsj srl span test eval:") precision, recall, f1, ul_prec, ul_recall, ul_f1 = ( srl_eval.compute_srl_f1(self.ptb_dataset['test_srlspan_sent'], self.ptb_dataset['test_srlspan_dict'], srlspan_pred, srl_conll_eval_path=False)) summary_dict["srlspan test F1"] = f1 summary_dict["srlspan test precision"] = precision summary_dict["srlspan test recall"] = precision print("===============================================") print("wsj srl dep test eval:") precision, recall, f1 = ( srl_eval.compute_dependency_f1(self.ptb_dataset['test_srldep_sent'], self.ptb_dataset['test_srldep_dict'], srldep_pred, srl_conll_eval_path=False, use_gold=self.hparams.use_gold_predicate)) summary_dict["srldep test F1"] = f1 summary_dict["srldep test precision"] = precision summary_dict["srldep test recall"] = precision print("===============================================") print( '============================================================================================================================') syntree_pred = [] srlspan_pred = [] srldep_pred = [] for start_index in range(0, len(self.ptb_dataset['brown_srlspan_sent']), self.eval_batch_size): subbatch_words = self.ptb_dataset['brown_srlspan_sent'][start_index:start_index + self.eval_batch_size] subbatch_pos = self.ptb_dataset['brown_srlspan_pos'][start_index:start_index + self.eval_batch_size] subbatch_sentences = [[(tag, word) for j, (tag, word) in enumerate(zip(tags, words))] for i, (tags, words) in enumerate(zip(subbatch_pos, subbatch_words))] syntree, srlspan_dict, _ = \ model(subbatch_sentences, gold_verbs=self.ptb_dataset['brown_srlspan_verb'][ start_index:start_index + self.eval_batch_size]) syntree_pred.extend(syntree) srlspan_pred.extend(srlspan_dict) for start_index in range(0, len(self.ptb_dataset['brown_srldep_sent']), self.eval_batch_size): subbatch_words_srldep = self.ptb_dataset['brown_srldep_sent'][start_index:start_index + self.eval_batch_size] subbatch_pos_srldep = self.ptb_dataset['brown_srldep_sent'][start_index:start_index + self.eval_batch_size] subbatch_sentences = [[(tag, word) for j, (tag, word) in enumerate(zip(tags, words))] for i, (tags, words) in enumerate(zip(subbatch_pos_srldep, subbatch_words_srldep))] _, _, srldep_dict = \ model(subbatch_sentences, gold_verbs=self.ptb_dataset['brown_srldep_verb'][ start_index:start_index + self.eval_batch_size]) srldep_pred.extend(srldep_dict) # predicate F1 # pid_precision, pred_recall, pid_f1, _, _, _, _ = srl_eval.compute_span_f1( # srlspan_test_verb, test_pred_verb, "Predicate ID") print("===============================================") print("brown srl span test eval:") precision, recall, f1, ul_prec, ul_recall, ul_f1 = ( srl_eval.compute_srl_f1(self.ptb_dataset['brown_srlspan_sent'], self.ptb_dataset['brown_srlspan_dict'], srlspan_pred, srl_conll_eval_path=False)) summary_dict["srlspan brown F1"] = f1 summary_dict["srlspan brown precision"] = precision summary_dict["srlspan brown recall"] = precision print("===============================================") print("brown srl dep test eval:") precision, recall, f1 = ( srl_eval.compute_dependency_f1(self.ptb_dataset['brown_srldep_sent'], self.ptb_dataset['brown_srldep_dict'], srldep_pred, srl_conll_eval_path=False, use_gold=self.hparams.use_gold_predicate)) summary_dict["srldep brown F1"] = f1 summary_dict["srldep brown precision"] = precision summary_dict["srldep brown recall"] = precision print("===============================================") print( "test-elapsed {} ".format( format_elapsed(test_start_time) ) ) print( '============================================================================================================================') if summary_dict['synconst dev F1'].fscore + summary_dict['syndep dev las'] + summary_dict["srlspan dev F1"] + summary_dict[ "srldep dev F1"] + summary_dict['pos dev'] > self.best_dev_score: if self.best_model_path is not None: extensions = [".pt"] for ext in extensions: path = self.best_model_path + ext if os.path.exists(path): print("Removing previous model file {}...".format(path)) os.remove(path) self.best_dev_score = summary_dict['synconst dev F1'].fscore + summary_dict['syndep dev las'] + summary_dict["srlspan dev F1"] + summary_dict[ "srldep dev F1"] + summary_dict['pos dev'] best_model_path = "{}_best_dev={:.2f}_devuas={:.2f}_devlas={:.2f}_devsrlspan={:.2f}_devsrldep={:.2f}".format( self.model_path_base, summary_dict['synconst dev F1'], summary_dict['syndep dev uas'], summary_dict['syndep dev las'], summary_dict["srlspan dev F1"], summary_dict["srldep dev F1"]) print("Saving new best model to {}...".format(best_model_path)) torch.save({ 'spec': model.spec, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, best_model_path + ".pt") log_data = "{} epoch, dev-fscore {:},test-fscore {:}, dev-uas {:.2f}, dev-las {:.2f}," \ "test-uas {:.2f}, test-las {:.2f}, dev-srlspan {:.2f}, test-wsj-srlspan {:.2f}, test-brown-srlspan {:.2f}," \ " dev-srldep {:.2f}, test-wsj-srldep {:.2f}, test-brown-srldep {:.2f}, dev-pos {:.2f}, test-pos {:.2f}," \ "dev_score {:.2f}, best_dev_score {:.2f}" \ .format(epoch_num, summary_dict["synconst dev F1"], summary_dict["synconst test F1"], summary_dict["syndep dev uas"], summary_dict["syndep dev las"], summary_dict["syndep test uas"], summary_dict["syndep test las"], summary_dict["srlspan dev F1"], summary_dict["srlspan test F1"], summary_dict["srlspan brown F1"], summary_dict["srldep dev F1"], summary_dict["srldep test F1"], summary_dict["srldep brown F1"], summary_dict["pos dev"], summary_dict["pos test"], summary_dict['synconst dev F1'].fscore + summary_dict['syndep dev las'] + summary_dict["srlspan dev F1"] + summary_dict["srldep dev F1"] + summary_dict['pos dev'], self.best_dev_score) if not os.path.exists(self.log_path): flog = open(self.log_path, 'w') flog = open(self.log_path, 'r+') content = flog.read() flog.seek(0, 0) flog.write(log_data + '\n' + content)
def check_dev(epoch_num): nonlocal best_dev_score nonlocal best_model_path nonlocal best_epoch print("Start dev eval:") dev_start_time = time.time() dev_fscore = evaluate.FScore(0, 0, 0) dev_uas = 0 dev_las = 0 pos_dev = 0 summary_dict = {} summary_dict["srlspan dev F1"] = 0 summary_dict["srldep dev F1"] = 0 parser.eval() syntree_pred = [] srlspan_pred = [] srldep_pred = [] pos_pred = [] if hparams.joint_syn_dep or hparams.joint_syn_const: for dev_start_index in range(0, len(dev_treebank), args.eval_batch_size): subbatch_trees = dev_treebank[dev_start_index:dev_start_index + args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] syntree, _, _ = parser.parse_batch(subbatch_sentences) syntree_pred.extend(syntree) #const parsing: dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank, syntree_pred) #dep parsing: dev_pred_head = [[leaf.father for leaf in tree.leaves()] for tree in syntree_pred] dev_pred_type = [[leaf.type for leaf in tree.leaves()] for tree in syntree_pred] assert len(dev_pred_head) == len(dev_pred_type) assert len(dev_pred_type) == len(syndep_dev_types) dev_uas, dev_las = dep_eval.eval(len(dev_pred_head), syndep_dev_sent, syndep_dev_pos, dev_pred_head, dev_pred_type, syndep_dev_heads, syndep_dev_types, punct_set=punct_set, symbolic_root=False) #for srl different dev set if hparams.joint_srl_span or hparams.joint_pos: for dev_start_index in range(0, len(srlspan_dev_sent), args.eval_batch_size): subbatch_words = srlspan_dev_sent[ dev_start_index:dev_start_index + args.eval_batch_size] subbatch_pos = srlspan_dev_predpos[ dev_start_index:dev_start_index + args.eval_batch_size] subbatch_sentences = [[ (tag, word) for j, (tag, word) in enumerate(zip(tags, words)) ] for i, ( tags, words) in enumerate(zip(subbatch_pos, subbatch_words))] if hparams.use_gold_predicate: srlspan_tree, srlspan_dict, _ = parser.parse_batch( subbatch_sentences, gold_verbs=srlspan_dev_verb[ dev_start_index:dev_start_index + args.eval_batch_size], syndep_heads=srlspan_dev_heads[ dev_start_index:dev_start_index + args.eval_batch_size]) else: srlspan_tree, srlspan_dict, _ = parser.parse_batch( subbatch_sentences, syndep_heads=srlspan_dev_heads[ dev_start_index:dev_start_index + args.eval_batch_size]) srlspan_pred.extend(srlspan_dict) pos_pred.extend([[leaf.goldtag for leaf in tree.leaves()] for tree in srlspan_tree]) if hparams.joint_srl_span: print("===============================================") print("srl span dev eval:") precision, recall, f1, ul_prec, ul_recall, ul_f1 = ( srl_eval.compute_srl_f1(srlspan_dev_sent, srlspan_dev_dict, srlspan_pred, srl_conll_eval_path=False)) summary_dict["srlspan dev F1"] = f1 summary_dict["srlspan dev precision"] = precision summary_dict["srlspan dev recall"] = precision if hparams.joint_pos: pos_dev = pos_eval.eval(srlspan_dev_goldpos, pos_pred) if hparams.joint_srl_dep: for dev_start_index in range(0, len(srldep_dev_sent), args.eval_batch_size): subbatch_words = srldep_dev_sent[ dev_start_index:dev_start_index + args.eval_batch_size] subbatch_pos = srldep_dev_predpos[ dev_start_index:dev_start_index + args.eval_batch_size] subbatch_sentences = [[ (tag, word) for j, (tag, word) in enumerate(zip(tags, words)) ] for i, ( tags, words) in enumerate(zip(subbatch_pos, subbatch_words))] if hparams.use_gold_predicate: _, _, srldep_dict = parser.parse_batch( subbatch_sentences, gold_verbs=srldep_dev_verb[ dev_start_index:dev_start_index + args.eval_batch_size], syndep_heads=srldep_dev_heads[ dev_start_index:dev_start_index + args.eval_batch_size]) else: _, _, srldep_dict = parser.parse_batch( subbatch_sentences, syndep_heads=srldep_dev_heads[ dev_start_index:dev_start_index + args.eval_batch_size]) srldep_pred.extend(srldep_dict) print("===============================================") print("srl dep dev eval:") precision, recall, f1 = (srl_eval.compute_dependency_f1( srldep_dev_sent, srldep_dev_dict, srldep_pred, srl_conll_eval_path=False, use_gold=hparams.use_gold_predicate)) summary_dict["srldep dev F1"] = f1 summary_dict["srldep dev precision"] = precision summary_dict["srldep dev recall"] = precision print("===============================================") print("dev-elapsed {} " "total-elapsed {}".format( format_elapsed(dev_start_time), format_elapsed(start_time), )) print( '============================================================================================================================' ) if dev_fscore.fscore + dev_las + summary_dict[ "srlspan dev F1"] + summary_dict[ "srldep dev F1"] + pos_dev > best_dev_score: if best_model_path is not None: extensions = [".pt"] for ext in extensions: path = best_model_path + ext if os.path.exists(path): print( "Removing previous model file {}...".format(path)) os.remove(path) best_dev_score = dev_fscore.fscore + dev_las + summary_dict[ "srlspan dev F1"] + summary_dict["srldep dev F1"] + pos_dev best_model_path = "{}_best_dev={:.2f}_devuas={:.2f}_devlas={:.2f}_devsrlspan={:.2f}_devsrldep={:.2f}".format( args.model_path_base, dev_fscore.fscore, dev_uas, dev_las, summary_dict["srlspan dev F1"], summary_dict["srldep dev F1"]) print("Saving new best model to {}...".format(best_model_path)) torch.save( { 'spec': parser.spec, 'state_dict': parser.state_dict(), 'trainer': trainer.state_dict(), }, best_model_path + ".pt")