示例#1
0
    def besteval(self, reflinks, cache=None):
        ''' sub-oracle '''

        if cache is None:
            cache = {}

        sig = (self.step, self.rank)
        if sig in cache:
            return cache[sig]        
            
        if self.action == 0:
            s = DepVal()
            t = self.tree()
        else:
            h = self.headidx
            s = -1
            t = None
            for ((left, right), action, _) in self.backptrs:
                m = left.headidx if action == 1 else right.headidx
                this = 1 if (m in reflinks and reflinks[m] == h) else 0
                thistot = 1 if (m in reflinks) else 0
                
                lefteval, lefttree = left.besteval(reflinks, cache)
                righteval, righttree = right.besteval(reflinks, cache)

                thiseval = DepVal(yes=this, tot=thistot) + lefteval + righteval
                
                if thiseval > s:
                    s = thiseval
                    t = DepTree.combine(lefttree, righttree, action)
                    
        cache[sig] = s, t
        return s, t
示例#2
0
def main():

    if FLAGS.sim is not None:
        sequencefile = open(FLAGS.sim)

    if FLAGS.weights is None:
        if not FLAGS.sim:
            print >> logs, "Error: must specify a weights file" + str(FLAGS)
            sys.exit(1)
        else:
            model = None  # can simulate w/o a model
    else:
        model = Model(FLAGS.weights)  #FLAGS.model, FLAGS.weights)

    print >> logs, "knowns", len(model.knowns)

    DepTree.model = model

    global parser
    parser = Parser(FLAGS.newstate, model, b=FLAGS.beam)

    global totalscore, totalstates, totaledges, totaluniq, totaltime, totalprec
    totalscore = 0
    totalstates = 0
    totaluniq = 0
    totaledges = 0
    totaltime = 0

    totalprec = DepVal()
    totaloracle = DepVal()

    for i, line in enumerate(sys.stdin, 1):
        print work(line, i)

    if FLAGS.newstate:
        print >> logs, "feature constructions: tot= %d shared= %d (%.2f%%)" % (
            State.tot, State.shared, State.shared / State.tot * 100)

    print >> logs, "beam= {b}, avg {a} sents,\tmodelcost= {c:.2f}\tprec= {p:.2%}" \
          "\tstates= {ns:.1f} (uniq {uq:.1f})\tedges= {ne:.1f}\ttime= {t:.4f}\n{d:s}" \
          .format(b=FLAGS.b, a=i, c=totalscore/i, p=totalprec.prec(),
                  ns=totalstates/i, uq=totaluniq/i, ne=totaledges/i, t=totaltime/i,
                  d=totalprec.details())

    if FLAGS.uniqstat:
        for i in sorted(uniqstats):
            print >> logs, "%d\t%.1lf\t%d\t%d" % \
                  (i, sum(uniqstats[i]) / len(uniqstats[i]), \
                   min(uniqstats[i]), max(uniqstats[i]))

    if FLAGS.oracle:
        print >> logs, "oracle= ", totaloracle
示例#3
0
    def forestoracle(self, reftree):

        reflinks = reftree.links()
        oracle = 0

        for i, state in enumerate(self.beams[-1]):
            h = state.headidx
            root = 1 if (h in reflinks and reflinks[h] == -1) else 0
            rooteval = DepVal(yes=root, tot=1)  # root link

            subeval, tree = state.besteval(reflinks)

            if rooteval + subeval > oracle:
                #                print i, rooteval, subeval
                oracle = rooteval + subeval
                oracletree = tree

        print >> logs, "oracle=", oracle, reftree.evaluate(oracletree)
        #        print "oracle=", oracletree
        return oracle, oracletree
示例#4
0
def worker_process(data):

    totalscore = 0
    totalstates = 0
    totaluniq = 0
    totaledges = 0
    totaltime = 0
    totalprec = DepVal()

    for line, i in data:
        parse_result, (score, nstates, nedges, nuniq, dtime, prec) = work(
            line, i, parser
        )  # add parser here improve the performance, maybe namespace searching issue?
        totalscore += score
        totalstates += nstates
        totaledges += nedges
        totaluniq += nuniq
        totaltime += dtime
        totalprec += prec
    return totalscore, totalstates, totaledges, totaluniq, totaltime, totalprec
示例#5
0
import gflags as flags
FLAGS = flags.FLAGS

if __name__ == "__main__":

    # TODO: automatically figure out maxk
    flags.DEFINE_integer("maxk", 128, "maxk")

    try:
        file = open(sys.argv[1])
    except:
        print >> logs, "Usage: cat <kbest-lists> | ./kbest_oracles.py <goldtrees>"
        sys.exit(1)

    tot = defaultdict(lambda: DepVal())

    for sid, reftree in enumerate(DepTree.load(sys.argv[1]), 1):

        sentid, k = sys.stdin.readline().split()
        k = int(k)

        best = -1
        besttree = None
        for i in range(1, k + 1):
            score, tree = sys.stdin.readline().split("\t")
            score = float(score)
            tree = DepTree.parse(tree)

            ev = reftree.evaluate(tree)
            if ev > best:
