예제 #1
0
    def test_cfg_approximation_conversion(self):
        grammar = self.build_nm_grammar()
        disco_grammar_rules = list(transform_grammar_cfg_approx(grammar))
        print(disco_grammar_rules)
        disco_grammar = Grammar(disco_grammar_rules, start=grammar.start())
        print(disco_grammar)
        n = 2
        m = 3
        inp = ["a"] * n + ["b"] * m + ["c"] * n + ["d"] * m

        chart, msg = parse(inp, disco_grammar, beam_beta=exp(-4))
        chart.filter()
        print(chart)
        print(msg)

        fine_grammar_rules = list(transform_grammar(grammar))

        fine = Grammar(fine_grammar_rules, start=grammar.start())
        fine.getmapping(disco_grammar, re.compile('\*[0-9]+$'), None, True, True)

        whitelist, msg = prunechart(chart, fine, k=10000, splitprune=True, markorigin=True, finecfg=False)
        print(msg)
        print(whitelist)

        chart2, msg = parse(inp, fine, whitelist=whitelist, splitprune=True, markorigin=True)
        print(msg)
        print(chart2)
예제 #2
0
def test_grammar(debug=False):
    """Demonstrate grammar extraction."""
    from discodop.grammar import treebankgrammar, dopreduction, doubledop
    from discodop import plcfrs
    from discodop.containers import Grammar
    from discodop.treebank import NegraCorpusReader
    from discodop.treetransforms import addfanoutmarkers
    from discodop.disambiguation import getderivations, marginalize
    corpus = NegraCorpusReader('alpinosample.export', punct='move')
    sents = list(corpus.sents().values())
    trees = [
        addfanoutmarkers(binarize(a.copy(True), horzmarkov=1))
        for a in list(corpus.trees().values())[:10]
    ]
    if debug:
        print('plcfrs\n', Grammar(treebankgrammar(trees, sents)))
        print('dop reduction')
    grammar = Grammar(dopreduction(trees[:2], sents[:2])[0],
                      start=trees[0].label)
    if debug:
        print(grammar)
    _ = grammar.testgrammar()

    grammarx, _backtransform, _, _ = doubledop(trees,
                                               sents,
                                               debug=False,
                                               numproc=1)
    if debug:
        print('\ndouble dop grammar')
    grammar = Grammar(grammarx, start=trees[0].label)
    grammar.getmapping(None,
                       striplabelre=None,
                       neverblockre=re.compile('^#[0-9]+|.+}<'),
                       splitprune=False,
                       markorigin=False)
    if debug:
        print(grammar)
    result, msg = grammar.testgrammar()
    assert result, 'RFE should sum to 1.\n%s' % msg
    for tree, sent in zip(corpus.trees().values(), sents):
        if debug:
            print('sentence:',
                  ' '.join(a.encode('unicode-escape').decode() for a in sent))
        chart, msg = plcfrs.parse(sent, grammar, exhaustive=True)
        if debug:
            print('\n', msg, '\ngold ', tree, '\n', 'double dop', end='')
        if chart:
            getderivations(chart, 100)
            _parses, _msg = marginalize('mpp', chart)
        elif debug:
            print('no parse\n', chart)
        if debug:
            print()
    tree = Tree.parse('(ROOT (S (F (E (S (C (B (A 0))))))))', parse_leaf=int)
    Grammar(treebankgrammar([tree], [[str(a) for a in range(10)]]))
예제 #3
0
 def __init__(self,
              grammar,
              input=None,
              save_preprocessing=None,
              load_preprocessing=None,
              k=50,
              heuristics=None,
              la=None,
              variational=False,
              sum_op=False,
              nontMap=None,
              cfg_ctf=False,
              beam_beta=0.0,
              beam_delta=50,
              pruning_k=10000,
              grammarInfo=None,
              projection_mode=False,
              latent_viterbi_mode=False,
              secondaries=None
              ):
     rule_list = list(transform_grammar(grammar))
     self.disco_grammar = Grammar(rule_list, start=grammar.start())
     self.chart = None
     self.input = input
     self.grammar = grammar
     self.k = k
     self.beam_beta = beam_beta # beam pruning factor, between 0.0 and 1.0; 0.0 to disable.
     self.beam_delta = beam_delta  # maximum span length to which beam_beta is applied
     self.counter = 0
     self.la = la
     self.nontMap = nontMap
     self.variational = variational
     self.op = add if sum_op else prod
     self.debug = False
     self.log_mode = True
     self.estimates = None
     self.cfg_approx = cfg_ctf
     self.pruning_k = pruning_k
     self.grammarInfo = grammarInfo
     self.projection_mode = projection_mode
     self.latent_viterbi_mode = latent_viterbi_mode
     self.secondaries = [] if secondaries is None else secondaries
     self.secondary_mode = "DEFAULT"
     self.k_best_reranker = None
     if grammarInfo is not None:
         if isinstance(self.la, PyLatentAnnotation):
             assert self.la.check_rule_split_alignment()
         else:
             for l in self.la:
                 assert l.check_rule_split_alignment()
     if cfg_ctf:
         cfg_rule_list = list(transform_grammar_cfg_approx(grammar))
         self.disco_cfg_grammar = Grammar(cfg_rule_list, start=grammar.start())
         self.disco_grammar.getmapping(self.disco_cfg_grammar, re.compile('\*[0-9]+$'), None, True, True)
예제 #4
0
def readgrammar(rulesfile, lexiconfile, start, backtransformfile=None):
	"""Read grammar into global variables."""
	global GRAMMAR, BACKTRANSFORM
	rules = (gzip.open if rulesfile.endswith('.gz') else open)(rulesfile).read()
	lexicon = codecs.getreader('utf-8')((gzip.open if lexiconfile.endswith('.gz')
			else open)(lexiconfile)).read()
	bitpar = rules[0] in '0123456789'
	GRAMMAR = Grammar(rules, lexicon,
			start=start, bitpar=bitpar)
	BACKTRANSFORM = None
	if backtransformfile:
		BACKTRANSFORM = (gzip.open if backtransformfile.endswith('.gz')
				else open)(backtransformfile).read().splitlines()
		_ = GRAMMAR.getmapping(None, neverblockre=re.compile(b'.+}<'))
예제 #5
0
def test_issue51():
    from discodop.containers import Grammar
    from discodop.plcfrs import parse
    g = Grammar([((('S', 'A'), ((0, ), )), 1.0),
                 ((('A', 'Epsilon'), ('a', )), 1.0)],
                start='S')
    chart, _msg = parse(['b'], g)
    chart.filter()
