示例#1
0
    def build_corpus(self):
        print(f'Loading training trees from `{self.train_path}`...')
        if self.multitask == 'ccg':
            train_treebank = ccg.fromfile(self.train_path)
        else:
            with open(self.train_path) as f:
                train_treebank = [fromstring(line.strip()) for line in f]

        print(f'Loading development trees from `{self.dev_path}`...')
        with open(self.dev_path) as f:
            dev_treebank = [fromstring(line.strip()) for line in f]

        print(f'Loading test trees from `{self.test_path}`...')
        with open(self.test_path) as f:
            test_treebank = [fromstring(line.strip()) for line in f]

        if self.multitask == 'spans':
            # need trees with span-information
            train_treebank = [tree.convert() for tree in train_treebank]
            dev_treebank = [tree.convert() for tree in dev_treebank]
            test_treebank = [tree.convert() for tree in test_treebank]

        print("Constructing vocabularies...")
        if self.vocab_path is not None:
            print(f'Using word vocabulary specified in `{self.vocab_path}`')
            with open(self.vocab_path) as f:
                vocab = json.load(f)
            words = [word for word, count in vocab.items() for _ in range(count)]
        else:
            words = [word for tree in train_treebank for word in tree.words()]

        if self.multitask == 'none':
            labels = []
        else:
            labels = [label for tree in train_treebank for label in tree.labels()]

        if self.multitask == 'none':
            words = [UNK, START] + words
        else:
            words = [UNK, START, STOP] + words

        word_vocab = Vocabulary.fromlist(words, unk_value=UNK)
        label_vocab = Vocabulary.fromlist(labels)

        self.word_vocab = word_vocab
        self.label_vocab = label_vocab

        self.train_treebank = train_treebank
        self.dev_treebank = dev_treebank
        self.test_treebank = test_treebank

        print('\n'.join((
            'Corpus statistics:',
            f'Vocab: {word_vocab.size:,} words, {label_vocab.size:,} nonterminals',
            f'Train: {len(train_treebank):,} sentences',
            f'Dev: {len(dev_treebank):,} sentences',
            f'Test: {len(test_treebank):,} sentences')))
示例#2
0
def predict_input_crf(args):
    print('Predicting with crf parser.')

    parser = load_model(args.checkpoint)

    ##
    right_branching = fromstring(
        "(S (NP (@ The) (@ (@ other) (@ (@ hungry) (@ cat)))) (@ (VP meows ) (@ .)))"
    ).convert()
    left_branching = fromstring(
        "(S (NP (@ The) (@ (@ (@ other) (@ hungry)) (@ cat))) (@ (VP meows ) (@ .)))"
    ).convert()

    # right_branching = fromstring("(X (X (@ The) (@ (@ other) (@ (@ hungry) (@ cat)))) (@ (X meows ) (@ .)))").convert()
    # left_branching = fromstring("(X (X (@ The) (@ (@ (@ other) (@ hungry)) (@ cat))) (@ (X meows ) (@ .)))").convert()

    right_nll = parser.forward(right_branching, is_train=False)
    left_nll = parser.forward(left_branching, is_train=False)

    print('Right:', right_nll.value())
    print('Left:', left_nll.value())
    ##

    while True:
        sentence = input('Input a sentence: ')
        words = sentence.split()

        print('Processed:', ' '.join(parser.word_vocab.process(words)))
        print()

        print('Parse:')
        tree, nll = parser.parse(words)
        print('  {} {:.3f}'.format(tree.linearize(with_tag=False),
                                   nll.value()))
        print()

        print('Samples:')
        parse, parse_logprob, samples, entropy = parser.parse_sample_entropy(
            words, num_samples=8, alpha=1)
        for tree, nll in samples:
            print('  {} {:.3f}'.format(tree.linearize(with_tag=False),
                                       nll.value()))
            # print('  {} ||| {} ||| {:.3f}'.format(
            #     tree.convert().linearize(with_tag=False), tree.un_cnf().linearize(with_tag=False), nll.value()))
            print()
        print('Parse (temperature {}):'.format(args.alpha))
        print('  {} {:.3f}'.format(parse.linearize(with_tag=False),
                                   -parse_logprob.value()))
        print()

        print('Entropy:')
        print('  {:.3f}'.format(entropy.value()))

        print('-' * 79)
        print()
