示例#1
0
    def expand_unary(self, i, j):
        """Finish bin (i,j) by building items with unary productions."""
        agenda = [(self.nonterminals.getrank(item.x), totalcost, item) for (totalcost, item) in self.bins[i][j]]
        heapq.heapify(agenda)
        while len(agenda) > 0:
            (trank, _, titem) = heapq.heappop(agenda)
            if log.level >= 3:
                log.write("Applying unary rules to %s\n" % titem)

            # it may happen that the item was defeated or pruned before we got to it
            if titem not in self.bins[i][j].index:
                continue

            for (g,dotchart) in self.grammars:
                if g.filterspan(i,j,self.n):
                    for (estcost, r) in g.unary_rules.get(titem.x, ()):
                        rank = self.nonterminals.getrank(r.lhs)

                        # if the new item isn't of lower priority
                        # than the current trigger item (because of
                        # a unary cycle), adding it could corrupt
                        # the forest
                        if rank <= trank:
                            self.unary_pruned += 1
                            continue

                        (totalcost, (cost, dcost, newstates)) = self.compute_item(r, (titem,), i, j)
                        ded = forest.Deduction((titem,), r, dcost, viterbi=cost)
                        item = forest.Item(r.lhs, i, j, deds=[ded], states=newstates, viterbi=cost)
                        if self.bins[i][j].add(totalcost, item):
                            heapq.heappush(agenda, (rank, totalcost, item))
示例#2
0
 def expand_goal(self, bin1):
     for (cost1, item1) in bin1:
         if item1.x == self.start_nonterminal:
             if log.level >= 3:
                 log.write("Considering: %s\n" % str(item1))
             dcost = sum((m.finaltransition(item1.states[m_i]) for (m_i,m) in enumerate(self.models)), svector.Vector())
             cost = item1.viterbi+self.weights.dot(dcost)
             ded = forest.Deduction((item1,), None, dcost, viterbi=cost)
             self.goal.add(cost, forest.Item(None, 0, self.n, deds=[ded], states=(), viterbi=cost))
示例#3
0
 def add_axiom(self, i, j, r):
     bin = self.bins[i][j]
     (totalcost, (cost, dcost, newstates)) = self.compute_item(r, (), i, j)
     if totalcost < bin.cutoff:
         ded = forest.Deduction((), r, dcost, viterbi=cost)
         item = forest.Item(r.lhs, i, j, deds=[ded], states=newstates, viterbi=cost)
         bin.add(totalcost, item)
     else:
         if log.level >= 4:
             log.write("Prepruning: %s\n" % r)
         self.prepruned += 1
示例#4
0
def make_forest(fieldss):
    nodes = {}
    goal_ids = set()
    for fields in fieldss:
        node_id = fields['hyp']
        if node_id not in nodes:
            nodes[node_id] = forest.Item(sym.fromtag('PHRASE'), 0, 0, [])
        node = nodes[node_id]

        if node_id == 0:
            r = rule.Rule(sym.fromtag('PHRASE'), rule.Phrase([]), rule.Phrase([]))
            node.deds.append(forest.Deduction((), r, svector.Vector()))
        else:
            m = scores_re.match(fields['scores'])
            core_values = [float(x) for x in m.group(1).split(',')]
            dcost = svector.Vector(m.group(2).encode('utf8'))
            for i, x in enumerate(core_values):
                dcost["_core%d" % i] = x

            back = int(fields['back'])
            ant = nodes[back]
            f = fields['src-phrase'].encode('utf8').split()
            e = fields['tgt-phrase'].encode('utf8').split()
            if len(f) != int(fields['cover-end']) - int(fields['cover-start']) + 1:
                sys.stderr.write("warning: French phrase length didn't match covered length\n")

            f = rule.Phrase([sym.setindex(sym.fromtag('PHRASE'), 1)] + f)
            e = rule.Phrase([sym.setindex(sym.fromtag('PHRASE'), 1)] + e)
            r = rule.Rule(sym.fromtag('PHRASE'), f, e)

            ded = forest.Deduction((ant,), r, dcost)
            node.deds.append(ded)

            if int(fields['forward']) < 0: # goal
                goal_ids.add(node_id)

    goal = forest.Item(None, 0, 0, [])
    for node_id in goal_ids:
        goal.deds.append(forest.Deduction((nodes[node_id],), None, svector.Vector()))
    return goal