예제 #6
0
def test_grammar(debug=False):
    """Demonstrate grammar extraction."""
    from discodop.grammar import treebankgrammar, dopreduction, doubledop
    from discodop import plcfrs
    from discodop.containers import Grammar
    from discodop.treebank import NegraCorpusReader
    from discodop.treetransforms import addfanoutmarkers, removefanoutmarkers
    from discodop.disambiguation import recoverfragments
    from discodop.kbest import lazykbest
    from math import exp
    corpus = NegraCorpusReader('alpinosample.export', punct='move')
    sents = list(corpus.sents().values())
    trees = [
        addfanoutmarkers(binarize(a.copy(True), horzmarkov=1))
        for a in list(corpus.trees().values())[:10]
    ]
    if debug:
        print('plcfrs\n', Grammar(treebankgrammar(trees, sents)))
        print('dop reduction')
    grammar = Grammar(dopreduction(trees[:2], sents[:2])[0],
                      start=trees[0].label)
    if debug:
        print(grammar)
    _ = grammar.testgrammar()

    grammarx, backtransform, _, _ = doubledop(trees,
                                              sents,
                                              debug=False,
                                              numproc=1)
    if debug:
        print('\ndouble dop grammar')
    grammar = Grammar(grammarx, start=trees[0].label)
    grammar.getmapping(grammar,
                       striplabelre=None,
                       neverblockre=re.compile('^#[0-9]+|.+}<'),
                       splitprune=False,
                       markorigin=False)
    if debug:
        print(grammar)
    assert grammar.testgrammar()[0], "RFE should sum to 1."
    for tree, sent in zip(corpus.trees().values(), sents):
        if debug:
            print("sentence:",
                  ' '.join(a.encode('unicode-escape').decode() for a in sent))
        chart, msg = plcfrs.parse(sent, grammar, exhaustive=True)
        if debug:
            print('\n', msg, '\ngold ', tree, '\n', 'double dop', end='')
        if chart:
            mpp, parsetrees = {}, {}
            derivations, _ = lazykbest(chart, 1000, '}<')
            for d, (t, p) in zip(chart.rankededges[chart.root()], derivations):
                r = Tree(recoverfragments(d.key, chart, backtransform))
                r = str(removefanoutmarkers(unbinarize(r)))
                mpp[r] = mpp.get(r, 0.0) + exp(-p)
                parsetrees.setdefault(r, []).append((t, p))
            if debug:
                print(len(mpp), 'parsetrees',
                      sum(map(len, parsetrees.values())), 'derivations')
            for t, tp in sorted(mpp.items(), key=itemgetter(1)):
                if debug:
                    print(tp, t, '\nmatch:', t == str(tree))
                if len(set(parsetrees[t])) != len(parsetrees[t]):
                    print('chart:\n', chart)
                    assert len(set(parsetrees[t])) == len(parsetrees[t])
                if debug:
                    for deriv, p in sorted(parsetrees[t], key=itemgetter(1)):
                        print(' <= %6g %s' % (exp(-p), deriv))
        elif debug:
            print('no parse\n', chart)
        if debug:
            print()
    tree = Tree.parse("(ROOT (S (F (E (S (C (B (A 0))))))))", parse_leaf=int)
    Grammar(treebankgrammar([tree], [[str(a) for a in range(10)]]))
예제 #7
0
    def test_individual_parsing_stages(self):
        grammar = self.build_grammar()

        for r in transform_grammar(grammar):
            pprint(r)

        rule_list = list(transform_grammar(grammar))
        pprint(rule_list)
        disco_grammar = Grammar(rule_list, start=grammar.start())
        print(disco_grammar)

        inp = ["a"] * 3
        estimates = 'SXlrgaps', getestimates(disco_grammar, 40, grammar.start())
        print(type(estimates))
        chart, msg = parse(inp, disco_grammar, estimates=estimates)
        print(chart)
        print(msg)
        chart.filter()
        print("filtered chart")
        print(disco_grammar.nonterminals)
        print(type(disco_grammar.nonterminals))

        print(chart)
        # print(help(chart))

        root = chart.root()
        print("root", root, type(root))
        print(chart.indices(root))
        print(chart.itemstr(root))
        print(chart.stats())
        print("root label", chart.label(root))
        print(root, chart.itemid1(chart.label(root), chart.indices(root)))
        for i in range(1, chart.numitems() + 1):
            print(i, chart.label(i), chart.indices(i), chart.numedges(i))
            if True or len(chart.indices(i)) > 1:
                for edge_num in range(chart.numedges(i)):
                    edge = chart.getEdgeForItem(i, edge_num)
                    if isinstance(edge, tuple):
                        print("\t", disco_grammar.nonterminalstr(chart.label(i)) + "[" + str(i) + "]", "->", ' '.join([disco_grammar.nonterminalstr(chart.label(j)) + "[" + str(j) + "]" for j in [edge[1], edge[2]] if j != 0]))
                    else:
                        print("\t", disco_grammar.nonterminalstr(chart.label(i)) + "[" + str(i) + "]", "->", inp[edge])
        print(chart.getEdgeForItem(root, 0))
        # print(lazykbest(chart, 5))

        manager = PyDerivationManager(grammar)
        manager.convert_chart_to_hypergraph(chart, disco_grammar, debug=True)

        file = tempfile.mktemp()
        print(file)
        manager.serialize(bytes(file, encoding="utf-8"))

        gi = PyGrammarInfo(grammar, manager.get_nonterminal_map())
        sm = PyStorageManager()
        la = build_PyLatentAnnotation_initial(grammar, gi, sm)

        vec = py_edge_weight_projection(la, manager, variational=True, debug=True, log_mode=False)
        print(vec)
        self.assertEqual([1.0, 1.0, 1.0, 0.5, 0.5, 0.5, 0.5, 0.25, 0.25, 0.25, 0.25, 1.0], vec)

        vec = py_edge_weight_projection(la, manager, variational=False, debug=True, log_mode=False)
        print(vec)
        self.assertEqual([1.0, 1.0, 1.0, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 1.0], vec)

        der = manager.viterbi_derivation(0, vec, grammar)
        print(der)

        # print(disco_grammar.rulenos)
        # print(disco_grammar.numrules)
        # print(disco_grammar.lexicalbylhs)
        # print(disco_grammar.lexicalbyword)
        # print(disco_grammar.lexicalbynum)
        # print(disco_grammar.origrules, type(disco_grammar.origrules))
        # print(disco_grammar.numbinary)
        # print(disco_grammar.numunary)
        # print(disco_grammar.toid)
        # print(disco_grammar.tolabel)
        # print(disco_grammar.bitpar)
        # striplabelre = re.compile(r'-\d+$')
        # msg = disco_grammar.getmapping(None, None)
        # disco_grammar.getrulemapping(disco_grammar, striplabelre)
        # mapping = disco_grammar.rulemapping
        # print(mapping)
        # for idx, group in enumerate(mapping):
        #     print("Index", idx)
        #     for elem in group:
        #         print(grammar.rule_index(elem))

        # for _, item in zip(range(20), chart.parseforest):
        #     edge = chart.parseforest[item]
        #     print(item, item.binrepr(), item.__repr__(), item.lexidx())
        #     print(type(edge))
        for _ in range(5):
            vec2 = py_edge_weight_projection(la, manager, debug=True, log_mode=True)
            print(vec2)
