示例#1
0
def initialise_itg(input_fsa, grammars, glue_grammars, options):
    semiring = SumTimes
    itg_grammar = CFG()
    for g in grammars:
        for r in iteritg(g):
            itg_grammar.add(r)
    itg_glue = CFG()
    for g in glue_grammars:
        for r in iteritg(g):
            itg_glue.add(r)

    logging.info('ITG grammar: terminals=%d nonterminals=%d rules=%d',
                 itg_grammar.n_terminals(), itg_grammar.n_nonterminals(),
                 len(itg_grammar))
    parser = ExactNederhof([itg_grammar],
                           input_fsa,
                           glue_grammars=[itg_glue],
                           semiring=semiring)
    forest = parser.do(root=Nonterminal(options.start),
                       goal=Nonterminal(options.goal))
    if not forest:
        logging.info('The ITG grammar cannot parse this input.')
        return None

    tsort = TopSortTable(forest)
    root = tsort.root()
    inside_nodes = compute_values(forest,
                                  tsort,
                                  semiring,
                                  infinity=options.generations)
    d = list(
        sample_k(forest, root, semiring, node_values=inside_nodes,
                 n_samples=1))[0]

    return {r.lhs.label: SumTimes.as_real(r.weight) for r in d}
示例#2
0
 def setUp(self):
     self.cfg = CFG()
     self.cfg.add(
         make_production(
             Nonterminal('S'),
             [Nonterminal('S'), Nonterminal('X')], 0.9))
     self.cfg.add(make_production(Nonterminal('S'), [Nonterminal('X')],
                                  0.1))
     self.cfg.add(make_production(Nonterminal('X'), [Terminal('a')], 1.0))
     self.model = PCFG('Prob')
示例#3
0
def initialise_coarse(input_fsa, grammars, glue_grammars, options):
    semiring = SumTimes
    coarse_grammar = CFG()
    for g in grammars:
        for r in itercoarse(g, semiring):
            coarse_grammar.add(r)
    coarse_glue = CFG()
    for g in glue_grammars:
        for r in itercoarse(g, semiring):
            coarse_glue.add(r)

    logging.info('Coarse grammar: terminals=%d nonterminals=%d rules=%d',
                 coarse_grammar.n_terminals(), coarse_grammar.n_nonterminals(),
                 len(coarse_grammar))
    parser = ExactNederhof([coarse_grammar],
                           input_fsa,
                           glue_grammars=[coarse_glue],
                           semiring=semiring)
    forest = parser.do(root=Nonterminal(options.start),
                       goal=Nonterminal(options.goal))
    if not forest:
        logging.info('The coarse grammar cannot parse this input.')
        return None

    tsort = TopSortTable(forest)
    root = tsort.root()
    inside_nodes = compute_values(forest,
                                  tsort,
                                  semiring,
                                  infinity=options.generations)
    d = list(
        sample_k(forest, root, semiring, node_values=inside_nodes,
                 n_samples=1))[0]

    spans = defaultdict(set)
    for r in d:
        spans[r.lhs.label[0]].add(r.lhs.label[1:])

    return refine_conditions(spans, grammars[0], semiring)
示例#4
0
 def setUp(self):
     self.semiring = semiring.viterbi
     self.cfg = CFG()
     self.cfg.add(get_rule(Nonterminal('S02'),
                                [Nonterminal('S01'), Nonterminal('X12'), Nonterminal('PUNC')],
                                self.semiring.from_real(0.5)))
     self.cfg.add(get_rule(Nonterminal('S01'),
                                [Nonterminal('X01')],
                                self.semiring.from_real(0.1)))
     self.cfg.add(get_rule(Nonterminal('X01'),
                                [Terminal('Hello')],
                                self.semiring.from_real(0.7)))
     self.cfg.add(get_rule(Nonterminal('X01'),
                                [Terminal('hello')],
                                self.semiring.from_real(0.1)))
     self.cfg.add(get_rule(Nonterminal('X12'),
                                [Terminal('World')],
                                self.semiring.from_real(0.6)))
     self.cfg.add(get_rule(Nonterminal('X12'),
                                [Terminal('world')],
                                self.semiring.from_real(0.2)))
     self.cfg.add(get_rule(Nonterminal('PUNC'),
                                [Terminal('!')],
                                self.semiring.from_real(0.1)))
     self.cfg.add(get_rule(Nonterminal('PUNC'),
                                [Terminal('!!!')],
                                self.semiring.from_real(0.3)))
     self.cfg.add(get_rule(Nonterminal('A'),
                                [Terminal('dead')],
                                self.semiring.from_real(0.3)))
     self.cfg.add(get_rule(Nonterminal('B'),
                                [],
                                self.semiring.from_real(0.3)))
     self.forest = cfg_to_hg([self.cfg], [], PCFG('LogProb'))
     self.tsort = AcyclicTopSortTable(self.forest)
     self.omega = HypergraphLookupFunction(self.forest)
