コード例 #1
0
ファイル: lmstate.py プロジェクト: rupenp/transforest
    def _toforest(self, lmforest, sent, cache, level=0):
        if self in cache:
            return cache[self]

        this_node = Node("", "X [0-0]", 1, Vector(), sent) # no node fv, temporary iden=""
        is_root = level == 0 # don't include final </s> </s>
        
        for prev_states, extra_fv, extra_words  in self.backptrs:
            prev_nodes = [p._toforest(lmforest, sent, cache, level+1) for p in prev_states if p is not None]

            if is_root:
                extra_words = extra_words[:-LMState.lm.order+1]
                
            edge = Hyperedge(this_node, prev_nodes, extra_fv,
                             prev_nodes + LMState.lm.ppqstr(extra_words).split())

            edge.rule = Rule("a(\"a\")", "b", "")
            edge.rule.ruleid = 1

            this_node.add_edge(edge)
        
        cache[self] = this_node
        this_node.iden = str(len(cache)) # post-order id
        lmforest.add_node(this_node)
        if is_root:
            lmforest.root = this_node
        return this_node
コード例 #2
0
ファイル: forest.py プロジェクト: srush/tf-fork
 def recover_oracle(self):
     '''oracle is already stored implicitly in the forest
     returns best_score, best_parseval, best_tree, edgelist
     '''
     edgelist = self.root.get_oracle_edgelist()
     fv = Hyperedge.deriv2fvector(edgelist)
     tr = Hyperedge.deriv2tree(edgelist)
     bleu_p1 = self.bleu.rescore(tr)
     return bleu_p1, tr, fv, edgelist
コード例 #3
0
ファイル: oracle.py プロジェクト: rupenp/transforest
def extract_oracle(forest):
    """oracle is already stored implicitly in the forest
       returns best_score, best_parseval, best_tree, edgelist
    """
    global implicit_oracle
    implicit_oracle = True
    edgelist = get_edgelist(forest.root)
    fv = Hyperedge.deriv2fvector(edgelist)
    tr = Hyperedge.deriv2tree(edgelist)
    return fv[0], Parseval(), tr, edgelist
コード例 #4
0
ファイル: lmstate.py プロジェクト: rupenp/transforest
    def start_state(root):
        ''' None -> <s>^{g-1} . TOP </s>^{g-1} '''

##        LMState.cache = {}

        lmstr = LMState.lm.raw_startsyms()
        lhsstr = lmstr + [root] + LMState.lm.raw_stopsyms()
        
        edge = Hyperedge(None, [root], Vector(), lhsstr)
        edge.lmlhsstr = LMState.lm.startsyms() + [root] + LMState.lm.stopsyms()
        edge.rule = Rule.parse("ROOT(TOP) -> x0 ### ")
        sc = root.bestres[0] if FLAGS.futurecost else 0
        return LMState(None, [DottedRule(edge, dot=len(lmstr))], LMState.lm.startsyms(),
                       step=0, score=sc) # future cost
コード例 #5
0
ファイル: decoder.py プロジェクト: rupenp/transforest
 def _oracle(self, forest):
     sc, parseval, tr, edgelist = forest_oracle(forest, forest.goldtree)
     forest.oracle_tree = tr
     forest.oracle_fvector = Hyperedge.deriv2fvector(edgelist)
     if Decoder.MAX_NUM_BRACKETS < 0:
         forest.oracle_size_ratio = 1
     else:
         forest.oracle_size_ratio = len(tr.all_label_spans()) / Decoder.MAX_NUM_BRACKETS
コード例 #6
0
ファイル: pattern_matching.py プロジェクト: srush/tf-fork
 def add_lex_th(self, lhs, node, isBP):
     ''' add lexical translation rules '''
     ruleset = self.ruleset
     if lhs in ruleset:
         rules = ruleset[lhs]
         
         # add all translation hyperedges
         for rule in rules:
             newrhs = [x[1:-1] for x in rule.rhs]
             tfedge = Hyperedge(node, [], Vector(rule.fields), newrhs)
             tfedge.rule = rule
             node.edges.append(tfedge)
                     
     elif not isBP: # add a default translation hyperedge (monotonic)
         defword = '"%s"' % node.word
         rule = Rule(lhs, [defword], self.deffields)
         tfedge = Hyperedge(node, [], Vector(self.deffields), [defword[1:-1]])
         tfedge.rule = rule
         ruleset.add_rule(rule)
         node.edges.append(tfedge)