예제 #8
0
파일: parser.py 프로젝트: tivaro/disco-dop
def readgrammars(resultdir, stages, postagging=None, top='ROOT'):
	"""Read the grammars from a previous experiment.

	Expects a directory ``resultdir`` which contains the relevant grammars and
	the parameter file ``params.prm``, as produced by ``runexp``."""
	for n, stage in enumerate(stages):
		logging.info('reading: %s', stage.name)
		rules = gzip.open('%s/%s.rules.gz' % (resultdir, stage.name)).read()
		lexicon = codecs.getreader('utf-8')(gzip.open('%s/%s.lex.gz' % (
				resultdir, stage.name)))
		grammar = Grammar(rules, lexicon.read(),
				start=top, bitpar=stage.mode.startswith('pcfg')
				or re.match(r'[-.e0-9]+\b', rules), binarized=stage.binarized)
		backtransform = outside = None
		if stage.dop:
			if stage.estimates is not None:
				raise ValueError('not supported')
			if stage.dop == 'doubledop':
				backtransform = gzip.open('%s/%s.backtransform.gz' % (
						resultdir, stage.name)).read().splitlines()
				if n and stage.prune:
					_ = grammar.getmapping(stages[n - 1].grammar,
						striplabelre=re.compile(b'@.+$'),
						neverblockre=re.compile(b'^#[0-9]+|.+}<'),
						splitprune=stage.splitprune and stages[n - 1].split,
						markorigin=stages[n - 1].markorigin)
				else:
					# recoverfragments() relies on this mapping to identify
					# binarization nodes
					_ = grammar.getmapping(None,
						neverblockre=re.compile(b'.+}<'))
			elif n and stage.prune:  # dop reduction
				_ = grammar.getmapping(stages[n - 1].grammar,
					striplabelre=re.compile(b'@[-0-9]+$'),
					neverblockre=re.compile(stage.neverblockre)
						if stage.neverblockre else None,
					splitprune=stage.splitprune and stages[n - 1].split,
					markorigin=stages[n - 1].markorigin)
				if stage.mode == 'dop-rerank':
					grammar.getrulemapping(
							stages[n - 1].grammar, re.compile(br'@[-0-9]+\b'))
			probsfile = '%s/%s.probs.npz' % (resultdir, stage.name)
			if os.path.exists(probsfile):
				probmodels = np.load(probsfile)  # pylint: disable=no-member
				for name in probmodels.files:
					if name != 'default':
						grammar.register(unicode(name), probmodels[name])
		else:  # not stage.dop
			if n and stage.prune:
				_ = grammar.getmapping(stages[n - 1].grammar,
					neverblockre=re.compile(stage.neverblockre)
						if stage.neverblockre else None,
					splitprune=stage.splitprune and stages[n - 1].split,
					markorigin=stages[n - 1].markorigin)
			if stage.estimates in ('SX', 'SXlrgaps'):
				if stage.estimates == 'SX' and grammar.maxfanout != 1:
					raise ValueError('SX estimate requires PCFG.')
				if stage.mode != 'plcfrs':
					raise ValueError('estimates require parser w/agenda.')
				outside = np.load(  # pylint: disable=no-member
						'%s/%s.outside.npz' % (resultdir, stage.name))['outside']
				logging.info('loaded %s estimates', stage.estimates)
			elif stage.estimates:
				raise ValueError('unrecognized value; specify SX or SXlrgaps.')

		if stage.mode.startswith('pcfg-bitpar'):
			if grammar.maxfanout != 1:
				raise ValueError('bitpar requires a PCFG.')

		_sumsto1, msg = grammar.testgrammar()
		logging.info('%s: %s', stage.name, msg)
		stage.update(grammar=grammar, backtransform=backtransform,
				outside=outside)
	if postagging and postagging.method == 'unknownword':
		postagging.unknownwordfun = UNKNOWNWORDFUNC[postagging.model]
		postagging.lexicon = {w for w in stages[0].grammar.lexicalbyword
				if not w.startswith(UNK)}
		postagging.sigs = {w for w in stages[0].grammar.lexicalbyword
				if w.startswith(UNK)}
예제 #9
0
class DiscodopKbestParser(AbstractParser):
    def __init__(self,
                 grammar,
                 input=None,
                 save_preprocessing=None,
                 load_preprocessing=None,
                 k=50,
                 heuristics=None,
                 la=None,
                 variational=False,
                 sum_op=False,
                 nontMap=None,
                 cfg_ctf=False,
                 beam_beta=0.0,
                 beam_delta=50,
                 pruning_k=10000,
                 grammarInfo=None,
                 projection_mode=False,
                 latent_viterbi_mode=False,
                 secondaries=None
                 ):
        rule_list = list(transform_grammar(grammar))
        self.disco_grammar = Grammar(rule_list, start=grammar.start())
        self.chart = None
        self.input = input
        self.grammar = grammar
        self.k = k
        self.beam_beta = beam_beta # beam pruning factor, between 0.0 and 1.0; 0.0 to disable.
        self.beam_delta = beam_delta  # maximum span length to which beam_beta is applied
        self.counter = 0
        self.la = la
        self.nontMap = nontMap
        self.variational = variational
        self.op = add if sum_op else prod
        self.debug = False
        self.log_mode = True
        self.estimates = None
        self.cfg_approx = cfg_ctf
        self.pruning_k = pruning_k
        self.grammarInfo = grammarInfo
        self.projection_mode = projection_mode
        self.latent_viterbi_mode = latent_viterbi_mode
        self.secondaries = [] if secondaries is None else secondaries
        self.secondary_mode = "DEFAULT"
        self.k_best_reranker = None
        if grammarInfo is not None:
            if isinstance(self.la, PyLatentAnnotation):
                assert self.la.check_rule_split_alignment()
            else:
                for l in self.la:
                    assert l.check_rule_split_alignment()
        if cfg_ctf:
            cfg_rule_list = list(transform_grammar_cfg_approx(grammar))
            self.disco_cfg_grammar = Grammar(cfg_rule_list, start=grammar.start())
            self.disco_grammar.getmapping(self.disco_cfg_grammar, re.compile('\*[0-9]+$'), None, True, True)
        # self.estimates = 'SXlrgaps', getestimates(self.disco_grammar, 40, grammar.start())

    def best(self):
        pass

    def recognized(self):
        if self.chart and self.chart.root() != 0:
            return True
        else:
            return False

    def max_rule_product_derivation(self):
        if self.recognized():
            return self.__projection_based_derivation_tree(self.la, variational=False, op=prod)

    def max_rule_sum_derivation(self):
        if self.recognized():
            return self.__projection_based_derivation_tree(self.la, variational=False,
                                                           op=add)

    def variational_derivation(self):
        if self.recognized():
            return self.__projection_based_derivation_tree(self, variational=True, op=prod)

    def __projection_based_derivation_tree(self, la, variational=False, op=prod):
        if self.nontMap is None:
            print("A nonterminal map is required for weight projection based parsing!")
            return None
        manager = PyDerivationManager(self.grammar, self.nontMap)
        manager.convert_chart_to_hypergraph(self.chart, self.disco_grammar, debug=False)
        if self.grammarInfo is not None:
            assert manager.is_consistent_with_grammar(self.grammarInfo)
        manager.set_io_cycle_limit(200)
        manager.set_io_precision(0.000001)

        if not isinstance(la, list):
            la = [la]

        edge_weights = None

        for l in la:
            edge_weights_l = py_edge_weight_projection(l, manager, variational=variational, debug=self.debug,
                                                 log_mode=self.log_mode)
            if edge_weights is None:
                edge_weights = edge_weights_l
            else:
                if self.log_mode:
                    edge_weights = [w1 + w2 for w1, w2 in zip(edge_weights, edge_weights_l)]
                else:
                    edge_weights = [op(w1, w2) for w1, w2 in zip(edge_weights, edge_weights_l)]

        if self.debug:
            nans = 0
            infs = 0
            zeros = 0
            for weight in edge_weights:
                if weight == float("nan"):
                    nans += 1
                if weight == float("inf") or weight == float("-inf"):
                    infs += 1
                if weight == 0.0:
                    zeros += 1
            print("[", len(edge_weights), nans, infs, zeros, "]")
            if len(edge_weights) < 100:
                print(edge_weights)

        der = manager.viterbi_derivation(0, edge_weights, self.grammar, op=op, log_mode=self.log_mode)
        if der is None:
            print("p", end="")
            der = self.latent_viterbi_derivation(debug=self.debug)
        if der is not None:
            der = LCFRSDerivationWrapper(der)
        if der is None:
            _, der = next(self.k_best_derivation_trees())
        return der

    def set_secondary_mode(self, mode):
        self.secondary_mode = mode

    def latent_viterbi_derivation(self, debug=False):
        manager = PyDerivationManager(self.grammar, self.nontMap)
        manager.convert_chart_to_hypergraph(self.chart, self.disco_grammar, debug=False)
        if debug:
            manager.serialize(b'/tmp/my_debug_hypergraph.hg')
        if isinstance(self.la, list):
            la = self.la[0]
        else:
            la = self.la
        vit_der = manager.latent_viterbi_derivation(0, la, self.grammar, debug=debug)
        # if len(self.input) < 15 and not debug:
        #     for weight, der in self.k_best_derivation_trees():
        #         if der != vit_der:
        #             print(weight, der, vit_der)
        #             vit_der2 = self.latent_viterbi_derivation(debug=True)
        #             print("vit2", vit_der2)
        #             if vit_der2 != vit_der:
        #                 print("first and second viterbi derivation differ")
        #             if vit_der2 == der:
        #                 print("second viterbi derivation = 1-best-disco-dop derivation")
        #         print("##############################", flush=True)
        #         break
        #         # raise Exception("too much to read")
        if vit_der is not None:
            vit_der = LCFRSDerivationWrapper(vit_der)
        return vit_der

    def best_derivation_tree(self):
        if (self.projection_mode and self.secondary_mode == "DEFAULT") \
                or self.secondary_mode in {"VARIATIONAL", "MAX-RULE-PRODUCT"}:
            variational = self.secondary_mode == "VARIATIONAL" or self.variational and self.secondary_mode == "DEFAULT"
            return self.__projection_based_derivation_tree(self.la, variational=variational, op=self.op)
        elif self.latent_viterbi_mode and self.secondary_mode == "DEFAULT" \
                or self.secondary_mode == "LATENT-VITERBI":
            return self.latent_viterbi_derivation()
        elif self.secondary_mode == "LATENT-RERANK":
            return self.k_best_reranker.best_derivation_tree()
        else:
            for weight, tree in self.k_best_derivation_trees():
                return tree

    def all_derivation_trees(self):
        pass

    def set_input(self, parser_input):
        self.input = parser_input

    def parse(self):
        self.counter += 1
        if self.cfg_approx:
            chart, msg = pcfg.parse(self.input,
                                    self.disco_cfg_grammar,
                                    beam_beta=self.beam_beta,
                                    beam_delta=self.beam_delta)
            if chart:
                chart.filter()
                whitelist, msg = prunechart(chart,
                                            self.disco_grammar,
                                            k=self.pruning_k,
                                            splitprune=True,
                                            markorigin=True,
                                            finecfg=False)
                try:
                    self.chart, msg = parse(self.input,
                                            self.disco_grammar,
                                            estimates=self.estimates,
                                            whitelist=whitelist,
                                            splitprune=True,
                                            markorigin=True,
                                            exhaustive=True)
                except ValueError as e:
                    self.chart = None
                    print("discodop error", e, file=stderr)
        else:
            self.chart, msg = parse(self.input,
                                    self.disco_grammar,
                                    estimates=self.estimates,
                                    beam_beta=self.beam_beta,
                                    beam_delta=self.beam_delta,
                                    exhaustive=True)
        # if self.counter > 86:
        #     print(self.input)
        #     print(self.chart)
        #     print(msg)
        if self.chart:
            self.chart.filter()

    def clear(self):
        self.input = None
        self.chart = None
        if self.k_best_reranker:
            self.k_best_reranker.k_best_list = None
            self.k_best_reranker.ranking = None
            self.k_best_reranker.ranker = None

    def k_best_derivation_trees(self):
        for tree_string, weight in lazykbest(self.chart, self.k):
            try:
                tree = nltk.Tree.fromstring(tree_string)
                yield weight, DiscodopDerivation(tree, self.grammar)
            except ValueError:
                print("\nill-bracketed string:", tree_string, file=stderr)