示例#3
0
 def read(path):
     with open(path) as f:
         trees = [
             fromstring(line.strip()) for line in f.readlines()
             if line.strip()
         ]
     return trees
示例#4
0
 def read_proposals(self, path):
     print(f'Loading discriminative (proposal) samples from `{path}`...')
     with open(path) as f:
         lines = [line.strip() for line in f.readlines()]
     sent_id = 0
     samples = []
     proposals = []
     for line in lines:
         sample_id, logprob, tree = line.split('|||')
         sample_id, logprob, tree = int(sample_id), float(
             logprob), fromstring(add_dummy_tags(tree.strip()))
         if sample_id > sent_id:
             # arrived at the first sample of next sentence
             if self.num_samples > len(samples):
                 raise ValueError(
                     'not enough samples for line {}'.format(sample_id))
             elif self.num_samples < len(samples):
                 samples = samples[:self.num_samples]
             else:
                 pass
             proposals.append(samples)
             sent_id = sample_id
             samples = []
         samples.append((tree, logprob))
     proposals.append(samples)
     return proposals
示例#5
0
def inspect_model(args):
    assert args.model_type == 'disc-rnng', args.model_type

    print(f'Inspecting attention for sentences in `{args.infile}`.')

    parser = load_model(args.checkpoint)

    with open(args.infile, 'r') as f:
        lines = [line.strip() for line in f.readlines()]
    lines = lines[:args.max_lines]
    if is_tree(lines[0]):
        sentences = [fromstring(line).words() for line in lines]
    else:
        sentences = [line.split() for line in lines]

    def inspect_after_reduce(parser):
        subtree = parser.stack._stack[-1].subtree
        head = subtree.label
        children = [
            child.label if isinstance(child, InternalNode) else child.word
            for child in subtree.children
        ]
        attention = parser.composer._attn
        gate = np.mean(parser.composer._gate)
        attention = [attention] if not isinstance(
            attention, list) else attention  # in case .value() returns a float
        attentive = [
            f'{child} ({attn:.2f})'
            for child, attn in zip(children, attention)
        ]
        print('  ', head, '|', ' '.join(attentive), f'[{gate:.2f}]')

    def parse_with_inspection(parser, words):
        parser.eval()
        nll = 0.
        word_ids = [parser.word_vocab.index_or_unk(word) for word in words]
        parser.initialize(word_ids)
        while not parser.stack.is_finished():
            u = parser.parser_representation()
            action_logits = parser.f_action(u)
            action_id = np.argmax(action_logits.value() +
                                  parser._add_actions_mask())
            nll += dy.pickneglogsoftmax(action_logits, action_id)
            parser.parse_step(action_id)
            if action_id == parser.REDUCE_ID:
                inspect_after_reduce(parser)
        tree = parser.get_tree()
        tree.substitute_leaves(iter(words))  # replaces UNKs with originals
        return tree, nll

        for sentence in sentences:
            tree, _ = parser.parse(sentence)
            print('>', ' '.join(sentence))
            print('>', tree.linearize(with_tag=False))
            parse_with_inspection(parser, sentence)
            print()
示例#6
0
def main():

    with open('/Users/daan/data/ptb-benepar/23.auto.clean') as f:
        lines = [line.strip() for line in f.readlines()]

    treebank = [trees.fromstring(line, strip_top=True) for line in lines[:100]]
    tree = treebank[0].cnf()

    # Obtain the word an label vocabularies
    words = [vocabulary.UNK, START, STOP] + [word for word in tree.words()]
    labels = [
        (trees.DUMMY, )
    ] + [label for tree in treebank[:100] for label in tree.cnf().labels()]

    word_vocab = vocabulary.Vocabulary.fromlist(words,
                                                unk_value=vocabulary.UNK)
    label_vocab = vocabulary.Vocabulary.fromlist(labels)

    model = dy.ParameterCollection()
    parser = ChartParser(
        model,
        word_vocab,
        label_vocab,
        word_embedding_dim=100,
        lstm_layers=2,
        lstm_dim=100,
        label_hidden_dim=100,
        dropout=0.,
    )
    optimizer = dy.AdamTrainer(model)

    for i in range(1000):
        dy.renew_cg()

        t0 = time.time()
        loss = parser.forward(tree)
        pred, _ = parser.parse(tree.words())
        sample, _ = parser.sample(tree.words())
        t1 = time.time()

        loss.forward()
        loss.backward()
        optimizer.update()

        t2 = time.time()

        print('step', i, 'loss', round(loss.value(), 2), 'forward-time',
              round(t1 - t0, 3), 'backward-time', round(t2 - t1, 3), 'length',
              len(words))
        print('>', tree.un_cnf().linearize(with_tag=False))
        print('>', pred.linearize(with_tag=False))
        print('>', sample.linearize(with_tag=False))
        print()
