コード例 #1
0
def test_fragments():
    from discodop._fragments import getctrees, extractfragments, exactcounts
    treebank = """\
(S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) (JJ 4) (NN 5))))\
	The cat saw the hungry dog
(S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) (NN 4))))\
	The cat saw the dog
(S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) (NN 4))))\
	The mouse saw the cat
(S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) (JJ 4) (NN 5))))\
	The mouse saw the yellow cat
(S (NP (DT 0) (JJ 1) (NN 2)) (VP (VBP 3) (NP (DT 4) (NN 5))))\
	The little mouse saw the cat
(S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) (NN 4))))\
	The cat ate the dog
(S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) (NN 4))))\
	The mouse ate the cat""".splitlines()
    trees = [binarize(Tree(line.split('\t')[0])) for line in treebank]
    sents = [line.split('\t')[1].split() for line in treebank]
    for tree in trees:
        for n, idx in enumerate(tree.treepositions('leaves')):
            tree[idx] = n
    params = getctrees(zip(trees, sents))
    fragments = extractfragments(params['trees1'],
                                 0,
                                 0,
                                 params['vocab'],
                                 disc=True,
                                 approx=False)
    counts = exactcounts(list(fragments.values()), params['trees1'],
                         params['trees1'])
    assert len(fragments) == 25
    assert sum(counts) == 100
コード例 #2
0
    def test_binarize(self):
        treestr = '(S (VP (PDS 0) (ADV 3) (VVINF 4)) (VMFIN 1) (PIS 2))'
        origtree = Tree(treestr)
        tree = Tree(treestr)
        tree[1].type = HEAD  # set VMFIN as head
        assert str(binarize(tree, horzmarkov=0)) == (
            '(S (VP (PDS 0) (VP|<> (ADV 3) (VVINF 4))) (S|<> (VMFIN 1)'
            ' (PIS 2)))')
        assert unbinarize(tree) == origtree

        assert str(binarize(tree, horzmarkov=1)) == (
            '(S (VP (PDS 0) (VP|<ADV> (ADV 3) (VVINF 4))) (S|<VMFIN> '
            '(VMFIN 1) (PIS 2)))')
        assert unbinarize(tree) == origtree

        assert str(
            binarize(tree,
                     horzmarkov=1,
                     leftmostunary=False,
                     rightmostunary=True,
                     headoutward=True)
        ) == ('(S (VP (PDS 0) (VP|<ADV> (ADV 3) (VP|<VVINF> (VVINF 4)))) '
              '(S|<VMFIN> (S|<VMFIN> (VMFIN 1)) (PIS 2)))')
        assert unbinarize(tree) == origtree

        assert str(
            binarize(tree,
                     horzmarkov=1,
                     leftmostunary=True,
                     rightmostunary=False,
                     headoutward=True)) == (
                         '(S (S|<VP> (VP (VP|<PDS> (PDS 0) (VP|<ADV> (ADV 3) '
                         '(VVINF 4)))) (S|<VMFIN> (VMFIN 1) (PIS 2))))')
        assert unbinarize(tree) == origtree

        assert str(
            binarize(tree, factor='left', horzmarkov=2, headoutward=True)) == (
                '(S (S|<VMFIN,PIS> (VP (VP|<PDS,ADV> (PDS 0) (ADV 3)) '
                '(VVINF 4)) (VMFIN 1)) (PIS 2))')
        assert unbinarize(tree) == origtree

        tree = Tree('(S (A 0) (B 1) (C 2) (D 3) (E 4) (F 5))')
        assert str(binarize(tree, headoutward=True)) == (
            '(S (A 0) (S|<B,C,D,E,F> (B 1) (S|<C,D,E,F> (C 2) (S|<D,E,F> '
            '(D 3) (S|<E,F> (E 4) (F 5))))))')