예제 #10
0
def test():
	""" Run some tests. """
	from discodop import plcfrs
	from discodop.containers import Grammar
	from discodop.treebank import NegraCorpusReader
	from discodop.treetransforms import binarize, unbinarize, \
			addfanoutmarkers, removefanoutmarkers
	from discodop.disambiguation import recoverfragments
	from discodop.kbest import lazykbest
	from discodop.fragments import getfragments
	logging.basicConfig(level=logging.DEBUG, format='%(message)s')
	filename = "alpinosample.export"
	corpus = NegraCorpusReader('.', filename, punct='move')
	sents = list(corpus.sents().values())
	trees = [addfanoutmarkers(binarize(a.copy(True), horzmarkov=1))
			for a in list(corpus.parsed_sents().values())[:10]]

	print('plcfrs')
	lcfrs = Grammar(treebankgrammar(trees, sents), start=trees[0].label)
	print(lcfrs)

	print('dop reduction')
	grammar = Grammar(dopreduction(trees[:2], sents[:2])[0],
			start=trees[0].label)
	print(grammar)
	grammar.testgrammar()

	fragments = getfragments(trees, sents, 1)
	debug = '--debug' in sys.argv
	grammarx, backtransform, _ = doubledop(trees, fragments, debug=debug)
	print('\ndouble dop grammar')
	grammar = Grammar(grammarx, start=trees[0].label)
	grammar.getmapping(grammar, striplabelre=None,
			neverblockre=re.compile(b'^#[0-9]+|.+}<'),
			splitprune=False, markorigin=False)
	print(grammar)
	assert grammar.testgrammar(), "DOP1 should sum to 1."
	for tree, sent in zip(corpus.parsed_sents().values(), sents):
		print("sentence:", ' '.join(a.encode('unicode-escape').decode()
				for a in sent))
		chart, msg = plcfrs.parse(sent, grammar, exhaustive=True)
		print('\n', msg, end='')
		print("\ngold ", tree)
		print("double dop", end='')
		if chart:
			mpp = {}
			parsetrees = {}
			derivations, _ = lazykbest(chart, 1000, b'}<')
			for d, (t, p) in zip(chart.rankededges[chart.root()], derivations):
				r = Tree(recoverfragments(d.getkey(), chart,
					grammar, backtransform))
				r = str(removefanoutmarkers(unbinarize(r)))
				mpp[r] = mpp.get(r, 0.0) + exp(-p)
				parsetrees.setdefault(r, []).append((t, p))
			print(len(mpp), 'parsetrees', end='')
			print(sum(map(len, parsetrees.values())), 'derivations')
			for t, tp in sorted(mpp.items(), key=itemgetter(1)):
				print(tp, '\n', t, end='')
				print("match:", t == str(tree))
				assert len(set(parsetrees[t])) == len(parsetrees[t])
				if not debug:
					continue
				for deriv, p in sorted(parsetrees[t], key=itemgetter(1)):
					print(' <= %6g %s' % (exp(-p), deriv))
		else:
			print("no parse")
			print(chart)
		print()
	tree = Tree.parse("(ROOT (S (F (E (S (C (B (A 0))))))))", parse_leaf=int)
	Grammar(treebankgrammar([tree], [[str(a) for a in range(10)]]))
