def test(model, corpus, sess, seq_len):
    prt = True
    corpus = data_ptb.Corpus('data/penn')

    prec_list = []
    reca_list = []
    f1_list = []

    pred_tree_list = []
    targ_tree_list = []

    nsens = 0
    word2idx = corpus.dict.word2idx
    if True:#args.wsj10:
        dataset = zip(corpus.train_sens, corpus.train_trees, corpus.train_nltktrees)
    else:
        dataset = zip(corpus.test_sens, corpus.test_trees, corpus.test_nltktrees)

    corpus_sys = {}
    corpus_ref = {}
    print(len(corpus.test_sens))
    for sen, sen_tree, sen_nltktree in dataset:
        if len(sen) > 12:#args.wsj10 and len(sen) > 12:
            continue
        input = numpy.array([word2idx[w] if w in word2idx else word2idx['<unk>'] for w in sen])
        #print(input.shape)
        input = numpy.stack([input] + [numpy.zeros(input.shape) for i in range(79)])

        #print(input.shape)

        _, _, distance_forget, distance_input =\
           sess.run([model.cell.forward_propagate(input.shape[1])], feed_dict={model.cell.input:input, model.cell.seq_len:seq_len, model.targets:numpy.zeros((80,1))})[0]

        #print(distance_forget.shape)
        #print(distance_input.shape)

        distance_forget = distance_forget[:,:,0]
        distance_input = distance_input[:,:,0]

        nsens += 1
        if prt and nsens % 100 == 0:
            for i in range(len(sen)):
                print('%15s\t%s\t%s' % (sen[i], str(distance_forget[:, i]), str(distance_input[:, i])))
            print('Standard output:', sen_tree)

        sen_cut = sen[1:-1]
        for gates in [
            # distance[0],
            distance_forget[1],
            # distance[2],
            # distance.mean(axis=0)
        ]:
            #print(gates.shape)
            #print(len(sen_cut))
            depth = gates[1:-1]
            parse_tree = build_tree(depth, sen_cut)

            corpus_sys[nsens] = MRG(parse_tree)
            corpus_ref[nsens] = MRG_labeled(sen_nltktree)

            pred_tree_list.append(parse_tree)
            targ_tree_list.append(sen_tree)

            model_out, _ = get_brackets(parse_tree)
            std_out, _ = get_brackets(sen_tree)
            overlap = model_out.intersection(std_out)

            prec = float(len(overlap)) / (len(model_out) + 1e-8)
            reca = float(len(overlap)) / (len(std_out) + 1e-8)
            if len(std_out) == 0:
                reca = 1.
                if len(model_out) == 0:
                    prec = 1.
            f1 = 2 * prec * reca / (prec + reca + 1e-8)
            prec_list.append(prec)
            reca_list.append(reca)
            f1_list.append(f1)

            if prt and nsens % 1 == 0:
                print('Model output:', parse_tree)
                print('Prec: %f, Reca: %f, F1: %f' % (prec, reca, f1))

        if prt and nsens % 100 == 0:
            print('-' * 80)

            _, axarr = plt.subplots(3, sharex=True, figsize=(distance_forget.shape[1] // 2, 6))
            axarr[0].bar(numpy.arange(distance_forget.shape[1])-0.2, distance_forget[0], width=0.4)
            axarr[0].bar(numpy.arange(distance_input.shape[1])+0.2, distance_input[0], width=0.4)
            axarr[0].set_ylim([0., 1.])
            axarr[0].set_ylabel('1st layer')
            axarr[1].bar(numpy.arange(distance_forget.shape[1]) - 0.2, distance_forget[1], width=0.4)
            axarr[1].bar(numpy.arange(distance_input.shape[1]) + 0.2, distance_input[1], width=0.4)
            axarr[1].set_ylim([0., 1.])
            axarr[1].set_ylabel('2nd layer')
            axarr[2].bar(numpy.arange(distance_forget.shape[1]) - 0.2, distance_forget[2], width=0.4)
            axarr[2].bar(numpy.arange(distance_input.shape[1]) + 0.2, distance_input[2], width=0.4)
            axarr[2].set_ylim([0., 1.])
            axarr[2].set_ylabel('3rd layer')
            plt.sca(axarr[2])
            plt.xlim(xmin=-0.5, xmax=distance_forget.shape[1] - 0.5)
            plt.xticks(numpy.arange(distance_forget.shape[1]), sen, fontsize=10, rotation=45)

            plt.savefig('figure/%d.png' % (nsens))
            plt.close()

    prec_list, reca_list, f1_list \
        = numpy.array(prec_list).reshape((-1,1)), numpy.array(reca_list).reshape((-1,1)), numpy.array(f1_list).reshape((-1,1))
    if prt:
        print('-' * 80)
        numpy.set_printoptions(precision=4)
        print('Mean Prec:', prec_list.mean(axis=0),
              ', Mean Reca:', reca_list.mean(axis=0),
              ', Mean F1:', f1_list.mean(axis=0))
        print('Number of sentence: %i' % nsens)

        correct, total = corpus_stats_labeled(corpus_sys, corpus_ref)
        print(correct)
        print(total)
        print('ADJP:', correct['ADJP'], total['ADJP'])
        print('NP:', correct['NP'], total['NP'])
        print('PP:', correct['PP'], total['PP'])
        print('INTJ:', correct['INTJ'], total['INTJ'])
        print(corpus_average_depth(corpus_sys))

        evalb(pred_tree_list, targ_tree_list)

    return f1_list.mean(axis=0)
def test(model, corpus, cuda, prt=False):
    model.eval()

    prec_list = []
    reca_list = []
    f1_list = []

    pred_tree_list = []
    targ_tree_list = []

    nsens = 0
    word2idx = corpus.dictionary.word2idx
    if args.wsj10:
        dataset = zip(corpus.train_sens, corpus.train_trees, corpus.train_nltktrees)
    else:
        dataset = zip(corpus.test_sens, corpus.test_trees, corpus.test_nltktrees)

    corpus_sys = {}
    corpus_ref = {}
    for sen, sen_tree, sen_nltktree in dataset:
        if args.wsj10 and len(sen) > 12:
            continue
        x = numpy.array([word2idx[w] if w in word2idx else word2idx['<unk>'] for w in sen])
        input = Variable(torch.LongTensor(x[:, None]))
        if cuda:
            input = input.cuda()

        hidden = model.init_hidden(1)
        _, hidden = model(input, hidden)

        distance = model.distance[0].squeeze().data.cpu().numpy()
        distance_in = model.distance[1].squeeze().data.cpu().numpy()

        nsens += 1
        if prt and nsens % 100 == 0:
            for i in range(len(sen)):
                print('%15s\t%s\t%s' % (sen[i], str(distance[:, i]), str(distance_in[:, i])))
            print('Standard output:', sen_tree)

        sen_cut = sen[1:-1]
        # gates = distance.mean(axis=0)
        for gates in [
            # distance[0],
            distance[1],
            # distance[2],
            # distance.mean(axis=0)
        ]:
            depth = gates[1:-1]
            parse_tree = build_tree(depth, sen_cut)

            corpus_sys[nsens] = MRG(parse_tree)
            corpus_ref[nsens] = MRG_labeled(sen_nltktree)

            pred_tree_list.append(parse_tree)
            targ_tree_list.append(sen_tree)

            model_out, _ = get_brackets(parse_tree)
            std_out, _ = get_brackets(sen_tree)
            overlap = model_out.intersection(std_out)

            prec = float(len(overlap)) / (len(model_out) + 1e-8)
            reca = float(len(overlap)) / (len(std_out) + 1e-8)
            if len(std_out) == 0:
                reca = 1.
                if len(model_out) == 0:
                    prec = 1.
            f1 = 2 * prec * reca / (prec + reca + 1e-8)
            prec_list.append(prec)
            reca_list.append(reca)
            f1_list.append(f1)

            if prt and nsens % 100 == 0:
                print('Model output:', parse_tree)
                print('Prec: %f, Reca: %f, F1: %f' % (prec, reca, f1))

        if prt and nsens % 100 == 0:
            print('-' * 80)

            f, axarr = plt.subplots(3, sharex=True, figsize=(distance.shape[1] // 2, 6))
            axarr[0].bar(numpy.arange(distance.shape[1])-0.2, distance[0], width=0.4)
            axarr[0].bar(numpy.arange(distance_in.shape[1])+0.2, distance_in[0], width=0.4)
            axarr[0].set_ylim([0., 1.])
            axarr[0].set_ylabel('1st layer')
            axarr[1].bar(numpy.arange(distance.shape[1]) - 0.2, distance[1], width=0.4)
            axarr[1].bar(numpy.arange(distance_in.shape[1]) + 0.2, distance_in[1], width=0.4)
            axarr[1].set_ylim([0., 1.])
            axarr[1].set_ylabel('2nd layer')
            axarr[2].bar(numpy.arange(distance.shape[1]) - 0.2, distance[2], width=0.4)
            axarr[2].bar(numpy.arange(distance_in.shape[1]) + 0.2, distance_in[2], width=0.4)
            axarr[2].set_ylim([0., 1.])
            axarr[2].set_ylabel('3rd layer')
            plt.sca(axarr[2])
            plt.xlim(xmin=-0.5, xmax=distance.shape[1] - 0.5)
            plt.xticks(numpy.arange(distance.shape[1]), sen, fontsize=10, rotation=45)

            plt.savefig('figure/%d.png' % (nsens))
            plt.close()

    prec_list, reca_list, f1_list \
        = numpy.array(prec_list).reshape((-1,1)), numpy.array(reca_list).reshape((-1,1)), numpy.array(f1_list).reshape((-1,1))
    if prt:
        print('-' * 80)
        numpy.set_printoptions(precision=4)
        print('Mean Prec:', prec_list.mean(axis=0),
              ', Mean Reca:', reca_list.mean(axis=0),
              ', Mean F1:', f1_list.mean(axis=0))
        print('Number of sentence: %i' % nsens)

        correct, total = corpus_stats_labeled(corpus_sys, corpus_ref)
        print(correct)
        print(total)
        print('ADJP:', correct['ADJP'], total['ADJP'])
        print('NP:', correct['NP'], total['NP'])
        print('PP:', correct['PP'], total['PP'])
        print('INTJ:', correct['INTJ'], total['INTJ'])
        print(corpus_average_depth(corpus_sys))

        evalb(pred_tree_list, targ_tree_list)

    return f1_list.mean(axis=0)
Beispiel #3
0
def generate_parse():
    from nltk import Tree
    from utils import evalb

    batch = []
    pred_tree_list = []
    targ_tree_list = []

    def process_batch():
        nonlocal batch, pred_tree_list, targ_tree_list
        idx = TEXT.process([example['sents'] for example in batch],
                           device=hidden[0].device)

        model(idx)

        probs = model.encoder.probs
        distance = torch.argmax(probs, dim=-1)
        distance[0] = args.nslot
        probs = probs.data.cpu().numpy()

        for i, example in enumerate(batch):
            sents = example['sents']
            sents_tree = example['sents_tree']
            depth = distance[:, i]

            parse_tree = build_tree(depth, sents)

            if len(sents) <= 100:
                pred_tree_list.append(parse_tree)
                targ_tree_list.append(sents_tree)

            if i == 0:
                for j in range(len(sents)):
                    print('%20s\t%2.2f\t%s' %
                          (sents[j], depth[j], plot(probs[j, i], 1.)))
                print(parse_tree)
                print(sents_tree)
                print('-' * 80)

        batch = []

    np.set_printoptions(precision=2,
                        suppress=True,
                        linewidth=5000,
                        formatter={'float': '{: 0.2f}'.format})

    model.eval()
    hidden = model.encoder.init_hidden(1)

    fin = open('.data/sst/trees/dev.txt', 'r')
    for line in fin:
        line = line.lower()
        sents_tree = Tree.fromstring(line)
        sents = sents_tree.leaves()
        batch.append({'sents_tree': sents_tree, 'sents': sents})

        if len(batch) == 16:
            process_batch()

    if len(batch) > 0:
        process_batch()

    evalb(pred_tree_list, targ_tree_list, evalb_path='./EVALB')