def test(): hnn, alpha = 30, 0.05 nt = NetTable(hnn, alpha) records_tra = read_records('../docs/digitstra.txt') records_tes = read_records('../docs/digitstest.txt') #records_tra = read_records('data.txt') #records_tes = read_records('data.txt') print "tranning..." for record in records_tra: for i in range(2): nt.back_propagate(record[:64], record[64]) print "testing..." count = 0 for record in records_tes: print "expect:", record[64], result = nt.feed_forward(record[:64])[1] index, value = 0, result[0] for ind, val in enumerate(result): if val > value: index = ind print "estimate:", index, result if record[64] == index: count = count + 1 print "%d of %d" % (count, len(records_tes))
def test(nh, alpha, beta): #nh, alpha, beta = 200, 0.05, 0.2 的时候正确率可以达到92% nt = NetTable(nh, alpha, beta) nt.train(records_tra) count = 0 for record in records_tes: result = nt.feed_forward(record[:64])[1] maxindex, maxvalue = 0, max(result) for index in range(10): if result[index] == max(result): maxindex = index if maxindex == record[64]: count = count + 1 print "%d,%.4f,%.4f,%.4f" % (nh, alpha, beta, (float(count) / len(records_tes)))