コード例 #7
0
ファイル: forest.py プロジェクト: rupenp/transforest
    def load(filename, lower=True, sentid=0):
        '''now return a generator! use load().next() for singleton.
           and read the last line as the gold tree -- TODO: optional!
           and there is an empty line at the end
        '''

        file = getfile(filename)
        line = None
        total_time = 0
        num_sents = 0
        
        while True:            
            
            start_time = time.time()
            ##'\tThe complicated language in ...\n"
            ## tag is often missing
            try:
                if line is None or line == "\n":
                    line = "\n"
                    while line == "\n":
                        line = file.readline()  # emulate seek                    
                tag, sent = line.split("\t")
            except:
                ## no more forests
                break

            num_sents += 1
            
            sent = sent.split()
            cased_sent = sent [:]
            if lower:
                sent = [w.lower() for w in sent]   # mark johnson: lowercase all words
            num = int(file.readline())

            forest = Forest(num, sent, cased_sent, tag)
            forest.labelspans = {}
            forest.short_edges = {}

            delta = num_spu = 0
            for i in xrange(1, num+1):

                ## '2\tDT* [0-1]\t1 ||| 1232=2 ...\n'
                ## node-based features here: wordedges, greedyheavy, word(1), [word(2)], ...
                line = file.readline()
                try:
                    keys, fields = line.split(" ||| ")
                except:
                    keys = line
                    fields = ""


                iden, labelspan, size = keys.split("\t") ## iden can be non-ints
                size = int(size)

                fvector = FVector.parse(fields)
                node = Node(iden, labelspan, size, fvector, sent)
                forest.add_node(node)

                if cache_same:
                    if labelspan in forest.labelspans:
                        node.same = forest.labelspans[labelspan]
                        node.fvector = node.same.fvector
                    else:
                        forest.labelspans[labelspan] = node

                for j in xrange(size):
                    is_oracle = False

                    ## '\t1 ||| 0=8.86276 1=2 3\n'
                    tails, fields = file.readline().strip().split(" ||| ")
                    
                    if tails[0] == "*":  #oracle edge
                        is_oracle = True
                        tails = tails[1:]
                        
                    tails = tails.split() ## could be non-integers
                    tailnodes = []

                    for x in tails:
                        assert x in forest.nodes, "BAD TOPOL ORDER: node #%s is referred to " % x + \
                               "(in a hyperedge of node #%s) before being defined" % iden
                        ## topological ordering
                        tail = forest.nodes[x]
                        tailnodes.append(tail)

                    use_same = False
                    if fields[-1] == "~":
                        use_same = True
                        fields = fields[:-1]
                        
                    fvector = FVector.parse(fields)
                    edge = Hyperedge(node, tailnodes, fvector)

                    if cache_same:

                        short_edge = edge.shorter()
                        if short_edge in forest.short_edges:
                            edge.same = forest.short_edges[short_edge]
                            if use_same:
                                edge.fvector += edge.same.fvector
                        else:
                            forest.short_edges[short_edge] = edge

                    node.add_edge(edge)
                    if is_oracle:
                        node.oracle_edge = edge

                    
                if node.sp_terminal():
                    node.word = node.edges[0].subs[0].word

            ## splitted nodes 12-3-4 => (12, 3, 4)
            tmp = sorted([(map(int, x.iden.split("-")), x) for x in forest.nodeorder])   
            forest.nodeorder = [x for (_, x) in tmp]

            forest.rehash()
            sentid += 1
            
##            print >> logs, "sent #%d %s, %d words, %d nodes, %d edges, loaded in %.2lf secs" \
##                  % (sentid, forest.tag, forest.len, num, forest.num_edges, time.time() - basetime)

            forest.root = node
            node.set_root(True)

            line = file.readline()

            if line is not None and line.strip() != "":
                if line[0] == "(":
                    forest.goldtree = Tree.parse(line.strip(), trunc=True, lower=True)
                    line = file.readline()
            else:
                line = None

            total_time += time.time() - start_time

            if num_sents % 100 == 0:
                print >> logs, "... %d sents loaded (%.2lf secs per sent) ..." \
                      % (num_sents, total_time/num_sents)
                
            yield forest

        Forest.load_time = total_time
        print >> logs, "%d forests loaded in %.2lf secs (avg %.2lf per sent)" \
              % (num_sents, total_time, total_time/num_sents)