예제 #11
0
def main():
	""" Command line interface to create grammars from treebanks. """
	import gzip
	from getopt import gnu_getopt, GetoptError
	from discodop.treetransforms import addfanoutmarkers, canonicalize
	from discodop.treebank import getreader, splitpath
	from discodop.fragments import getfragments
	logging.basicConfig(level=logging.DEBUG, format='%(message)s')
	shortoptions = ''
	flags = ('gzip', 'packed')
	options = ('inputfmt=', 'inputenc=', 'dopestimator=', 'numproc=')
	try:
		opts, args = gnu_getopt(sys.argv[1:], shortoptions, flags + options)
		model, treebankfile, grammarfile = args
	except (GetoptError, ValueError) as err:
		print('error: %r\n%s' % (err, USAGE))
		sys.exit(2)
	opts = dict(opts)
	assert model in ('pcfg', 'plcfrs', 'dopreduction', 'doubledop'), (
		'unrecognized model: %r' % model)
	assert opts.get('dopestimator', 'dop1') in ('dop1', 'ewe', 'shortest'), (
		'unrecognized estimator: %r' % opts['dopestimator'])

	# read treebank
	reader = getreader(opts.get('--inputfmt', 'export'))
	corpus = reader(*splitpath(treebankfile),
			encoding=opts.get('--inputenc', 'utf8'))
	trees = list(corpus.parsed_sents().values())
	sents = list(corpus.sents().values())
	for a in trees:
		canonicalize(a)
		addfanoutmarkers(a)

	# read off grammar
	if model in ('pcfg', 'plcfrs'):
		grammar = treebankgrammar(trees, sents)
	elif model == 'dopreduction':
		grammar, altweights = dopreduction(trees, sents,
				packedgraph='--packed' in opts)
	elif model == 'doubledop':
		numproc = int(opts.get('--numproc', 1))
		fragments = getfragments(trees, sents, numproc)
		grammar, backtransform, altweights = doubledop(trees, fragments)
	if opts.get('--dopestimator', 'dop1') == 'ewe':
		grammar = [(rule, w) for (rule, _), w in
				zip(grammar, altweights['ewe'])]
	elif opts.get('--dopestimator', 'dop1') == 'shortest':
		grammar = [(rule, w) for (rule, _), w in
				zip(grammar, altweights['shortest'])]

	print(grammarinfo(grammar))
	rules = grammarfile + '.rules'
	lexicon = grammarfile + '.lex'
	if '--gzip' in opts:
		myopen = gzip.open
		rules += '.gz'
		lexicon += '.gz'
	else:
		myopen = open
	bitpar = model == 'pcfg' or opts.get('--inputfmt') == 'bracket'
	rules, lexicon = write_lcfrs_grammar(grammar, bitpar=bitpar)
	try:
		from discodop.containers import Grammar
	except ImportError:
		pass
	else:
		cgrammar = Grammar(rules, lexicon)
		cgrammar.testgrammar()
	# write output
	with myopen(rules, 'w') as rulesfile:
		rulesfile.write(rules)
	with codecs.getwriter('utf-8')(myopen(lexicon, 'w')) as lexiconfile:
		lexiconfile.write(lexicon)
	if model == 'doubledop':
		backtransformfile = '%s.backtransform%s' % (grammarfile,
			'.gz' if '--gzip' in opts else '')
		myopen(backtransformfile, 'w').writelines(
				'%s\n' % a for a in backtransform)
		print('wrote backtransform to', backtransformfile)
	print('wrote grammar to %s and %s.' % (rules, lexicon))
예제 #12
0
def test_grammar(debug=False):
	"""Demonstrate grammar extraction."""
	from discodop.grammar import treebankgrammar, dopreduction, doubledop
	from discodop import plcfrs
	from discodop.containers import Grammar
	from discodop.treebank import NegraCorpusReader
	from discodop.treetransforms import addfanoutmarkers, removefanoutmarkers
	from discodop.disambiguation import recoverfragments
	from discodop.kbest import lazykbest
	from math import exp
	corpus = NegraCorpusReader('alpinosample.export', punct='move')
	sents = list(corpus.sents().values())
	trees = [addfanoutmarkers(binarize(a.copy(True), horzmarkov=1))
			for a in list(corpus.trees().values())[:10]]
	if debug:
		print('plcfrs\n', Grammar(treebankgrammar(trees, sents)))
		print('dop reduction')
	grammar = Grammar(dopreduction(trees[:2], sents[:2])[0],
			start=trees[0].label)
	if debug:
		print(grammar)
	_ = grammar.testgrammar()

	grammarx, backtransform, _, _ = doubledop(trees, sents,
			debug=debug, numproc=1)
	if debug:
		print('\ndouble dop grammar')
	grammar = Grammar(grammarx, start=trees[0].label)
	grammar.getmapping(grammar, striplabelre=None,
			neverblockre=re.compile(b'^#[0-9]+|.+}<'),
			splitprune=False, markorigin=False)
	if debug:
		print(grammar)
	assert grammar.testgrammar()[0], "RFE should sum to 1."
	for tree, sent in zip(corpus.trees().values(), sents):
		if debug:
			print("sentence:", ' '.join(a.encode('unicode-escape').decode()
					for a in sent))
		chart, msg = plcfrs.parse(sent, grammar, exhaustive=True)
		if debug:
			print('\n', msg, '\ngold ', tree, '\n', 'double dop', end='')
		if chart:
			mpp, parsetrees = {}, {}
			derivations, _ = lazykbest(chart, 1000, b'}<')
			for d, (t, p) in zip(chart.rankededges[chart.root()], derivations):
				r = Tree(recoverfragments(d.key, chart, backtransform))
				r = str(removefanoutmarkers(unbinarize(r)))
				mpp[r] = mpp.get(r, 0.0) + exp(-p)
				parsetrees.setdefault(r, []).append((t, p))
			if debug:
				print(len(mpp), 'parsetrees',
						sum(map(len, parsetrees.values())), 'derivations')
			for t, tp in sorted(mpp.items(), key=itemgetter(1)):
				if debug:
					print(tp, t, '\nmatch:', t == str(tree))
				if len(set(parsetrees[t])) != len(parsetrees[t]):
					print('chart:\n', chart)
					assert len(set(parsetrees[t])) == len(parsetrees[t])
				if debug:
					for deriv, p in sorted(parsetrees[t], key=itemgetter(1)):
						print(' <= %6g %s' % (exp(-p), deriv))
		elif debug:
			print('no parse\n', chart)
		if debug:
			print()
	tree = Tree.parse("(ROOT (S (F (E (S (C (B (A 0))))))))", parse_leaf=int)
	Grammar(treebankgrammar([tree], [[str(a) for a in range(10)]]))