示例#5
0
    def expand_cell(self, i, j, bintuples):
        """Fill bin (i,j).
        bintuples is a list of (rule, bin, ...) tuples where rule matches
        the input span (i,j) and the bins are the bins of potential antcedents.
        """
        bin = self.bins[i][j]

        for bins in bintuples:
            for (rscore,r) in bins[0]:
                if r.arity() == 1:
                    for (ant1score,ant1) in bins[1]:
                        (totalcost, (cost, dcost, newstates)) = self.compute_item(r, (ant1,), i, j)
                        if totalcost < bin.cutoff:
                            ded = forest.Deduction((ant1,), r, dcost, viterbi=cost)
                            item = forest.Item(r.lhs, i, j, deds=[ded], states=newstates, viterbi=cost)
                            bin.add(totalcost, item)
                        else:
                            if log.level >= 4:
                                log.write("Prepruning: %s (totalcost=%f, cutoff=%f)\n" % (r, totalcost, bin.cutoff))
                            self.prepruned += 1

                elif r.arity() == 2:
                    for (ant1score,ant1) in bins[1]:
                        for (ant2score,ant2) in bins[2]:
                            (totalcost, (cost, dcost, newstates)) = self.compute_item(r, (ant1,ant2), i, j)
                            if totalcost < bin.cutoff:
                                ded = forest.Deduction((ant1,ant2), r, dcost, viterbi=cost)
                                item = forest.Item(r.lhs, i, j, deds=[ded], states=newstates, viterbi=cost)
                                bin.add(totalcost, item)
                            else:
                                if log.level >= 4:
                                    log.write("Prepruning: %s (totalcost=%f, cutoff=%f)\n" % (r, totalcost, bin.cutoff))
                                self.prepruned += 1

                else:
                    log.write("this shouldn't happen")
示例#6
0
文件: decoder.py 项目: jungikim/sbmt
    def expand_cell_cubeprune(self, i, j, cubes):
        # initialize candidate list
        cand = []
        index = collections.defaultdict(int)
        for cube in cubes:
            if len(cube) > 0:
                ranks = cube.first()
                r, ants = cube[ranks]
                (totalcost, info) = self.compute_item(r, ants, i, j,
                                                      cube.latticev)
                cand.append((totalcost, info, cube, ranks))
                index[cube, ranks] += 1
        heapq.heapify(cand)

        bin = self.bins[i][j]

        popped = 0
        while len(cand) > 0 and (self.pop_limit is None
                                 or popped < self.pop_limit):
            # Get the best item on the heap
            (totalcost, (cost, dcost, newstates), cube,
             ranks) = heapq.heappop(cand)
            popped += 1
            r, ants = cube[ranks]

            if totalcost < bin.cutoff:
                # Turn it into a real Item
                ded = forest.Deduction(ants, r, dcost, viterbi=cost)
                item = forest.Item(r.lhs,
                                   i,
                                   j,
                                   deds=[ded],
                                   states=newstates,
                                   viterbi=cost)
                bin.add(totalcost, item)
            else:
                self.prepruned += 1

            # Put item's successors into the heap
            for nextranks in cube.successors(ranks):
                index[cube, nextranks] += 1
                if index[cube, nextranks] == cube.n_predecessors(nextranks):
                    r, ants = cube[nextranks]
                    (totalcost, info) = self.compute_item(r, ants, i, j)
                    heapq.heappush(cand, (totalcost, info, cube, nextranks))

        self.discarded += len(cand)
        self.max_popped = max(self.max_popped, popped)
