def test(parser, corpus, device, prt=False, gap=0): """Compute UF1 and UAS scores. Args: parser: pretrained model corpus: labeled corpus device: cpu or gpu prt: bool, whether print examples gap: distance gap for building non-binary tree Returns: UF1: unlabeled F1 score for constituency parsing """ parser.eval() prec_list = [] reca_list = [] f1_list = [] dtree_list = [] corpus_sys = {} corpus_ref = {} nsens = 0 word2idx = corpus.dictionary.word2idx dataset = zip(corpus.test_sens, corpus.test_trees, corpus.test_nltktrees) for sen, sen_tree, sen_nltktree in dataset: x = [word2idx[w] if w in word2idx else word2idx['<unk>'] for w in sen] data = torch.LongTensor([x]).to(device) pos = torch.LongTensor([list(range(len(sen)))]).to(device) _, p_dict = parser(data, pos) block = p_dict['block'] cibling = p_dict['cibling'] head = p_dict['head'] distance = p_dict['distance'] height = p_dict['height'] distance = distance.clone().squeeze(0).cpu().numpy().tolist() height = height.clone().squeeze(0).cpu().numpy().tolist() head = head.clone().squeeze(0).cpu().numpy() max_height = numpy.max(height) parse_tree = tree_utils.build_tree(distance, sen, gap=gap) model_out, _ = tree_utils.get_brackets(parse_tree) std_out, _ = tree_utils.get_brackets(sen_tree) overlap = model_out.intersection(std_out) corpus_sys[nsens] = tree_utils.mrg(parse_tree) corpus_ref[nsens] = tree_utils.mrg_labeled(sen_nltktree) prec = float(len(overlap)) / (len(model_out) + 1e-8) reca = float(len(overlap)) / (len(std_out) + 1e-8) if not std_out: reca = 1. if not model_out: prec = 1. f1 = 2 * prec * reca / (prec + reca + 1e-8) prec_list.append(prec) reca_list.append(reca) f1_list.append(f1) new_words = [] true_words = sen_nltktree.pos() for d, c, w, ph in zip(distance, height, sen, head): next_word = true_words.pop(0) while next_word[1] not in data_ptb.WORD_TAGS: next_word = true_words.pop(0) new_words.append({ 'address': len(new_words) + 1, 'word': next_word[0], 'lemma': None, 'ctag': None, 'tag': next_word[1], 'feats': None, 'head': numpy.argmax(ph) + 1 if c < max_height else 0, 'deps': collections.defaultdict(list), 'rel': None, 'distance': d, 'height': c }) while true_words: next_word = true_words.pop(0) assert next_word[1] not in data_ptb.WORD_TAGS dtree = DependencyGraph() for w in new_words: dtree.add_node(w) dtree_list.append(dtree) if prt and len(dtree_list) % 100 == 0: cibling = cibling.clone().squeeze(0).cpu().numpy() block = block.clone().squeeze(0).cpu().numpy() for word_i, d_i, imp_i, block_i, cibling_i, head_i in zip( sen, distance, height, block, cibling, head): print('%20s\t%10.2f\t%5.2f\t%s\t%s\t%s' % (word_i, d_i, imp_i, plot(block_i, max_val=1.), plot(head_i, max_val=1), plot(cibling_i, max_val=1.))) print('Standard output:', sen_tree) print('Model output:', parse_tree) print(dtree.to_conll(10)) print() fig_i, ax_i = plt.subplots() ax_i.set_xticks(numpy.arange(len(sen))) ax_i.set_yticks(numpy.arange(len(sen))) ax_i.set_xticklabels(sen) ax_i.set_yticklabels(sen) plt.setp(ax_i.get_xticklabels(), rotation=45, ha='right', rotation_mode='anchor') for row in range(len(sen)): for col in range(len(sen)): _ = ax_i.text(col, row, '%.2f' % (head[row, col]), ha='center', va='center', color='w') fig_i.tight_layout() plt.savefig('./figures/sentence-%d.png' % (len(dtree_list)), dpi=300, format='png') nsens += 1 print('Constituency parsing performance:') print('Mean Prec: %.4f, Mean Reca: %.4f, Mean F1: %.4f' % (mean(prec_list), mean(reca_list), mean(f1_list))) correct, total = tree_utils.corpus_stats_labeled(corpus_sys, corpus_ref) print(correct) print(total) print('SBAR: %.4f' % (correct['SBAR'] / total['SBAR'])) print('NP: %.4f' % (correct['NP'] / total['NP'])) print('VP: %.4f' % (correct['VP'] / total['VP'])) print('PP: %.4f' % (correct['PP'] / total['PP'])) print('ADJP: %.4f' % (correct['ADJP'] / total['ADJP'])) print('ADVP: %.4f' % (correct['ADVP'] / total['ADVP'])) print(tree_utils.corpus_average_depth(corpus_sys)) print('-' * 89) print('Dependency parsing performance:') print('Stanford Style:') tree_utils.evald(dtree_list, '../data/ptb/test.stanford', directed=True) tree_utils.evald(dtree_list, '../data/ptb/test.stanford', directed=False) print('Conll Style:') tree_utils.evald(dtree_list, '../data/ptb/test.conll', directed=True) tree_utils.evald(dtree_list, '../data/ptb/test.conll', directed=False) return mean(f1_list)
print('Loading PTB dataset...') ptb_corpus = data_ptb.Corpus(args.data) print('Evaluating...') if args.cuda: eval_device = torch.device('cuda:0') else: eval_device = torch.device('cpu') print('=' * 89) test(model, ptb_corpus, eval_device, prt=args.print, gap=args.gap) print('=' * 89) rel_weight = model.rel_weight.detach().cpu().numpy() fig, axs = plt.subplots(8, 8, sharex=True, sharey=True) names = ['p', 'd'] for i in range(rel_weight.shape[0]): for j in range(rel_weight.shape[1]): print(plot(rel_weight[i, j], max_val=1.), end=' ') values = rel_weight[i, j] if i == 0: axs[i, j].set_title('%d' % (j, )) if j == 0: axs[i, j].set_ylabel('%d' % (i, )) axs[i, j].bar(names, values) print() plt.savefig('./figures/mask_weights.png', dpi=300, format='png')