예제 #13
0
파일: parser.py 프로젝트: arne-cl/disco-dop
def main():
	""" Handle command line arguments. """
	print('PLCFRS parser - Andreas van Cranenburgh', file=sys.stderr)
	options = 'ctf= prob mpd'.split()
	try:
		opts, args = gnu_getopt(sys.argv[1:], 'u:b:s:z', options)
		assert 2 <= len(args) <= 6, 'incorrect number of arguments'
	except (GetoptError, AssertionError) as err:
		print(err, USAGE)
		return
	for n, filename in enumerate(args):
		assert os.path.exists(filename), (
				'file %d not found: %r' % (n + 1, filename))
	opts = dict(opts)
	k = int(opts.get('-b', 1))
	top = opts.get('-s', 'TOP')
	threshold = int(opts.get('--ctf', 0))
	prob = '--prob' in opts
	oneline = '-z' in opts
	rules = (gzip.open if args[0].endswith('.gz') else open)(args[0]).read()
	lexicon = codecs.getreader('utf-8')((gzip.open if args[1].endswith('.gz')
			else open)(args[1])).read()
	bitpar = rules[0] in string.digits
	coarse = Grammar(rules, lexicon, start=top, bitpar=bitpar)
	stages = []
	stage = DEFAULTSTAGE.copy()
	stage.update(
			name='coarse',
			mode='pcfg' if bitpar else 'plcfrs',
			grammar=coarse,
			backtransform=None,
			m=k)
	stages.append(DictObj(stage))
	if 4 <= len(args) <= 6 and threshold:
		rules = (gzip.open if args[2].endswith('.gz') else open)(args[2]).read()
		lexicon = codecs.getreader('utf-8')((gzip.open
				if args[3].endswith('.gz') else open)(args[3])).read()
		# detect bitpar format
		bitpar = rules[0] in string.digits
		fine = Grammar(rules, lexicon, start=top, bitpar=bitpar)
		fine.getmapping(coarse, striplabelre=re.compile(b'@.+$'))
		stage = DEFAULTSTAGE.copy()
		stage.update(
				name='fine',
				mode='pcfg' if bitpar else 'plcfrs',
				grammar=fine,
				backtransform=None,
				m=k,
				prune=True,
				k=threshold,
				objective='mpd' if '--mpd' in opts else 'mpp')
		stages.append(DictObj(stage))
		infile = (io.open(args[4], encoding='utf-8')
				if len(args) >= 5 else sys.stdin)
		out = (io.open(args[5], 'w', encoding='utf-8')
				if len(args) == 6 else sys.stdout)
	else:
		infile = (io.open(args[2], encoding='utf-8')
				if len(args) >= 3 else sys.stdin)
		out = (io.open(args[3], 'w', encoding='utf-8')
				if len(args) == 4 else sys.stdout)
	doparsing(Parser(stages), infile, out, prob, oneline)
예제 #14
0
파일: parser.py 프로젝트: arne-cl/disco-dop
def readgrammars(resultdir, stages, postagging=None, top='ROOT'):
	""" Read the grammars from a previous experiment.
	Expects a directory 'resultdir' which contains the relevant grammars and
	the parameter file 'params.prm', as produced by runexp. """
	for n, stage in enumerate(stages):
		logging.info('reading: %s', stage.name)
		rules = gzip.open('%s/%s.rules.gz' % (resultdir, stage.name))
		lexicon = codecs.getreader('utf-8')(gzip.open('%s/%s.lex.gz' % (
				resultdir, stage.name)))
		grammar = Grammar(rules.read(), lexicon.read(),
				start=top, bitpar=stage.mode.startswith('pcfg'))
		backtransform = None
		if stage.dop:
			assert stage.useestimates is None, 'not supported'
			if stage.usedoubledop:
				backtransform = gzip.open('%s/%s.backtransform.gz' % (
						resultdir, stage.name)).read().splitlines()
				if n and stage.prune:
					_ = grammar.getmapping(stages[n - 1].grammar,
						striplabelre=re.compile(b'@.+$'),
						neverblockre=re.compile(b'^#[0-9]+|.+}<'),
						splitprune=stage.splitprune and stages[n - 1].split,
						markorigin=stages[n - 1].markorigin)
				else:
					# recoverfragments() relies on this mapping to identify
					# binarization nodes
					_ = grammar.getmapping(None,
						neverblockre=re.compile(b'.+}<'))
			elif n and stage.prune:  # dop reduction
				_ = grammar.getmapping(stages[n - 1].grammar,
					striplabelre=re.compile(b'@[-0-9]+$'),
					neverblockre=re.compile(stage.neverblockre)
						if stage.neverblockre else None,
					splitprune=stage.splitprune and stages[n - 1].split,
					markorigin=stages[n - 1].markorigin)
				if stage.mode == 'dop-rerank':
					grammar.getrulemapping(stages[n - 1].grammar)
			probmodels = np.load('%s/%s.probs.npz' % (resultdir, stage.name))
			for name in probmodels.files:
				if name != 'default':
					grammar.register(unicode(name), probmodels[name])
		else:  # not stage.dop
			if n and stage.prune:
				_ = grammar.getmapping(stages[n - 1].grammar,
					neverblockre=re.compile(stage.neverblockre)
						if stage.neverblockre else None,
					splitprune=stage.splitprune and stages[n - 1].split,
					markorigin=stages[n - 1].markorigin)
		if stage.mode == 'pcfg-bitpar':
			assert grammar.maxfanout == 1
		grammar.testgrammar()
		stage.update(grammar=grammar, backtransform=backtransform, outside=None)
	if postagging and postagging['method'] == 'unknownword':
		postagging['unknownwordfun'] = getunknownwordfun(postagging['model'])
		postagging['lexicon'] = {w for w in stages[0].grammar.lexicalbyword
				if not w.startswith(UNK)}
		postagging['sigs'] = {w for w in stages[0].grammar.lexicalbyword
				if w.startswith(UNK)}