コード例 #8
0
ファイル: forest.py プロジェクト: zhangxt/lineardpparser
    def load(filename, lower=False, sentid=0):
        '''now return a generator! use load().next() for singleton.
           and read the last line as the gold tree -- TODO: optional!
           and there is an empty line at the end
        '''

        file = getfile(filename)

        line = None
        total_time = 0
        num_sents = 0

        while True:

            start_time = time.time()
            ##'\tThe complicated language in ...\n"
            ## tag is often missing
            try:
                if line is None or line == "\n":
                    line = "\n"
                    while line == "\n":
                        line = file.readline()  # emulate seek
                tag, sent = line.split("\t")
            except:
                ## no more forests
                break

            num_sents += 1

            sent = sent.split()
            cased_sent = sent[:]
            if lower:
                sent = [w.lower()
                        for w in sent]  # mark johnson: lowercase all words
            num = int(file.readline())

            forest = Forest(num, sent, cased_sent, tag)
            forest.labelspans = {}
            forest.short_edges = {}

            delta = num_spu = 0
            for i in xrange(1, num + 1):

                ## '2\tDT* [0-1]\t1 ||| 1232=2 ...\n'
                ## node-based features here: wordedges, greedyheavy, word(1), [word(2)], ...
                line = file.readline()
                try:
                    keys, fields = line.split(" ||| ")
                except:
                    keys = line
                    fields = ""

                iden, labelspan, size = keys.split(
                    "\t")  ## iden can be non-ints
                size = int(size)

                fvector = FVector(fields)  # TODO: myvector
                node = Node(iden, labelspan, size, fvector, sent)
                forest.add_node(node)

                if cache_same:
                    if labelspan in forest.labelspans:
                        node.same = forest.labelspans[labelspan]
                        node.fvector = node.same.fvector
                    else:
                        forest.labelspans[labelspan] = node

                for j in xrange(size):
                    is_oracle = False

                    ## '\t1 ||| 0=8.86276 1=2 3\n'
                    tails, fields = file.readline().strip().split(" ||| ")

                    if tails[0] == "*":  #oracle edge
                        is_oracle = True
                        tails = tails[1:]

                    tails = tails.split()  ## could be non-integers
                    tailnodes = []

                    for x in tails:
                        assert x in forest.nodes, "BAD TOPOL ORDER: node #%s is referred to " % x + \
                               "(in a hyperedge of node #%s) before being defined" % iden
                        ## topological ordering
                        tail = forest.nodes[x]
                        tailnodes.append(tail)

                    use_same = False
                    if fields[-1] == "~":
                        use_same = True
                        fields = fields[:-1]

                    fvector = FVector(fields)
                    edge = Hyperedge(node, tailnodes, fvector)

                    if cache_same:

                        short_edge = edge.shorter()
                        if short_edge in forest.short_edges:
                            edge.same = forest.short_edges[short_edge]
                            if use_same:
                                edge.fvector += edge.same.fvector
                        else:
                            forest.short_edges[short_edge] = edge

                    node.add_edge(edge)
                    if is_oracle:
                        node.oracle_edge = edge

                if node.sp_terminal():
                    node.word = node.edges[0].subs[0].word

            ## splitted nodes 12-3-4 => (12, 3, 4)
            tmp = sorted([(map(int, x.iden.split("-")), x)
                          for x in forest.nodeorder])
            forest.nodeorder = [x for (_, x) in tmp]

            forest.rehash()
            sentid += 1

            ##          print >> logs, "sent #%d %s, %d words, %d nodes, %d edges, loaded in %.2lf secs" \
            ##                % (sentid, forest.tag, forest.len, num, forest.num_edges, time.time() - basetime)

            forest.root = node
            node.set_root(True)

            line = file.readline()

            if line is not None and line.strip() != "":
                if line[0] == "(":
                    ##                    forest.goldtree = Tree.parse(line.strip(), trunc=True, lower=False)
                    line = file.readline()
            else:
                line = None

            total_time += time.time() - start_time

            if num_sents % 100 == 0:
                print >> logs, "... %d sents loaded (%.2lf secs per sent) ..." \
                      % (num_sents, total_time/num_sents)

            yield forest

        Forest.load_time = total_time
        if num_sents > 0:
            print >> logs, "%d forests loaded in %.2lf secs (avg %.2lf per sent)" \
                  % (num_sents, total_time, total_time/num_sents)
