def test(model, corpus, sess, seq_len): prt = True corpus = data_ptb.Corpus('data/penn') prec_list = [] reca_list = [] f1_list = [] pred_tree_list = [] targ_tree_list = [] nsens = 0 word2idx = corpus.dict.word2idx if True:#args.wsj10: dataset = zip(corpus.train_sens, corpus.train_trees, corpus.train_nltktrees) else: dataset = zip(corpus.test_sens, corpus.test_trees, corpus.test_nltktrees) corpus_sys = {} corpus_ref = {} print(len(corpus.test_sens)) for sen, sen_tree, sen_nltktree in dataset: if len(sen) > 12:#args.wsj10 and len(sen) > 12: continue input = numpy.array([word2idx[w] if w in word2idx else word2idx['<unk>'] for w in sen]) #print(input.shape) input = numpy.stack([input] + [numpy.zeros(input.shape) for i in range(79)]) #print(input.shape) _, _, distance_forget, distance_input =\ sess.run([model.cell.forward_propagate(input.shape[1])], feed_dict={model.cell.input:input, model.cell.seq_len:seq_len, model.targets:numpy.zeros((80,1))})[0] #print(distance_forget.shape) #print(distance_input.shape) distance_forget = distance_forget[:,:,0] distance_input = distance_input[:,:,0] nsens += 1 if prt and nsens % 100 == 0: for i in range(len(sen)): print('%15s\t%s\t%s' % (sen[i], str(distance_forget[:, i]), str(distance_input[:, i]))) print('Standard output:', sen_tree) sen_cut = sen[1:-1] for gates in [ # distance[0], distance_forget[1], # distance[2], # distance.mean(axis=0) ]: #print(gates.shape) #print(len(sen_cut)) depth = gates[1:-1] parse_tree = build_tree(depth, sen_cut) corpus_sys[nsens] = MRG(parse_tree) corpus_ref[nsens] = MRG_labeled(sen_nltktree) pred_tree_list.append(parse_tree) targ_tree_list.append(sen_tree) model_out, _ = get_brackets(parse_tree) std_out, _ = get_brackets(sen_tree) overlap = model_out.intersection(std_out) prec = float(len(overlap)) / (len(model_out) + 1e-8) reca = float(len(overlap)) / (len(std_out) + 1e-8) if len(std_out) == 0: reca = 1. if len(model_out) == 0: prec = 1. f1 = 2 * prec * reca / (prec + reca + 1e-8) prec_list.append(prec) reca_list.append(reca) f1_list.append(f1) if prt and nsens % 1 == 0: print('Model output:', parse_tree) print('Prec: %f, Reca: %f, F1: %f' % (prec, reca, f1)) if prt and nsens % 100 == 0: print('-' * 80) _, axarr = plt.subplots(3, sharex=True, figsize=(distance_forget.shape[1] // 2, 6)) axarr[0].bar(numpy.arange(distance_forget.shape[1])-0.2, distance_forget[0], width=0.4) axarr[0].bar(numpy.arange(distance_input.shape[1])+0.2, distance_input[0], width=0.4) axarr[0].set_ylim([0., 1.]) axarr[0].set_ylabel('1st layer') axarr[1].bar(numpy.arange(distance_forget.shape[1]) - 0.2, distance_forget[1], width=0.4) axarr[1].bar(numpy.arange(distance_input.shape[1]) + 0.2, distance_input[1], width=0.4) axarr[1].set_ylim([0., 1.]) axarr[1].set_ylabel('2nd layer') axarr[2].bar(numpy.arange(distance_forget.shape[1]) - 0.2, distance_forget[2], width=0.4) axarr[2].bar(numpy.arange(distance_input.shape[1]) + 0.2, distance_input[2], width=0.4) axarr[2].set_ylim([0., 1.]) axarr[2].set_ylabel('3rd layer') plt.sca(axarr[2]) plt.xlim(xmin=-0.5, xmax=distance_forget.shape[1] - 0.5) plt.xticks(numpy.arange(distance_forget.shape[1]), sen, fontsize=10, rotation=45) plt.savefig('figure/%d.png' % (nsens)) plt.close() prec_list, reca_list, f1_list \ = numpy.array(prec_list).reshape((-1,1)), numpy.array(reca_list).reshape((-1,1)), numpy.array(f1_list).reshape((-1,1)) if prt: print('-' * 80) numpy.set_printoptions(precision=4) print('Mean Prec:', prec_list.mean(axis=0), ', Mean Reca:', reca_list.mean(axis=0), ', Mean F1:', f1_list.mean(axis=0)) print('Number of sentence: %i' % nsens) correct, total = corpus_stats_labeled(corpus_sys, corpus_ref) print(correct) print(total) print('ADJP:', correct['ADJP'], total['ADJP']) print('NP:', correct['NP'], total['NP']) print('PP:', correct['PP'], total['PP']) print('INTJ:', correct['INTJ'], total['INTJ']) print(corpus_average_depth(corpus_sys)) evalb(pred_tree_list, targ_tree_list) return f1_list.mean(axis=0)
def test(model, corpus, cuda, prt=False): model.eval() prec_list = [] reca_list = [] f1_list = [] pred_tree_list = [] targ_tree_list = [] nsens = 0 word2idx = corpus.dictionary.word2idx if args.wsj10: dataset = zip(corpus.train_sens, corpus.train_trees, corpus.train_nltktrees) else: dataset = zip(corpus.test_sens, corpus.test_trees, corpus.test_nltktrees) corpus_sys = {} corpus_ref = {} for sen, sen_tree, sen_nltktree in dataset: if args.wsj10 and len(sen) > 12: continue x = numpy.array([word2idx[w] if w in word2idx else word2idx['<unk>'] for w in sen]) input = Variable(torch.LongTensor(x[:, None])) if cuda: input = input.cuda() hidden = model.init_hidden(1) _, hidden = model(input, hidden) distance = model.distance[0].squeeze().data.cpu().numpy() distance_in = model.distance[1].squeeze().data.cpu().numpy() nsens += 1 if prt and nsens % 100 == 0: for i in range(len(sen)): print('%15s\t%s\t%s' % (sen[i], str(distance[:, i]), str(distance_in[:, i]))) print('Standard output:', sen_tree) sen_cut = sen[1:-1] # gates = distance.mean(axis=0) for gates in [ # distance[0], distance[1], # distance[2], # distance.mean(axis=0) ]: depth = gates[1:-1] parse_tree = build_tree(depth, sen_cut) corpus_sys[nsens] = MRG(parse_tree) corpus_ref[nsens] = MRG_labeled(sen_nltktree) pred_tree_list.append(parse_tree) targ_tree_list.append(sen_tree) model_out, _ = get_brackets(parse_tree) std_out, _ = get_brackets(sen_tree) overlap = model_out.intersection(std_out) prec = float(len(overlap)) / (len(model_out) + 1e-8) reca = float(len(overlap)) / (len(std_out) + 1e-8) if len(std_out) == 0: reca = 1. if len(model_out) == 0: prec = 1. f1 = 2 * prec * reca / (prec + reca + 1e-8) prec_list.append(prec) reca_list.append(reca) f1_list.append(f1) if prt and nsens % 100 == 0: print('Model output:', parse_tree) print('Prec: %f, Reca: %f, F1: %f' % (prec, reca, f1)) if prt and nsens % 100 == 0: print('-' * 80) f, axarr = plt.subplots(3, sharex=True, figsize=(distance.shape[1] // 2, 6)) axarr[0].bar(numpy.arange(distance.shape[1])-0.2, distance[0], width=0.4) axarr[0].bar(numpy.arange(distance_in.shape[1])+0.2, distance_in[0], width=0.4) axarr[0].set_ylim([0., 1.]) axarr[0].set_ylabel('1st layer') axarr[1].bar(numpy.arange(distance.shape[1]) - 0.2, distance[1], width=0.4) axarr[1].bar(numpy.arange(distance_in.shape[1]) + 0.2, distance_in[1], width=0.4) axarr[1].set_ylim([0., 1.]) axarr[1].set_ylabel('2nd layer') axarr[2].bar(numpy.arange(distance.shape[1]) - 0.2, distance[2], width=0.4) axarr[2].bar(numpy.arange(distance_in.shape[1]) + 0.2, distance_in[2], width=0.4) axarr[2].set_ylim([0., 1.]) axarr[2].set_ylabel('3rd layer') plt.sca(axarr[2]) plt.xlim(xmin=-0.5, xmax=distance.shape[1] - 0.5) plt.xticks(numpy.arange(distance.shape[1]), sen, fontsize=10, rotation=45) plt.savefig('figure/%d.png' % (nsens)) plt.close() prec_list, reca_list, f1_list \ = numpy.array(prec_list).reshape((-1,1)), numpy.array(reca_list).reshape((-1,1)), numpy.array(f1_list).reshape((-1,1)) if prt: print('-' * 80) numpy.set_printoptions(precision=4) print('Mean Prec:', prec_list.mean(axis=0), ', Mean Reca:', reca_list.mean(axis=0), ', Mean F1:', f1_list.mean(axis=0)) print('Number of sentence: %i' % nsens) correct, total = corpus_stats_labeled(corpus_sys, corpus_ref) print(correct) print(total) print('ADJP:', correct['ADJP'], total['ADJP']) print('NP:', correct['NP'], total['NP']) print('PP:', correct['PP'], total['PP']) print('INTJ:', correct['INTJ'], total['INTJ']) print(corpus_average_depth(corpus_sys)) evalb(pred_tree_list, targ_tree_list) return f1_list.mean(axis=0)
def generate_parse(): from nltk import Tree from utils import evalb batch = [] pred_tree_list = [] targ_tree_list = [] def process_batch(): nonlocal batch, pred_tree_list, targ_tree_list idx = TEXT.process([example['sents'] for example in batch], device=hidden[0].device) model(idx) probs = model.encoder.probs distance = torch.argmax(probs, dim=-1) distance[0] = args.nslot probs = probs.data.cpu().numpy() for i, example in enumerate(batch): sents = example['sents'] sents_tree = example['sents_tree'] depth = distance[:, i] parse_tree = build_tree(depth, sents) if len(sents) <= 100: pred_tree_list.append(parse_tree) targ_tree_list.append(sents_tree) if i == 0: for j in range(len(sents)): print('%20s\t%2.2f\t%s' % (sents[j], depth[j], plot(probs[j, i], 1.))) print(parse_tree) print(sents_tree) print('-' * 80) batch = [] np.set_printoptions(precision=2, suppress=True, linewidth=5000, formatter={'float': '{: 0.2f}'.format}) model.eval() hidden = model.encoder.init_hidden(1) fin = open('.data/sst/trees/dev.txt', 'r') for line in fin: line = line.lower() sents_tree = Tree.fromstring(line) sents = sents_tree.leaves() batch.append({'sents_tree': sents_tree, 'sents': sents}) if len(batch) == 16: process_batch() if len(batch) > 0: process_batch() evalb(pred_tree_list, targ_tree_list, evalb_path='./EVALB')