Esempio n. 1
0
    def one_pass_on_train(self, data):

        num_updates = 0
        early_updates = 0
        ##        for i, example in enumerate(self.decoder.load(self.trainfile), 1):
        for i, line in enumerate(data, 1):
            example = DepTree.parse(line)
            #print >> logs, "... example %d (len %d)..." % (i, len(example)),
            self.c += 1

            similarity, deltafeats = self.decoder.decode(example,
                                                         early_stop=True)

            if similarity < 1 - 1e-8:  #update

                num_updates += 1

                #print >> logs, "sim={0}, |delta|={1}".format(similarity, len(deltafeats))
                if FLAGS.debuglevel >= 2:
                    print >> logs, "deltafv=", deltafeats

                self.weights += deltafeats
                if FLAGS.avg:
                    self.allweights += deltafeats * self.c

                if similarity < 1e-8:  # early-update happened
                    early_updates += 1

            else:
                #print >> logs, "PASSED! :)"
                None

        return num_updates, early_updates
Esempio n. 2
0
    def load(self, lines, shuffle=False):

        if shuffle:
            print >> logs, "Shuffling training set..."
            random.shuffle(lines)


##        for i, line in enumerate(open(filename), 1):
        for i, line in enumerate(lines, 1):
            yield DepTree.parse(line)
Esempio n. 3
0
    def eval_worker(self, sentences):

        sub = decoder.evalclass()  # global variable trainer

        for i, example in enumerate(
                sentences):  # trainer.decoder.load(sentences), 1):
            tree = DepTree.parse(
                example
            )  # have to parse here, not outside, because Tree.words is static
            similarity, _ = trainer.decoder.decode(tree)
            sub += similarity  # do it inside instead of outside

        return similarity
Esempio n. 4
0
def work(line, i):
    global totalscore, totalstates, totaledges, totaluniq, totaltime, totalprec

    line = line.strip()
    if line[0] == "(":
        # input is a gold tree (so that we can evaluate)
        reftree = DepTree.parse(line)
        sentence = DepTree.sent  # assigned in DepTree.parse()
    else:
        # input is word/tag list
        reftree = None
        sentence = [tuple(x.rsplit("/", 1))
                    for x in line.split()]  # split by default returns list
        DepTree.sent = sentence

    if FLAGS.debuglevel >= 1:
        print >> logs, sentence
        print >> logs, reftree

    mytime.zero()

    if FLAGS.sim is not None:  # simulation, not parsing
        actions = map(int, sequencefile.readline().split())
        goal, feats = parser.simulate(actions,
                                      sentence)  #if model is None score=0
        print >> logs, feats
        score, tree = goal.score, goal.top()
        (nstates, nedges, nuniq) = (0, 0, 0)
    else:
        # real parsing
        if True:  #FLAGS.earlystop:
            refseq = reftree.seq() if reftree is not None else None
            tree, myseq, score = parser.try_parse(sentence,
                                                  refseq,
                                                  early_stop=FLAGS.earlystop)
            if FLAGS.earlystop:
                print >> logs, "ref=", refseq
                print >> logs, "myt=", myseq

                refseq = refseq[:len(myseq)]  # truncate
                _, reffeats = parser.simulate(refseq, sentence)
                _, myfeats = parser.simulate(myseq, sentence)
                print >> logs, "feat diff=", Model.trim(reffeats - myfeats)

            nstates, nedges, nuniq = parser.stats()
        else:
            goal = parser.parse(sentence)
            nstates, nedges, nuniq = parser.stats()

    dtime = mytime.period()

    if not FLAGS.earlystop and not FLAGS.profile:
        if FLAGS.forest:
            parser.dumpforest(i)
        else:
            if not FLAGS.kbest:
                toprint = str(tree)
            else:
                stuff = parser.beams[-1][:FLAGS.kbest]
                toprint = "sent.%d\t%d" % (i, len(stuff))
                toprint += [
                    "%.2f\t%s" % (state.score, state.tree()) for state in stuff
                ]

        if FLAGS.oracle:
            oracle, oracletree = parser.forestoracle(reftree)
            totaloracle += oracle

    prec = DepTree.compare(tree, reftree)  # OK if either is None

    searched = sum(x.derivation_count()
                   for x in parser.beams[-1]) if FLAGS.forest else 0
    print >> logs, "sent {i:-4} (len {l}):\tmodelcost= {c:.2f}\tprec= {p:.2%}"\
          "\tstates= {ns} (uniq {uq})\tedges= {ne}\ttime= {t:.3f}\tsearched= {sp}" \
          .format(i=i, l=len(sentence), c=score, p=prec.prec(), \
                  ns=nstates, uq=nuniq, ne=nedges, t=dtime, sp=searched)
    if FLAGS.seq:
        actions = goal.all_actions()
        print >> logs, " ".join(actions)
        check = simulate(actions, sentence, model)  #if model is None score=0
        checkscore = check.score
        checktree = check.top()
        print >> logs, checktree
        checkprec = checktree.evaluate(reftree)
        print >> logs, "verify: tree:%s\tscore:%s\tprec:%s" % \
              (tree == checktree, score == checkscore, prec == checkprec)
        print >> logs, "sentence %-4d (len %d): modelcost= %.2lf\tprec= %.2lf\tstates= %d (uniq %d)\tedges= %d\ttime= %.3lf" % \
              (i, len(sentence), checkscore, checkprec.prec100(), nstates, nuniq, nedges, dtime)

    totalscore += score
    totalstates += nstates
    totaledges += nedges
    totaluniq += nuniq
    totaltime += dtime

    totalprec += prec

    return toprint
