示例#1
0
def extract(tree, sentence, fclasses, do_sub=True, logprob=None):
    ''' extract all features, return a FVector.
        visit subtrees first, and then extract all features on this level.
        mapping from full-names to ids are in fvector.py

        the non-sub version (just this level is used by BUDecoder (forest decoder).
    '''    

    fvector = FVector()

    tree.annotate(None, do_sub=do_sub)
    
    if do_sub:
        if not tree.is_terminal():
            for sub in tree.subs:
                fvector += extract(sub, sentence, fclasses)

    jobs = []
    for fclass in fclasses:
        if not fclass.is_global() or tree.is_root():
            if use_pp:
                jobs.append(job_server.submit(fclass.extract, (tree, sentence), (quantize,)))
            else:
                fvector += FVector.convert_fullname(fclass.extract(tree, sentence))

    if use_pp:
        for job in jobs:
            fvector += FVector.convert_fullname(job())

    if logprob is not None:
        fvector[0] = logprob
    return fvector
示例#2
0
    def load(filename):
        '''small.13.1      7       50
           #0      5 7           0=-42.9527 ...
           ...
        N.B.  a dummy TAB between sizes and fvector, sorry.
        '''
        
        total_time = 0
        num_sents = 0
        f = getfile(filename)
        while True: #now < len(lines):

            start_time = time.time()
            
            line = f.readline() #lines[now]
            if line == '':
                break

            num_sents += 1
##            print >> logs, line,
            tag, goldsize, k = line.split("\t")
            goldsize = int(goldsize)
            k = int(k)

            kparses = []
            best_pp = None   ## CAREFUL! could be 0
            for i in xrange(k):
                sentid, sizes, _, fv = f.readline().split("\t")
                matchbr, testbr = map(int, sizes.split())
                fvector = FVector.parse(fv)
                pp = Parseval.get_parseval(matchbr, testbr, goldsize)

                curr = [fvector, pp]
                kparses.append(curr)

                if best_pp is None or pp < best_pp:  ## < is better in oracle
                    best_pp = pp
                    oracle = curr
                    oracle_testbr = testbr

            forest = NBestList(k, tag, kparses, goldsize)
            forest.oracle_tree = oracle
            forest.oracle_fvector, forest.oracle_pp = oracle
            
            if Decoder.MAX_NUM_BRACKETS < 0:
                forest.oracle_size_ratio = 1
            else:
                forest.oracle_size_ratio = oracle_testbr / Decoder.MAX_NUM_BRACKETS

            total_time += time.time() - start_time

            yield forest

        NBestList.load_time = total_time
        print >> logs, "%d nbest lists loaded in %.2lf secs (avg %.2lf per sent)" \
              % (num_sents, total_time, total_time/num_sents)
示例#3
0
def reduce_counts(kparses):
    
    overall = FVector()

    for fvector, _ in kparses:
        for f, v in fvector.items():
            overall [f] = overall.get(f, 0) + v

    l = len(kparses)
    for f, v in sorted(overall.items()):
        overall[f] = int(round(float(v)/l - 0.0001))
    ##    print >>logs, f, overall[f], "\t",

    for fvector, _ in kparses:
        for f, v in overall.items():
            if v == 0:
                continue
            x = fvector.get(f, 0) - v
            if math.fabs(x) > 1e-4:
                fvector [f] = x
            else:
                del fvector [f]
