def run_test(args): const_test_path = args.consttest_ptb_path dep_test_path = args.deptest_ptb_path if args.dataset == 'ctb': const_test_path = args.consttest_ctb_path dep_test_path = args.deptest_ctb_path print("Loading model from {}...".format(args.model_path_base)) assert args.model_path_base.endswith(".pt"), "Only pytorch savefiles supported" info = torch_load(args.model_path_base) assert 'hparams' in info['spec'], "Older savefiles not supported" parser = Zparser.ChartParser.from_spec(info['spec'], info['state_dict']) parser.eval() dep_test_reader = CoNLLXReader(dep_test_path, parser.type_vocab) print('Reading dependency parsing data from %s' % dep_test_path) dep_test_data = [] test_inst = dep_test_reader.getNext() dep_test_headid = np.zeros([40000, 300], dtype=int) dep_test_type = [] dep_test_word = [] dep_test_pos = [] dep_test_lengs = np.zeros(40000, dtype=int) cun = 0 while test_inst is not None: inst_size = test_inst.length() dep_test_lengs[cun] = inst_size sent = test_inst.sentence dep_test_data.append((sent.words, test_inst.postags, test_inst.heads, test_inst.types)) for i in range(inst_size): dep_test_headid[cun][i] = test_inst.heads[i] dep_test_type.append(test_inst.types) dep_test_word.append(sent.words) dep_test_pos.append(sent.postags) # dep_sentences.append([(tag, word) for i, (word, tag) in enumerate(zip(sent.words, sent.postags))]) test_inst = dep_test_reader.getNext() cun = cun + 1 dep_test_reader.close() print("Loading test trees from {}...".format(const_test_path)) test_treebank = trees.load_trees(const_test_path, dep_test_headid, dep_test_type, dep_test_word) print("Loaded {:,} test examples.".format(len(test_treebank))) print("Parsing test sentences...") start_time = time.time() punct_set = '.' '``' "''" ':' ',' parser.eval() test_predicted = [] for start_index in range(0, len(test_treebank), args.eval_batch_size): subbatch_trees = test_treebank[start_index:start_index + args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] predicted, _, = parser.parse_batch(subbatch_sentences) del _ test_predicted.extend([p.convert() for p in predicted]) test_fscore = evaluate.evalb(args.evalb_dir, test_treebank, test_predicted) print( "test-fscore {} " "test-elapsed {}".format( test_fscore, format_elapsed(start_time), ) ) test_pred_head = [[leaf.father for leaf in tree.leaves()] for tree in test_predicted] test_pred_type = [[leaf.type for leaf in tree.leaves()] for tree in test_predicted] assert len(test_pred_head) == len(test_pred_type) assert len(test_pred_type) == len(dep_test_type) stats, stats_nopunc, stats_root, test_total_inst = dep_eval.eval(len(test_pred_head), dep_test_word, dep_test_pos, test_pred_head, test_pred_type, dep_test_headid, dep_test_type, dep_test_lengs, punct_set=punct_set, symbolic_root=False) test_ucorrect, test_lcorrect, test_total, test_ucomlpete_match, test_lcomplete_match = stats test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc, test_ucomlpete_match_nopunc, test_lcomplete_match_nopunc = stats_nopunc test_root_correct, test_total_root = stats_root print( 'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % ( test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 / test_total, test_lcorrect * 100 / test_total, test_ucomlpete_match * 100 / test_total_inst, test_lcomplete_match * 100 / test_total_inst )) print( 'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% ' % ( test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc, test_ucorrect_nopunc * 100 / test_total_nopunc, test_lcorrect_nopunc * 100 / test_total_nopunc, test_ucomlpete_match_nopunc * 100 / test_total_inst, test_lcomplete_match_nopunc * 100 / test_total_inst)) print('best test Root: corr: %d, total: %d, acc: %.2f%%' % ( test_root_correct, test_total_root, test_root_correct * 100 / test_total_root)) print( '============================================================================================================================')
def check_dev(epoch_num): nonlocal best_dev_score nonlocal best_model_path dev_start_time = time.time() parser.eval() dev_predicted = [] for dev_start_index in range(0, len(dev_treebank), args.eval_batch_size): subbatch_trees = dev_treebank[dev_start_index:dev_start_index+args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] predicted, _,= parser.parse_batch(subbatch_sentences) del _ dev_predicted.extend([p.convert() for p in predicted]) dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank, dev_predicted) print( "dev-fscore {} " "dev-elapsed {} " "total-elapsed {}".format( dev_fscore, format_elapsed(dev_start_time), format_elapsed(start_time), ) ) dev_pred_head = [[leaf.father for leaf in tree.leaves()] for tree in dev_predicted] dev_pred_type = [[leaf.type for leaf in tree.leaves()] for tree in dev_predicted] assert len(dev_pred_head) == len(dev_pred_type) assert len(dev_pred_type) == len(dep_dev_type) stats, stats_nopunc, stats_root, num_inst = dep_eval.eval(len(dev_pred_head), dep_dev_word, dep_dev_pos, dev_pred_head, dev_pred_type, dep_dev_headid, dep_dev_type, dep_dev_lengs, punct_set=punct_set, symbolic_root=False) dev_ucorr, dev_lcorr, dev_total, dev_ucomlpete, dev_lcomplete = stats dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucomlpete_nopunc, dev_lcomplete_nopunc = stats_nopunc dev_root_corr, dev_total_root = stats_root dev_total_inst = num_inst print( 'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % ( dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total, dev_lcorr * 100 / dev_total, dev_ucomlpete * 100 / dev_total_inst, dev_lcomplete * 100 / dev_total_inst)) print( 'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % ( dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucorr_nopunc * 100 / dev_total_nopunc, dev_lcorr_nopunc * 100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 / dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst)) print('Root: corr: %d, total: %d, acc: %.2f%%' % ( dev_root_corr, dev_total_root, dev_root_corr * 100 / dev_total_root)) dev_uas = dev_ucorr_nopunc * 100 / dev_total_nopunc dev_las = dev_lcorr_nopunc * 100 / dev_total_nopunc if dev_fscore.fscore + dev_las > best_dev_score : if best_model_path is not None: extensions = [".pt"] for ext in extensions: path = best_model_path + ext if os.path.exists(path): print("Removing previous model file {}...".format(path)) os.remove(path) best_dev_score = dev_fscore.fscore + dev_las best_model_path = "{}_best_dev={:.2f}_devuas={:.2f}_devlas={:.2f}".format( args.model_path_base, dev_fscore.fscore, dev_uas,dev_las) print("Saving new best model to {}...".format(best_model_path)) torch.save({ 'spec': parser.spec, 'state_dict': parser.state_dict(), 'trainer' : trainer.state_dict(), }, besthh_model_path + ".pt")