コード例 #3
0
    def get_subtree(self, nt):
        """Return a derivation subtree.

        Parameters
        ----------
        nt : str
            The nonterminal to start with.

        Returns
        -------
        Tree
            The tree with the current node as root.

        """
        edge = self.get_witness(nt)[0]
        if edge is None:
            raise ValueError("There is no witness for %s" % nt)
        if not edge.get_successors():
            stdout.flush()
            return Tree(edge.get_nonterminal(), [self.get_label()[0]])
        else:
            s = edge.get_successors()
            return Tree(edge.get_nonterminal(),
                        [t.get_subtree(n) for t, n in s])
コード例 #4
0
def reattach():
    """Re-draw tree after re-attaching node under new parent."""
    sentno = int(request.args.get('sentno'))  # 1-indexed
    sent = SENTENCES[QUEUE[sentno - 1][0]]
    senttok, _ = worker.postokenize(sent)
    treestr = request.args.get('tree', '')
    try:
        tree, _sent1 = validate(treestr, senttok)
    except ValueError as err:
        return str(err)
    dt = DrawTree(tree, senttok)
    error = ''
    if request.args.get('newparent') == 'deletenode':
        # remove nodeid by replacing it with its children
        _treeid, nodeid = request.args.get('nodeid', '').lstrip('t').split('_')
        nodeid = int(nodeid)
        x = dt.nodes[nodeid]
        if nodeid == 0 or isinstance(x[0], int):
            error = 'ERROR: cannot remove ROOT or POS node'
        else:
            children = list(x)
            x[:] = []
            for y in dt.nodes[0].subtrees():
                if any(child is x for child in y):
                    i = y.index(x)
                    y[i:i + 1] = children
                    tree = canonicalize(dt.nodes[0])
                    dt = DrawTree(tree, senttok)  # kludge..
                    break
    elif request.args.get('nodeid', '').startswith('newlabel_'):
        # splice in a new node under parentid
        _treeid, newparent = request.args.get('newparent',
                                              '').lstrip('t').split('_')
        newparent = int(newparent)
        label = request.args.get('nodeid').split('_', 1)[1]
        y = dt.nodes[newparent]
        if isinstance(y[0], int):
            error = 'ERROR: cannot add node under POS tag'
        else:
            children = list(y)
            y[:] = []
            y[:] = [Tree(label, children)]
            tree = canonicalize(dt.nodes[0])
            dt = DrawTree(tree, senttok)  # kludge..
    else:  # re-attach existing node at existing new parent
        _treeid, nodeid = request.args.get('nodeid', '').lstrip('t').split('_')
        nodeid = int(nodeid)
        _treeid, newparent = request.args.get('newparent',
                                              '').lstrip('t').split('_')
        newparent = int(newparent)
        # remove node from old parent
        # dt.nodes[nodeid].parent.pop(dt.nodes[nodeid].parent_index)
        x = dt.nodes[nodeid]
        y = dt.nodes[newparent]
        for node in x.subtrees():
            if node is y:
                error = ('ERROR: cannot re-attach subtree'
                         ' under (descendant of) itself\n')
                break
        else:
            for node in dt.nodes[0].subtrees():
                if any(child is x for child in node):
                    if len(node) > 1:
                        node.remove(x)
                        dt.nodes[newparent].append(x)
                        tree = canonicalize(dt.nodes[0])
                        dt = DrawTree(tree, senttok)  # kludge..
                    else:
                        error = ('ERROR: re-attaching only child creates'
                                 ' empty node %s; remove manually\n' % node)
                    break
    treestr = writediscbrackettree(tree, senttok, pretty=True).rstrip()
    link = ('<a href="/annotate/accept?%s">accept this tree</a>' %
            urlencode(dict(sentno=sentno, tree=treestr)))
    if error == '':
        session['actions'][REATTACH] += 1
        session.modified = True
    return Markup('%s\n\n%s%s\t%s' % (link, error,
                                      dt.text(unicodelines=True,
                                              html=True,
                                              funcsep='-',
                                              morphsep='/',
                                              nodeprops='t0'), treestr))