Esempio n. 5
0
    def load(self, lines):

        for i, line in enumerate(lines, 1):
            yield DepTree.parse(line)
Esempio n. 6
0
        print >> logs, "Usage: cat <kbest-lists> | ./kbest_oracles.py <goldtrees>"
        sys.exit(1)

    tot = defaultdict(lambda: DepVal())

    for sid, reftree in enumerate(DepTree.load(sys.argv[1]), 1):

        sentid, k = sys.stdin.readline().split()
        k = int(k)

        best = -1
        besttree = None
        for i in range(1, k + 1):
            score, tree = sys.stdin.readline().split("\t")
            score = float(score)
            tree = DepTree.parse(tree)

            ev = reftree.evaluate(tree)
            if ev > best:
                best = ev
                besttree = tree

            tot[i] += best

        for i in range(k + 1, FLAGS.maxk + 1):  # if short list
            tot[i] += best

        sys.stdin.readline()
        print "%s\t%s\t%s" % (sid, best, besttree)
        sys.stdout.flush()
Esempio n. 7
0
def main():

    if FLAGS.sim is not None:
        sequencefile = open(FLAGS.sim)

    parser = Parser(model, b=FLAGS.beam)

    print >> logs, "memory usage before parsing: ", human(memory(start_mem))

    totalscore = 0
    totalstates = 0
    totaluniq = 0
    totaledges = 0
    totaltime = 0

    totalprec = DepVal()    
    totaloracle = DepVal()

    print >> logs, "gc.collect unreachable: %d" % gc.collect()

    if FLAGS.manual_gc:
        gc.disable()
    
    i = 0
    gctime = 0
    for i, line in enumerate(shell_input(), 1):

        if FLAGS.manual_gc and i % FLAGS.gc == 0:
            print >> logs, "garbage collection...",
            tt = time.time()
            print >> logs, "gc.collect unreachable: %d" % gc.collect()
            tt = time.time() - tt
            print >> logs, "took %.1f seconds" % tt
            gctime += tt

        line = line.strip()
        if line[0]=="(":
            # input is a gold tree (so that we can evaluate)
            reftree = DepTree.parse(line)
            sentence = DepTree.sent # assigned in DepTree.parse()            
        else:
            # input is word/tag list
            reftree = None
            sentence = [tuple(x.rsplit("/", 1)) for x in line.split()]   # split by default returns list            
            DepTree.sent = sentence

        if FLAGS.debuglevel >= 1:
            print >> logs, sentence
            print >> logs, reftree

        mytime.zero()
        
        if FLAGS.sim is not None: # simulation, not parsing
            actions = map(int, sequencefile.readline().split())
            goal, feats = parser.simulate(actions, sentence) #if model is None score=0
            print >> logs, feats
            score, tree = goal.score, goal.top()
            (nstates, nedges, nuniq) = (0, 0, 0)
        else:
            # real parsing
            if True: #FLAGS.earlystop:
                refseq = reftree.seq() if reftree is not None else None
                tree, myseq, score, _ = parser.try_parse(sentence, refseq, update=False)
                if FLAGS.early:
                    print >> logs, "ref=", refseq
                    print >> logs, "myt=", myseq

                    refseq = refseq[:len(myseq)] # truncate
                    _, reffeats = parser.simulate(refseq, sentence) 
                    _, myfeats = parser.simulate(myseq, sentence)
                    print >> logs, "+feats", reffeats
                    print >> logs, "-feats", myfeats
                    
                nstates, nedges, nuniq = parser.stats()
            else:
                goal = parser.parse(sentence)
                nstates, nedges, nuniq = parser.stats()