예제 #15
0
파일: runexp.py 프로젝트: arne-cl/disco-dop
def getgrammars(trees, sents, stages, bintype, horzmarkov, vertmarkov, factor,
		tailmarker, revmarkov, leftmostunary, rightmostunary, pospa, markhead,
		fanout_marks_before_bin, testmaxwords, resultdir, numproc,
		lexmodel, simplelexsmooth, top, relationalrealizational):
	""" Apply binarization and read off the requested grammars. """
	# fixme: this n should correspond to sentence id
	tbfanout, n = treebankfanout(trees)
	logging.info('treebank fan-out before binarization: %d #%d\n%s\n%s',
			tbfanout, n, trees[n], ' '.join(sents[n]))
	# binarization
	begin = time.clock()
	if fanout_marks_before_bin:
		trees = [addfanoutmarkers(t) for t in trees]
	if bintype == 'binarize':
		bintype += ' %s h=%d v=%d %s' % (factor, horzmarkov, vertmarkov,
			'tailmarker' if tailmarker else '')
		for a in trees:
			binarize(a, factor=factor, tailmarker=tailmarker,
					horzmarkov=horzmarkov, vertmarkov=vertmarkov,
					leftmostunary=leftmostunary, rightmostunary=rightmostunary,
					reverse=revmarkov, pospa=pospa,
					headidx=-1 if markhead else None,
					filterfuncs=(relationalrealizational['ignorefunctions']
						+ (relationalrealizational['adjunctionlabel'], ))
						if relationalrealizational else ())
	elif bintype == 'optimal':
		trees = [Tree.convert(optimalbinarize(tree))
						for n, tree in enumerate(trees)]
	elif bintype == 'optimalhead':
		trees = [Tree.convert(optimalbinarize(tree, headdriven=True,
				h=horzmarkov, v=vertmarkov)) for n, tree in enumerate(trees)]
	trees = [addfanoutmarkers(t) for t in trees]
	logging.info('binarized %s cpu time elapsed: %gs',
						bintype, time.clock() - begin)
	logging.info('binarized treebank fan-out: %d #%d', *treebankfanout(trees))
	trees = [canonicalize(a).freeze() for a in trees]

	for n, stage in enumerate(stages):
		if stage.split:
			traintrees = [binarize(splitdiscnodes(Tree.convert(a),
					stage.markorigin), childchar=':').freeze() for a in trees]
			logging.info('splitted discontinuous nodes')
		else:
			traintrees = trees
		if stage.mode.startswith('pcfg'):
			assert tbfanout == 1 or stage.split
		backtransform = None
		if stage.dop:
			if stage.usedoubledop:
				# find recurring fragments in treebank,
				# as well as depth 1 'cover' fragments
				fragments = getfragments(traintrees, sents, numproc,
						iterate=stage.iterate, complement=stage.complement)
				xgrammar, backtransform, altweights = doubledop(
						traintrees, fragments)
			else:  # DOP reduction
				xgrammar, altweights = dopreduction(
						traintrees, sents, packedgraph=stage.packedgraph)
			nodes = sum(len(list(a.subtrees())) for a in traintrees)
			if lexmodel and simplelexsmooth:
				newrules = simplesmoothlexicon(lexmodel)
				xgrammar.extend(newrules)
				for weights in altweights.values():
					weights.extend(w for _, w in newrules)
			elif lexmodel:
				xgrammar = smoothlexicon(xgrammar, lexmodel)
			msg = grammarinfo(xgrammar)
			rules, lexicon = write_lcfrs_grammar(
					xgrammar, bitpar=stage.mode.startswith('pcfg'))
			grammar = Grammar(rules, lexicon, start=top,
					bitpar=stage.mode.startswith('pcfg'))
			for name in altweights:
				grammar.register(u'%s' % name, altweights[name])
			with gzip.open('%s/%s.rules.gz' % (
					resultdir, stage.name), 'wb') as rulesfile:
				rulesfile.write(rules)
			with codecs.getwriter('utf-8')(gzip.open('%s/%s.lex.gz' % (
					resultdir, stage.name), 'wb')) as lexiconfile:
				lexiconfile.write(lexicon)
			logging.info('DOP model based on %d sentences, %d nodes, '
				'%d nonterminals', len(traintrees), nodes, len(grammar.toid))
			logging.info(msg)
			if stage.estimator != 'dop1':
				grammar.switch(u'%s' % stage.estimator)
			_sumsto1 = grammar.testgrammar()
			if stage.usedoubledop:
				# backtransform keys are line numbers to rules file;
				# to see them together do:
				# $ paste <(zcat dop.rules.gz) <(zcat dop.backtransform.gz)
				with codecs.getwriter('ascii')(gzip.open(
						'%s/%s.backtransform.gz' % (resultdir, stage.name),
						'w')) as out:
					out.writelines('%s\n' % a for a in backtransform)
				if n and stage.prune:
					msg = grammar.getmapping(stages[n - 1].grammar,
						striplabelre=None if stages[n - 1].dop
							else re.compile(b'@.+$'),
						neverblockre=re.compile(b'.+}<'),
						splitprune=stage.splitprune and stages[n - 1].split,
						markorigin=stages[n - 1].markorigin)
				else:
					# recoverfragments() relies on this mapping to identify
					# binarization nodes
					msg = grammar.getmapping(None,
						striplabelre=None,
						neverblockre=re.compile(b'.+}<'),
						splitprune=False, markorigin=False)
				logging.info(msg)
			elif n and stage.prune:  # dop reduction
				msg = grammar.getmapping(stages[n - 1].grammar,
					striplabelre=None if stages[n - 1].dop
						and not stages[n - 1].usedoubledop
						else re.compile(b'@[-0-9]+$'),
					neverblockre=re.compile(stage.neverblockre)
						if stage.neverblockre else None,
					splitprune=stage.splitprune and stages[n - 1].split,
					markorigin=stages[n - 1].markorigin)
				if stage.mode == 'dop-rerank':
					grammar.getrulemapping(stages[n - 1].grammar)
				logging.info(msg)
			# write prob models
			np.savez_compressed('%s/%s.probs.npz' % (resultdir, stage.name),
					**{name: mod for name, mod
						in zip(grammar.modelnames, grammar.models)})
		else:  # not stage.dop
			xgrammar = treebankgrammar(traintrees, sents)
			logging.info('induced %s based on %d sentences',
				('PCFG' if tbfanout == 1 or stage.split else 'PLCFRS'),
				len(traintrees))
			if stage.split or os.path.exists('%s/pcdist.txt' % resultdir):
				logging.info(grammarinfo(xgrammar))
			else:
				logging.info(grammarinfo(xgrammar,
						dump='%s/pcdist.txt' % resultdir))
			if lexmodel and simplelexsmooth:
				newrules = simplesmoothlexicon(lexmodel)
				xgrammar.extend(newrules)
			elif lexmodel:
				xgrammar = smoothlexicon(xgrammar, lexmodel)
			rules, lexicon = write_lcfrs_grammar(
					xgrammar, bitpar=stage.mode.startswith('pcfg'))
			grammar = Grammar(rules, lexicon, start=top,
					bitpar=stage.mode.startswith('pcfg'))
			with gzip.open('%s/%s.rules.gz' % (
					resultdir, stage.name), 'wb') as rulesfile:
				rulesfile.write(rules)
			with codecs.getwriter('utf-8')(gzip.open('%s/%s.lex.gz' % (
					resultdir, stage.name), 'wb')) as lexiconfile:
				lexiconfile.write(lexicon)
			_sumsto1 = grammar.testgrammar()
			if n and stage.prune:
				msg = grammar.getmapping(stages[n - 1].grammar,
					striplabelre=None,
					neverblockre=re.compile(stage.neverblockre)
						if stage.neverblockre else None,
					splitprune=stage.splitprune and stages[n - 1].split,
					markorigin=stages[n - 1].markorigin)
				logging.info(msg)
		logging.info('wrote grammar to %s/%s.{rules,lex%s}.gz', resultdir,
				stage.name, ',backtransform' if stage.usedoubledop else '')

		outside = None
		if stage.getestimates == 'SX':
			assert tbfanout == 1 or stage.split, 'SX estimate requires PCFG.'
			logging.info('computing PCFG estimates')
			begin = time.clock()
			outside = getpcfgestimates(grammar, testmaxwords,
					grammar.toid[trees[0].label])
			logging.info('estimates done. cpu time elapsed: %gs',
					time.clock() - begin)
			np.savez('pcfgoutside.npz', outside=outside)
			logging.info('saved PCFG estimates')
		elif stage.useestimates == 'SX':
			assert tbfanout == 1 or stage.split, 'SX estimate requires PCFG.'
			assert stage.mode != 'pcfg', (
					'estimates require agenda-based parser.')
			outside = np.load('pcfgoutside.npz')['outside']
			logging.info('loaded PCFG estimates')
		if stage.getestimates == 'SXlrgaps':
			logging.info('computing PLCFRS estimates')
			begin = time.clock()
			outside = getestimates(grammar, testmaxwords,
					grammar.toid[trees[0].label])
			logging.info('estimates done. cpu time elapsed: %gs',
						time.clock() - begin)
			np.savez('outside.npz', outside=outside)
			logging.info('saved estimates')
		elif stage.useestimates == 'SXlrgaps':
			outside = np.load('outside.npz')['outside']
			logging.info('loaded PLCFRS estimates')
		stage.update(grammar=grammar, backtransform=backtransform,
				outside=outside)