コード例 #5
0
def test_treedraw():
    """Draw some trees. Only tests whether no exception occurs."""
    trees = '''(ROOT (S (ADV 0) (VVFIN 1) (NP (PDAT 2) (NN 3)) (PTKNEG 4) \
				(PP (APPRART 5) (NN 6) (NP (ART 7) (ADJA 8) (NN 9)))) ($. 10))
			(S (NP (NN 1) (EX 3)) (VP (VB 0) (JJ 2)))
			(S (VP (PDS 0) (ADV 3) (VVINF 4)) (PIS 2) (VMFIN 1))
			(top (du (comp 0) (smain (noun 1) (verb 2) (inf (verb 8) (inf \
				(adj 3) (pp (prep 4) (np (det 5) (noun 6))) (part 7) (verb 9) \
				(pp (prep 10) (np (det 11) (noun 12) (pp (prep 13) (mwu \
				(noun 14) (noun 15))))))))) (punct 16))
			(top (smain (noun 0) (verb 1) (inf (verb 5) (inf (np (det 2) \
				(adj 3) (noun 4)) (verb 6) (pp (prep 7) (noun 8))))) (punct 9))
			(top (smain (noun 0) (verb 1) (noun 2) (inf (adv 3) (verb 4))) \
				(punct 5))
			(top (punct 5) (du (smain (noun 0) (verb 1) (ppart (np (det 2) \
				(noun 3)) (verb 4))) (conj (sv1 (conj (noun 6) (vg 7) (np \
				(det 8) (noun 9))) (verb 10) (noun 11) (part 12)) (vg 13) \
				(sv1 (verb 14) (ti (comp 19) (inf (np (conj (det 15) (vg 16) \
				(det 17)) (noun 18)) (verb 20)))))) (punct 21))
			(top (punct 10) (punct 16) (punct 18) (smain (np (det 0) (noun 1) \
				(pp (prep 2) (np (det 3) (noun 4)))) (verb 5) (adv 6) (np \
				(noun 7) (noun 8)) (part 9) (np (det 11) (noun 12) (pp \
				(prep 13) (np (det 14) (noun 15)))) (conj (vg 20) (ppres \
				(adj 17) (pp (prep 22) (np (det 23) (adj 24) (noun 25)))) \
				(ppres (adj 19)) (ppres (adj 21)))) (punct 26))
			(top (punct 10) (punct 11) (punct 16) (smain (np (det 0) \
				(noun 1)) (verb 2) (np (det 3) (noun 4)) (adv 5) (du (cp \
				(comp 6) (ssub (noun 7) (verb 8) (inf (verb 9)))) (du \
				(smain (noun 12) (verb 13) (adv 14) (part 15)) (noun 17)))) \
				(punct 18) (punct 19))
			(top (smain (noun 0) (verb 1) (inf (verb 8) (inf (verb 9) (inf \
				(adv 2) (pp (prep 3) (noun 4)) (pp (prep 5) (np (det 6) \
				(noun 7))) (verb 10))))) (punct 11))
			(top (smain (noun 0) (verb 1) (pp (prep 2) (np (det 3) (adj 4) \
				(noun 5) (rel (noun 6) (ssub (noun 7) (verb 10) (ppart \
				(adj 8) (part 9) (verb 11))))))) (punct 12))
			(top (smain (np (det 0) (noun 1)) (verb 2) (ap (adv 3) (num 4) \
				(cp (comp 5) (np (det 6) (adj 7) (noun 8) (rel (noun 9) (ssub \
				(noun 10) (verb 11) (pp (prep 12) (np (det 13) (adj 14) \
				(adj 15) (noun 16))))))))) (punct 17))
			(top (smain (np (det 0) (noun 1)) (verb 2) (adv 3) (pp (prep 4) \
				(np (det 5) (noun 6)) (part 7))) (punct 8))
			(top (punct 7) (conj (smain (noun 0) (verb 1) (np (det 2) \
				(noun 3)) (pp (prep 4) (np (det 5) (noun 6)))) (smain \
				(verb 8) (np (det 9) (num 10) (noun 11)) (part 12)) (vg 13) \
				(smain (verb 14) (noun 15) (pp (prep 16) (np (det 17) \
				(noun 18) (pp (prep 19) (np (det 20) (noun 21))))))) \
				(punct 22))
			(top (smain (np (det 0) (noun 1) (rel (noun 2) (ssub (np (num 3) \
				(noun 4)) (adj 5) (verb 6)))) (verb 7) (ppart (verb 8) (pp \
				(prep 9) (noun 10)))) (punct 11))
			(top (conj (sv1 (np (det 0) (noun 1)) (verb 2) (ppart (verb 3))) \
				(vg 4) (sv1 (verb 5) (pp (prep 6) (np (det 7) (adj 8) \
				(noun 9))))) (punct 10))
			(top (smain (noun 0) (verb 1) (np (det 2) (noun 3)) (inf (adj 4) \
				(verb 5) (cp (comp 6) (ssub (noun 7) (adv 8) (verb 10) (ap \
				(num 9) (cp (comp 11) (np (det 12) (adj 13) (noun 14) (pp \
				(prep 15) (conj (np (det 16) (noun 17)) (vg 18) (np \
				(noun 19))))))))))) (punct 20))
			(top (punct 8) (smain (noun 0) (verb 1) (inf (verb 5) \
				(inf (verb 6) (conj (inf (pp (prep 2) (np (det 3) (noun 4))) \
				(verb 7)) (inf (verb 9)) (vg 10) (inf (verb 11)))))) \
				(punct 12))
			(top (smain (verb 2) (noun 3) (adv 4) (ppart (np (det 0) \
				(noun 1)) (verb 5))) (punct 6))
			(top (conj (smain (np (det 0) (noun 1)) (verb 2) (adj 3) (pp \
				(prep 4) (np (det 5) (noun 6)))) (vg 7) (smain (np (det 8) \
				(noun 9) (pp (prep 10) (np (det 11) (noun 12)))) (verb 13) \
				(pp (prep 14) (np (det 15) (noun 16))))) (punct 17))
			(top (conj (smain (noun 0) (verb 1) (inf (ppart (np (noun 2) \
				(noun 3)) (verb 4)) (verb 5))) (vg 6) (smain (noun 7) \
				(inf (ppart (np (det 8) (noun 9)))))) (punct 10))
			(A (B1 (t 6) (t 13)) (B2 (t 3) (t 7) (t 10))  (B3 (t 1) \
				(t 9) (t 11) (t 14) (t 16)) (B4 (t 0) (t 5) (t 8)))
			(A (B1 6 13) (B2 3 7 10)  (B3 1 \
				9 11 14 16) (B4 0 5 8))
			(VP (VB 0) (PRT 2))
			(VP (VP 0 3) (NP (PRP 1) (NN 2)))
			(ROOT (S (VP_2 (PP (APPR 0) (ART 1) (NN 2) (PP (APPR 3) (ART 4) \
				(ADJA 5) (NN 6))) (ADJD 10) (PP (APPR 11) (NN 12)) (VVPP 13)) \
				(VAFIN 7) (NP (ART 8) (NN 9))) ($. 14))'''
    sents = '''Leider stehen diese Fragen nicht im Vordergrund der \
				augenblicklichen Diskussion .
			is Mary happy there
			das muss man jetzt machen
			Of ze had gewoon met haar vriendinnen rond kunnen slenteren in de \
				buurt van Trafalgar Square .
			Het had een prachtige dag kunnen zijn in Londen .
			Cathy zag hen wild zwaaien .
			Het was een spel geworden , zij en haar vriendinnen kozen iemand \
				uit en probeerden zijn of haar nationaliteit te raden .
			Elk jaar in het hoogseizoen trokken daar massa's toeristen \
				voorbij , hun fototoestel in de aanslag , pratend , gillend \
				en lachend in de vreemdste talen .
			Haar vader stak zijn duim omhoog alsof hij wilde zeggen : " het \
				komt wel goed , joch " .
			Ze hadden languit naast elkaar op de strandstoelen kunnen gaan \
				liggen .
			Het hoorde bij de warme zomerdag die ze ginds achter had gelaten .
			De oprijlaan was niet meer dan een hobbelige zandstrook die zich \
				voortslingerde tussen de hoge grijze boomstammen .
			Haar moeder kleefde bijna tegen het autoraampje aan .
			Ze veegde de tranen uit haar ooghoeken , tilde haar twee koffers \
				op en begaf zich in de richting van het landhuis .
			Het meisje dat vijf keer juist raadde werd getrakteerd op ijs .
			Haar neus werd platgedrukt en leek op een jonge champignon .
			Cathy zag de BMW langzaam verdwijnen tot hij niet meer was dan \
				een zilveren schijnsel tussen de bomen en struiken .
			Ze had met haar moeder kunnen gaan winkelen , zwemmen of \
				terrassen .
			Dat werkwoord had ze zelf uitgevonden .
			De middagzon hing klein tussen de takken en de schaduwen van de \
				wolken drentelden over het gras .
			Zij zou mams rug ingewreven hebben en mam de hare .
			0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
			0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
			0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
			0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
			Mit einer Messe in der Sixtinischen Kapelle ist das Konklave \
				offiziell zu Ende gegangen .'''
    from discodop.tree import DrawTree
    trees = [Tree(a) for a in trees.splitlines()]
    sents = [a.split() for a in sents.splitlines()]
    sents.extend([['Wake', None, 'up'], [None, 'your', 'friend', None]])
    for n, (tree, sent) in enumerate(zip(trees, sents)):
        drawtree = DrawTree(tree, sent)
        print('\ntree, sent',
              n,
              tree,
              ' '.join('...' if a is None else a for a in sent),
              repr(drawtree),
              sep='\n')
        try:
            print(drawtree.text(unicodelines=True, ansi=True), sep='\n')
        except (UnicodeDecodeError, UnicodeEncodeError):
            print(drawtree.text(unicodelines=False, ansi=False), sep='\n')