示例#6
0
def main():

    if FLAGS.sim is not None:
        sequencefile = open(FLAGS.sim)

    parser = Parser(model, b=FLAGS.beam)

    print >> logs, "memory usage before parsing: ", human(memory(start_mem))

    totalscore = 0
    totalstates = 0
    totaluniq = 0
    totaledges = 0
    totaltime = 0

    totalprec = DepVal()    
    totaloracle = DepVal()

    print >> logs, "gc.collect unreachable: %d" % gc.collect()

    if FLAGS.manual_gc:
        gc.disable()
    
    i = 0
    gctime = 0
    for i, line in enumerate(shell_input(), 1):

        if FLAGS.manual_gc and i % FLAGS.gc == 0:
            print >> logs, "garbage collection...",
            tt = time.time()
            print >> logs, "gc.collect unreachable: %d" % gc.collect()
            tt = time.time() - tt
            print >> logs, "took %.1f seconds" % tt
            gctime += tt

        line = line.strip()
        if line[0]=="(":
            # input is a gold tree (so that we can evaluate)
            reftree = DepTree.parse(line)
            sentence = DepTree.sent # assigned in DepTree.parse()            
        else:
            # input is word/tag list
            reftree = None
            sentence = [tuple(x.rsplit("/", 1)) for x in line.split()]   # split by default returns list            
            DepTree.sent = sentence

        if FLAGS.debuglevel >= 1:
            print >> logs, sentence
            print >> logs, reftree

        mytime.zero()
        
        if FLAGS.sim is not None: # simulation, not parsing
            actions = map(int, sequencefile.readline().split())
            goal, feats = parser.simulate(actions, sentence) #if model is None score=0
            print >> logs, feats
            score, tree = goal.score, goal.top()
            (nstates, nedges, nuniq) = (0, 0, 0)
        else:
            # real parsing
            if True: #FLAGS.earlystop:
                refseq = reftree.seq() if reftree is not None else None
                tree, myseq, score, _ = parser.try_parse(sentence, refseq, update=False)
                if FLAGS.early:
                    print >> logs, "ref=", refseq
                    print >> logs, "myt=", myseq

                    refseq = refseq[:len(myseq)] # truncate
                    _, reffeats = parser.simulate(refseq, sentence) 
                    _, myfeats = parser.simulate(myseq, sentence)
                    print >> logs, "+feats", reffeats
                    print >> logs, "-feats", myfeats
                    
                nstates, nedges, nuniq = parser.stats()
            else:
                goal = parser.parse(sentence)
                nstates, nedges, nuniq = parser.stats()

##        score, tree = goal.score, goal.top()
#        score, tree = mytree
            
        dtime = mytime.period()

        if not FLAGS.early and not FLAGS.profile:
            if FLAGS.forest:
                parser.dumpforest(i)
            elif FLAGS.output:
                if not FLAGS.kbest:
                    print tree
                else:
                    stuff = parser.beams[-1][:FLAGS.kbest]
                    print "sent.%d\t%d" % (i, len(stuff))
                    for state in stuff:
                        print "%.2f\t%s" % (state.score, state.tree())
                    print
                    
            if FLAGS.oracle:
                oracle, oracletree = parser.forestoracle(reftree)
                totaloracle += oracle

        prec = DepTree.compare(tree, reftree) # OK if either is None

        searched = sum(x.derivation_count() for x in parser.beams[-1]) if FLAGS.forest else 0
        print >> logs, "sent {i:-4} (len {l}):\tmodelcost= {c:.2f}\tprec= {p:.2%}"\
              "\tstates= {ns} (uniq {uq})\tedges= {ne}\ttime= {t:.3f}\tsearched= {sp}" \
              .format(i=i, l=len(sentence), c=score, p=prec.prec(), \
                      ns=nstates, uq=nuniq, ne=nedges, t=dtime, sp=searched)
        if FLAGS.seq:
            actions = goal.all_actions()
            print >> logs, " ".join(actions)
            check = simulate(actions, sentence, model) #if model is None score=0
            checkscore = check.score
            checktree = check.top()
            print >> logs, checktree
            checkprec = checktree.evaluate(reftree)
            print >> logs, "verify: tree:%s\tscore:%s\tprec:%s" % (tree == checktree, score == checkscore, prec == checkprec)
            print >> logs, "sentence %-4d (len %d): modelcost= %.2lf\tprec= %.2lf\tstates= %d (uniq %d)\tedges= %d\ttime= %.3lf" % \
                  (i, len(sentence), checkscore, checkprec.prec100(), nstates, nuniq, nedges, dtime)

        totalscore += score
        totalstates += nstates
        totaledges += nedges
        totaluniq += nuniq
        totaltime += dtime

        totalprec += prec

    if i == 0:
        print >> logs, "Error: empty input."
        sys.exit(1)

    if FLAGS.featscache:
        print >> logs, "feature constructions: tot= %d shared= %d (%.2f%%)" % (State.tot, State.shared, State.shared / State.tot * 100)

    print >> logs, "beam= {b}, avg {a} sents,\tmodelcost= {c:.2f}\tprec= {p:.2%}" \
          "\tstates= {ns:.1f} (uniq {uq:.1f})\tedges= {ne:.1f}\ttime= {t:.4f}\n{d:s}" \
          .format(b=FLAGS.b, a=i, c=totalscore/i, p=totalprec.prec(), 
                  ns=totalstates/i, uq=totaluniq/i, ne=totaledges/i, t=totaltime/i, 
                  d=totalprec.details())
    
    if FLAGS.uniqstat:
        for i in sorted(uniqstats):
            print >> logs, "%d\t%.1lf\t%d\t%d" % \
                  (i, sum(uniqstats[i]) / len(uniqstats[i]), \
                   min(uniqstats[i]), max(uniqstats[i]))

    if FLAGS.oracle:
        print >> logs, "oracle= ", totaloracle

    if FLAGS.manual_gc:
        print >> logs, "garbage collection took %.1f seconds" % gctime

    print >> logs, "memory usage after parsing: ", human(memory(start_mem))
    if FLAGS.mydouble:
        from mydouble import counts
        print >> logs, "mydouble usage and freed: %d %d" % counts()