示例#5
0
class InferenceTestCase(unittest.TestCase):

    def setUp(self):
        self.semiring = semiring.viterbi
        self.cfg = CFG()
        self.cfg.add(get_rule(Nonterminal('S02'),
                                   [Nonterminal('S01'), Nonterminal('X12'), Nonterminal('PUNC')],
                                   self.semiring.from_real(0.5)))
        self.cfg.add(get_rule(Nonterminal('S01'),
                                   [Nonterminal('X01')],
                                   self.semiring.from_real(0.1)))
        self.cfg.add(get_rule(Nonterminal('X01'),
                                   [Terminal('Hello')],
                                   self.semiring.from_real(0.7)))
        self.cfg.add(get_rule(Nonterminal('X01'),
                                   [Terminal('hello')],
                                   self.semiring.from_real(0.1)))
        self.cfg.add(get_rule(Nonterminal('X12'),
                                   [Terminal('World')],
                                   self.semiring.from_real(0.6)))
        self.cfg.add(get_rule(Nonterminal('X12'),
                                   [Terminal('world')],
                                   self.semiring.from_real(0.2)))
        self.cfg.add(get_rule(Nonterminal('PUNC'),
                                   [Terminal('!')],
                                   self.semiring.from_real(0.1)))
        self.cfg.add(get_rule(Nonterminal('PUNC'),
                                   [Terminal('!!!')],
                                   self.semiring.from_real(0.3)))
        self.cfg.add(get_rule(Nonterminal('A'),
                                   [Terminal('dead')],
                                   self.semiring.from_real(0.3)))
        self.cfg.add(get_rule(Nonterminal('B'),
                                   [],
                                   self.semiring.from_real(0.3)))
        self.forest = cfg_to_hg([self.cfg], [], PCFG('LogProb'))
        self.tsort = AcyclicTopSortTable(self.forest)
        self.omega = HypergraphLookupFunction(self.forest)

    def test_viterbi(self):
        d = viterbi_derivation(self.forest, self.tsort)
        score = self.omega.reduce(self.semiring.times, d)
        self.assertAlmostEqual(self.semiring.as_real(score), 0.006299999999999999)

        ten = [viterbi_derivation(self.forest, self.tsort) for _ in range(10)]
        self.assertEqual(len(set(ten)), 1)
        self.assertAlmostEqual(self.semiring.as_real(self.omega.reduce(self.semiring.times,
                                                                       ten[0])),
                               0.006299999999999999)

        #der = [self.forest.rule(e) for e in d]
        #print('\n', score, self.semiring.as_real(score))
        #for r in der:
        #    print(r)

    def test_sample(self):
        counts = Counter(sample_derivations(self.forest, self.tsort, 1000))
        ranking = counts.most_common()
        top, n = ranking[0]
        score = self.omega.reduce(semiring.inside.times, top)
        self.assertAlmostEqual(self.semiring.as_real(score), 0.006299999999999999)
        self.assertTrue(n/1000 > 0.4)
        """
        print()
        for d, n in counts.most_common():
            score = self.omega.reduce(semiring.inside.times, d)
            der = [self.forest.rule(e) for e in d]
            print(score, semiring.inside.as_real(score), n, n/1000)
            for r in der:
                print(r)
            print()
        """

    def test_ancestral(self):
        sampler = AncestralSampler(self.forest, self.tsort)
        size = 1000
        counts = Counter(sampler.sample(size))
        ranking = counts.most_common()
        top, n = ranking[0]
        #print()
        #print(n/size, sampler.prob(top))
        self.assertEqual(sampler.Z, -4.358310174252031)
        self.assertAlmostEqual(n/size, sampler.prob(top), places=1, msg='Random effects apply - double check.')
示例#6
0
def core(job, args, outdir):
    """
    The main pipeline.

    :param job: a tuple containing an id and an input string
    :param args: the command line options
    :param outdir: where to save results
    """

    # Load main grammars
    logging.info('Loading main grammar...')
    cfg = load_grammar(args.grammar, args.grammarfmt, args.log)
    logging.info('Main grammar: terminals=%d nonterminals=%d productions=%d',
                 cfg.n_terminals(), cfg.n_nonterminals(), len(cfg))

    # Load additional grammars
    main_grammars = [cfg]
    if args.extra_grammar:
        for grammar_path in args.extra_grammar:
            logging.info('Loading additional grammar: %s', grammar_path)
            grammar = load_grammar(grammar_path, args.grammarfmt, args.log)
            logging.info(
                'Additional grammar: terminals=%d nonterminals=%d productions=%d',
                grammar.n_terminals(), grammar.n_nonterminals(), len(grammar))
            main_grammars.append(grammar)

    # Load glue grammars
    glue_grammars = []
    if args.glue_grammar:
        for glue_path in args.glue_grammar:
            logging.info('Loading glue grammar: %s', glue_path)
            glue = load_grammar(glue_path, args.grammarfmt, args.log)
            logging.info(
                'Glue grammar: terminals=%d nonterminals=%d productions=%d',
                glue.n_terminals(), glue.n_nonterminals(), len(glue))
            glue_grammars.append(glue)

    # Make surface lexicon
    surface_lexicon = set()
    for grammar in chain(main_grammars, glue_grammars):
        surface_lexicon.update(t.surface for t in grammar.iterterminals())

    i, input_str = job
    seg = make_sentence(i, input_str, semiring.inside, surface_lexicon,
                        args.unkmodel)

    grammars = list(main_grammars)

    if args.unkmodel == 'passthrough':
        # TODO: what feature value? one? zero?
        grammars.append(
            CFG(
                get_oov_cfg_productions(seg.oovs, args.unklhs, 'LogProb',
                                        semiring.inside.one)))

    logging.info('[%d] Parsing %d words: %s', seg.id, len(seg), seg)

    dt, _ = t_do(seg, grammars, glue_grammars, args, outdir)

    logging.info('[%d] parsing time: %s', seg.id, dt)

    return dt