##        score, tree = goal.score, goal.top()
#        score, tree = mytree
            
        dtime = mytime.period()

        if not FLAGS.early and not FLAGS.profile:
            if FLAGS.forest:
                parser.dumpforest(i)
            elif FLAGS.output:
                if not FLAGS.kbest:
                    print tree
                else:
                    stuff = parser.beams[-1][:FLAGS.kbest]
                    print "sent.%d\t%d" % (i, len(stuff))
                    for state in stuff:
                        print "%.2f\t%s" % (state.score, state.tree())
                    print
                    
            if FLAGS.oracle:
                oracle, oracletree = parser.forestoracle(reftree)
                totaloracle += oracle

        prec = DepTree.compare(tree, reftree) # OK if either is None

        searched = sum(x.derivation_count() for x in parser.beams[-1]) if FLAGS.forest else 0
        print >> logs, "sent {i:-4} (len {l}):\tmodelcost= {c:.2f}\tprec= {p:.2%}"\
              "\tstates= {ns} (uniq {uq})\tedges= {ne}\ttime= {t:.3f}\tsearched= {sp}" \
              .format(i=i, l=len(sentence), c=score, p=prec.prec(), \
                      ns=nstates, uq=nuniq, ne=nedges, t=dtime, sp=searched)
        if FLAGS.seq:
            actions = goal.all_actions()
            print >> logs, " ".join(actions)
            check = simulate(actions, sentence, model) #if model is None score=0
            checkscore = check.score
            checktree = check.top()
            print >> logs, checktree
            checkprec = checktree.evaluate(reftree)
            print >> logs, "verify: tree:%s\tscore:%s\tprec:%s" % (tree == checktree, score == checkscore, prec == checkprec)
            print >> logs, "sentence %-4d (len %d): modelcost= %.2lf\tprec= %.2lf\tstates= %d (uniq %d)\tedges= %d\ttime= %.3lf" % \
                  (i, len(sentence), checkscore, checkprec.prec100(), nstates, nuniq, nedges, dtime)

        totalscore += score
        totalstates += nstates
        totaledges += nedges
        totaluniq += nuniq
        totaltime += dtime

        totalprec += prec

    if i == 0:
        print >> logs, "Error: empty input."
        sys.exit(1)

    if FLAGS.featscache:
        print >> logs, "feature constructions: tot= %d shared= %d (%.2f%%)" % (State.tot, State.shared, State.shared / State.tot * 100)

    print >> logs, "beam= {b}, avg {a} sents,\tmodelcost= {c:.2f}\tprec= {p:.2%}" \
          "\tstates= {ns:.1f} (uniq {uq:.1f})\tedges= {ne:.1f}\ttime= {t:.4f}\n{d:s}" \
          .format(b=FLAGS.b, a=i, c=totalscore/i, p=totalprec.prec(), 
                  ns=totalstates/i, uq=totaluniq/i, ne=totaledges/i, t=totaltime/i, 
                  d=totalprec.details())
    
    if FLAGS.uniqstat:
        for i in sorted(uniqstats):
            print >> logs, "%d\t%.1lf\t%d\t%d" % \
                  (i, sum(uniqstats[i]) / len(uniqstats[i]), \
                   min(uniqstats[i]), max(uniqstats[i]))

    if FLAGS.oracle:
        print >> logs, "oracle= ", totaloracle

    if FLAGS.manual_gc:
        print >> logs, "garbage collection took %.1f seconds" % gctime

    print >> logs, "memory usage after parsing: ", human(memory(start_mem))
    if FLAGS.mydouble:
        from mydouble import counts
        print >> logs, "mydouble usage and freed: %d %d" % counts()
Esempio n. 8
0
    def load(self, filename):

        for i, line in enumerate(open(filename), 1):
            yield DepTree.parse(line)
Esempio n. 9
0
if __name__ == "__main__":

    flags.DEFINE_integer("cutoff", 1, "cut off freq")
    flags.DEFINE_integer("cutoff_char", 1, "cut off freq for chars")

    argv = FLAGS(sys.argv)

    word_dict = defaultdict(lambda: [0, defaultdict(int)])
    unktags = defaultdict(int)

    for line in sys.stdin:
        line = line.strip()
        if line == "":
            continue

        reftree = DepTree.parse(line)
        words = DepTree.words
        tags = reftree.tagseq()
        for w, t in zip(words, tags):
            word_dict[w][0] += 1
            word_dict[w][1][t] += 1

    unkcnt = 0
    for w in sorted(word_dict):
        freq, tags = word_dict[w]
        if freq <= FLAGS.cutoff:
            for t in tags:
                unktags[t] += tags[t]
                unkcnt += freq
        else:
            print "%s\t%d\t%s" % (w, freq, sorttags(tags))