コード例 #9
0
ファイル: oracle.py プロジェクト: rupenp/transforest
def forest_oracle(forest, goldtree, del_puncs=False, prune_results=False):
    """ returns best_score, best_parseval, best_tree, edgelist
           now non-recursive topol-sort-style
    """

    if hasattr(forest.root, "oracle_edge"):
        return extract_oracle(forest)

    ## modifies forest also!!
    if del_puncs:
        idx_mapping, newforest = check_puncs(forest, goldtree.tag_seq)
    else:
        idx_mapping, newforest = lambda x: x, forest

    goldspans = merge_labels(goldtree.all_label_spans(), idx_mapping)
    goldbrs = set(goldspans)  ## including TOP

    for node in newforest:
        if node.is_terminal():
            results = Oracles.unit("(%s %s)" % (node.label, node.word))  ## multiplication unit

        else:
            a, b = (
                (0, 0)
                if node.is_spurious()
                else ((1, 1) if (merge_label((node.label, node.span), idx_mapping) in goldbrs) else (1, 0))
            )

            label = "" if node.is_spurious() else node.label
            results = Oracles()  ## addition unit
            for edge in node.edges:
                edgeres = Oracles.unit()  ## multiplication unit

                for sub in edge.subs:
                    assert hasattr(sub, "oracles"), "%s ; %s ; %s" % (node, sub, edge)
                    edgeres = edgeres * sub.oracles

                ##                nodehead = (a, RES((b, -edge.fvector[0], label, [edge])))   ## originally there is label
                assert 0 in edge.fvector, edge
                nodehead = (a, RES((b, -edge.fvector[0], [edge])))
                results += nodehead * edgeres  ## mul

        if prune_results:
            prune(results)
        node.oracles = results
        if debug:
            print >> logs, node.labelspan(), "\n", results, "----------"

    res = (-1, RES((-1, 0, []))) * newforest.root.oracles  ## scale, remove TOP match

    num_gold = len(goldspans) - 1  ## omit TOP.  N.B. goldspans, not brackets! (NP (NP ...))

    best_parseval = None
    for num_test in res:
        ##        num_matched, score, tree_str, edgelist = res[num_test]
        num_matched, score, edgelist = res[num_test]
        this = Parseval.get_parseval(num_matched, num_test, num_gold)
        if best_parseval is None or this < best_parseval:
            best_parseval = this
            best_score = score
            ##            best_tree = tree_str
            best_edgelist = edgelist

    best_tree = Hyperedge.deriv2tree(best_edgelist)

    ## annotate the forest for oracle so that next-time you can preload oracle
    for edge in best_edgelist:
        edge.head.oracle_edge = edge

    ## very careful here: desymbol !
    ##    return -best_score, best_parseval, Tree.parse(desymbol(best_tree)), best_edgelist
    return -best_score, best_parseval, best_tree, best_edgelist
コード例 #10
0
ファイル: forest.py プロジェクト: srush/tf-fork
    def load(filename, is_tforest=False, lower=False, sentid=0, first=None, lm=None):
        '''now returns a generator! use load().next() for singleton.
           and read the last line as the gold tree -- TODO: optional!
           and there is an empty line at the end
        '''
        if first is None: # N.B.: must be here, not in the param line (after program initializes)
            first = FLAGS.first
            
        file = getfile(filename)
        line = None
        total_time = 0
        num_sents = 0        
        
        while True:            
            
            start_time = time.time()
            ##'\tThe complicated language in ...\n"
            ## tag is often missing
            line = file.readline()  # emulate seek
            if len(line) == 0:
                break
            try:
                ## strict format, no consecutive breaks