示例#7
0
class HypergraphTestCase(unittest.TestCase):
    def setUp(self):
        self.cfg = CFG()
        self.cfg.add(
            make_production(
                Nonterminal('S'),
                [Nonterminal('S'), Nonterminal('X')], 0.9))
        self.cfg.add(make_production(Nonterminal('S'), [Nonterminal('X')],
                                     0.1))
        self.cfg.add(make_production(Nonterminal('X'), [Terminal('a')], 1.0))
        self.model = PCFG('Prob')

    def test_construct(self):
        hg = Hypergraph()
        self.assertEqual(hg.n_nodes(), 0)
        self.assertEqual(hg.n_edges(), 0)

    def test_update(self):
        hg = cfg_to_hg([self.cfg], [], self.model)
        self.assertEqual(hg.n_nodes(), 3)
        self.assertEqual(hg.n_edges(), 3)

    def test_nonterminal(self):
        hg = Hypergraph()
        hg.add_node(Nonterminal('S'))
        S = hg.fetch(Nonterminal('S'))
        self.assertNotEqual(S, -1)
        self.assertEqual(hg.label(S), Nonterminal('S'))
        self.assertTrue(hg.is_nonterminal(S))

    def test_terminal(self):
        hg = Hypergraph()
        hg.add_node(Terminal('a'))
        a = hg.fetch(Terminal('a'))
        self.assertNotEqual(a, -1)
        self.assertEqual(hg.label(a), Terminal('a'))
        self.assertTrue(hg.is_terminal(a))

    def test_rule(self):
        hg = Hypergraph()
        rule = make_production(
            Nonterminal('S'),
            [Nonterminal('X'), Terminal('a')], 1.0)
        e = hg.add_xedge(rule.lhs, rule.rhs, self.model(rule), rule, False)
        self.assertEqual(hg.rule(e), rule)
        self.assertEqual(hg.n_nodes(), 3)
        self.assertEqual(hg.n_edges(), 1)

    def test_cfg(self):
        hg = cfg_to_hg([self.cfg], [], self.model)

        S = hg.fetch(Nonterminal('S'))
        X = hg.fetch(Nonterminal('X'))
        a = hg.fetch(Terminal('a'))

        self.assertTrue(hg.is_source(a))
        self.assertFalse(hg.is_source(X))
        self.assertFalse(hg.is_source(S))
        self.assertEqual(len(list(hg.iterbs(S))), 2)
        self.assertEqual(len(list(hg.iterbs(X))), 1)
        self.assertEqual(len(list(hg.iterbs(a))), 0)
        self.assertEqual(hg.label(S), Nonterminal('S'))
        self.assertEqual(hg.label(X), Nonterminal('X'))
        self.assertEqual(hg.label(a), Terminal('a'))

    def test_stars(self):
        hg = Hypergraph()
        r1 = make_production(
            Nonterminal('S'),
            [Nonterminal('S'), Nonterminal('X')], 0.5)
        r2 = make_production(Nonterminal('S'), [Nonterminal('X')], 0.5)
        r3 = make_production(Nonterminal('X'), [Terminal('a')], 0.5)
        e1 = hg.add_xedge(r1.lhs, r1.rhs, self.model(r1), r1, False)
        e2 = hg.add_xedge(r2.lhs, r2.rhs, self.model(r2), r2, False)
        e3 = hg.add_xedge(r3.lhs, r3.rhs, self.model(r3), r3, False)
        S = hg.fetch(Nonterminal('S'))
        X = hg.fetch(Nonterminal('X'))
        a = hg.fetch(Terminal('a'))

        self.assertSequenceEqual(set(hg.iterfs(S)), {e1})
        self.assertSequenceEqual(set(hg.iterfs(X)), {e1, e2})
        self.assertSequenceEqual(set(hg.iterfs(a)), {e3})

        self.assertSequenceEqual(set(hg.iterbs(S)), {e1, e2})
        self.assertSequenceEqual(set(hg.iterbs(X)), {e3})
        self.assertSequenceEqual(set(hg.iterbs(a)), set())

        self.assertSequenceEqual(set(hg.iterdeps(S)), {S, X})
        self.assertSequenceEqual(set(hg.iterdeps(X)), {a})
        self.assertSequenceEqual(set(hg.iterdeps(a)), set())