def test_optimalbinarize(): """Verify that all optimal parsing complexities are lower than or equal to the complexities of right-to-left binarizations.""" from discodop.treetransforms import optimalbinarize, complexityfanout from discodop.treebank import NegraCorpusReader corpus = NegraCorpusReader('alpinosample.export', punct='move') total = violations = violationshd = 0 for n, (tree, sent) in enumerate(zip(list( corpus.trees().values())[:-2000], corpus.sents().values())): t = addbitsets(tree) if all(fanout(x) == 1 for x in t.subtrees()): continue print(n, tree, '\n', ' '.join(sent)) total += 1 optbin = optimalbinarize(tree.copy(True), headdriven=False, h=None, v=1) # undo head-ordering to get a normal right-to-left binarization normbin = addbitsets(binarize(canonicalize(Tree.convert(tree)))) if (max(map(complexityfanout, optbin.subtrees())) > max(map(complexityfanout, normbin.subtrees()))): print('non-hd\n', tree) print(max(map(complexityfanout, optbin.subtrees())), optbin) print(max(map(complexityfanout, normbin.subtrees())), normbin, '\n') violations += 1 optbin = optimalbinarize(tree.copy(True), headdriven=True, h=1, v=1) normbin = addbitsets(binarize(Tree.convert(tree), horzmarkov=1)) if (max(map(complexityfanout, optbin.subtrees())) > max(map(complexityfanout, normbin.subtrees()))): print('hd\n', tree) print(max(map(complexityfanout, optbin.subtrees())), optbin) print(max(map(complexityfanout, normbin.subtrees())), normbin, '\n') violationshd += 1 print('opt. bin. violations normal: %d / %d; hd: %d / %d' % ( violations, total, violationshd, total)) assert violations == violationshd == 0
def test_fragments(): from discodop._fragments import getctrees, extractfragments, exactcounts treebank = """\ (S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) (JJ 4) (NN 5))))\ The cat saw the hungry dog (S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) (NN 4))))\ The cat saw the dog (S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) (NN 4))))\ The mouse saw the cat (S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) (JJ 4) (NN 5))))\ The mouse saw the yellow cat (S (NP (DT 0) (JJ 1) (NN 2)) (VP (VBP 3) (NP (DT 4) (NN 5))))\ The little mouse saw the cat (S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) (NN 4))))\ The cat ate the dog (S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) (NN 4))))\ The mouse ate the cat""".splitlines() trees = [binarize(Tree(line.split('\t')[0])) for line in treebank] sents = [line.split('\t')[1].split() for line in treebank] for tree in trees: for n, idx in enumerate(tree.treepositions('leaves')): tree[idx] = n params = getctrees(zip(trees, sents)) fragments = extractfragments(params['trees1'], 0, 0, params['vocab'], disc=True, approx=False) counts = exactcounts(params['trees1'], params['trees1'], list(fragments.values())) assert len(fragments) == 25 assert sum(counts) == 100
def convertTreeToVector(treeStringsInput, featureMap): totalTrees = len(treeStringsInput) treeStrings = treeStringsInput[:] treeStrings.extend(treeStrings[:]) text = BracketStringReader(treeStrings) trees = [treetransforms.binarize(tree, horzmarkov=1, vertmarkov=1) for _, (tree, _) in text.itertrees(0)] sents = [sent for _, (_, sent) in text.itertrees(0)] result = fragments.getfragments(trees, sents, numproc=1, disc=False, cover=True) featureVector = {} found = 0 total = 0 for tree, sentDict in result.items(): total += 1 if tree in featureMap: found += 1 #print '%3d\t%s' % (sum(b.values()), a) for key, count in sentDict.items(): if key < totalTrees: treeIndex = featureMap[tree.strip()] if treeIndex in featureVector: featureVector[treeIndex] += count else: featureVector[treeIndex] = count print "Found", found, "out of", total,"trees in featurespace." return featureVector
def test_fragments(): from discodop._fragments import getctrees, extractfragments, exactcounts treebank = """\ (S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) (JJ 4) (NN 5))))\ The cat saw the hungry dog (S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) (NN 4))))\ The cat saw the dog (S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) (NN 4))))\ The mouse saw the cat (S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) (JJ 4) (NN 5))))\ The mouse saw the yellow cat (S (NP (DT 0) (JJ 1) (NN 2)) (VP (VBP 3) (NP (DT 4) (NN 5))))\ The little mouse saw the cat (S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) (NN 4))))\ The cat ate the dog (S (NP (DT 0) (NN 1)) (VP (VBP 2) (NP (DT 3) (NN 4))))\ The mouse ate the cat""".splitlines() trees = [binarize(Tree(line.split('\t')[0])) for line in treebank] sents = [line.split('\t')[1].split() for line in treebank] for tree in trees: for n, idx in enumerate(tree.treepositions('leaves')): tree[idx] = n params = getctrees(zip(trees, sents)) fragments = extractfragments(params['trees1'], 0, 0, params['vocab'], disc=True, approx=False) counts = exactcounts(list(fragments.values()), params['trees1'], params['trees1']) assert len(fragments) == 25 assert sum(counts) == 100
def dobinarization(trees, sents, binarization, relationalrealizational): """Apply binarization.""" # fixme: this n should correspond to sentence id tbfanout, n = treebank.treebankfanout(trees) logging.info('treebank fan-out before binarization: %d #%d\n%s\n%s', tbfanout, n, trees[n], ' '.join(sents[n])) # binarization begin = time.clock() msg = 'binarization: %s' % binarization.method if binarization.fanout_marks_before_bin: trees = [treetransforms.addfanoutmarkers(t) for t in trees] if binarization.method is None: pass elif binarization.method == 'default': msg += ' %s h=%d v=%d %s' % ( binarization.factor, binarization.h, binarization.v, 'tailmarker' if binarization.tailmarker else '') for a in trees: treetransforms.binarize(a, factor=binarization.factor, tailmarker=binarization.tailmarker, horzmarkov=binarization.h, vertmarkov=binarization.v, leftmostunary=binarization.leftmostunary, rightmostunary=binarization.rightmostunary, reverse=binarization.revmarkov, headidx=-1 if binarization.markhead else None, filterfuncs=(relationalrealizational['ignorefunctions'] + (relationalrealizational['adjunctionlabel'], )) if relationalrealizational else (), labelfun=binarization.labelfun) elif binarization.method == 'optimal': trees = [Tree.convert(treetransforms.optimalbinarize(tree)) for n, tree in enumerate(trees)] elif binarization.method == 'optimalhead': msg += ' h=%d v=%d' % ( binarization.h, binarization.v) trees = [Tree.convert(treetransforms.optimalbinarize( tree, headdriven=True, h=binarization.h, v=binarization.v)) for n, tree in enumerate(trees)] trees = [treetransforms.addfanoutmarkers(t) for t in trees] logging.info('%s; cpu time elapsed: %gs', msg, time.clock() - begin) trees = [treetransforms.canonicalize(a).freeze() for a in trees] return trees
def test_grammar(debug=False): """Demonstrate grammar extraction.""" from discodop.grammar import treebankgrammar, dopreduction, doubledop from discodop import plcfrs from discodop.containers import Grammar from discodop.treebank import NegraCorpusReader from discodop.treetransforms import addfanoutmarkers from discodop.disambiguation import getderivations, marginalize corpus = NegraCorpusReader('alpinosample.export', punct='move') sents = list(corpus.sents().values()) trees = [ addfanoutmarkers(binarize(a.copy(True), horzmarkov=1)) for a in list(corpus.trees().values())[:10] ] if debug: print('plcfrs\n', Grammar(treebankgrammar(trees, sents))) print('dop reduction') grammar = Grammar(dopreduction(trees[:2], sents[:2])[0], start=trees[0].label) if debug: print(grammar) _ = grammar.testgrammar() grammarx, _backtransform, _, _ = doubledop(trees, sents, debug=False, numproc=1) if debug: print('\ndouble dop grammar') grammar = Grammar(grammarx, start=trees[0].label) grammar.getmapping(None, striplabelre=None, neverblockre=re.compile('^#[0-9]+|.+}<'), splitprune=False, markorigin=False) if debug: print(grammar) result, msg = grammar.testgrammar() assert result, 'RFE should sum to 1.\n%s' % msg for tree, sent in zip(corpus.trees().values(), sents): if debug: print('sentence:', ' '.join(a.encode('unicode-escape').decode() for a in sent)) chart, msg = plcfrs.parse(sent, grammar, exhaustive=True) if debug: print('\n', msg, '\ngold ', tree, '\n', 'double dop', end='') if chart: getderivations(chart, 100) _parses, _msg = marginalize('mpp', chart) elif debug: print('no parse\n', chart) if debug: print() tree = Tree.parse('(ROOT (S (F (E (S (C (B (A 0))))))))', parse_leaf=int) Grammar(treebankgrammar([tree], [[str(a) for a in range(10)]]))
def test_binarize(self): treestr = '(S (VP (PDS 0) (ADV 3) (VVINF 4)) (PIS 2) (VMFIN 1))' origtree = Tree(treestr) tree = Tree(treestr) assert str(binarize(tree, horzmarkov=0, tailmarker='')) == ( '(S (VP (PDS 0) (VP|<> (ADV 3) (VVINF 4))) (S|<> (PIS 2) ' '(VMFIN 1)))') assert unbinarize(tree) == origtree assert str(binarize(tree, horzmarkov=1, tailmarker='')) == ( '(S (VP (PDS 0) (VP|<ADV> (ADV 3) (VVINF 4))) (S|<PIS> ' '(PIS 2) (VMFIN 1)))') assert unbinarize(tree) == origtree assert str(binarize(tree, horzmarkov=1, leftmostunary=False, rightmostunary=True, tailmarker='')) == ( '(S (VP (PDS 0) (VP|<ADV> (ADV 3) (VP|<VVINF> (VVINF 4)))) ' '(S|<PIS> (PIS 2) (S|<VMFIN> (VMFIN 1))))') assert unbinarize(tree) == origtree assert str(binarize(tree, horzmarkov=1, leftmostunary=True, rightmostunary=False, tailmarker='')) == ( '(S (S|<VP> (VP (VP|<PDS> (PDS 0) (VP|<ADV> (ADV 3) ' '(VVINF 4)))) (S|<PIS> (PIS 2) (VMFIN 1))))') assert unbinarize(tree) == origtree assert str(binarize(tree, factor='left', horzmarkov=2, tailmarker='') ) == ('(S (S|<PIS,VMFIN> (VP (VP|<ADV,VVINF> (PDS 0) (ADV 3)) ' '(VVINF 4)) (PIS 2)) (VMFIN 1))') assert unbinarize(tree) == origtree tree = Tree('(S (A 0) (B 1) (C 2) (D 3) (E 4) (F 5))') assert str(binarize(tree, tailmarker='', reverse=False)) == ( '(S (A 0) (S|<B,C,D,E,F> (B 1) (S|<C,D,E,F> (C 2) (S|<D,E,F> ' '(D 3) (S|<E,F> (E 4) (F 5))))))')
def test_binarize(self): treestr = '(S (VP (PDS 0) (ADV 3) (VVINF 4)) (VMFIN 1) (PIS 2))' origtree = Tree(treestr) tree = Tree(treestr) tree[1].type = HEAD # set VMFIN as head assert str(binarize(tree, horzmarkov=0)) == ( '(S (VP (PDS 0) (VP|<> (ADV 3) (VVINF 4))) (S|<> (VMFIN 1)' ' (PIS 2)))') assert unbinarize(tree) == origtree assert str(binarize(tree, horzmarkov=1)) == ( '(S (VP (PDS 0) (VP|<ADV> (ADV 3) (VVINF 4))) (S|<VMFIN> ' '(VMFIN 1) (PIS 2)))') assert unbinarize(tree) == origtree assert str( binarize(tree, horzmarkov=1, leftmostunary=False, rightmostunary=True, headoutward=True) ) == ('(S (VP (PDS 0) (VP|<ADV> (ADV 3) (VP|<VVINF> (VVINF 4)))) ' '(S|<VMFIN> (S|<VMFIN> (VMFIN 1)) (PIS 2)))') assert unbinarize(tree) == origtree assert str( binarize(tree, horzmarkov=1, leftmostunary=True, rightmostunary=False, headoutward=True)) == ( '(S (S|<VP> (VP (VP|<PDS> (PDS 0) (VP|<ADV> (ADV 3) ' '(VVINF 4)))) (S|<VMFIN> (VMFIN 1) (PIS 2))))') assert unbinarize(tree) == origtree assert str( binarize(tree, factor='left', horzmarkov=2, headoutward=True)) == ( '(S (S|<VMFIN,PIS> (VP (VP|<PDS,ADV> (PDS 0) (ADV 3)) ' '(VVINF 4)) (VMFIN 1)) (PIS 2))') assert unbinarize(tree) == origtree tree = Tree('(S (A 0) (B 1) (C 2) (D 3) (E 4) (F 5))') assert str(binarize(tree, headoutward=True)) == ( '(S (A 0) (S|<B,C,D,E,F> (B 1) (S|<C,D,E,F> (C 2) (S|<D,E,F> ' '(D 3) (S|<E,F> (E 4) (F 5))))))')
def test_fragments(): from discodop._fragments import getctrees, extractfragments, exactcounts treebank = [binarize(Tree(x)) for x in """\ (S (NP (DT The) (NN cat)) (VP (VBP saw) (NP (DT the) (JJ hungry) (NN dog)))) (S (NP (DT The) (NN cat)) (VP (VBP saw) (NP (DT the) (NN dog)))) (S (NP (DT The) (NN mouse)) (VP (VBP saw) (NP (DT the) (NN cat)))) (S (NP (DT The) (NN mouse)) (VP (VBP saw) (NP (DT the) (JJ yellow) (NN cat)))) (S (NP (DT The) (JJ little) (NN mouse)) (VP (VBP saw) (NP (DT the) (NN cat)))) (S (NP (DT The) (NN cat)) (VP (VBP ate) (NP (DT the) (NN dog)))) (S (NP (DT The) (NN mouse)) (VP (VBP ate) (NP (DT the) (NN cat))))\ """.splitlines()] sents = [tree.leaves() for tree in treebank] for tree in treebank: for n, idx in enumerate(tree.treepositions('leaves')): tree[idx] = n params = getctrees(treebank, sents) fragments = extractfragments(params['trees1'], params['sents1'], 0, 0, params['labels'], discontinuous=True, approx=False) counts = exactcounts(params['trees1'], params['trees1'], list(fragments.values())) assert len(fragments) == 25 assert sum(counts) == 100 for (a, b), c in sorted(zip(fragments, counts), key=repr): print("%s\t%d" % (re.sub("[0-9]+", lambda x: b[int(x.group())], a), c))
def test_binarize(self): treestr = '(S (VP (PDS 0) (ADV 3) (VVINF 4)) (VMFIN 1) (PIS 2))' origtree = Tree(treestr) tree = Tree(treestr) sethead(tree[1]) # set VMFIN as head assert str(binarize(tree, horzmarkov=0)) == ( '(S (VP (PDS 0) (VP|<> (ADV 3) (VVINF 4))) (S|<> (VMFIN 1)' ' (PIS 2)))') assert unbinarize(tree) == origtree assert str(binarize(tree, horzmarkov=1)) == ( '(S (VP (PDS 0) (VP|<ADV> (ADV 3) (VVINF 4))) (S|<VMFIN> ' '(VMFIN 1) (PIS 2)))') assert unbinarize(tree) == origtree assert str(binarize(tree, horzmarkov=1, leftmostunary=False, rightmostunary=True, headoutward=True)) == ( '(S (VP (PDS 0) (VP|<ADV> (ADV 3) (VP|<VVINF> (VVINF 4)))) ' '(S|<VMFIN> (S|<VMFIN> (VMFIN 1)) (PIS 2)))') assert unbinarize(tree) == origtree assert str(binarize(tree, horzmarkov=1, leftmostunary=True, rightmostunary=False, headoutward=True)) == ( '(S (S|<VP> (VP (VP|<PDS> (PDS 0) (VP|<ADV> (ADV 3) ' '(VVINF 4)))) (S|<VMFIN> (VMFIN 1) (PIS 2))))') assert unbinarize(tree) == origtree assert str(binarize(tree, factor='left', horzmarkov=2, headoutward=True) ) == ('(S (S|<VMFIN,PIS> (VP (VP|<PDS,ADV> (PDS 0) (ADV 3)) ' '(VVINF 4)) (VMFIN 1)) (PIS 2))') assert unbinarize(tree) == origtree tree = Tree('(S (A 0) (B 1) (C 2) (D 3) (E 4) (F 5))') assert str(binarize(tree, headoutward=True)) == ( '(S (A 0) (S|<B,C,D,E,F> (B 1) (S|<C,D,E,F> (C 2) (S|<D,E,F> ' '(D 3) (S|<E,F> (E 4) (F 5))))))')
def test_grammar(debug=False): """Demonstrate grammar extraction.""" from discodop.grammar import treebankgrammar, dopreduction, doubledop from discodop import plcfrs from discodop.containers import Grammar from discodop.treebank import NegraCorpusReader from discodop.treetransforms import addfanoutmarkers, removefanoutmarkers from discodop.disambiguation import recoverfragments from discodop.kbest import lazykbest from math import exp corpus = NegraCorpusReader('alpinosample.export', punct='move') sents = list(corpus.sents().values()) trees = [addfanoutmarkers(binarize(a.copy(True), horzmarkov=1)) for a in list(corpus.trees().values())[:10]] if debug: print('plcfrs\n', Grammar(treebankgrammar(trees, sents))) print('dop reduction') grammar = Grammar(dopreduction(trees[:2], sents[:2])[0], start=trees[0].label) if debug: print(grammar) _ = grammar.testgrammar() grammarx, backtransform, _, _ = doubledop(trees, sents, debug=debug, numproc=1) if debug: print('\ndouble dop grammar') grammar = Grammar(grammarx, start=trees[0].label) grammar.getmapping(grammar, striplabelre=None, neverblockre=re.compile(b'^#[0-9]+|.+}<'), splitprune=False, markorigin=False) if debug: print(grammar) assert grammar.testgrammar()[0], "RFE should sum to 1." for tree, sent in zip(corpus.trees().values(), sents): if debug: print("sentence:", ' '.join(a.encode('unicode-escape').decode() for a in sent)) chart, msg = plcfrs.parse(sent, grammar, exhaustive=True) if debug: print('\n', msg, '\ngold ', tree, '\n', 'double dop', end='') if chart: mpp, parsetrees = {}, {} derivations, _ = lazykbest(chart, 1000, b'}<') for d, (t, p) in zip(chart.rankededges[chart.root()], derivations): r = Tree(recoverfragments(d.key, chart, backtransform)) r = str(removefanoutmarkers(unbinarize(r))) mpp[r] = mpp.get(r, 0.0) + exp(-p) parsetrees.setdefault(r, []).append((t, p)) if debug: print(len(mpp), 'parsetrees', sum(map(len, parsetrees.values())), 'derivations') for t, tp in sorted(mpp.items(), key=itemgetter(1)): if debug: print(tp, t, '\nmatch:', t == str(tree)) if len(set(parsetrees[t])) != len(parsetrees[t]): print('chart:\n', chart) assert len(set(parsetrees[t])) == len(parsetrees[t]) if debug: for deriv, p in sorted(parsetrees[t], key=itemgetter(1)): print(' <= %6g %s' % (exp(-p), deriv)) elif debug: print('no parse\n', chart) if debug: print() tree = Tree.parse("(ROOT (S (F (E (S (C (B (A 0))))))))", parse_leaf=int) Grammar(treebankgrammar([tree], [[str(a) for a in range(10)]]))
def getgrammars(trees, sents, stages, testmaxwords, resultdir, numproc, lexmodel, simplelexsmooth, top): """Read off the requested grammars.""" tbfanout, n = treebank.treebankfanout(trees) logging.info('binarized treebank fan-out: %d #%d', tbfanout, n) for n, stage in enumerate(stages): if stage.split: traintrees = [treetransforms.binarize( treetransforms.splitdiscnodes( Tree.convert(a), stage.markorigin), childchar=':', dot=True, ids=grammar.UniqueIDs()).freeze() for a in trees] logging.info('splitted discontinuous nodes') else: traintrees = trees if stage.mode.startswith('pcfg'): if tbfanout != 1 and not stage.split: raise ValueError('Cannot extract PCFG from treebank ' 'with discontinuities.') backtransform = extrarules = None if lexmodel and simplelexsmooth: extrarules = lexicon.simplesmoothlexicon(lexmodel) if stage.dop: if stage.dop == 'doubledop': (xgrammar, backtransform, altweights, fragments ) = grammar.doubledop( traintrees, sents, binarized=stage.binarized, iterate=stage.iterate, complement=stage.complement, numproc=numproc, extrarules=extrarules) # dump fragments with codecs.getwriter('utf-8')(gzip.open('%s/%s.fragments.gz' % (resultdir, stage.name), 'w')) as out: out.writelines('%s\t%d\n' % (treebank.writetree(a, b, 0, 'bracket' if stage.mode.startswith('pcfg') else 'discbracket').rstrip(), sum(c.values())) for (a, b), c in fragments.items()) elif stage.dop == 'reduction': xgrammar, altweights = grammar.dopreduction( traintrees, sents, packedgraph=stage.packedgraph, extrarules=extrarules) else: raise ValueError('unrecognized DOP model: %r' % stage.dop) nodes = sum(len(list(a.subtrees())) for a in traintrees) if lexmodel and not simplelexsmooth: # FIXME: altweights? xgrammar = lexicon.smoothlexicon(xgrammar, lexmodel) msg = grammar.grammarinfo(xgrammar) rules, lex = grammar.write_lcfrs_grammar( xgrammar, bitpar=stage.mode.startswith('pcfg')) gram = Grammar(rules, lex, start=top, bitpar=stage.mode.startswith('pcfg'), binarized=stage.binarized) for name in altweights: gram.register(u'%s' % name, altweights[name]) with gzip.open('%s/%s.rules.gz' % ( resultdir, stage.name), 'wb') as rulesfile: rulesfile.write(rules) with codecs.getwriter('utf-8')(gzip.open('%s/%s.lex.gz' % ( resultdir, stage.name), 'wb')) as lexiconfile: lexiconfile.write(lex) logging.info('DOP model based on %d sentences, %d nodes, ' '%d nonterminals', len(traintrees), nodes, len(gram.toid)) logging.info(msg) if stage.estimator != 'rfe': gram.switch(u'%s' % stage.estimator) logging.info(gram.testgrammar()[1]) if stage.dop == 'doubledop': # backtransform keys are line numbers to rules file; # to see them together do: # $ paste <(zcat dop.rules.gz) <(zcat dop.backtransform.gz) with codecs.getwriter('ascii')(gzip.open( '%s/%s.backtransform.gz' % (resultdir, stage.name), 'w')) as out: out.writelines('%s\n' % a for a in backtransform) if n and stage.prune: msg = gram.getmapping(stages[n - 1].grammar, striplabelre=None if stages[n - 1].dop else re.compile(b'@.+$'), neverblockre=re.compile(b'.+}<'), splitprune=stage.splitprune and stages[n - 1].split, markorigin=stages[n - 1].markorigin) else: # recoverfragments() relies on this mapping to identify # binarization nodes msg = gram.getmapping(None, striplabelre=None, neverblockre=re.compile(b'.+}<'), splitprune=False, markorigin=False) logging.info(msg) elif n and stage.prune: # dop reduction msg = gram.getmapping(stages[n - 1].grammar, striplabelre=None if stages[n - 1].dop and stages[n - 1].dop != 'doubledop' else re.compile(b'@[-0-9]+$'), neverblockre=re.compile(stage.neverblockre) if stage.neverblockre else None, splitprune=stage.splitprune and stages[n - 1].split, markorigin=stages[n - 1].markorigin) if stage.mode == 'dop-rerank': gram.getrulemapping( stages[n - 1].grammar, re.compile(br'@[-0-9]+\b')) logging.info(msg) # write prob models np.savez_compressed( # pylint: disable=no-member '%s/%s.probs.npz' % (resultdir, stage.name), **{name: mod for name, mod in zip(gram.modelnames, gram.models)}) else: # not stage.dop xgrammar = grammar.treebankgrammar(traintrees, sents, extrarules=extrarules) logging.info('induced %s based on %d sentences', ('PCFG' if tbfanout == 1 or stage.split else 'PLCFRS'), len(traintrees)) if stage.split or os.path.exists('%s/pcdist.txt' % resultdir): logging.info(grammar.grammarinfo(xgrammar)) else: logging.info(grammar.grammarinfo(xgrammar, dump='%s/pcdist.txt' % resultdir)) if lexmodel and not simplelexsmooth: xgrammar = lexicon.smoothlexicon(xgrammar, lexmodel) rules, lex = grammar.write_lcfrs_grammar( xgrammar, bitpar=stage.mode.startswith('pcfg')) gram = Grammar(rules, lex, start=top, bitpar=stage.mode.startswith('pcfg')) with gzip.open('%s/%s.rules.gz' % ( resultdir, stage.name), 'wb') as rulesfile: rulesfile.write(rules) with codecs.getwriter('utf-8')(gzip.open('%s/%s.lex.gz' % ( resultdir, stage.name), 'wb')) as lexiconfile: lexiconfile.write(lex) logging.info(gram.testgrammar()[1]) if n and stage.prune: msg = gram.getmapping(stages[n - 1].grammar, striplabelre=None, neverblockre=re.compile(stage.neverblockre) if stage.neverblockre else None, splitprune=stage.splitprune and stages[n - 1].split, markorigin=stages[n - 1].markorigin) logging.info(msg) logging.info('wrote grammar to %s/%s.{rules,lex%s}.gz', resultdir, stage.name, ',backtransform' if stage.dop == 'doubledop' else '') outside = None if stage.estimates in ('SX', 'SXlrgaps'): if stage.estimates == 'SX' and tbfanout != 1 and not stage.split: raise ValueError('SX estimate requires PCFG.') elif stage.mode != 'plcfrs': raise ValueError('estimates require parser w/agenda.') begin = time.clock() logging.info('computing %s estimates', stage.estimates) if stage.estimates == 'SX': outside = estimates.getpcfgestimates(gram, testmaxwords, gram.toid[trees[0].label]) elif stage.estimates == 'SXlrgaps': outside = estimates.getestimates(gram, testmaxwords, gram.toid[trees[0].label]) logging.info('estimates done. cpu time elapsed: %gs', time.clock() - begin) np.savez_compressed( # pylint: disable=no-member '%s/%s.outside.npz' % (resultdir, stage.name), outside=outside) logging.info('saved %s estimates', stage.estimates) elif stage.estimates: raise ValueError('unrecognized value; specify SX or SXlrgaps.') stage.update(grammar=gram, backtransform=backtransform, outside=outside)
def test_grammar(debug=False): """Demonstrate grammar extraction.""" from discodop.grammar import treebankgrammar, dopreduction, doubledop from discodop import plcfrs from discodop.containers import Grammar from discodop.treebank import NegraCorpusReader from discodop.treetransforms import addfanoutmarkers, removefanoutmarkers from discodop.disambiguation import recoverfragments from discodop.kbest import lazykbest from math import exp corpus = NegraCorpusReader('alpinosample.export', punct='move') sents = list(corpus.sents().values()) trees = [ addfanoutmarkers(binarize(a.copy(True), horzmarkov=1)) for a in list(corpus.trees().values())[:10] ] if debug: print('plcfrs\n', Grammar(treebankgrammar(trees, sents))) print('dop reduction') grammar = Grammar(dopreduction(trees[:2], sents[:2])[0], start=trees[0].label) if debug: print(grammar) _ = grammar.testgrammar() grammarx, backtransform, _, _ = doubledop(trees, sents, debug=False, numproc=1) if debug: print('\ndouble dop grammar') grammar = Grammar(grammarx, start=trees[0].label) grammar.getmapping(grammar, striplabelre=None, neverblockre=re.compile('^#[0-9]+|.+}<'), splitprune=False, markorigin=False) if debug: print(grammar) assert grammar.testgrammar()[0], "RFE should sum to 1." for tree, sent in zip(corpus.trees().values(), sents): if debug: print("sentence:", ' '.join(a.encode('unicode-escape').decode() for a in sent)) chart, msg = plcfrs.parse(sent, grammar, exhaustive=True) if debug: print('\n', msg, '\ngold ', tree, '\n', 'double dop', end='') if chart: mpp, parsetrees = {}, {} derivations, _ = lazykbest(chart, 1000, '}<') for d, (t, p) in zip(chart.rankededges[chart.root()], derivations): r = Tree(recoverfragments(d.key, chart, backtransform)) r = str(removefanoutmarkers(unbinarize(r))) mpp[r] = mpp.get(r, 0.0) + exp(-p) parsetrees.setdefault(r, []).append((t, p)) if debug: print(len(mpp), 'parsetrees', sum(map(len, parsetrees.values())), 'derivations') for t, tp in sorted(mpp.items(), key=itemgetter(1)): if debug: print(tp, t, '\nmatch:', t == str(tree)) if len(set(parsetrees[t])) != len(parsetrees[t]): print('chart:\n', chart) assert len(set(parsetrees[t])) == len(parsetrees[t]) if debug: for deriv, p in sorted(parsetrees[t], key=itemgetter(1)): print(' <= %6g %s' % (exp(-p), deriv)) elif debug: print('no parse\n', chart) if debug: print() tree = Tree.parse("(ROOT (S (F (E (S (C (B (A 0))))))))", parse_leaf=int) Grammar(treebankgrammar([tree], [[str(a) for a in range(10)]]))
def test(): """ Run some tests. """ from discodop import plcfrs from discodop.containers import Grammar from discodop.treebank import NegraCorpusReader from discodop.treetransforms import binarize, unbinarize, \ addfanoutmarkers, removefanoutmarkers from discodop.disambiguation import recoverfragments from discodop.kbest import lazykbest from discodop.fragments import getfragments logging.basicConfig(level=logging.DEBUG, format='%(message)s') filename = "alpinosample.export" corpus = NegraCorpusReader('.', filename, punct='move') sents = list(corpus.sents().values()) trees = [addfanoutmarkers(binarize(a.copy(True), horzmarkov=1)) for a in list(corpus.parsed_sents().values())[:10]] print('plcfrs') lcfrs = Grammar(treebankgrammar(trees, sents), start=trees[0].label) print(lcfrs) print('dop reduction') grammar = Grammar(dopreduction(trees[:2], sents[:2])[0], start=trees[0].label) print(grammar) grammar.testgrammar() fragments = getfragments(trees, sents, 1) debug = '--debug' in sys.argv grammarx, backtransform, _ = doubledop(trees, fragments, debug=debug) print('\ndouble dop grammar') grammar = Grammar(grammarx, start=trees[0].label) grammar.getmapping(grammar, striplabelre=None, neverblockre=re.compile(b'^#[0-9]+|.+}<'), splitprune=False, markorigin=False) print(grammar) assert grammar.testgrammar(), "DOP1 should sum to 1." for tree, sent in zip(corpus.parsed_sents().values(), sents): print("sentence:", ' '.join(a.encode('unicode-escape').decode() for a in sent)) chart, msg = plcfrs.parse(sent, grammar, exhaustive=True) print('\n', msg, end='') print("\ngold ", tree) print("double dop", end='') if chart: mpp = {} parsetrees = {} derivations, _ = lazykbest(chart, 1000, b'}<') for d, (t, p) in zip(chart.rankededges[chart.root()], derivations): r = Tree(recoverfragments(d.getkey(), chart, grammar, backtransform)) r = str(removefanoutmarkers(unbinarize(r))) mpp[r] = mpp.get(r, 0.0) + exp(-p) parsetrees.setdefault(r, []).append((t, p)) print(len(mpp), 'parsetrees', end='') print(sum(map(len, parsetrees.values())), 'derivations') for t, tp in sorted(mpp.items(), key=itemgetter(1)): print(tp, '\n', t, end='') print("match:", t == str(tree)) assert len(set(parsetrees[t])) == len(parsetrees[t]) if not debug: continue for deriv, p in sorted(parsetrees[t], key=itemgetter(1)): print(' <= %6g %s' % (exp(-p), deriv)) else: print("no parse") print(chart) print() tree = Tree.parse("(ROOT (S (F (E (S (C (B (A 0))))))))", parse_leaf=int) Grammar(treebankgrammar([tree], [[str(a) for a in range(10)]]))
indices = [] with open('../../datasets/preprocessed/test_indices.txt') as f: for line in f: indices.append(int(line)) vectorizer = feature_extraction.DictVectorizer(sparse=True) treeStrings = [ line[:-1] for line in io.open('../../datasets/preprocessed/test_trees.txt', encoding='utf-8')] print "Total of", len(treeStrings), "trees in test set." treeStrings.extend([ line[:-1] for line in io.open('../../datasets/preprocessed/trees.txt', encoding='utf-8')]) text = BracketStringReader(treeStrings) print "Made treebank" trees = [treetransforms.binarize(tree, horzmarkov=1, vertmarkov=1) for _, (tree, _) in text.itertrees(0)] print "Binarized trees" sents = [sent for _, (_, sent) in text.itertrees(0)] print "Starting fragment extraction" result = fragments.getfragments(trees, sents, numproc=1, disc=False, cover=True) print "Extracted fragments" treeIndex = 0 found = 0 total = 0 for tree, sentDict in result.items(): total += 1 if tree in featureMap: found += 1
def bitext(): """ Bitext parsing with a synchronous CFG. Translation would require a special decoder (instead of normal kbest derivations where the whole sentence is given). """ print("bitext parsing with a synchronous CFG") trees = [Tree.parse(a, parse_leaf=int) for a in """\ (ROOT (S (NP (NNP (John 0) (John 7))) (VP (VB (misses 1) (manque 5))\ (PP (IN (a` 6)) (NP (NNP (Mary 2) (Mary 4)))))) (SEP (| 3))) (ROOT (S (NP (NNP (Mary 0) (Mary 4))) (VP (VB (likes 1) (aimes 5))\ (NP (DT (la 6)) (NN (pizza 2) (pizza 7))))) (SEP (| 3)))""".split('\n')] sents = [["0"] * len(a.leaves()) for a in trees] for a in trees: treetransforms.binarize(a) compiled_scfg = Grammar(treebankgrammar(trees, sents)) print("sentences:") for tree in trees: print(' '.join(w for _, w in sorted(tree.pos()))) print("treebank:") for tree in trees: print(tree) print(compiled_scfg, "\n") print("correct translations:") assert parse(compiled_scfg, ["0"] * 7, "John likes Mary | John aimes Mary".split()) assert parse(compiled_scfg, ["0"] * 9, "John misses pizza | la pizza manque a` John".split()) print("incorrect translations:") assert not parse(compiled_scfg, ["0"] * 7, "John likes Mary | Mary aimes John".split()) assert not parse(compiled_scfg, ["0"] * 9, "John misses pizza | John manque a` la pizza".split()) # the following SCFG is taken from: # http://cdec-decoder.org/index.php?title=SCFG_translation # the grammar has been binarized and some new non-terminals had to be # introduced because terminals cannot appear in binary rules. lexicon = ("|", "ein", "ich", "Haus", "kleines", "grosses", "sah", "fand", "small", "little", "big", "large", "house", "shell", "a", "I", "saw", "found") another_scfg = Grammar([ ((('DT', '_ein', '_a'), ((0, ), (1, ))), 0.5), ((('JJ', '_kleines', '_small'), ((0, ), (1, ))), 0.1), ((('JJ', '_kleines', '_little'), ((0, ), (1, ))), 0.9), ((('JJ', '_grosses', '_big'), ((0, ), (1, ))), 0.8), ((('JJ', '_grosses', '_large'), ((0, ), (1, ))), 0.2345), ((('NN_house', '_Haus', '_house'), ((0, ), (1, ))), 1), ((('NN_shell', '_Haus', '_shell'), ((0, ), (1, ))), 1), ((('NP', '_ich', '_I'), ((0, ), (1, ), )), 0.6), ((('NP', 'DT', 'NP|<JJ-NN>'), ((0, 1), (0, 1))), 0.5), ((('NP|<JJ-NN>', 'JJ', 'NN_house'), ((0, 1), (0, 1))), 0.1), ((('NP|<JJ-NN>', 'JJ', 'NN_shell'), ((0, 1), (0, 1))), 1.3), ((('ROOT', 'S', '_|'), ((0, 1, 0), )), 1), ((('S', 'NP', 'VP'), ((0, 1), (0, 1))), 0.2), ((('VP', 'V', 'NP'), ((0, 1), (0, 1))), 0.1), ((('V', '_sah', '_saw'), ((0, ), (1, ))), 0.4), ((('V', '_fand', '_found'), ((0, ), (1, ))), 0.4)] + [((('_%s' % word, 'Epsilon'), (word, )), 1) for word in lexicon]) print(another_scfg) sents = [ "ich sah ein kleines Haus | I saw a small house".split(), "ich sah ein kleines Haus | I saw a little house".split(), "ich sah ein kleines Haus | I saw a small shell".split(), "ich sah ein kleines Haus | I saw a little shell".split()] for sent in sents: assert parse(another_scfg, sent), sent
def getgrammars(trees, sents, stages, bintype, horzmarkov, vertmarkov, factor, tailmarker, revmarkov, leftmostunary, rightmostunary, pospa, markhead, fanout_marks_before_bin, testmaxwords, resultdir, numproc, lexmodel, simplelexsmooth, top, relationalrealizational): """ Apply binarization and read off the requested grammars. """ # fixme: this n should correspond to sentence id tbfanout, n = treebankfanout(trees) logging.info('treebank fan-out before binarization: %d #%d\n%s\n%s', tbfanout, n, trees[n], ' '.join(sents[n])) # binarization begin = time.clock() if fanout_marks_before_bin: trees = [addfanoutmarkers(t) for t in trees] if bintype == 'binarize': bintype += ' %s h=%d v=%d %s' % (factor, horzmarkov, vertmarkov, 'tailmarker' if tailmarker else '') for a in trees: binarize(a, factor=factor, tailmarker=tailmarker, horzmarkov=horzmarkov, vertmarkov=vertmarkov, leftmostunary=leftmostunary, rightmostunary=rightmostunary, reverse=revmarkov, pospa=pospa, headidx=-1 if markhead else None, filterfuncs=(relationalrealizational['ignorefunctions'] + (relationalrealizational['adjunctionlabel'], )) if relationalrealizational else ()) elif bintype == 'optimal': trees = [Tree.convert(optimalbinarize(tree)) for n, tree in enumerate(trees)] elif bintype == 'optimalhead': trees = [Tree.convert(optimalbinarize(tree, headdriven=True, h=horzmarkov, v=vertmarkov)) for n, tree in enumerate(trees)] trees = [addfanoutmarkers(t) for t in trees] logging.info('binarized %s cpu time elapsed: %gs', bintype, time.clock() - begin) logging.info('binarized treebank fan-out: %d #%d', *treebankfanout(trees)) trees = [canonicalize(a).freeze() for a in trees] for n, stage in enumerate(stages): if stage.split: traintrees = [binarize(splitdiscnodes(Tree.convert(a), stage.markorigin), childchar=':').freeze() for a in trees] logging.info('splitted discontinuous nodes') else: traintrees = trees if stage.mode.startswith('pcfg'): assert tbfanout == 1 or stage.split backtransform = None if stage.dop: if stage.usedoubledop: # find recurring fragments in treebank, # as well as depth 1 'cover' fragments fragments = getfragments(traintrees, sents, numproc, iterate=stage.iterate, complement=stage.complement) xgrammar, backtransform, altweights = doubledop( traintrees, fragments) else: # DOP reduction xgrammar, altweights = dopreduction( traintrees, sents, packedgraph=stage.packedgraph) nodes = sum(len(list(a.subtrees())) for a in traintrees) if lexmodel and simplelexsmooth: newrules = simplesmoothlexicon(lexmodel) xgrammar.extend(newrules) for weights in altweights.values(): weights.extend(w for _, w in newrules) elif lexmodel: xgrammar = smoothlexicon(xgrammar, lexmodel) msg = grammarinfo(xgrammar) rules, lexicon = write_lcfrs_grammar( xgrammar, bitpar=stage.mode.startswith('pcfg')) grammar = Grammar(rules, lexicon, start=top, bitpar=stage.mode.startswith('pcfg')) for name in altweights: grammar.register(u'%s' % name, altweights[name]) with gzip.open('%s/%s.rules.gz' % ( resultdir, stage.name), 'wb') as rulesfile: rulesfile.write(rules) with codecs.getwriter('utf-8')(gzip.open('%s/%s.lex.gz' % ( resultdir, stage.name), 'wb')) as lexiconfile: lexiconfile.write(lexicon) logging.info('DOP model based on %d sentences, %d nodes, ' '%d nonterminals', len(traintrees), nodes, len(grammar.toid)) logging.info(msg) if stage.estimator != 'dop1': grammar.switch(u'%s' % stage.estimator) _sumsto1 = grammar.testgrammar() if stage.usedoubledop: # backtransform keys are line numbers to rules file; # to see them together do: # $ paste <(zcat dop.rules.gz) <(zcat dop.backtransform.gz) with codecs.getwriter('ascii')(gzip.open( '%s/%s.backtransform.gz' % (resultdir, stage.name), 'w')) as out: out.writelines('%s\n' % a for a in backtransform) if n and stage.prune: msg = grammar.getmapping(stages[n - 1].grammar, striplabelre=None if stages[n - 1].dop else re.compile(b'@.+$'), neverblockre=re.compile(b'.+}<'), splitprune=stage.splitprune and stages[n - 1].split, markorigin=stages[n - 1].markorigin) else: # recoverfragments() relies on this mapping to identify # binarization nodes msg = grammar.getmapping(None, striplabelre=None, neverblockre=re.compile(b'.+}<'), splitprune=False, markorigin=False) logging.info(msg) elif n and stage.prune: # dop reduction msg = grammar.getmapping(stages[n - 1].grammar, striplabelre=None if stages[n - 1].dop and not stages[n - 1].usedoubledop else re.compile(b'@[-0-9]+$'), neverblockre=re.compile(stage.neverblockre) if stage.neverblockre else None, splitprune=stage.splitprune and stages[n - 1].split, markorigin=stages[n - 1].markorigin) if stage.mode == 'dop-rerank': grammar.getrulemapping(stages[n - 1].grammar) logging.info(msg) # write prob models np.savez_compressed('%s/%s.probs.npz' % (resultdir, stage.name), **{name: mod for name, mod in zip(grammar.modelnames, grammar.models)}) else: # not stage.dop xgrammar = treebankgrammar(traintrees, sents) logging.info('induced %s based on %d sentences', ('PCFG' if tbfanout == 1 or stage.split else 'PLCFRS'), len(traintrees)) if stage.split or os.path.exists('%s/pcdist.txt' % resultdir): logging.info(grammarinfo(xgrammar)) else: logging.info(grammarinfo(xgrammar, dump='%s/pcdist.txt' % resultdir)) if lexmodel and simplelexsmooth: newrules = simplesmoothlexicon(lexmodel) xgrammar.extend(newrules) elif lexmodel: xgrammar = smoothlexicon(xgrammar, lexmodel) rules, lexicon = write_lcfrs_grammar( xgrammar, bitpar=stage.mode.startswith('pcfg')) grammar = Grammar(rules, lexicon, start=top, bitpar=stage.mode.startswith('pcfg')) with gzip.open('%s/%s.rules.gz' % ( resultdir, stage.name), 'wb') as rulesfile: rulesfile.write(rules) with codecs.getwriter('utf-8')(gzip.open('%s/%s.lex.gz' % ( resultdir, stage.name), 'wb')) as lexiconfile: lexiconfile.write(lexicon) _sumsto1 = grammar.testgrammar() if n and stage.prune: msg = grammar.getmapping(stages[n - 1].grammar, striplabelre=None, neverblockre=re.compile(stage.neverblockre) if stage.neverblockre else None, splitprune=stage.splitprune and stages[n - 1].split, markorigin=stages[n - 1].markorigin) logging.info(msg) logging.info('wrote grammar to %s/%s.{rules,lex%s}.gz', resultdir, stage.name, ',backtransform' if stage.usedoubledop else '') outside = None if stage.getestimates == 'SX': assert tbfanout == 1 or stage.split, 'SX estimate requires PCFG.' logging.info('computing PCFG estimates') begin = time.clock() outside = getpcfgestimates(grammar, testmaxwords, grammar.toid[trees[0].label]) logging.info('estimates done. cpu time elapsed: %gs', time.clock() - begin) np.savez('pcfgoutside.npz', outside=outside) logging.info('saved PCFG estimates') elif stage.useestimates == 'SX': assert tbfanout == 1 or stage.split, 'SX estimate requires PCFG.' assert stage.mode != 'pcfg', ( 'estimates require agenda-based parser.') outside = np.load('pcfgoutside.npz')['outside'] logging.info('loaded PCFG estimates') if stage.getestimates == 'SXlrgaps': logging.info('computing PLCFRS estimates') begin = time.clock() outside = getestimates(grammar, testmaxwords, grammar.toid[trees[0].label]) logging.info('estimates done. cpu time elapsed: %gs', time.clock() - begin) np.savez('outside.npz', outside=outside) logging.info('saved estimates') elif stage.useestimates == 'SXlrgaps': outside = np.load('outside.npz')['outside'] logging.info('loaded PLCFRS estimates') stage.update(grammar=grammar, backtransform=backtransform, outside=outside)
assert len(argv) == 2, (f"use {argv[0]} <data.conf>") cp = ConfigParser() cp.read(argv[1]) config = corpusparam(**cp["Corpus"], **cp["Grammar"]) from discodop.tree import Tree from discodop.treebank import READERS from discodop.treetransforms import addfanoutmarkers, binarize, collapseunary from discodop.lexgrammar import SupertagCorpus, SupertagGrammar corpus = READERS[config.inputfmt](config.filename, encoding=config.inputenc, punct="move") trees = [ addfanoutmarkers( binarize( collapseunary( Tree.convert(t), collapseroot=True, collapsepos=True), horzmarkov=config.h, vertmarkov=config.v)) for t in corpus.trees().values()] sents = list(corpus.sents().values()) corpus = SupertagCorpus(trees, sents) size = len(corpus.sent_corpus) portions = config.split.split() names = "train dev test".split() assert len(portions) in [3,4] if portions[0] == "debug": portions = tuple(int(portion) for portion in portions[1:2]+portions[1:]) limits = tuple((name, slice(0, end)) for name, end in zip(names, portions)) else:
def getfragments(trees, sents, numproc=1, iterate=False, complement=False): """ Get recurring fragments with exact counts in a single treebank. :returns: a dictionary whose keys are fragments as strings, and frequencies / indices as values. :param trees: a sequence of binarized Tree objects. """ if numproc == 0: numproc = cpu_count() numtrees = len(trees) assert numtrees mult = 1 # 3 if numproc > 1 else 1 fragments = {} trees = trees[:] work = workload(numtrees, mult, numproc) PARAMS.update(disc=True, indices=True, approx=False, complete=False, quadratic=False, complement=complement) if numproc == 1: initworkersimple(trees, list(sents)) mymap = map myapply = APPLY else: logging.info("work division:\n%s", "\n".join(" %s: %r" % kv for kv in sorted(dict(numchunks=len(work), numproc=numproc).items()))) # start worker processes pool = Pool(processes=numproc, initializer=initworkersimple, initargs=(trees, list(sents))) mymap = pool.map myapply = pool.apply # collect recurring fragments logging.info("extracting recurring fragments") for a in mymap(worker, work): fragments.update(a) # add 'cover' fragments corresponding to single productions cover = myapply(coverfragworker, ()) before = len(fragments) fragments.update(cover) logging.info("merged %d unseen cover fragments", len(fragments) - before) fragmentkeys = list(fragments) bitsets = [fragments[a] for a in fragmentkeys] countchunk = len(bitsets) // numproc + 1 work = list(range(0, len(bitsets), countchunk)) work = [(n, len(work), bitsets[a:a + countchunk]) for n, a in enumerate(work)] logging.info("getting exact counts for %d fragments", len(bitsets)) counts = [] for a in mymap(exactcountworker, work): counts.extend(a) if numproc != 1: pool.close() pool.join() del pool if iterate: # optionally collect fragments of fragments logging.info("extracting fragments of recurring fragments") PARAMS['complement'] = False # needs to be turned off if it was on newfrags = fragments trees, sents = None, None ids = count() for _ in range(10): # up to 10 iterations newtrees = [binarize( introducepreterminals(Tree.parse(tree, parse_leaf=int), ids=ids), childchar="}") for tree, _ in newfrags] newsents = [["#%d" % next(ids) if word is None else word for word in sent] for _, sent in newfrags] newfrags, newcounts = iteratefragments( fragments, newtrees, newsents, trees, sents, numproc) if len(newfrags) == 0: break if trees is None: trees = [] sents = [] trees.extend(newtrees) sents.extend(newsents) fragmentkeys.extend(newfrags) counts.extend(newcounts) fragments.update(zip(newfrags, newcounts)) logging.info("found %d fragments", len(fragmentkeys)) return dict(zip(fragmentkeys, counts))
def getfragments(trees, sents, numproc=1, disc=True, iterate=False, complement=False, indices=True, cover=True): """Get recurring fragments with exact counts in a single treebank. :returns: a dictionary whose keys are fragments as strings, and indices as values. When ``disc`` is ``True``, keys are of the form ``(frag, sent)`` where ``frag`` is a unicode string, and ``sent`` is a list of words as unicode strings; when ``disc`` is ``False``, keys are of the form ``frag`` where ``frag`` is a unicode string. :param trees: a sequence of binarized Tree objects. :param numproc: number of processes to use; pass 0 to use detected # CPUs. :param disc: when disc=True, assume trees with discontinuous constituents. :param iterate, complement: see :func:`_fragments.extractfragments`""" if numproc == 0: numproc = cpu_count() numtrees = len(trees) if not numtrees: raise ValueError('no trees.') mult = 1 # 3 if numproc > 1 else 1 fragments = {} trees = trees[:] work = workload(numtrees, mult, numproc) PARAMS.update(disc=disc, indices=indices, approx=False, complete=False, complement=complement, debug=False, adjacent=False, twoterms=False) initworkersimple(trees, list(sents), disc) if numproc == 1: mymap = map myapply = APPLY else: logging.info("work division:\n%s", "\n".join(" %s: %r" % kv for kv in sorted(dict(numchunks=len(work), numproc=numproc).items()))) # start worker processes pool = Pool(processes=numproc, initializer=initworkersimple, initargs=(trees, list(sents), disc)) mymap = pool.map myapply = pool.apply # collect recurring fragments logging.info("extracting recurring fragments") for a in mymap(worker, work): fragments.update(a) # add 'cover' fragments corresponding to single productions if cover: cover = myapply(coverfragworker, ()) before = len(fragments) fragments.update(cover) logging.info("merged %d unseen cover fragments", len(fragments) - before) fragmentkeys = list(fragments) bitsets = [fragments[a] for a in fragmentkeys] countchunk = len(bitsets) // numproc + 1 work = list(range(0, len(bitsets), countchunk)) work = [(n, len(work), bitsets[a:a + countchunk]) for n, a in enumerate(work)] logging.info("getting exact counts for %d fragments", len(bitsets)) counts = [] for a in mymap(exactcountworker, work): counts.extend(a) if numproc != 1: pool.close() pool.join() del pool if iterate: # optionally collect fragments of fragments logging.info("extracting fragments of recurring fragments") PARAMS['complement'] = False # needs to be turned off if it was on newfrags = fragments trees, sents = None, None ids = count() for _ in range(10): # up to 10 iterations newtrees = [binarize( introducepreterminals(Tree.parse(tree, parse_leaf=int), ids=ids), childchar="}") for tree, _ in newfrags] newsents = [["#%d" % next(ids) if word is None else word for word in sent] for _, sent in newfrags] newfrags, newcounts = iteratefragments( fragments, newtrees, newsents, trees, sents, numproc) if len(newfrags) == 0: break if trees is None: trees = [] sents = [] trees.extend(newtrees) sents.extend(newsents) fragmentkeys.extend(newfrags) counts.extend(newcounts) fragments.update(zip(newfrags, newcounts)) logging.info("found %d fragments", len(fragmentkeys)) if not disc: return {a.decode('utf-8'): b for a, b in zip(fragmentkeys, counts)} return {(a.decode('utf-8'), b): c for (a, b), c in zip(fragmentkeys, counts)}