コード例 #6
0
def test_allfragments():
    from discodop.fragments import recurringfragments
    model = """\
(DT the)	1
(DT The)	1
(JJ hungry)	1
(NN cat)	1
(NN dog)	1
(NP|<DT.JJ,NN> (JJ hungry) (NN ))	1
(NP|<DT.JJ,NN> (JJ hungry) (NN dog))	1
(NP|<DT.JJ,NN> (JJ ) (NN ))	1
(NP|<DT.JJ,NN> (JJ ) (NN dog))	1
(NP (DT ) (NN ))	1
(NP (DT ) (NN cat))	1
(NP (DT ) (NP|<DT.JJ,NN> ))	1
(NP (DT ) (NP|<DT.JJ,NN> (JJ hungry) (NN )))	1
(NP (DT ) (NP|<DT.JJ,NN> (JJ hungry) (NN dog)))	1
(NP (DT ) (NP|<DT.JJ,NN> (JJ ) (NN )))	1
(NP (DT ) (NP|<DT.JJ,NN> (JJ ) (NN dog)))	1
(NP (DT The) (NN ))	1
(NP (DT The) (NN cat))	1
(NP (DT the) (NP|<DT.JJ,NN> ))	1
(NP (DT the) (NP|<DT.JJ,NN> (JJ hungry) (NN )))	1
(NP (DT the) (NP|<DT.JJ,NN> (JJ hungry) (NN dog)))	1
(NP (DT the) (NP|<DT.JJ,NN> (JJ ) (NN )))	1
(NP (DT the) (NP|<DT.JJ,NN> (JJ ) (NN dog)))	1
(S (NP (DT ) (NN cat)) (VP ))	1
(S (NP (DT ) (NN cat)) (VP (VBP ) (NP )))	1
(S (NP (DT ) (NN cat)) (VP (VBP ) (NP (DT ) (NP|<DT.JJ,NN> ))))	1
(S (NP (DT ) (NN cat)) (VP (VBP saw) (NP )))	1
(S (NP (DT ) (NN cat)) (VP (VBP saw) (NP (DT ) (NP|<DT.JJ,NN> ))))	1
(S (NP (DT ) (NN )) (VP ))	1
(S (NP (DT ) (NN )) (VP (VBP ) (NP )))	1
(S (NP (DT ) (NN )) (VP (VBP ) (NP (DT ) (NP|<DT.JJ,NN> ))))	1
(S (NP (DT ) (NN )) (VP (VBP saw) (NP )))	1
(S (NP (DT ) (NN )) (VP (VBP saw) (NP (DT ) (NP|<DT.JJ,NN> ))))	1
(S (NP (DT The) (NN cat)) (VP ))	1
(S (NP (DT The) (NN cat)) (VP (VBP ) (NP )))	1
(S (NP (DT The) (NN cat)) (VP (VBP ) (NP (DT ) (NP|<DT.JJ,NN> ))))	1
(S (NP (DT The) (NN cat)) (VP (VBP saw) (NP )))	1
(S (NP (DT The) (NN cat)) (VP (VBP saw) (NP (DT ) (NP|<DT.JJ,NN> ))))	1
(S (NP (DT The) (NN )) (VP ))	1
(S (NP (DT The) (NN )) (VP (VBP ) (NP )))	1
(S (NP (DT The) (NN )) (VP (VBP ) (NP (DT ) (NP|<DT.JJ,NN> ))))	1
(S (NP (DT The) (NN )) (VP (VBP saw) (NP )))	1
(S (NP (DT The) (NN )) (VP (VBP saw) (NP (DT ) (NP|<DT.JJ,NN> ))))	1
(S (NP ) (VP ))	1
(S (NP ) (VP (VBP ) (NP )))	1
(S (NP ) (VP (VBP ) (NP (DT ) (NP|<DT.JJ,NN> ))))	1
(S (NP ) (VP (VBP saw) (NP )))	1
(S (NP ) (VP (VBP saw) (NP (DT ) (NP|<DT.JJ,NN> ))))	1
(VBP saw)	1
(VP (VBP ) (NP ))	1
(VP (VBP ) (NP (DT ) (NP|<DT.JJ,NN> )))	1
(VP (VBP ) (NP (DT ) (NP|<DT.JJ,NN> (JJ ) (NN ))))	1
(VP (VBP ) (NP (DT the) (NP|<DT.JJ,NN> )))	1
(VP (VBP ) (NP (DT the) (NP|<DT.JJ,NN> (JJ ) (NN ))))	1
(VP (VBP saw) (NP ))	1
(VP (VBP saw) (NP (DT ) (NP|<DT.JJ,NN> )))	1
(VP (VBP saw) (NP (DT ) (NP|<DT.JJ,NN> (JJ ) (NN ))))	1
(VP (VBP saw) (NP (DT the) (NP|<DT.JJ,NN> )))	1
(VP (VBP saw) (NP (DT the) (NP|<DT.JJ,NN> (JJ ) (NN ))))	1"""
    model = {
        a.split('\t')[0]: int(a.split('\t')[1])
        for a in model.splitlines()
    }
    answers = recurringfragments([
        Tree('(S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) '
             '(NP|<DT.JJ,NN> (JJ 4) (NN 5)))))')
    ], [['The', 'cat', 'saw', 'the', 'hungry', 'dog']],
                                 disc=False,
                                 indices=False,
                                 maxdepth=3,
                                 maxfrontier=999)
    assert model
    assert answers
    assert answers == model