示例#7
0
def predict_entropy(args):
    print(
        f'Predicting entropy for lines in `{args.infile}`, writing to `{args.outfile}`...'
    )
    print(f'Loading model from `{args.checkpoint}`.')
    print(f'Using {args.num_samples} samples.')

    parser = load_model(args.checkpoint)
    parser.eval()

    with open(args.infile, 'r') as f:
        lines = [line.strip() for line in f.readlines()]

    if is_tree(lines[0]):
        sentences = [fromstring(line.strip()).words() for line in lines]
    else:
        sentences = [line.strip().split() for line in lines]

    with open(args.outfile, 'w') as f:
        print('id',
              'entropy',
              'num-samples',
              'model',
              'file',
              file=f,
              sep='\t')
        for i, words in enumerate(tqdm(sentences)):
            dy.renew_cg()
            if args.num_samples == 0:
                assert args.model_type == 'crf', 'exact computation only for crf.'
                entropy = parser.entropy(words)
            else:
                if args.model_type == 'crf':
                    samples = parser.sample(words,
                                            num_samples=args.num_samples)
                    if args.num_samples == 1:
                        samples = [samples]
                else:
                    samples = [
                        parser.sample(words, alpha=args.alpha)
                        for _ in range(args.num_samples)
                    ]
                trees, nlls = zip(*samples)
                entropy = dy.esum(list(nlls)) / len(nlls)
            print(i,
                  entropy.value(),
                  args.num_samples,
                  args.model_type,
                  args.infile,
                  file=f,
                  sep='\t')
示例#8
0
def predict_perplexity(args):

    np.random.seed(args.numpy_seed)

    with open(args.infile, 'r') as f:
        lines = [line.strip() for line in f.readlines()]

    if is_tree(lines[0]):
        sentences = [fromstring(line.strip()).words() for line in lines]
    else:
        sentences = [line.strip().split() for line in lines]

    model = load_model(args.checkpoint)
    proposal = load_model(args.proposal_model)
    decoder = GenerativeDecoder(model=model,
                                proposal=proposal,
                                num_samples=args.num_samples,
                                alpha=args.alpha)

    proposal_type = 'disc-rnng' if isinstance(proposal, DiscRNNG) else 'crf'

    filename_base = 'proposal={}_num-samples={}_temp={}_seed={}'.format(
        proposal_type, args.num_samples, args.alpha, args.numpy_seed)
    proposals_path = os.path.join(args.outdir, filename_base + '.props')
    result_path = os.path.join(args.outdir, filename_base + '.tsv')

    print('Predicting perplexity with Generative RNNG.')
    print(f'Loading model from `{args.checkpoint}`.')
    print(f'Loading proposal from `{args.proposal_model}`.')
    print(f'Loading lines from directory `{args.infile}`.')
    print(f'Writing proposals to `{proposals_path}`.')
    print(f'Writing predictions to `{result_path}`.')

    print('Sampling proposals...')
    decoder.generate_proposal_samples(sentences, proposals_path)
    print('Computing perplexity...')
    _, perplexity = decoder.predict_from_proposal_samples(proposals_path)

    with open(result_path, 'w') as f:
        print('\t'.join(
            ('proposal', 'file', 'perplexity', 'num-samples', 'temp', 'seed')),
              file=f)
        print('\t'.join(
            (proposal_type, os.path.basename(args.infile), str(perplexity),
             str(args.num_samples), str(args.alpha), str(args.numpy_seed))),
              file=f)
