示例#1
0
def run(sentset, labelset, postagset, all_feats, info, weights, testdata, ad):
    tsents, tgoldtagseqs, tpostagseqs, tinfo = testdata

    weights_avg = init_weights(all_feats)
    order = [i for i in range(len(sentset))]
    shuffle(order)
    k = 0
    for j in order:
	sys.stderr.write(str(k)+"\r")
	
	sent = sentset[j]
	labelseq = labelset[j]
	postagseq = postagset[j]
        
	predseq = execute(sent, all_labels, postagseq, weights, labelseq, info)
	if labelseq != predseq:
	    update(weights, predseq, labelseq, sent, postagseq, info, ad)
	    add_weights(weights_avg, weights)

        k += 1
        if k % 10000 == 0: 
            framework.write_weights(weights, k)
            decode(tsents, tgoldtagseqs, tpostagseqs, tinfo, weights)

    decode(tsents, tgoldtagseqs, tpostagseqs, tinfo, weights)
    return weights_avg, weights
示例#2
0
def learn_and_decode(trainfile, featlistfile, gazfile, brownfile, num_iter, testfile):
    sentset, labelset, postagset, all_feats, info = framework.get_all(trainfile, gazfile, featlistfile, brownfile)
    sys.stderr.write("\n" + str(len(all_feats)) + " features in all\n")

    sys.stderr.write("\nreading test data \n")
    tsents, tgoldtagseqs, tpostagseqs = framework.read_data(testfile)
    tinfo = framework.get_maps(tsents, tpostagseqs, gazfile, brownfile)
    
    testdata = (tsents, tgoldtagseqs, tpostagseqs, tinfo)
    weights = init_weights(all_feats)
    tot_weights = init_weights(all_feats)
 
    #ADAGRAD
    ad = init_weights(all_feats)

    for ite in range(num_iter):
        sys.stderr.write("Iteration " + str(ite) + "\n---------------------------\ntotal train sentences = "+ str(len(sentset)) + "\n")
        weights_a, weights = run(sentset, labelset, postagset, all_feats, info, weights, testdata, ad) #ADAGRAD
        framework.write_weights(weights, ite)
        add_weights(tot_weights, weights_a)

    for key in tot_weights.iterkeys():
        tot_weights[key] /= num_iter*len(sentset)

    sys.stderr.write("\n\nfinal performance on test\n")
    decode(tsents, tgoldtagseqs, tpostagseqs, tinfo, tot_weights)