示例#7
0
def parse(n, xrules, rules):
    """
    n = length of sentence
    xrules = rules with position info, to be assembled into forest
    rules = grammar of rules from all sentences
    N.B. This does not work properly without tight_phrases"""

    chart = [[dict((v, None) for v in nonterminals) for j in xrange(n + 1)]
             for i in xrange(n + 1)]

    for l in xrange(1, n + 1):
        for i in xrange(n - l + 1):
            k = i + l

            for x in nonterminals:
                if x != START:
                    item = forest.Item(x, i, k)
                    for r in xrules.get((x, i, k), ()):
                        ants = []
                        for fi in xrange(len(r.f)):
                            if type(r.fpos[fi]) is tuple:
                                (subi, subj) = r.fpos[fi]
                                ants.append(chart[subi][subj][sym.clearindex(
                                    r.f[fi])])
                        if None not in ants:
                            item.derive(
                                ants, rules[r], r.scores[0]
                            )  # the reason for the lookup in rules is to allow duplicate rules to be freed
                    if len(item.deds) == 0:
                        item = None
                    if item is not None:
                        chart[i][k][x] = item

                else:  # x == START
                    item = forest.Item(x, i, k)

                    # S -> X
                    if i == 0:
                        for y in nonterminals:
                            if y != START and chart[i][k][y] is not None:
                                item.derive([chart[i][k][y]], gluestop[y])

                    # S -> S X
                    for j in xrange(i, k + 1):
                        for y in nonterminals:
                            if chart[i][j][START] is not None and chart[j][k][
                                    y] is not None:
                                item.derive(
                                    [chart[i][j][START], chart[j][k][y]],
                                    glue[y])

                    if len(item.deds) > 0:
                        chart[i][k][x] = item
                        for ded in item.deds:
                            ded.rule.scores = [
                                ded.rule.scores[0] + 1. / len(item.deds)
                            ]

    covered = [False] * n
    spans = []
    # find biggest nonoverlapping spans
    for l in xrange(n, 0, -1):
        for i in xrange(n - l + 1):
            k = i + l

            flag = False
            for v in reversed(nonterminals):
                if chart[i][k][v] is not None:
                    flag = True
            if flag:
                for j in xrange(i, k):
                    # don't let any of the spans overlap
                    if covered[j]:
                        flag = False
                if flag:
                    for j in xrange(i, k):
                        covered[j] = True
                    spans.append((i, k))

    # old buggy version
    #spans = [(0,n)]
    #sys.stderr.write("%s\n" % spans)

    # put in topological order
    itemlists = []
    for (start, stop) in spans:
        items = []
        for l in xrange(1, stop - start + 1):
            for i in xrange(start, stop - l + 1):
                k = i + l
                for v in nonterminals:
                    if chart[i][k][v] is not None:
                        items.append(chart[i][k][v])
        if len(items) > 0:
            itemlists.append(items)

    return itemlists
示例#8
0
    def expand_cell_cubeprune(self, i, j, bintuples):
        """Fill bin (i,j).
        bintuples is a list of (rule, bin, ...) tuples where rule matches
        the input span (i,j) and the bins are the bins of potential antecedents.
        """
        # initialize candidate list
        cand = []
        index = collections.defaultdict(int)
        for bins in bintuples:
            if log.level >= 3:
                log.write("Enqueueing cube %s\n" % ",".join(str(bin) for bin in bins))
            for bin in bins:
                if len(bin) == 0:
                    break
            else:
                r = bins[0][0][1]

                ants = tuple([bin[0][1] for bin in bins[1:]])
                (totalcost, info) = self.compute_item(r, ants, i, j)
                ranks = tuple([0 for bin in bins])
                cand.append((totalcost, info, bins, ranks))
                index[(bins,ranks)] += 1
        heapq.heapify(cand)

        bin = self.bins[i][j]

        popped = 0
        while len(cand) > 0 and (self.pop_limit is None or popped < self.pop_limit):

            (totalcost, (cost, dcost, newstates), bins, ranks) = heapq.heappop(cand)
            popped += 1

            if log.level >= 3:
                log.write("pop %d: totalcost=%s cutoff=%s\n" % (popped, totalcost, bin.cutoff))
            r = bins[0][ranks[0]][1]
            ants = [bins[bj][ranks[bj]][1] for bj in xrange(1,len(bins))]

            if totalcost < bin.cutoff:
                ded = forest.Deduction(ants, r, dcost, viterbi=cost)
                item = forest.Item(r.lhs, i, j, deds=[ded], states=newstates, viterbi=cost)
                bin.add(totalcost, item)
            else:
                if log.level >= 4:
                    log.write("Prepruning: %s (totalcost=%f, cutoff=%f)\n" % (r, totalcost, bin.cutoff))
                self.prepruned += 1
                # but we're still going to visit its successors

                # If the top item fell outside the beam, bet that the rest of the heap
                # will too
                #break

            # Put item's successors into the heap
            for bi in xrange(len(bins)):
                nextranks = list(ranks)
                nextranks[bi] += 1
                nextranks = tuple(nextranks)
                if nextranks[bi] < len(bins[bi]):
                    index[bins, nextranks] += 1

                    n_predecessors = len([rank for rank in nextranks if rank > 0])
                    if index[bins, nextranks] == n_predecessors:

                        if bi == 0:
                            save = r
                            r = bins[bi][nextranks[bi]][1]
                        else:
                            save = ants[bi-1]
                            ants[bi-1] = bins[bi][nextranks[bi]][1]

                        (totalcost, info) = self.compute_item(r, ants, i, j)

                        heapq.heappush(cand, (totalcost, info, bins, nextranks))
                        if log.level >= 3:
                            log.write(" push: totalcost=%s\n" % totalcost)

                        if bi == 0:
                            r = save
                        else:
                            ants[bi-1] = save

        self.discarded += len(cand)
        self.max_popped = max(self.max_popped, popped)