示例#7
0
def main():

    if FLAGS.sim is not None:
        sequencefile = open(FLAGS.sim)

    if FLAGS.weights is None:
        if not FLAGS.sim:
            print >> logs, "Error: must specify a weights file" + str(FLAGS)
            sys.exit(1)
        else:
            model = None  # can simulate w/o a model
    else:
        model = Model(FLAGS.weights)  #FLAGS.model, FLAGS.weights)

    print >> logs, "knowns", len(model.knowns)

    # global totalscore, totalstates, totaledges, totaluniq, totaltime, totalprec

    totalscore = 0
    totalstates = 0
    totaluniq = 0
    totaledges = 0
    totaltime = 0
    totalprec = DepVal()

    totaloracle = DepVal()

    global parser
    parser = Parser(FLAGS.newstate, model, b=FLAGS.beam)

    ncpus = cpu_count()

    datas = [[] for i in range(ncpus)]
    for i, line in enumerate(sys.stdin, 1):
        datas[i % ncpus].append((line, i))

    print >> logs, "using %d CPUs" % ncpus
    pool = Pool(processes=ncpus)

    l = pool.imap(worker_process, datas)
    pool.close()
    pool.join()
    #exit()
    for x in l:
        (score, nstates, nedges, nuniq, dtime, prec) = x
        totalscore += score
        totalstates += nstates
        totaledges += nedges
        totaluniq += nuniq
        totaltime += dtime
        totalprec += prec

    #if FLAGS.newstate:
    #print >> logs, "feature constructions: tot= %d shared= %d (%.2f%%)" % (State.tot, State.shared, State.shared / State.tot * 100)

    print >> logs, "beam= {b}, avg {a} sents,\tmodelcost= {c:.2f}\tprec= {p:.2%}" \
          "\tstates= {ns:.1f} (uniq {uq:.1f})\tedges= {ne:.1f}\ttime= {t:.4f}\n{d:s}" \
          .format(b=FLAGS.b, a=i, c=totalscore/i, p=totalprec.prec(),
                  ns=totalstates/i, uq=totaluniq/i, ne=totaledges/i, t=totaltime/i,
                  d=totalprec.details())

    if FLAGS.uniqstat:
        for i in sorted(uniqstats):
            print >> logs, "%d\t%.1lf\t%d\t%d" % \
                  (i, sum(uniqstats[i]) / len(uniqstats[i]), \
                   min(uniqstats[i]), max(uniqstats[i]))

    if FLAGS.oracle:
        print >> logs, "oracle= ", totaloracle
示例#8
0
FLAGS = flags.FLAGS

from deptree import DepTree, DepVal

if __name__ == "__main__":

    flags.DEFINE_boolean("senteval",
                         False,
                         "sentence by sentence output",
                         short_name="v")
    argv = FLAGS(sys.argv)
    if len(argv) != 3:
        print >> logs, "Usage: %s <file1> <file2>" % argv[0] + str(FLAGS)
        sys.exit(1)

    totalprec = DepVal()

    filea, fileb = open(argv[1]), open(argv[2])

    for i, (linea, lineb) in enumerate(zip(filea, fileb), 1):

        treea, treeb = map(DepTree.parse, (linea, lineb))
        prec = treea.evaluate(treeb)

        if FLAGS.senteval:
            print "sent {i:-4} (len {l}):\tprec= {p:.2%}".format(i=i,
                                                                 l=len(treea),
                                                                 p=prec.prec())

        totalprec += prec