コード例 #7
0
ファイル: unittests.py プロジェクト: jstenhouse/disco-dop
def test_grammar(debug=False):
    """Demonstrate grammar extraction."""
    from discodop.grammar import treebankgrammar, dopreduction, doubledop
    from discodop import plcfrs
    from discodop.containers import Grammar
    from discodop.treebank import NegraCorpusReader
    from discodop.treetransforms import addfanoutmarkers, removefanoutmarkers
    from discodop.disambiguation import recoverfragments
    from discodop.kbest import lazykbest
    from math import exp
    corpus = NegraCorpusReader('alpinosample.export', punct='move')
    sents = list(corpus.sents().values())
    trees = [
        addfanoutmarkers(binarize(a.copy(True), horzmarkov=1))
        for a in list(corpus.trees().values())[:10]
    ]
    if debug:
        print('plcfrs\n', Grammar(treebankgrammar(trees, sents)))
        print('dop reduction')
    grammar = Grammar(dopreduction(trees[:2], sents[:2])[0],
                      start=trees[0].label)
    if debug:
        print(grammar)
    _ = grammar.testgrammar()

    grammarx, backtransform, _, _ = doubledop(trees,
                                              sents,
                                              debug=False,
                                              numproc=1)
    if debug:
        print('\ndouble dop grammar')
    grammar = Grammar(grammarx, start=trees[0].label)
    grammar.getmapping(grammar,
                       striplabelre=None,
                       neverblockre=re.compile('^#[0-9]+|.+}<'),
                       splitprune=False,
                       markorigin=False)
    if debug:
        print(grammar)
    assert grammar.testgrammar()[0], "RFE should sum to 1."
    for tree, sent in zip(corpus.trees().values(), sents):
        if debug:
            print("sentence:",
                  ' '.join(a.encode('unicode-escape').decode() for a in sent))
        chart, msg = plcfrs.parse(sent, grammar, exhaustive=True)
        if debug:
            print('\n', msg, '\ngold ', tree, '\n', 'double dop', end='')
        if chart:
            mpp, parsetrees = {}, {}
            derivations, _ = lazykbest(chart, 1000, '}<')
            for d, (t, p) in zip(chart.rankededges[chart.root()], derivations):
                r = Tree(recoverfragments(d.key, chart, backtransform))
                r = str(removefanoutmarkers(unbinarize(r)))
                mpp[r] = mpp.get(r, 0.0) + exp(-p)
                parsetrees.setdefault(r, []).append((t, p))
            if debug:
                print(len(mpp), 'parsetrees',
                      sum(map(len, parsetrees.values())), 'derivations')
            for t, tp in sorted(mpp.items(), key=itemgetter(1)):
                if debug:
                    print(tp, t, '\nmatch:', t == str(tree))
                if len(set(parsetrees[t])) != len(parsetrees[t]):
                    print('chart:\n', chart)
                    assert len(set(parsetrees[t])) == len(parsetrees[t])
                if debug:
                    for deriv, p in sorted(parsetrees[t], key=itemgetter(1)):
                        print(' <= %6g %s' % (exp(-p), deriv))
        elif debug:
            print('no parse\n', chart)
        if debug:
            print()
    tree = Tree.parse("(ROOT (S (F (E (S (C (B (A 0))))))))", parse_leaf=int)
    Grammar(treebankgrammar([tree], [[str(a) for a in range(10)]]))