예제 #16
0
파일: runexp.py 프로젝트: tivaro/disco-dop
def getgrammars(trees, sents, stages, testmaxwords, resultdir,
		numproc, lexmodel, simplelexsmooth, top):
	"""Read off the requested grammars."""
	tbfanout, n = treebank.treebankfanout(trees)
	logging.info('binarized treebank fan-out: %d #%d', tbfanout, n)
	for n, stage in enumerate(stages):
		if stage.split:
			traintrees = [treetransforms.binarize(
					treetransforms.splitdiscnodes(
						Tree.convert(a), stage.markorigin),
					childchar=':', dot=True, ids=grammar.UniqueIDs()).freeze()
					for a in trees]
			logging.info('splitted discontinuous nodes')
		else:
			traintrees = trees
		if stage.mode.startswith('pcfg'):
			if tbfanout != 1 and not stage.split:
				raise ValueError('Cannot extract PCFG from treebank '
						'with discontinuities.')
		backtransform = extrarules = None
		if lexmodel and simplelexsmooth:
			extrarules = lexicon.simplesmoothlexicon(lexmodel)
		if stage.dop:
			if stage.dop == 'doubledop':
				(xgrammar, backtransform, altweights, fragments
					) = grammar.doubledop(
						traintrees, sents, binarized=stage.binarized,
						iterate=stage.iterate, complement=stage.complement,
						numproc=numproc, extrarules=extrarules)
				# dump fragments
				with codecs.getwriter('utf-8')(gzip.open('%s/%s.fragments.gz' %
						(resultdir, stage.name), 'w')) as out:
					out.writelines('%s\t%d\n' % (treebank.writetree(a, b, 0,
							'bracket' if stage.mode.startswith('pcfg')
							else 'discbracket').rstrip(), sum(c.values()))
							for (a, b), c in fragments.items())
			elif stage.dop == 'reduction':
				xgrammar, altweights = grammar.dopreduction(
						traintrees, sents, packedgraph=stage.packedgraph,
						extrarules=extrarules)
			else:
				raise ValueError('unrecognized DOP model: %r' % stage.dop)
			nodes = sum(len(list(a.subtrees())) for a in traintrees)
			if lexmodel and not simplelexsmooth:  # FIXME: altweights?
				xgrammar = lexicon.smoothlexicon(xgrammar, lexmodel)
			msg = grammar.grammarinfo(xgrammar)
			rules, lex = grammar.write_lcfrs_grammar(
					xgrammar, bitpar=stage.mode.startswith('pcfg'))
			gram = Grammar(rules, lex, start=top,
					bitpar=stage.mode.startswith('pcfg'),
					binarized=stage.binarized)
			for name in altweights:
				gram.register(u'%s' % name, altweights[name])
			with gzip.open('%s/%s.rules.gz' % (
					resultdir, stage.name), 'wb') as rulesfile:
				rulesfile.write(rules)
			with codecs.getwriter('utf-8')(gzip.open('%s/%s.lex.gz' % (
					resultdir, stage.name), 'wb')) as lexiconfile:
				lexiconfile.write(lex)
			logging.info('DOP model based on %d sentences, %d nodes, '
				'%d nonterminals', len(traintrees), nodes, len(gram.toid))
			logging.info(msg)
			if stage.estimator != 'rfe':
				gram.switch(u'%s' % stage.estimator)
			logging.info(gram.testgrammar()[1])
			if stage.dop == 'doubledop':
				# backtransform keys are line numbers to rules file;
				# to see them together do:
				# $ paste <(zcat dop.rules.gz) <(zcat dop.backtransform.gz)
				with codecs.getwriter('ascii')(gzip.open(
						'%s/%s.backtransform.gz' % (resultdir, stage.name),
						'w')) as out:
					out.writelines('%s\n' % a for a in backtransform)
				if n and stage.prune:
					msg = gram.getmapping(stages[n - 1].grammar,
						striplabelre=None if stages[n - 1].dop
							else re.compile(b'@.+$'),
						neverblockre=re.compile(b'.+}<'),
						splitprune=stage.splitprune and stages[n - 1].split,
						markorigin=stages[n - 1].markorigin)
				else:
					# recoverfragments() relies on this mapping to identify
					# binarization nodes
					msg = gram.getmapping(None,
						striplabelre=None,
						neverblockre=re.compile(b'.+}<'),
						splitprune=False, markorigin=False)
				logging.info(msg)
			elif n and stage.prune:  # dop reduction
				msg = gram.getmapping(stages[n - 1].grammar,
					striplabelre=None if stages[n - 1].dop
						and stages[n - 1].dop != 'doubledop'
						else re.compile(b'@[-0-9]+$'),
					neverblockre=re.compile(stage.neverblockre)
						if stage.neverblockre else None,
					splitprune=stage.splitprune and stages[n - 1].split,
					markorigin=stages[n - 1].markorigin)
				if stage.mode == 'dop-rerank':
					gram.getrulemapping(
							stages[n - 1].grammar, re.compile(br'@[-0-9]+\b'))
				logging.info(msg)
			# write prob models
			np.savez_compressed(  # pylint: disable=no-member
					'%s/%s.probs.npz' % (resultdir, stage.name),
					**{name: mod for name, mod
						in zip(gram.modelnames, gram.models)})
		else:  # not stage.dop
			xgrammar = grammar.treebankgrammar(traintrees, sents,
					extrarules=extrarules)
			logging.info('induced %s based on %d sentences',
				('PCFG' if tbfanout == 1 or stage.split else 'PLCFRS'),
				len(traintrees))
			if stage.split or os.path.exists('%s/pcdist.txt' % resultdir):
				logging.info(grammar.grammarinfo(xgrammar))
			else:
				logging.info(grammar.grammarinfo(xgrammar,
						dump='%s/pcdist.txt' % resultdir))
			if lexmodel and not simplelexsmooth:
				xgrammar = lexicon.smoothlexicon(xgrammar, lexmodel)
			rules, lex = grammar.write_lcfrs_grammar(
					xgrammar, bitpar=stage.mode.startswith('pcfg'))
			gram = Grammar(rules, lex, start=top,
					bitpar=stage.mode.startswith('pcfg'))
			with gzip.open('%s/%s.rules.gz' % (
					resultdir, stage.name), 'wb') as rulesfile:
				rulesfile.write(rules)
			with codecs.getwriter('utf-8')(gzip.open('%s/%s.lex.gz' % (
					resultdir, stage.name), 'wb')) as lexiconfile:
				lexiconfile.write(lex)
			logging.info(gram.testgrammar()[1])
			if n and stage.prune:
				msg = gram.getmapping(stages[n - 1].grammar,
					striplabelre=None,
					neverblockre=re.compile(stage.neverblockre)
						if stage.neverblockre else None,
					splitprune=stage.splitprune and stages[n - 1].split,
					markorigin=stages[n - 1].markorigin)
				logging.info(msg)
		logging.info('wrote grammar to %s/%s.{rules,lex%s}.gz',
				resultdir, stage.name,
				',backtransform' if stage.dop == 'doubledop' else '')

		outside = None
		if stage.estimates in ('SX', 'SXlrgaps'):
			if stage.estimates == 'SX' and tbfanout != 1 and not stage.split:
				raise ValueError('SX estimate requires PCFG.')
			elif stage.mode != 'plcfrs':
				raise ValueError('estimates require parser w/agenda.')
			begin = time.clock()
			logging.info('computing %s estimates', stage.estimates)
			if stage.estimates == 'SX':
				outside = estimates.getpcfgestimates(gram, testmaxwords,
						gram.toid[trees[0].label])
			elif stage.estimates == 'SXlrgaps':
				outside = estimates.getestimates(gram, testmaxwords,
						gram.toid[trees[0].label])
			logging.info('estimates done. cpu time elapsed: %gs',
					time.clock() - begin)
			np.savez_compressed(  # pylint: disable=no-member
					'%s/%s.outside.npz' % (resultdir, stage.name),
					outside=outside)
			logging.info('saved %s estimates', stage.estimates)
		elif stage.estimates:
			raise ValueError('unrecognized value; specify SX or SXlrgaps.')

		stage.update(grammar=gram, backtransform=backtransform,
				outside=outside)