#                 if line is None or line == "\n":
#                     line = "\n"
#                     while line == "\n":
#                         line = file.readline()  # emulate seek
                        
                tag, sent = line.split("\t")   # foreign sentence
                
            except:
                ## no more forests
                yield None
                continue

            num_sents += 1

            # caching the original, word-based, true-case sentence
            sent = sent.split() ## no splitting with " "
            cased_sent = sent [:]            
            if lower:
                sent = [w.lower() for w in sent]   # mark johnson: lowercase all words

            #sent = words_to_chars(sent, encode_back=True)  # split to chars

            ## read in references
            refnum = int(file.readline().strip())
            refs = []
            for i in xrange(refnum):
                refs.append(file.readline().strip())

            ## sizes: number of nodes, number of edges (optional)
            num, nedges = map(int, file.readline().split("\t"))   

            forest = Forest(sent, cased_sent, tag, is_tforest)

            forest.tag = tag

            forest.refs = refs
            forest.bleu = Bleu(refs=refs)  ## initial (empty test) bleu; used repeatedly later
            
            forest.labelspans = {}
            forest.short_edges = {}
            forest.rules = {}

            for i in xrange(1, num+1):

                ## '2\tDT* [0-1]\t1 ||| 1232=2 ...\n'
                ## node-based features here: wordedges, greedyheavy, word(1), [word(2)], ...
                line = file.readline()
                try:
                    keys, fields = line.split(" ||| ")
                except:
                    keys = line
                    fields = ""

                iden, labelspan, size = keys.split("\t") ## iden can be non-ints
                size = int(size)

                fvector = Vector(fields) #
##                remove_blacklist(fvector)
                node = Node(iden, labelspan, size, fvector, sent)
                forest.add_node(node)

                if cache_same:
                    if labelspan in forest.labelspans:
                        node.same = forest.labelspans[labelspan]
                        node.fvector = node.same.fvector
                    else:
                        forest.labelspans[labelspan] = node

                for j in xrange(size):
                    is_oracle = False

                    ## '\t1 ||| 0=8.86276 1=2 3\n'
                    ## N.B.: can't just strip! "\t... ||| ... ||| \n" => 2 fields instead of 3
                    tails, rule, fields = file.readline().strip("\t\n").split(" ||| ")

                    if tails != "" and tails[0] == "*":  #oracle edge
                        is_oracle = True
                        tails = tails[1:]

                    tails = tails.split() ## N.B.: don't split by " "!
                    tailnodes = []
                    lhsstr = [] # 123 "thank" 456

                    lmstr = []
                    lmscore = 0
                    lmlhsstr = []
                    
                    for x in tails:
                        if x[0]=='"': # word
                            word = desymbol(x[1:-1])
                            lhsstr.append(word)  ## desymbol here and only here; ump will call quoteattr
                            
                            if lm is not None:
                                this = lm.word2index(word)
                                lmscore += lm.ngram.wordprob(this, lmstr)
                                lmlhsstr.append(this)
                                lmstr += [this,]
                                
                        else: # variable

                            assert x in forest.nodes, "BAD TOPOL ORDER: node #%s is referred to " % x + \
                                         "(in a hyperedge of node #%s) before being defined" % iden
                            tail = forest.nodes[x]
                            tailnodes.append(tail)
                            lhsstr.append(tail)                            

                            if lm is not None:
                                lmstr = []  # "..." "..." x0 "..."
                                lmlhsstr.append(tail) # sync with lhsstr

                    fvector = Vector(fields)
                    if lm is not None:
                        fvector["lm1"] = lmscore # hack

                    edge = Hyperedge(node, tailnodes, fvector, lhsstr)
                    edge.lmlhsstr = lmlhsstr

                    ## new
                    x = rule.split()
                    edge.ruleid = int(x[0])
                    if len(x) > 1:
                        edge.rule = Rule.parse(" ".join(x[1:]) + " ### " + fields)
                        forest.rules[edge.ruleid] = edge.rule #" ".join(x[1:]) #, None)
                    else:
                        edge.rule = forest.rules[edge.ruleid] # cahced rule

                    node.add_edge(edge)
                    if is_oracle:
                        node.oracle_edge = edge
                    
                if node.sp_terminal():
                    node.word = node.edges[0].subs[0].word

            ## splitted nodes 12-3-4 => (12, 3, 4)
            tmp = sorted([(map(int, x.iden.split("-")), x) for x in forest.nodeorder])   
            forest.nodeorder = [x for (_, x) in tmp]

            forest.rehash()
            sentid += 1
            