コード例 #8
0
    def evaluate(self,
                 sentences: SupertagParseDataset,
                 mini_batch_size: int = 32,
                 num_workers: int = 1,
                 embedding_storage_mode: str = "none",
                 out_path=None,
                 only_disc: str = "both",
                 accuracy: str = "both",
                 pos_accuracy: bool = True,
                 return_loss: bool = True) -> Tuple[Result, float]:
        """ Predicts supertags, pos tags and parse trees, and reports the
            predictions scores for a set of sentences.
            :param sentences: a ``DataSet`` of sentences. For each sentence
                a gold parse tree is expected as value of the `tree` label, as
                provided by ``SupertagParseDataset``.
            :param only_disc: If set, overrides the setting `DISC_ONLY` in the
                evaluation parameter file ``self.evalparam``, i.e. only evaluates
                discontinuous constituents if True. Pass "both" to report both
                results.
            :param accuracy: either 'none', 'best', 'kbest' or 'both'.
                Determines if the accuracy is computed from the best, or k-best
                predicted tags.
            :param pos_accuracy: if set, reports acc. of predicted pos tags.
            :param return_loss: if set, nll loss wrt. gold tags is reported,
                otherwise the second component in the returned tuple is 0.
            :returns: tuple with evaluation ``Result``, where the main score
                is the f1-score (for all constituents, if only_disc == "both").
        """
        from flair.datasets import DataLoader
        from discodop.tree import ParentedTree, Tree
        from discodop.treetransforms import unbinarize, removefanoutmarkers
        from discodop.eval import Evaluator, readparam
        from timeit import default_timer
        from collections import Counter

        if self.__evalparam__ is None:
            raise Exception(
                "Need to specify evaluator parameter file before evaluating")
        if only_disc == "both":
            evaluators = {
                "F1-all": Evaluator({
                    **self.evalparam, "DISC_ONLY": False
                }),
                "F1-disc": Evaluator({
                    **self.evalparam, "DISC_ONLY": True
                })
            }
        else:
            mode = self.evalparam["DISC_ONLY"] if only_disc == "param" else (
                only_disc == "true")
            strmode = "F1-disc" if mode else "F1-all"
            evaluators = {
                strmode: Evaluator({
                    **self.evalparam, "DISC_ONLY": mode
                })
            }

        data_loader = DataLoader(sentences,
                                 batch_size=mini_batch_size,
                                 num_workers=num_workers)

        # predict supertags and parse trees
        eval_loss = 0
        start_time = default_timer()
        for batch in data_loader:
            loss = self.predict(batch,
                                embedding_storage_mode=embedding_storage_mode,
                                supertag_storage_mode=accuracy,
                                postag_storage_mode=pos_accuracy,
                                label_name='predicted',
                                return_loss=return_loss)
            eval_loss += loss if return_loss else 0
        end_time = default_timer()

        i = 0
        batches = 0
        noparses = 0
        acc_ctr = Counter()
        for batch in data_loader:
            for sentence in batch:
                for token in sentence:
                    if accuracy in ("kbest", "both") and token.get_tag("supertag").value in \
                            (l.value for l in token.get_tags_proba_dist('predicted-supertag')):
                        acc_ctr["kbest"] += 1
                    if accuracy in ("best", "both") and token.get_tag("supertag").value == \
                            token.get_tag('predicted-supertag').value:
                        acc_ctr["best"] += 1
                    if pos_accuracy and token.get_tag(
                            "pos").value == token.get_tag(
                                "predicted-pos").value:
                        acc_ctr["pos"] += 1
                acc_ctr["all"] += len(sentence)
                sent = [token.text for token in sentence]
                gold = Tree(sentence.get_labels("tree")[0].value)
                gold = ParentedTree.convert(
                    unbinarize(removefanoutmarkers(gold)))
                parse = Tree(sentence.get_labels("predicted")[0].value)
                parse = ParentedTree.convert(
                    unbinarize(removefanoutmarkers(parse)))
                if parse.label == "NOPARSE":
                    noparses += 1
                for evaluator in evaluators.values():
                    evaluator.add(i, gold.copy(deep=True), list(sent),
                                  parse.copy(deep=True), list(sent))
                i += 1
            batches += 1
        scores = {
            strmode: float_or_zero(evaluator.acc.scores()['lf'])
            for strmode, evaluator in evaluators.items()
        }
        if accuracy in ("both", "kbest"):
            scores["accuracy-kbest"] = acc_ctr["kbest"] / acc_ctr["all"]
        if accuracy in ("both", "best"):
            scores["accuracy-best"] = acc_ctr["best"] / acc_ctr["all"]
        if pos_accuracy:
            scores["accuracy-pos"] = acc_ctr["pos"] / acc_ctr["all"]
        scores["coverage"] = 1 - (noparses / i)
        scores["time"] = end_time - start_time
        return (Result(
            scores['F1-all'] if 'F1-all' in scores else scores['F1-disc'],
            "\t".join(f"{mode}" for mode in scores),
            "\t".join(f"{s}" for s in scores.values()),
            '\n\n'.join(evaluator.summary()
                        for evaluator in evaluators.values())),
                eval_loss / batches)