示例#4
0
    def load(filename, lower=True, sentid=0):
        '''now return a generator! use load().next() for singleton.
           and read the last line as the gold tree -- TODO: optional!
           and there is an empty line at the end
        '''

        file = getfile(filename)
        line = None
        total_time = 0
        num_sents = 0
        
        while True:            
            
            start_time = time.time()
            ##'\tThe complicated language in ...\n"
            ## tag is often missing
            try:
                if line is None or line == "\n":
                    line = "\n"
                    while line == "\n":
                        line = file.readline()  # emulate seek                    
                tag, sent = line.split("\t")
            except:
                ## no more forests
                break

            num_sents += 1
            
            sent = sent.split()
            cased_sent = sent [:]
            if lower:
                sent = [w.lower() for w in sent]   # mark johnson: lowercase all words
            num = int(file.readline())

            forest = Forest(num, sent, cased_sent, tag)
            forest.labelspans = {}
            forest.short_edges = {}

            delta = num_spu = 0
            for i in xrange(1, num+1):

                ## '2\tDT* [0-1]\t1 ||| 1232=2 ...\n'
                ## node-based features here: wordedges, greedyheavy, word(1), [word(2)], ...
                line = file.readline()
                try:
                    keys, fields = line.split(" ||| ")
                except:
                    keys = line
                    fields = ""


                iden, labelspan, size = keys.split("\t") ## iden can be non-ints
                size = int(size)

                fvector = FVector.parse(fields)
                node = Node(iden, labelspan, size, fvector, sent)
                forest.add_node(node)

                if cache_same:
                    if labelspan in forest.labelspans:
                        node.same = forest.labelspans[labelspan]
                        node.fvector = node.same.fvector
                    else:
                        forest.labelspans[labelspan] = node

                for j in xrange(size):
                    is_oracle = False

                    ## '\t1 ||| 0=8.86276 1=2 3\n'
                    tails, fields = file.readline().strip().split(" ||| ")
                    
                    if tails[0] == "*":  #oracle edge
                        is_oracle = True
                        tails = tails[1:]
                        
                    tails = tails.split() ## could be non-integers
                    tailnodes = []

                    for x in tails:
                        assert x in forest.nodes, "BAD TOPOL ORDER: node #%s is referred to " % x + \
                               "(in a hyperedge of node #%s) before being defined" % iden
                        ## topological ordering
                        tail = forest.nodes[x]
                        tailnodes.append(tail)

                    use_same = False
                    if fields[-1] == "~":
                        use_same = True
                        fields = fields[:-1]
                        
                    fvector = FVector.parse(fields)
                    edge = Hyperedge(node, tailnodes, fvector)

                    if cache_same:

                        short_edge = edge.shorter()
                        if short_edge in forest.short_edges:
                            edge.same = forest.short_edges[short_edge]
                            if use_same:
                                edge.fvector += edge.same.fvector
                        else:
                            forest.short_edges[short_edge] = edge

                    node.add_edge(edge)
                    if is_oracle:
                        node.oracle_edge = edge

                    
                if node.sp_terminal():
                    node.word = node.edges[0].subs[0].word

            ## splitted nodes 12-3-4 => (12, 3, 4)
            tmp = sorted([(map(int, x.iden.split("-")), x) for x in forest.nodeorder])   
            forest.nodeorder = [x for (_, x) in tmp]

            forest.rehash()
            sentid += 1
            
##            print >> logs, "sent #%d %s, %d words, %d nodes, %d edges, loaded in %.2lf secs" \
##                  % (sentid, forest.tag, forest.len, num, forest.num_edges, time.time() - basetime)

            forest.root = node
            node.set_root(True)

            line = file.readline()

            if line is not None and line.strip() != "":
                if line[0] == "(":
                    forest.goldtree = Tree.parse(line.strip(), trunc=True, lower=True)
                    line = file.readline()
            else:
                line = None

            total_time += time.time() - start_time

            if num_sents % 100 == 0:
                print >> logs, "... %d sents loaded (%.2lf secs per sent) ..." \
                      % (num_sents, total_time/num_sents)
                
            yield forest

        Forest.load_time = total_time
        print >> logs, "%d forests loaded in %.2lf secs (avg %.2lf per sent)" \
              % (num_sents, total_time, total_time/num_sents)
示例#5
0
    optparser.add_option("-k", "", dest="N", type=int, help="first N-best only", metavar="N", default=50)
    optparser.add_option("-W", "", dest="weightsfile", help="read weights from", metavar="FILE", default=None)
    optparser.add_option("-w", "", dest="weights", help="read weights from str", metavar="W", default=None)
    optparser.add_option("-O", "--oracle", dest="oracle", action="store_true", \
                         help="compute nbest oracles (instead of decoding)", default=False)
    optparser.add_option("-R", "--reduce", dest="reduce", action="store_true", \
                         help="reduce absolute feature counts to relative", default=False)
    optparser.add_option("-v", "--verbose", dest="verbose", action="store_true", \
                         help="print result for each sentence", default=False)
    optparser.add_option("-t", "--trees", dest="nbesttreesfile", help="read nbest trees", \
                         metavar="FILE", default=None)

    (opts, args) = optparser.parse_args()

    if opts.weights:
        weights = FVector.parse(opts.weights)
    elif opts.weightsfile:
        weights = FVector.readweights(opts.weightsfile)
    else:
        weights = FVector({0:1})

    if opts.nbesttreesfile is not None:
        from readkbest import NBestForest
        nbesttrees = NBestForest.load(opts.nbesttreesfile, read_gold=False)        

    decoder = NBestDecoder(opts.N)
    
    all_pp = Parseval()
    decode_time, parseval_time = 0, 0