##            print >> logs, "sent #%d %s, %d words, %d nodes, %d edges, loaded in %.2lf secs" \
##                  % (sentid, forest.tag, forest.len, num, forest.num_edges, time.time() - basetime)

            forest.root = node
            node.set_root(True)
            line = file.readline()

            if line is not None and line.strip() != "":
                if line[0] == "(":
                    forest.goldtree = Tree.parse(line.strip(), trunc=True, lower=False)
                    line = file.readline()
            else:
                line = None

            forest.number_nodes()
            #print forest.root.position_id
          

            total_time += time.time() - start_time

            if num_sents % 100 == 0:
                print >> logs, "... %d sents loaded (%.2lf secs per sent) ..." \
                      % (num_sents, total_time/num_sents)

            forest.subtree() #compute the subtree string for each node

            yield forest

            if first is not None and num_sents >= first:
                break                

        # better check here instead of zero-division exception
        if num_sents == 0:
            print >> logs, "NO FORESTS FOUND!!! (empty input file?)"
            sys.exit(1)            
#            yield None # new: don't halt -- WHY?
        
        Forest.load_time = total_time
        print >> logs, "%d forests loaded in %.2lf secs (avg %.2lf per sent)" \
              % (num_sents, total_time, total_time/(num_sents+0.001))
コード例 #11
0
ファイル: pattern_matching.py プロジェクト: srush/tf-fork
    def add_nonter_th(self, node):
        ''' add translation hyperedges to non-terminal node '''
        ruleset = self.ruleset
        tfedges = []
        for edge in node.edges:
            # enumerate all the possible frags
            basefrags = [("%s(" % node.label, [], 1)]
            lastchild = len(edge.subs) - 1
            if len(edge.subs) >= 5: # this guy has too many children! it cannot be matched!
                deflhs = "%s(%s)" % (node.label, " ".join(sub.label for sub in edge.subs))
                defrhs = edge.subs
                defheight = 1
                basefrags = [(deflhs, defrhs, defheight)]
            else:
                for (id, sub) in enumerate(edge.subs):
                    oldfrags = basefrags
                    # cross-product
                    basefrags = [PatternMatching.combinetwofrags(oldfrag, frag, id, lastchild, self.max_height) \
                                 for oldfrag in oldfrags for frag in sub.frags]

            # for each frag add translation hyperedges
            for extfrag in basefrags:
                (extlhs, extrhs, extheight) = extfrag
                # add frags
                if extheight <= self.max_height - 1:
                    node.frags.append(extfrag)

                if self.filter:
                    self.all_lhss.add(extlhs)
                else:
                    # add translation hyperedges
                    if extlhs in ruleset:
                        for des in extrhs:
                            self.descendants[node.iden].add(des.iden) #unit(set(extrhs))
                        #print self.descendants[node.iden]
                    
                        rules = ruleset[extlhs]
                
                        # add all translation hyperedges
                        for rule in rules:
                            rhsstr = [x[1:-1] if x[0]=='"' \
                                      else extrhs[int(x.split('x')[1])] \
                                      for x in rule.rhs]
                            tfedge = Hyperedge(node, extrhs,\
                                     Vector(rule.fields), rhsstr)
                            tfedge.rule = rule
                            tfedges.append(tfedge)
            
            if (not self.filter) and (len(tfedges) == 0):  # no translation hyperedge
                for des in edge.subs:
                    self.descendants[node.iden].add(des.iden) #unit(set(edge.subs))
                # add a default translation hyperedge
                deflhs = "%s(%s)" % (node.label, " ".join(sub.label for sub in edge.subs))
                defrhs = ["x%d" % i for i, _ in enumerate(edge.subs)] # N.B.: do not supply str
                defrule = Rule(deflhs, defrhs, self.deffields)
                tfedge = Hyperedge(node, edge.subs,\
                                   Vector(self.deffields), edge.subs)
                tfedge.rule = defrule
                ruleset.add_rule(defrule)
                tfedges.append(tfedge)

        if not self.filter:
            # inside replace
            node.edges = tfedges