示例#9
0
def predict_tree_file(args):
    assert os.path.exists(args.infile), 'specifiy file to parse with --infile.'

    print(f'Predicting trees for lines in `{args.infile}`.')

    with open(args.infile, 'r') as f:
        lines = [
            fromstring(line.strip()).words() for line in f if line.strip()
        ]

    if args.model_type == 'disc':
        print('Loading discriminative model...')
        parser = load_model(args.checkpoint)
        parser.eval()
        print('Done.')

    elif args.model_type == 'gen':
        exit('Not yet...')

        print('Loading generative model...')
        parser = GenerativeDecoder()
        parser.load_model(path=args.checkpoint)
        if args.proposal_model:
            parser.load_proposal_model(path=args.proposal_model)
        if args.proposal_samples:
            parser.load_proposal_samples(path=args.proposal_samples)

    trees = []
    for line in tqdm(lines):
        tree, _ = parser.parse(line)
        trees.append(tree.linearize())

    pred_path = os.path.join(args.outfile)
    result_path = args.outfile + '.results'
    # Save the predicted trees.
    with open(pred_path, 'w') as f:
        print('\n'.join(trees), file=f)
    # Score the trees.
    fscore = evalb(args.evalb_dir, pred_path, args.infile, result_path)
    print(
        f'Predictions saved in `{pred_path}`. Results saved in `{result_path}`.'
    )
    print(f'F-score {fscore:.2f}.')
示例#10
0
def sample_proposals(args):
    assert os.path.exists(args.infile), 'specifiy file to parse with --infile.'

    print(f'Sampling proposal trees for sentences in `{args.infile}`.')

    with open(args.infile, 'r') as f:
        lines = [line.strip() for line in f.readlines()]

    if is_tree(lines[0]):
        sentences = [fromstring(line).words() for line in lines]
    else:
        sentences = [line.split() for line in lines]

    parser = load_model(args.checkpoint)

    samples = []
    if args.model_type == 'crf':
        for i, words in enumerate(tqdm(sentences)):
            dy.renew_cg()
            for tree, nll in parser.sample(words,
                                           num_samples=args.num_samples):
                samples.append(' ||| '.join((str(i), str(-nll.value()),
                                             tree.linearize(with_tag=False))))
                print(' ||| '.join((str(i), str(-nll.value()),
                                    tree.linearize(with_tag=False))))

    else:
        for i, words in enumerate(tqdm(sentences)):
            for _ in range(args.num_samples):
                dy.renew_cg()
                tree, nll = parser.sample(words, alpha=args.alpha)
                samples.append(' ||| '.join((str(i), str(-nll.value()),
                                             tree.linearize(with_tag=False))))

    with open(args.outfile, 'w') as f:
        print('\n'.join(samples), file=f, end='')
示例#11
0
 def load_trees(self, path):
     with open(path) as f:
         trees = [fromstring(line.strip()).convert() for line in f]
     return trees
示例#12
0
    def build_corpus(self):
        print(f'Loading training trees from `{self.train_path}`...')
        with open(self.train_path) as f:
            train_treebank = [fromstring(line.strip()) for line in f]

        print(f'Loading development trees from `{self.dev_path}`...')
        with open(self.dev_path) as f:
            dev_treebank = [fromstring(line.strip()) for line in f]

        print(f'Loading test trees from `{self.test_path}`...')
        with open(self.test_path) as f:
            test_treebank = [fromstring(line.strip()) for line in f]

        if self.unlabeled:
            print(f'Converting trees to unlabeled form...')
            for tree in train_treebank:
                tree.unlabelize()

        if self.model_type == 'crf':
            print(f'Converting trees to CNF...')
            train_treebank = [tree.cnf() for tree in train_treebank]

            if self.unlabeled:
                for tree in train_treebank:
                    tree.remove_chains()

        print("Constructing vocabularies...")
        if self.vocab_path is not None:
            print(f'Using word vocabulary specified in `{self.vocab_path}`')
            with open(self.vocab_path) as f:
                vocab = json.load(f)
            words = [
                word for word, count in vocab.items() for _ in range(count)
            ]
        else:
            words = [word for tree in train_treebank for word in tree.words()]

        if self.max_sent_len > 0:
            filtered_treebank = [
                tree for tree in train_treebank
                if len(tree.words()) <= self.max_sent_len
            ]

            print(
                "Using sentences with length <= {}: {:.1%} of all training trees."
                .format(self.max_sent_len,
                        len(filtered_treebank) / len(train_treebank)))

            train_treebank = filtered_treebank

        if self.min_label_count > 1:
            counted_labels = Counter(
                [label for tree in train_treebank for label in tree.labels()])
            filtered_labels = [
                label for label, count in counted_labels.most_common()
                if count >= self.min_label_count
            ]
            filtered_treebank = [
                tree for tree in train_treebank
                if all(label in filtered_labels for label in tree.labels())
            ]

            print(
                "Using labels with count >= {}: {}/{} ({:.1%}) of all labels and {:.1%} of all training trees."
                .format(self.min_label_count, len(filtered_labels),
                        len(counted_labels),
                        len(filtered_labels) / len(counted_labels),
                        len(filtered_treebank) / len(train_treebank)))

            train_treebank = filtered_treebank

        labels = [label for tree in train_treebank for label in tree.labels()]

        if self.model_type == 'crf':
            words = [UNK, START, STOP] + words
        else:
            words = [UNK] + words

        word_vocab = Vocabulary.fromlist(words, unk_value=UNK)
        label_vocab = Vocabulary.fromlist(labels)

        ##
        # counted_labels = Counter(label_vocab.counts).most_common()
        # pprint(counted_labels)
        ##

        if self.model_type.endswith('rnng'):
            # Order is very important! See DiscParser/GenParser classes to know why.
            if self.model_type == 'disc-rnng':
                actions = [SHIFT, REDUCE
                           ] + [NT(label) for label in label_vocab]
            elif self.model_type == 'gen-rnng':
                actions = [REDUCE] + [NT(label) for label in label_vocab
                                      ] + [GEN(word) for word in word_vocab]
            action_vocab = Vocabulary()
            for action in actions:
                action_vocab.add(action)
        else:
            action_vocab = Vocabulary()

        self.word_vocab = word_vocab
        self.label_vocab = label_vocab
        self.action_vocab = action_vocab

        self.train_treebank = train_treebank
        self.dev_treebank = dev_treebank
        self.test_treebank = test_treebank

        print('\n'.join((
            'Corpus statistics:',
            f'Vocab: {word_vocab.size:,} words, {label_vocab.size:,} nonterminals, {action_vocab.size:,} actions',
            f'Train: {len(train_treebank):,} sentences',
            f'Dev: {len(dev_treebank):,} sentences',
            f'Test: {len(test_treebank):,} sentences')))
示例#13
0
    def predict_from_proposal_samples(self, inpath, unlabeled=False):
        """Predict MAP trees and perplexity from proposal samples in one fell swoop."""

        # load scored proposal samples
        all_samples = defaultdict(list)  # i -> [samples for sentence i]
        with open(inpath) as f:
            for line in f:
                i, proposal_logprob, tree = line.strip().split(' ||| ')
                i, proposal_logprob, tree = int(i), float(
                    proposal_logprob), fromstring(add_dummy_tags(tree.strip()))
                if unlabeled:
                    tree.unlabelize()
                all_samples[i].append((tree, proposal_logprob))

        # check if number of samples is as desired
        for i, samples in all_samples.items():
            if self.num_samples > len(samples):
                raise ValueError('not enough samples for line {}'.format(i))
            elif self.num_samples < len(samples):
                all_samples[i] = samples[:self.num_samples]
            else:
                pass

        # score the trees
        for i, samples in tqdm(all_samples.items()):
            # count and remove duplicates
            samples = self.count_samples(samples)
            scored_samples = []
            for (tree, proposal_logprob, count) in samples:
                dy.renew_cg()
                joint_logprob = -self.model.forward(tree,
                                                    is_train=False).value()
                scored_samples.append(
                    (tree, proposal_logprob, joint_logprob, count))
            all_samples[i] = scored_samples

        # get the predictions
        trees = []
        nlls = []
        lengths = []
        for i, scored in all_samples.items():
            # sort the scored tuples according to the joint logprob
            ranked = sorted(scored, reverse=True, key=lambda t: t[2])
            # pick by highest logprob to estimate the map tree
            tree, _, _, _ = ranked[0]

            # estimate the perplexity
            weights, counts = np.zeros(len(scored)), np.zeros(len(scored))
            for i, (_, proposal_logprob, joint_logprob,
                    count) in enumerate(scored):
                weights[i] = joint_logprob - proposal_logprob
                counts[i] = count
            # log-mean-exp for stability
            a = weights.max()
            logprob = a + np.log(np.mean(np.exp(weights - a) * counts))

            trees.append(tree.linearize())  # the estimated MAP tree
            nlls.append(-logprob)  # the estimate for -log p(x)
            lengths.append(len(tree.words()))  # needed to compute perplexity

        # the perplexity is averaged over the total number of words
        perplexity = np.exp(np.sum(nlls) / np.sum(lengths))

        return trees, round(perplexity, 2)